会写就行了,我们先求出在root情况下的lca,发现实际上是所有lca中深度最大的那个,于是我们可以分情况大力讨论一下lca和root的关系,同理我们query的时候也是讨论一下当前x和root的关系
剩下就只需要一个求k级祖先的过程,这个我们可以长链剖分,但是作者直接写了个倍增的log求法
剩下树剖就完了!
代码:
#include<iostream> #include<cstdio> #include<cstring> #include<vector> #include<algorithm> #include<queue> #include<cmath> #include<cstdlib> using namespace std; #define LL long long #define LD long double #define DB double LL read(){ char ch=getchar();LL x=0,fl=1; for(;!isdigit(ch);ch=getchar())if(ch=='-')fl=-1; for(;isdigit(ch);ch=getchar())x=(x<<3)+(x<<1)+(ch-'0'); return x*fl; } const int NN=100000+17; void open(){ freopen("a.in","r",stdin); freopen("a.out","w",stdout); } int n,m,root; int fa[NN],dep[NN],siz[NN],son[NN],top[NN],dfn[NN],rev[NN]; int up[NN][21]; int tim; int len[NN<<2]; LL a[NN],sum[NN<<2],tag[NN<<2]; vector<int> e[NN]; void set_tag(int rt,LL x){ sum[rt]+=1LL*len[rt]*x; tag[rt]+=x; } void psd(int rt){ if(tag[rt]){ set_tag(rt<<1,tag[rt]); set_tag(rt<<1|1,tag[rt]); tag[rt]=0LL; } } void build(int rt,int l,int r){ len[rt]=r-l+1; if(l==r){ sum[rt]=a[rev[l]]; return; } int mid=(l+r)>>1; build(rt<<1,l,mid); build(rt<<1|1,mid+1,r); sum[rt]=sum[rt<<1]+sum[rt<<1|1]; } void modify(int rt,int l,int r,int ll,int rr,LL x){ if(ll<=l&&r<=rr){ set_tag(rt,x); return; } psd(rt); int mid=(l+r)>>1; if(ll<=mid)modify(rt<<1,l,mid,ll,rr,x); if(rr>mid)modify(rt<<1|1,mid+1,r,ll,rr,x); sum[rt]=sum[rt<<1]+sum[rt<<1|1]; } LL query(int rt,int l,int r,int ll,int rr){ if(ll<=l&&r<=rr)return sum[rt]; psd(rt); int mid=(l+r)>>1; LL res=0LL; if(ll<=mid)res+=query(rt<<1,l,mid,ll,rr); if(rr>mid)res+=query(rt<<1|1,mid+1,r,ll,rr); return res; } void dfs(int x,int ff){ fa[x]=up[x][0]=ff; dep[x]=dep[ff]+1; siz[x]=1; for(int i=1;i<=20;i++)up[x][i]=up[up[x][i-1]][i-1]; for(int i=0,top=e[x].size();i<top;i++){ int y=e[x][i]; if(y!=ff){ dfs(y,x); siz[x]+=siz[y]; if(siz[y]>siz[son[x]])son[x]=y; } } } void get_top(int x,int now_top){ top[x]=now_top; dfn[x]=++tim; rev[tim]=x; if(son[x])get_top(son[x],now_top); for(int i=0,top=e[x].size();i<top;i++){ int y=e[x][i]; if(y!=fa[x]&&y!=son[x])get_top(y,y); } } int lca(int x,int y){ while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]])swap(x,y); x=fa[top[x]]; } return (dep[x]<dep[y])?x:y; } int get_kth(int x,int k){ for(int i=0;i<=20;i++){ if(k&(1<<i))x=up[x][i]; } return x; } int get_max(int x,int y){ return (dep[x]>dep[y])?x:y; } int chk_in(int x,int y){ return dfn[x]<=dfn[y]&&dfn[y]<=dfn[x]+siz[x]-1; } void add(int l,int r,LL val){ if(l<=r)modify(1,1,n,l,r,val); } LL ask(int l,int r){ if(l<=r)return query(1,1,n,l,r); return 0LL; } int main(){ //open(); n=read(); m=read(); for(int i=1;i<=n;i++)a[i]=read(); for(int i=1;i<n;i++){ int x=read(),y=read(); e[x].push_back(y); e[y].push_back(x); } root=1; dfs(1,0); get_top(1,1); build(1,1,n); while(m--){ int opt=read(); if(opt==1){ int x=read(); root=x; } else if(opt==2){ int x=read(),y=read(); LL val=read(); int pos=get_max(lca(x,y),get_max(lca(x,root),lca(y,root))); if(pos==root||x==root||y==root){ add(1,n,val); continue; } if(chk_in(pos,root)){ add(1,n,val); pos=get_kth(root,dep[root]-dep[pos]-1); add(dfn[pos],dfn[pos]+siz[pos]-1,-val); } else{ add(dfn[pos],dfn[pos]+siz[pos]-1,val); } } else{ int x=read(); if(x==root){ printf("%lld\n",ask(1,n)); } else if(chk_in(x,root)){ int pos=get_kth(root,dep[root]-dep[x]-1); printf("%lld\n",ask(1,n)-ask(dfn[pos],dfn[pos]+siz[pos]-1)); } else{ printf("%lld\n",ask(dfn[x],dfn[x]+siz[x]-1)); } } } return 0; }