赛时代码如下:(n * logn)
#include <bits/stdc++.h>
//#define int long long
using namespace std;
const int maxn = 2e6 + 20;
int t, n, d[maxn], f[maxn][20], ans[maxn];
vector<int> e[maxn];
// st表求祖先
void dfs1(int x, int fx) {
f[x][0] = fx;
for(int i = 1;f[x][i-1];++i) f[x][i] = f[f[x][i-1]][i-1];
for(int y : e[x]) {
if(y == fx) continue;
dfs1(y, x);
}
}
// 求x的第k级祖先
int ask(int x,int k) {
for(;k;k-=k&(-k)){
x = f[x][__lg(k&-k)];
}
return x;
}
// 统计子树权值和
void dfs2(int x, int fx) {
for(int y : e[x]) {
if(y == fx) continue;
dfs2(y, x);
ans[x] += ans[y];
}
}
signed main() {
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
cin >> n;
// 建边
for(int i = 1; i < n; ++i) {
int u,v;
cin>>u>>v;
e[u].push_back(v);
e[v].push_back(u);
}
for(int i = 1; i <= n; ++i) cin >> d[i];
dfs1(1, 0); // 预处理祖先
for(int i = 1; i <= n; ++i) {// 进行所有"区间"操作
++ans[i];
--ans[ask(i, d[i]+1)];
}
dfs2(1, 0); // 统计子树权值和
for(int i = 1; i <= n; ++i) cout << ans[i] << " \n"[i==n];
return 0;
}
O(n) 写法如下:
#include <bits/stdc++.h>
//#define int long long
using namespace std;
const int maxn = 2e6 + 20;
int t, n, d[maxn], ans[maxn];
vector<int> e[maxn], path;
// 进行所有"区间"操作
void dfs1(int x, int fx) {
path.push_back(x);
int pos = max(0, (int)(path.size()-1) - d[x] - 1); // 注意: 长为d[x]的路径有d[x]+1个点
--ans[path[pos]]; // l[x] = path[pos]
++ans[x]; // r[x] = x
for(int y : e[x]) {
if(y == fx) continue;
dfs1(y, x);
}
path.pop_back();
}
// 统计子树权值和
void dfs2(int x, int fx) {
for(int y : e[x]) {
if(y == fx) continue;
dfs2(y, x);
ans[x] += ans[y];
}
}
signed main() {
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
cin >> n;
// 建边
for(int i = 1; i < n; ++i) {
int u,v;
cin>>u>>v;
e[u].push_back(v);
e[v].push_back(u);
}
for(int i = 1; i <= n; ++i) cin >> d[i];
path.push_back(0);
dfs1(1, 0); // 进行所有"区间"操作
dfs2(1, 0); // 统计子树权值和
for(int i = 1; i <= n; ++i) cout << ans[i] << " \n"[i==n];
return 0;
}