树句节狗提
正解部分
先题意转换:
求 子树内 所有与 x 距离 至少为 K 的点的权值和
↓
子树内 的和 − 子树内 与 x 距离小于等于 K−1 的点的权值和.
实现部分
树状数组 中下标为 深度, 按 dfs序 顺序将点权加入 树状数组, 保证 深度不超过 k−1 这个条件,
把询问 离线, 转化为差分操作, 按 dfs序遍历到子树的根时减, 离开子树时再加, 可以得到 子树内深度不超过 K−1 的值 .
- 离线可以使用 vector 或者 链表 .
- 在差分右端点时注意边界 !!!
#include<bits/stdc++.h>
#define fi first
#define se second
#define reg register
#define pb push_back
typedef long long ll;
typedef std::pair<int, int> pr;
int read(){
char c;
int s = 0, flag = 1;
while((c=getchar()) && !isdigit(c))
if(c == '-'){ flag = -1, c = getchar(); break ; }
while(isdigit(c)) s = s*10 + c-'0', c = getchar();
return s * flag;
}
const int maxn = 2525015;
int N;
int Q_;
int dft;
int num0;
int A[maxn];
int Mp[maxn];
int dep[maxn];
int dfn[maxn];
int head[maxn];
int sum_n[maxn];
ll Ans[maxn];
std::vector <pr> que[maxn];
struct Edge{ int nxt, to; } edge[maxn << 1];
void Add(int from, int to){
edge[++ num0] = (Edge){ head[from], to };
head[from] = num0;
}
struct Bit_Tree{
ll v[maxn];
void Add(int k, int p){ while(k<=N)v[k]+=p,k+=k&-k; }
ll Query(int k){ ll s = 0; while(k)s+=v[k],k-=k&-k; return s; }
} bit_1;
void DFS(int k, int fa){
dfn[k] = ++ dft, sum_n[k] = 1;
dep[k] = dep[fa] + 1, Mp[dft] = k;
for(reg int i = head[k]; i; i = edge[i].nxt){
int to = edge[i].to;
if(to == fa) continue ;
DFS(to, k), sum_n[k] += sum_n[to];
}
}
void Calc_ans(){
ll Tmp_1 = 0;
for(reg int i = 1; i <= N; i ++){
int x = Mp[i];
bit_1.Add(dep[x], A[x]);
Tmp_1 += A[x];
int size = que[i].size();
for(reg int j = 0; j < size; j ++){
int q_id = que[i][j].fi, dep_k = que[i][j].se;
if(q_id >= 0) Ans[q_id] += Tmp_1 - bit_1.Query(dep_k);
else Ans[-q_id] += bit_1.Query(dep_k) - Tmp_1;
}
}
}
void print(int q, long long* ans, int lim);
int main(){
N = read();
for(reg int i = 1; i <= N; i ++) A[i] = read();
for(reg int i = 2; i <= N; i ++){
int x = read();
Add(x, i), Add(i, x);
}
DFS(1, 0); Q_ = read();
for(reg int i = 1; i <= Q_; i ++){
int x = read(), k = read();
k = std::min(N, dep[x]+k-1);
que[dfn[x]-1].pb(pr(-i, k));
que[dfn[x]+sum_n[x]-1].pb(pr(i, k));
}
Calc_ans();
print(Q_, Ans, read());
return 0;
}
void print(int q, long long* ans, int lim) {
for(int i = 1; i <= q; ) {
long long res = 0;
for(int j = i; j <= std::min(q, i + lim - 1); j ++) res ^= ans[j];
i += lim;
printf("%lld\n", res);
}
return ;
}