imsuck's library

This documentation is automatically generated by competitive-verifier/competitive-verifier

View the Project on GitHub imsuck/library

:heavy_check_mark: test/yosupo/tree/tree_path_composite_sum.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/tree_path_composite_sum"

#include <bits/stdc++.h>
using namespace std;

#include "tree/rerooting.hpp"
#include "math/modint.hpp"

using mint = modint998;

struct Monoid {
    using T = array<mint, 2>;
    using Cost = array<mint, 2>;

    static T id() { return {0, 0}; }
    static T op(const T &l, const T &r) { return {l[0] + r[0], l[1] + r[1]}; }
};

int main() {
    cin.tie(nullptr)->sync_with_stdio(false);
    int n;
    cin >> n;
    vector<mint> a(n);
    for (mint &i : a) cin >> i;
    Rerooting<Monoid> dp(n);
    for (int i = 0, u, v, b, c; i < n - 1; i++) {
        cin >> u >> v >> b >> c;
        dp.add_edge(u, v, {b, c});
    }
    using T = Monoid::T;
    using Cost = Monoid::Cost;
    auto res = dp.run(
        [](const T &x, int, int, const Cost &c) -> T {
            return {x[0] * c[0] + x[1] * c[1], x[1]};
        },
        [&a](const T &x, int v) -> T { return {x[0] + a[v], x[1] + 1}; }
    );
    for (int i = 0; i < n; i++) cout << res[i][0] << " \n"[i == n - 1];
}
#line 1 "test/yosupo/tree/tree_path_composite_sum.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/tree_path_composite_sum"

#include <bits/stdc++.h>
using namespace std;

#line 2 "tree/rerooting.hpp"

template<class M> struct Rerooting {
    using T = typename M::T;
    using Cost = typename M::Cost;

    Rerooting(int n) : g(n) {}

    void add_edge(int u, int v, const Cost &c) {
        g[u].emplace_back(v, c);
        g[v].emplace_back(u, c);
    }

    template<class F1, class F2> vector<T> run(F1 &&f1, F2 &&f2) {
        apply_edge = f1, apply_vertex = f2;
        dp_sub.assign(g.size(), M::id()), dp_all.resize(g.size());
        dfs_sub(0, -1);
        dfs_all(0, -1, M::id());
        return dp_all;
    }

  private:
    vector<vector<pair<int, Cost>>> g;
    function<T(T, int, int, Cost)> apply_edge;
    function<T(T, int)> apply_vertex;
    vector<T> dp_sub, dp_all;

    void dfs_sub(int v, int p) {
        for (auto &[c, cost] : g[v]) {
            if (c == p) continue;
            dfs_sub(c, v);
            dp_sub[v] = M::op(dp_sub[v], apply_edge(dp_sub[c], v, c, cost));
        }
        dp_sub[v] = apply_vertex(dp_sub[v], v);
    }
    void dfs_all(int v, int p, const T &val) {
        vector ds = {val};
        for (auto &[c, cost] : g[v]) {
            if (c == p) continue;
            ds.push_back(apply_edge(dp_sub[c], v, c, cost));
        }
        int n = (int)ds.size();
        vector head(n + 1, M::id()), tail(n + 1, M::id());
        for (int i = 0; i < n; i++) head[i + 1] = M::op(head[i], ds[i]);
        for (int i = n; i--;) tail[i] = M::op(ds[i], tail[i + 1]);
        dp_all[v] = apply_vertex(head[n], v);
        int k = 1;
        for (auto &[c, cost] : g[v]) {
            if (c == p) continue;
            dfs_all(c, v,
                    apply_edge(apply_vertex(M::op(head[k], tail[k + 1]), v), c,
                               v, cost));
            k++;
        }
    }
};
#line 2 "math/modint.hpp"

// clang-format off
template<uint32_t m> struct modint {
    static_assert(m >= 1, "Modulus must be in the range [1;2^31)");

    using mint = modint;
    static constexpr bool is_simple = true;

    static constexpr uint32_t mod() noexcept { return m; }
    constexpr modint() noexcept = default;
    constexpr modint(int64_t v) noexcept : _v(uint32_t((v %= m) < 0 ? v + m : v)) {}
    constexpr static mint raw(uint32_t v) noexcept { mint x; return x._v = v, x; }
    template<class T> constexpr explicit operator T() const noexcept { return _v; }

    constexpr mint &operator++() noexcept { return _v = ++_v == mod() ? 0 : _v, *this; }
    constexpr mint &operator--() noexcept { --(_v ? _v : _v = mod()); return *this; }
    constexpr mint operator++(int) noexcept { return exchange(*this, ++mint(*this)); }
    constexpr mint operator--(int) noexcept { return exchange(*this, --mint(*this)); }

    constexpr mint &operator+=(mint rhs) noexcept {
        return _v = int(_v += rhs._v - mod()) < 0 ? _v + mod() : _v, *this;
    }
    constexpr mint &operator-=(mint rhs) noexcept {
        return _v = int(_v -= rhs._v) < 0 ? _v + mod() : _v, *this;
    }
    constexpr mint &operator*=(mint rhs) noexcept {
        return _v = uint64_t(_v) * rhs._v % mod(), *this;
    }
    constexpr mint &operator/=(mint rhs) noexcept {
        return *this = *this * rhs.inv();
    }

    constexpr friend mint operator+(mint l, mint r) noexcept { return l += r; }
    constexpr friend mint operator-(mint l, mint r) noexcept { return l -= r; }
    constexpr friend mint operator*(mint l, mint r) noexcept { return l *= r; }
    constexpr friend mint operator/(mint l, mint r) noexcept { return l /= r; }

    constexpr mint operator+() const noexcept { return *this; }
    constexpr mint operator-() const noexcept { return raw(_v ? mod() - _v : 0); }

    constexpr friend bool operator==(mint l, mint r) noexcept { return l._v == r._v; }
    constexpr friend bool operator!=(mint l, mint r) noexcept { return l._v != r._v; }
    constexpr friend bool operator<(mint l, mint r) noexcept { return l._v < r._v; }

    constexpr mint pow(uint64_t n) const noexcept {
        mint b = *this, res = 1;
        while (n) n & 1 ? res *= b : 0, b *= b, n >>= 1;
        return res;
    }

    constexpr mint inv() const noexcept {
        int a = _v, b = mod(), x = 1, y = 0;
        while (b) {
            x = exchange(y, x - a / b * y);
            a = exchange(b, a % b);
        }
        assert(a == 1);
        return x;
    }

    friend istream &operator>>(istream &is, mint &x) {
        int64_t v{};
        return is >> v, x = v, is;
    }
    friend ostream &operator<<(ostream &os, const mint &x) { return os << x._v; }

  private:
    uint32_t _v = 0;
};
using modint107 = modint<1'000'000'007>;
using modint998 = modint<998'244'353>;
// clang-format on
#line 8 "test/yosupo/tree/tree_path_composite_sum.test.cpp"

