题目大意
给定一棵大小为n的树,有m次操作,分为三种:
- (1 x w) 位置x上的权值+w,同时所有位置的权值加上w-dis(x,y);(dis(x,y)为从x到y的边数)
- (2 x) 如果x位置的权值>0,那么设为0。
- (3 x) 输出x位置的权值
解题思路
这题有很多种做法,我再这里用的是魔改版的树链剖分(刚学会)。
操作1
我们在画图分析后,可以得到这样一个简单的式子,来转化操作1:
(LCA:最近公共祖先,应该都了解怎么求,为的深度)对于任意的来说,都应该是个常数,所以我们可以专门记录。
接下来,我们设第次的操作节点为,修改的节点为,为;
可以得到,我们每执行一次操作1,就可以在每个结点的权值加上(这里的是操作1的次数)。
那么,我们的式子就是。但是求得这个就是较为棘手的了,需要开脑洞:
对于每次操作1,我们可以从根到x这条路径上所有的结点权值+2,然后在询问时,求就是求根到y这条路径上结点的权值之和。
操作2
对于这个归零的操作,我们用一个额外的dt数组记录。
如果当前x的权值大于0,需要修改,那么我们就使dt[x]减去x的权值。
操作3:计算后输出即可
AC代码
#include<bits/stdc++.h> using namespace std; const int N=5e4+10; struct node { int l,r,s; long long x,y; } tr[N<<2]; struct link { int x,y; } e[N<<1]; int a[N],d[N],f[N],h[N],id[N],s[N],tot[N],top[N],z[N],n,t,r,q,mod; long long dt[N],sum,num; ///////////////////////////////////////// void pushup(int x) { tr[x].y=tr[x<<1].y+tr[x<<1|1].y; } void pushdown(int x) { if(!tr[x].x) return; tr[x<<1].y=tr[x<<1].y+tr[x].x*(1ll*tr[x<<1].s); tr[x<<1|1].y=tr[x<<1|1].y+tr[x].x*(1ll*tr[x<<1|1].s); tr[x<<1].x=tr[x<<1].x+tr[x].x; tr[x<<1|1].x=tr[x<<1|1].x+tr[x].x; tr[x].x=0; } void build(int x,int l,int r) { tr[x].l=l,tr[x].r=r,tr[x].s=r-l+1; if(l==r) { tr[x].y=1ll*a[l]; return; } int m=(l+r)/2; build(x<<1,l,m); build(x<<1|1,m+1,r); pushup(x); } ///////////////////////////////////////// void add(int x,int l,int r,int y) { if(l<=tr[x].l && tr[x].r<=r) { tr[x].y+=(1ll*tr[x].s)*(1ll*y); tr[x].x=(tr[x].x+1ll*y); return; } pushdown(x); if(tr[x<<1].r>=l) add(x<<1,l,r,y); if(tr[x<<1|1].l<=r) add(x<<1|1,l,r,y); pushup(x); } void add1(int x,int y) { e[++t].x=y,e[t].y=h[x],h[x]=t; } void add2(int x,int y,int z) { while(top[x]!=top[y]) { if(d[top[x]]<d[top[y]]) swap(x,y); add(1,id[top[x]],id[x],z); x=f[top[x]]; } if(d[x]>d[y]) swap(x,y); add(1,id[x],id[y],z); } ///////////////////////////////////////// long long query(int x,int l,int r) { if(l<=tr[x].l && tr[x].r<=r) return tr[x].y; pushdown(x); long long ans=0; if(tr[x<<1].r>=l) ans=(ans+query(x<<1,l,r)); if(tr[x<<1|1].l<=r) ans=(ans+query(x<<1|1,l,r)); return ans; } int dfs(int x,int y,int z) { d[x]=z,f[x]=y,tot[x]=1; int u,v=-1,i; for(i=h[x];i;i=e[i].y) { u=e[i].x; if(u==y) continue; tot[x]+=dfs(u,x,z+1); if(tot[u]>v) v=tot[u],s[x]=u; } return tot[x]; } void ddfs(int x,int y) { id[x]=++q,a[q]=z[x],top[x]=y; if(!s[x]) return; ddfs(s[x],y); int z,i; for(i=h[x];i;i=e[i].y) { z=e[i].x; if(z==f[x]||z==s[x]) continue; ddfs(z,z); } } long long find(int x,int y) { long long ans=0; while(top[x]!=top[y]) { if(d[top[x]]<d[top[y]]) swap(x,y); ans=(ans+query(1,id[top[x]],id[x])); x=f[top[x]]; } if(d[x]>d[y]) swap(x,y); ans=(ans+query(1,id[x],id[y])); return ans; } ///////////////////////////////////////// int main() { int T,m,x,y,z,ch,i; long long ans; scanf("%d",&T); while(T--) { memset(dt,0,sizeof(dt)); memset(h,0,sizeof(h)); memset(s,0,sizeof(s)); memset(tr,0,sizeof(tr)); t=q=sum=num=0,r=1; scanf("%d%d",&n,&m); for(i=1;i<n;i++) { scanf("%d%d",&x,&y); add1(x,y);add1(y,x); } dfs(r,0,1);ddfs(r,r);build(1,1,n); while(m--) { scanf("%d",&ch); if(ch==1) { scanf("%d%d",&x,&y); sum+=1ll*y-1ll*d[x],num++; add2(r,x,2); } else if(ch==2) { scanf("%d",&x); ans=sum-1ll*d[x]*num+find(r,x)+dt[x]; if(ans>=0) dt[x]-=ans; } else { scanf("%d",&x); ans=sum-1ll*d[x]*num+find(r,x)+dt[x]; printf("%lld\n",ans); } } } }