33#include " ../template.hpp"
44#include " ../edge_cd_asserts.hpp"
55#include " ../../../library/trees/edge_cd.hpp"
6- # include " ../../../library/math/mod_int.hpp "
6+ const int mod = 998244353 ;
77int main () {
88 cin.tie (0 )->sync_with_stdio (0 );
99 int n;
1010 cin >> n;
1111 vector<int > a (n);
12- vector<mint > res (n);
12+ vector<int > res (n);
1313 for (int i = 0 ; i < n; i++) {
1414 cin >> a[i];
1515 res[i] = a[i];
1616 }
1717 vector<vi> adj (n);
18- vector<mint > b (n - 1 ), c (n - 1 );
18+ vector<int > b (n - 1 ), c (n - 1 );
1919 vector<pair<int , int >> par (n, {-1 , -1 });
2020 vector<vector<int >> base_adj (n);
2121 {
2222 vector<vector<pair<int , int >>> adj_with_id (n);
2323 for (int i = 0 ; i < n - 1 ; i++) {
2424 int u, v;
25- cin >> u >> v >> b[i]. x >> c[i]. x ;
25+ cin >> u >> v >> b[i] >> c[i];
2626 adj[u].push_back (v);
2727 adj[v].push_back (u);
2828 base_adj[u].push_back (v);
2929 base_adj[v].push_back (u);
3030 adj_with_id[u].emplace_back (v, i);
3131 adj_with_id[v].emplace_back (u, i);
32- res[u] = res[u] + b[i] * a[v] + c[i];
33- res[v] = res[v] + b[i] * a[u] + c[i];
32+ res[u] = ( res[u] + 1LL * b[i] * a[v] + c[i]) % mod ;
33+ res[v] = ( res[v] + 1LL * b[i] * a[u] + c[i]) % mod ;
3434 }
3535 auto dfs = [&](auto && self, int u) -> void {
3636 for (auto [v, e_id] : adj_with_id[u])
@@ -49,30 +49,35 @@ int main() {
4949 edge_cd (adj,
5050 [&](const vector<vi>& cd_adj, int cent,
5151 int split) -> void {
52- array<vector<array<mint , 3 >>, 2 > all_backwards;
53- array<mint , 2 > sum_forward = {0 , 0 };
52+ array<vector<array<int , 3 >>, 2 > all_backwards;
53+ array<int , 2 > sum_forward = {0 , 0 };
5454 array<int , 2 > cnt_nodes = {0 , 0 };
5555 auto dfs = [&](auto && self, int u, int p,
56- array<mint , 2 > forwards,
57- array<mint , 2 > backwards,
56+ array<int , 2 > forwards,
57+ array<int , 2 > backwards,
5858 int side) -> void {
5959 all_backwards[side].push_back (
6060 {u, backwards[0 ], backwards[1 ]});
61- sum_forward[side] = sum_forward[side] +
62- forwards[0 ] * a[u] + forwards[1 ];
61+ sum_forward[side] =
62+ (sum_forward[side] + 1LL * forwards[0 ] * a[u] +
63+ forwards[1 ]) %
64+ mod;
6365 cnt_nodes[side]++;
6466 for (int v : cd_adj[u]) {
6567 if (v == p) continue ;
6668 int e_id = edge_id (u, v);
6769 // f(x) = ax+b
6870 // g(x) = cx+d
6971 // f(g(x)) = a(cx+d)+b = acx+ad+b
70- array<mint, 2 > curr_forw = {
71- forwards[0 ] * b[e_id],
72- forwards[0 ] * c[e_id] + forwards[1 ]};
73- array<mint, 2 > curr_backw = {
74- backwards[0 ] * b[e_id],
75- backwards[1 ] * b[e_id] + c[e_id]};
72+ array<int , 2 > curr_forw = {
73+ int (1LL * forwards[0 ] * b[e_id] % mod),
74+ int (
75+ (1LL * forwards[0 ] * c[e_id] + forwards[1 ]) %
76+ mod)};
77+ array<int , 2 > curr_backw = {
78+ int (1LL * backwards[0 ] * b[e_id] % mod),
79+ int ((1LL * backwards[1 ] * b[e_id] + c[e_id]) %
80+ mod)};
7681 self (self, v, u, curr_forw, curr_backw, side);
7782 }
7883 };
@@ -84,13 +89,14 @@ int main() {
8489 for (int side = 0 ; side < 2 ; side++) {
8590 for (
8691 auto [u, curr_b, curr_c] : all_backwards[side]) {
87- res[u.x ] = res[u.x ] +
88- curr_b * sum_forward[!side] +
89- curr_c * cnt_nodes[!side];
92+ res[u] =
93+ (res[u] + 1LL * curr_b * sum_forward[!side] +
94+ 1LL * curr_c * cnt_nodes[!side]) %
95+ mod;
9096 }
9197 }
9298 });
93- for (int i = 0 ; i < n; i++) cout << res[i]. x << ' ' ;
99+ for (int i = 0 ; i < n; i++) cout << res[i] << ' ' ;
94100 cout << ' \n ' ;
95101 return 0 ;
96102}
0 commit comments