using mint = modint998;

struct Monoid {
    using T = array<mint, 2>;
    using Cost = array<mint, 2>;

    static T id() { return {0, 0}; }
    static T op(const T &l, const T &r) { return {l[0] + r[0], l[1] + r[1]}; }
};

int main() {
    cin.tie(nullptr)->sync_with_stdio(false);
    int n;
    cin >> n;
    vector<mint> a(n);
    for (mint &i : a) cin >> i;
    Rerooting<Monoid> dp(n);
    for (int i = 0, u, v, b, c; i < n - 1; i++) {
        cin >> u >> v >> b >> c;
        dp.add_edge(u, v, {b, c});
    }
    using T = Monoid::T;
    using Cost = Monoid::Cost;
    auto res = dp.run(
        [](const T &x, int, int, const Cost &c) -> T {
            return {x[0] * c[0] + x[1] * c[1], x[1]};
        },
        [&a](const T &x, int v) -> T { return {x[0] + a[v], x[1] + 1}; }
    );
    for (int i = 0; i < n; i++) cout << res[i][0] << " \n"[i == n - 1];
}

Test cases

Env Name Status Elapsed Memory
g++ example_00 :heavy_check_mark: AC 4 ms 4 MB
g++ example_01 :heavy_check_mark: AC 4 ms 4 MB
g++ n_hundreds_00 :heavy_check_mark: AC 4 ms 4 MB
g++ n_hundreds_01 :heavy_check_mark: AC 4 ms 4 MB
g++ tiny_00 :heavy_check_mark: AC 4 ms 4 MB
g++ tiny_01 :heavy_check_mark: AC 4 ms 4 MB
g++ typical_tree_max_00 :heavy_check_mark: AC 205 ms 71 MB
g++ typical_tree_max_01 :heavy_check_mark: AC 112 ms 26 MB
g++ typical_tree_max_02 :heavy_check_mark: AC 141 ms 24 MB
g++ typical_tree_max_03 :heavy_check_mark: AC 141 ms 20 MB
g++ typical_tree_max_04 :heavy_check_mark: AC 194 ms 60 MB
g++ typical_tree_max_05 :heavy_check_mark: AC 160 ms 53 MB
g++ typical_tree_max_06 :heavy_check_mark: AC 167 ms 52 MB
g++ typical_tree_max_07 :heavy_check_mark: AC 204 ms 67 MB
g++ typical_tree_max_08 :heavy_check_mark: AC 128 ms 22 MB
g++ typical_tree_max_09 :heavy_check_mark: AC 128 ms 22 MB
clang++ example_00 :heavy_check_mark: AC 4 ms 4 MB
clang++ example_01 :heavy_check_mark: AC 4 ms 4 MB
clang++ n_hundreds_00 :heavy_check_mark: AC 4 ms 4 MB
clang++ n_hundreds_01 :heavy_check_mark: AC 4 ms 4 MB
clang++ tiny_00 :heavy_check_mark: AC 4 ms 4 MB
clang++ tiny_01 :heavy_check_mark: AC 4 ms 4 MB
clang++ typical_tree_max_00 :heavy_check_mark: AC 184 ms 52 MB
clang++ typical_tree_max_01 :heavy_check_mark: AC 112 ms 26 MB
clang++ typical_tree_max_02 :heavy_check_mark: AC 135 ms 24 MB
clang++ typical_tree_max_03 :heavy_check_mark: AC 137 ms 20 MB
clang++ typical_tree_max_04 :heavy_check_mark: AC 172 ms 48 MB
clang++ typical_tree_max_05 :heavy_check_mark: AC 152 ms 42 MB
clang++ typical_tree_max_06 :heavy_check_mark: AC 153 ms 41 MB
clang++ typical_tree_max_07 :heavy_check_mark: AC 208 ms 49 MB
clang++ typical_tree_max_08 :heavy_check_mark: AC 123 ms 22 MB
clang++ typical_tree_max_09 :heavy_check_mark: AC 126 ms 22 MB
Back to top page