小 Q 与树
给定一棵带权的树,每条边的距离都为,要我们求,
如果考虑 dsu on tree,则是枚举,分两种情况统计答案:
,则,则我们只要知道集合中有多少个点,以及即可,
设点的个数为,,则上式等价于。
,则,则我们只要知道,以及即可求得答案,
上式等价于。
所以可以对点权离散化,然后用线段树来维护上面需要的四个值,即可进行 dsu on tree,整体复杂度。
由于上面的统计我们都是进行的单向计算,所以还要对上述计算完后的答案乘以即可。
#include <bits/stdc++.h> #define ls rt << 1 #define rs rt << 1 | 1 #define mid (l + r >> 1) #define lson ls, l, mid #define rson rs, mid + 1, r using namespace std; const int N = 2e5 + 10, mod = 998244353; int head[N], to[N << 1], nex[N << 1], cnt = 1; int son[N], sz[N], l[N], r[N], rk[N], dep[N], tot; int sum1[N << 2], sum2[N << 2], sum3[N << 2], sum4[N << 2]; int a[N], b[N], n, m; inline int add(int x, int y) { return x + y < mod ? x + y : x + y - mod; } inline int sub(int x, int y) { return x >= y ? x - y : x - y + mod; } inline int mul(int x, int y) { return 1ll * x * y % mod; } void Add(int x, int y) { to[cnt] = y; nex[cnt] = head[x]; head[x] = cnt++; } void dfs(int rt, int fa) { dep[rt] = dep[fa] + 1, sz[rt] = 1, l[rt] = ++tot, rk[tot] = rt; for (int i = head[rt]; i; i = nex[i]) { if (to[i] == fa) { continue; } dfs(to[i], rt); sz[rt] += sz[to[i]]; if (!son[rt] || sz[to[i]] > sz[son[rt]]) { son[rt] = to[i]; } } r[rt] = tot; } void push_up(int rt) { sum1[rt] = add(sum1[ls], sum1[rs]); sum2[rt] = add(sum2[ls], sum2[rs]); sum3[rt] = add(sum3[ls], sum3[rs]); sum4[rt] = add(sum4[ls], sum4[rs]); } void update(int rt, int l, int r, int x, int v, int op) { if (l == r) { if (op == 1) { sum1[rt] += 1, sum2[rt] = add(sum2[rt], v), sum3[rt] = add(sum3[rt], mul(b[x], v)), sum4[rt] = add(sum4[rt], b[x]); } else { sum1[rt] -= 1, sum2[rt] = sub(sum2[rt], v), sum3[rt] = sub(sum3[rt], mul(b[x], v)), sum4[rt] = sub(sum4[rt], b[x]); } return ; } if (x <= mid) { update(lson, x, v, op); } else { update(rson, x, v, op); } push_up(rt); } int ans, ans1, ans2, ans3, ans4, ans5; void query(int rt, int l, int r, int L, int R) { if (l >= L && r <= R) { ans1 = add(ans1, sum1[rt]), ans2 = add(ans2, sum2[rt]), ans3 = add(ans3, sum3[rt]), ans4 = add(ans4, sum4[rt]); return ; } if (L <= mid) { query(lson, L, R); } if (R > mid) { query(rson, L, R); } } void dfs(int rt, int fa, bool keep) { for (int i = head[rt]; i; i = nex[i]) { if (to[i] == fa || to[i] == son[rt]) { continue; } dfs(to[i], rt, 0); } if (son[rt]) { dfs(son[rt], rt, 1); } for (int i = head[rt]; i; i = nex[i]) { if (to[i] == fa || to[i] == son[rt]) { continue; } for (int j = l[to[i]]; j <= r[to[i]]; j++) { ans1 = ans2 = ans3 = ans4 = 0; query(1, 1, m, a[rk[j]], m); ans = add(ans, mul(ans2, b[a[rk[j]]])); ans = add(ans, mul(b[a[rk[j]]], mul(ans1, sub(dep[rk[j]], 2 * dep[rt])))); if (a[rk[j]] != 1) { ans1 = ans2 = ans3 = ans4 = 0; query(1, 1, m, 1, a[rk[j]] - 1); ans = add(ans, ans3); ans = add(ans, mul(ans4, sub(dep[rk[j]], 2 * dep[rt]))); } } for (int j = l[to[i]]; j <= r[to[i]]; j++) { update(1, 1, m, a[rk[j]], dep[rk[j]], 1); } } ans1 = ans2 = ans3 = ans4 = 0; query(1, 1, m, a[rt], m); ans = add(ans, mul(ans2, b[a[rt]])); ans = add(ans, mul(b[a[rt]], mul(ans1, sub(dep[rt], 2 * dep[rt])))); if (a[rt] != 1) { ans1 = ans2 = ans3 = ans4 = 0; query(1, 1, m, 1, a[rt] - 1); ans = add(ans, ans3); ans = add(ans, mul(ans4, sub(dep[rt], 2 * dep[rt]))); } update(1, 1, m, a[rt], dep[rt], 1); if (!keep) { for (int i = l[rt]; i <= r[rt]; i++) { update(1, 1, m, a[rk[i]], dep[rk[i]], -1); } } } int main() { // freopen("in.txt", "r", stdin); // freopen("out.txt", "w", stdout); scanf("%d", &n); for (int i = 1; i <= n; i++) { scanf("%d", &a[i]); b[i] = a[i]; } sort(b + 1, b + 1 + n); m = unique(b + 1, b + 1 + n) - (b + 1); for (int i = 1; i <= n; i++) { a[i] = lower_bound(b + 1, b + 1 + m, a[i]) - b; } for (int i = 1, x, y; i < n; i++) { scanf("%d %d", &x, &y); Add(x, y); Add(y, x); } dfs(1, 0); dfs(1, 0, 1); printf("%d\n", mul(2, ans)); return 0; }