题目
这题线段树过不了,一定要用树状数组
ra[u][]维护 u子树中与 u各种距离的价值和
rb[u][]维护 u子树中与 u的点分树上的父亲的各种距离的价值和
Code
#include<bits/stdc++.h>
#define mid ((l+r)>>1)
const int N=100002,M=5600002;
struct node{
int to,ne;
}e[N<<1];
int cnt,tot,sz[N],dep[N],x,y,i,op,mx[N],fa[N][17],p[N],ra[N],rb[N],d[N],w[N],h[N],s[M],L[M],R[M],sum,rt,n,m,las;
bool vis[N];
inline char gc(){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int rd(){
int x=0;char ch=gc();
for(;ch<48||ch>57;ch=gc());
for(;48<=ch&&ch<=57;ch=gc())x=(x<<3)+(x<<1)+(ch^48);
return x;
}
inline void wri(int a){if(a>=10)wri(a/10);putchar(a%10|48);}
inline void wln(int a){wri(a),puts("");}
inline int max(int x,int y){return x>y?x:y;}
inline int lca(int x,int y){
if (dep[x]<dep[y]) x^=y,y^=x,x^=y;
for (int i=16;~i;i--)
if (dep[fa[x][i]]>=dep[y]) x=fa[x][i];
if (x==y) return x;
for (int i=16;~i;i--)
if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
inline int dis(int x,int y){return dep[x]+dep[y]-2*dep[lca(x,y)];}
inline void add(int x,int y){e[++tot]=(node){y,h[x]},h[x]=tot;}
inline void ins(int &t,int l,int r,int x,int y){
if (!t) t=++cnt;
s[t]+=y;
if (l==r) return;
if (x<=mid) ins(L[t],l,mid,x,y);
else ins(R[t],mid+1,r,x,y);
}
inline int query(int t,int l,int r,int x,int y){
if (!t) return 0;
if (x<=l && r<=y) return s[t];
int sum=0;
if (x<=mid) sum+=query(L[t],l,mid,x,y);
if (mid<y) sum+=query(R[t],mid+1,r,x,y);
return sum;
}
inline void dfs(int u,int f){
fa[u][0]=f,dep[u]=dep[f]+1;
for (int i=1;i<17;i++) fa[u][i]=fa[fa[u][i-1]][i-1];
for (int i=h[u],v;i;i=e[i].ne)
if ((v=e[i].to)!=f) dfs(v,u);
}
inline void getrt(int u,int fa){
sz[u]=1,mx[u]=0;
for (int i=h[u],v;i;i=e[i].ne)
if ((v=e[i].to)!=fa && !vis[v]){
getrt(v,u);
sz[u]+=sz[v];
mx[u]=max(mx[u],sz[v]);
}
mx[u]=max(mx[u],sum-sz[u]);
if (mx[u]<mx[rt]) rt=u;
}
inline void getdis(int t,int u,int fa){
ins(ra[t],0,n,d[u],w[u]);
for (int i=h[u],v;i;i=e[i].ne)
if ((v=e[i].to)!=fa && !vis[v]) d[v]=d[u]+1,getdis(t,v,u);
}
inline void getdis1(int t,int u,int fa){
ins(rb[t],0,n,d[u],w[u]);
for (int i=h[u],v;i;i=e[i].ne)
if ((v=e[i].to)!=fa && !vis[v]) d[v]=d[u]+1,getdis1(t,v,u);
}
inline void solve(int u){
vis[u]=1,d[u]=0,getdis(u,u,0);
for (int i=h[u],v;i;i=e[i].ne)
if (!vis[v=e[i].to]){
sum=sz[v],rt=0,getrt(v,u);
d[v]=1,getdis1(rt,v,u);
p[rt]=u,solve(rt);
}
}
inline int ask(int x,int k){
int t=query(ra[x],0,n,0,k);
for (int i=x;p[i];i=p[i]){
int D=dis(x,p[i]);
if (D<=k) t+=query(ra[p[i]],0,n,0,k-D)-query(rb[i],0,n,0,k-D);
}
return t;
}
inline void update(int x,int y){
int del=y-query(ra[x],0,n,0,0);
ins(ra[x],0,n,0,del);
for (int i=x;p[i];i=p[i]){
int D=dis(x,p[i]);
ins(ra[p[i]],0,n,D,del);
ins(rb[i],0,n,D,del);
}
}
int main(){
n=rd(),m=rd();
for (i=1;i<=n;i++) w[i]=rd();
for (i=1;i<n;i++) x=rd(),y=rd(),add(x,y),add(y,x);
dfs(1,0);
sum=mx[0]=n,getrt(1,0);
solve(rt);
for (;m--;){
op=rd(),x=rd()^las,y=rd()^las;
if (!op) wln(las=ask(x,y));
else update(x,y);
}
}