旗鼓相当的对手

题目链接:牛客 旗鼓相当的对手

Description

给定一个 个点的树,每个节点有点权。
如果 ,并且 不等于 且树上距离为 ,那么点 的答案就会加上
求所有点的答案。
数据范围

Solution

长链剖分板子。 表示子树深度为的节点数 表示子树深度为的节点权值和

Code

// Author: wlzhouzhuan
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define ull unsigned long long
#define rint register int
#define rep(i, l, r) for (rint i = l; i <= r; i++)
#define per(i, l, r) for (rint i = l; i >= r; i--)
#define mset(s, _) memset(s, _, sizeof(s))
#define pb push_back
#define pii pair <int, int>
#define mp(a, b) make_pair(a, b)

inline int read() {
  int x = 0, neg = 1; char op = getchar();
  while (!isdigit(op)) { if (op == '-') neg = -1; op = getchar(); }
  while (isdigit(op)) { x = 10 * x + op - '0'; op = getchar(); }
  return neg * x;
}
inline void print(int x) {
  if (x < 0) { putchar('-'); x = -x; }
  if (x >= 10) print(x / 10);
  putchar(x % 10 + '0');
}

const int N = 100005;
vector <int> adj[N];
void add(int u, int v) { adj[u].pb(v); }
int n, k;
int col[N];

int heavy[N], len[N];
void dfs1(int u, int fa) {
  for (auto v: adj[u]) {
    if (v == fa) continue;
    dfs1(v, u);
    if (len[v] > len[heavy[u]]) heavy[u] = v;
  }
  len[u] = len[heavy[u]] + 1;
}
// f[i][j] 表示u子树深度为j的节点数
// g[i][j] 表示u子树深度为j的节点权值和 
ll _tmp1[N], *id1 = _tmp1, *f[N];
ll _tmp2[N], *id2 = _tmp2, *g[N];
ll ans[N];
int all;
void dfs2(int u, int fa) {
  f[u][0] = 1, g[u][0] = col[u];
  if (heavy[u]) {
    f[heavy[u]] = f[u] + 1;
    g[heavy[u]] = g[u] + 1;
    dfs2(heavy[u], u);
  }
  for (auto v: adj[u]) {
    if (v == fa || v == heavy[u]) continue;
    f[v] = id1, id1 += len[v];
    g[v] = id2, id2 += len[v];
    dfs2(v, u);
    for (rint j = 0; j < len[v]; j++) {
      if (j >= 0 && k - 1 - j >= 1 && k - 1 - j < len[u]) {
        ans[u] += f[v][j] * g[u][k - 1 - j];
        ans[u] += g[v][j] * f[u][k - 1 - j];
      }
    }
    for (rint j = 1; j <= len[v]; j++) {
      f[u][j] += f[v][j - 1];
      g[u][j] += g[v][j - 1];
    }
  }
}

int main() {
  scanf("%d%d", &n, &k);
  for (rint i = 1; i <= n; i++) {
    col[i] = read();
  }
  for (rint i = 1; i < n; i++) {
    int u = read(), v = read();
    add(u, v), add(v, u);
  }  
  int root = 1;
  dfs1(root, 0);
  f[root] = id1, id1 += len[root];
  g[root] = id2, id2 += len[root];
  dfs2(root, 0);
  for (rint i = 1; i <= n; i++) {
    printf("%lld ", ans[i]); 
  }
  return 0;
}