给出一颗树,结点权值为 求:
思路
本题为点分治模板题
以重心为根,用 solve(x) 解决 子树内贡献
每次 solve(x) 时,首先得到 经过该点 和 不经过该点 的贡献总和 calc(x, fa, 0)
这个过程首先利用 dfs_dis(x, fa, 0) 得到以 为根的链信息再将链两两合并,得到
的路径贡献
排除 不经过该点, 即排除 的情况,只需要
先向下走一步,然后统计答案,及
calc(x, fa, 1)
注意到以 为根后会将
删去,则每条路径有且仅会被统计一次,故答案正确。
对于本题,由于求 时若暴力枚举,
calc 复杂度会为 总复杂度为
需要先排序处理,calc 复杂度为 总复杂度为
注意: 求重心时 S = sz[x] 每次都要更新,否则复杂度不对
代码
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5 + 10;
#define rep(i, s, t) for (int i = (int)(s); i <= (int)(t); ++i)
const int mod = 998244353;
int root, mx[N], sz[N];
int n, m, S, tot;
bool vis[N];
vector<int> G[N];
bool chkmax(int &x, int y) {
if (x < y) return x = y, 1; return false;
}
int read() {
int x = 0, ch = getchar();
while (ch < '0' || ch > '9') ch = getchar();
while (ch >='0' && ch <='9')
x = x * 10 + (ch ^ 48), ch = getchar();
return x;
}
#define int long long
int a[N], ans;
pair<int, int> q[(int)(1e6) + 10];
void dfs_root(int x, int fa) {
mx[x] = 0, sz[x] = 1;
for (int y : G[x]) {
if (y == fa || vis[y]) continue;
dfs_root(y, x);
sz[x] += sz[y];
chkmax(mx[x], sz[y]);
}
chkmax(mx[x], S - sz[x]);
if (mx[x] < mx[root]) root = x;
}
void dfs_dis(int x, int fa, int dis) {
q[++ tot] = make_pair(a[x], dis);
for (int y : G[x]) {
if (y == fa || vis[y]) continue;
dfs_dis(y, x, dis + 1);
}
}
int Mod(int x) {
return (x % mod + mod) % mod;
}
int calc(int x, int fa, int org) {
int res = 0;
tot = 0;
dfs_dis(x, fa, org);
sort(q + 1, q + tot + 1);
int dis_sum = 0, dis_prefix = 0;
rep(i, 1, tot) {
dis_sum += q[i].second;
}
rep(i, 1, tot) {
dis_prefix += q[i].second;
res = Mod(res + q[i].first * (dis_sum - dis_prefix));
res = Mod(res + (tot - i) * q[i].second * q[i].first);
}
return Mod(res + res);
}
void solve(int x) {
tot = 0;
vis[x] = 1;
ans += calc(x, 0, 0);
for (int y : G[x]) {
if (vis[y]) continue;
ans = Mod(ans - calc(y, x, 1));
mx[root = 0] = S = n;
S = sz[x];
dfs_root(y, x);
solve(root);
}
}
signed main() {
n = read();
rep(i, 1, n) a[i] = read();
rep(i, 1, n - 1) {
int x = read(), y = read();
G[x].push_back(y);
G[y].push_back(x);
}
mx[root = 0] = S = n;
dfs_root(1, 0);
solve(root);
printf("%lld", ans);
}

京公网安备 11010502036488号