给出一颗树,结点权值为 求:

思路

本题为点分治模板题

以重心为根,用 solve(x) 解决 子树内贡献

每次 solve(x) 时,首先得到 经过该点不经过该点 的贡献总和 calc(x, fa, 0)

这个过程首先利用 dfs_dis(x, fa, 0) 得到以 为根的链信息再将链两两合并,得到 的路径贡献

排除 不经过该点, 即排除 的情况,只需要 先向下走一步,然后统计答案,及 calc(x, fa, 1)

注意到以 为根后会将 删去,则每条路径有且仅会被统计一次,故答案正确。

对于本题,由于求 时若暴力枚举,calc 复杂度会为 总复杂度为

需要先排序处理,calc 复杂度为 总复杂度为

注意: 求重心时 S = sz[x] 每次都要更新,否则复杂度不对

代码

#include <bits/stdc++.h>
using namespace std;

const int N = 2e5 + 10;

#define rep(i, s, t) for (int i = (int)(s); i <= (int)(t); ++i)

const int mod = 998244353;

int root, mx[N], sz[N];
int n, m, S, tot;
bool vis[N];
vector<int> G[N];

bool chkmax(int &x, int y) {
  if (x < y) return x = y, 1; return false;
}

int read() {
  int x = 0, ch = getchar();
  while (ch < '0' || ch > '9') ch = getchar();
  while (ch >='0' && ch <='9')
    x = x * 10 + (ch ^ 48), ch = getchar();
  return x;
}

#define int long long
int a[N], ans;
pair<int, int> q[(int)(1e6) + 10];

void dfs_root(int x, int fa) {
  mx[x] = 0, sz[x] = 1;
  for (int y : G[x]) {
    if (y == fa || vis[y]) continue;
    dfs_root(y, x);
    sz[x] += sz[y];
    chkmax(mx[x], sz[y]);
  }
  chkmax(mx[x], S - sz[x]);
  if (mx[x] < mx[root]) root = x;
}

void dfs_dis(int x, int fa, int dis) {
  q[++ tot] = make_pair(a[x], dis);
  for (int y : G[x]) {
    if (y == fa || vis[y]) continue;
    dfs_dis(y, x, dis + 1);
  }
}

int Mod(int x) {
  return (x % mod + mod) % mod;
}

int calc(int x, int fa, int org) {
  int res = 0;
  tot = 0;
  dfs_dis(x, fa, org);
  sort(q + 1, q + tot + 1);
  int dis_sum = 0, dis_prefix = 0;
  rep(i, 1, tot) {
    dis_sum += q[i].second;
  }
  rep(i, 1, tot) {
    dis_prefix += q[i].second;
    res = Mod(res + q[i].first * (dis_sum - dis_prefix));
    res = Mod(res + (tot - i) * q[i].second * q[i].first);
  }
  return Mod(res + res);
}

void solve(int x) {
  tot = 0;
  vis[x] = 1;
  ans += calc(x, 0, 0);
  for (int y : G[x]) {
    if (vis[y]) continue;
    ans = Mod(ans - calc(y, x, 1));
    mx[root = 0] = S = n;
    S = sz[x];
    dfs_root(y, x);
    solve(root);
  }
}

signed main() {
  n = read();
  rep(i, 1, n) a[i] = read();
  rep(i, 1, n - 1) {
    int x = read(), y = read();
    G[x].push_back(y);
    G[y].push_back(x);
  }
  mx[root = 0] = S = n;
  dfs_root(1, 0);
  solve(root);
  printf("%lld", ans);
}