旗鼓相当的对手
题目链接:牛客 旗鼓相当的对手
Description
给定一个 个点的树,每个节点有点权。
如果 为 和 的,并且 和 不等于 且树上距离为 ,那么点 的答案就会加上 。
求所有点的答案。
数据范围。
Solution
长链剖分板子。 表示子树深度为的节点数 表示子树深度为的节点权值和
Code
// Author: wlzhouzhuan #pragma GCC optimize(2) #pragma GCC optimize(3) #include <bits/stdc++.h> using namespace std; #define ll long long #define ull unsigned long long #define rint register int #define rep(i, l, r) for (rint i = l; i <= r; i++) #define per(i, l, r) for (rint i = l; i >= r; i--) #define mset(s, _) memset(s, _, sizeof(s)) #define pb push_back #define pii pair <int, int> #define mp(a, b) make_pair(a, b) inline int read() { int x = 0, neg = 1; char op = getchar(); while (!isdigit(op)) { if (op == '-') neg = -1; op = getchar(); } while (isdigit(op)) { x = 10 * x + op - '0'; op = getchar(); } return neg * x; } inline void print(int x) { if (x < 0) { putchar('-'); x = -x; } if (x >= 10) print(x / 10); putchar(x % 10 + '0'); } const int N = 100005; vector <int> adj[N]; void add(int u, int v) { adj[u].pb(v); } int n, k; int col[N]; int heavy[N], len[N]; void dfs1(int u, int fa) { for (auto v: adj[u]) { if (v == fa) continue; dfs1(v, u); if (len[v] > len[heavy[u]]) heavy[u] = v; } len[u] = len[heavy[u]] + 1; } // f[i][j] 表示u子树深度为j的节点数 // g[i][j] 表示u子树深度为j的节点权值和 ll _tmp1[N], *id1 = _tmp1, *f[N]; ll _tmp2[N], *id2 = _tmp2, *g[N]; ll ans[N]; int all; void dfs2(int u, int fa) { f[u][0] = 1, g[u][0] = col[u]; if (heavy[u]) { f[heavy[u]] = f[u] + 1; g[heavy[u]] = g[u] + 1; dfs2(heavy[u], u); } for (auto v: adj[u]) { if (v == fa || v == heavy[u]) continue; f[v] = id1, id1 += len[v]; g[v] = id2, id2 += len[v]; dfs2(v, u); for (rint j = 0; j < len[v]; j++) { if (j >= 0 && k - 1 - j >= 1 && k - 1 - j < len[u]) { ans[u] += f[v][j] * g[u][k - 1 - j]; ans[u] += g[v][j] * f[u][k - 1 - j]; } } for (rint j = 1; j <= len[v]; j++) { f[u][j] += f[v][j - 1]; g[u][j] += g[v][j - 1]; } } } int main() { scanf("%d%d", &n, &k); for (rint i = 1; i <= n; i++) { col[i] = read(); } for (rint i = 1; i < n; i++) { int u = read(), v = read(); add(u, v), add(v, u); } int root = 1; dfs1(root, 0); f[root] = id1, id1 += len[root]; g[root] = id2, id2 += len[root]; dfs2(root, 0); for (rint i = 1; i <= n; i++) { printf("%lld ", ans[i]); } return 0; }