一是将根换为x;
二是给出两个节点u,v,把包含这两个点的最小子树中每个节点权值加上x;
三是查询以u为根的权值和。
题解:不考虑换根时,需要的算法是裸的模板线段树+dfs序+LCA,考虑到换根,我们不必真的换,只需要对操作2和3进行分类讨论。
对于第一种操作,直接将根换一下即可。
对于第二种操作,设t为(u,v)的LCA,分为两大种情况,一是根不在t的子树内,二是根在t的子树内。
第一种情况,比较简单,直接用线段树区间修改下t的子树的权值。
根不在t的子树内
第二种情况,又分为两类,一是当Root为根时,(u,v)的LCA为Root,如图
根在t的子树内
那么显然此时需要修改的就是整棵树了。
二是当Root为根时,(u,v)的LCA为不为Root,而是为a,如图
此时我们需要修改的部分是,图中红色线以上部分,于是我们可以先将整体全部修改,再将b的子树反向修改(加x变为减x),从而达到不修改b的子树的目的。
问题来了,怎么确定a和b呢?先说a,通过观察可以看出,a无非是LCA(u,Root)或LCA(v,Root),到底是哪个,我们由深度判断,更深的点就是a。图中可以看出,a的深度是大于t((Root,v)的LCA)的。
在说b,我们从Root开始,不断向上找父亲就可以,这里可以用到log(n)的优化,看程序不难理解。
最后是第三种操作,相似的也分为两大种情况。一是根不在x的子树内,二是根在x的子树内。
不在时,很简单,输出x的子树权值和即可。
在时,同操作二的第二种情况
答案为总共的权值减去b的子树的权值,b的求法和上面b的求法一致。
注意一种特殊情况,x==Root时,不需要再减,输出总共的权值即可。
代码:
#include<bits/stdc++.h>
#define N 100010
#define INF 1e9
#define LL long long
#define cal(x) query(1,1,n,st[x],en[x])
using namespace std;
vector<int>G[N];
int st[N],en[N],L[N],fa[N],anc[N][20],val[N];
LL d[N<<2],a[N<<2];
int n,m,t,r;
void preprocess()
{
for (int i=1;i<=n;i++)
{
anc[i][0]=fa[i];
for (int j=1;(1<<j)<=n;j++) anc[i][j]=-1;
}
for (int j=1;(1<<j)<=n;j++)
for (int i=1;i<=n;i++)
if (anc[i][j-1]!=-1)
{
int x=anc[i][j-1];
anc[i][j]=anc[x][j-1];
}
}
int LCA(int p,int q)
{
int log,i;
if (L[p]<L[q]) swap(p,q);
for (log=1;(1<<log)<=L[p];log++);log--;
int ans=-INF;
for (int i=log;i>=0;i--)
if (L[p]-(1<<i)>=L[q])p=anc[p][i];
if (p==q)return p;
for (int i=log;i>=0;i--)
if (anc[p][i]!=-1 && anc[p][i]!=anc[q][i])
{
p=anc[p][i];
q=anc[q][i];
}
return fa[p];//LCA为fa[p]
}
void dfs(int x,int p)
{
st[x]=++t; d[t]=x;
for (int i=0;i<G[x].size();i++)
{
int v=G[x][i];
if (v!=p)
{
fa[v]=x;
L[v]=L[x]+1;
dfs(v,x);
}
}
en[x]=t;
}
void push_down(int x,int l,int r,int t)
{
if (d[x]!=0)
{
d[x*2]+=d[x];a[x*2]+=d[x]*(t-l+1);
d[x*2+1]+=d[x];a[x*2+1]+=d[x]*(r-t);
d[x]=0;
}
}
LL query(int x,int l,int r,int fl,int fr)
{
if (l==fl && r==fr) return (a[x]);
int t=(l+r)>>1;
push_down(x,l,r,t);
if (fr<=t)return (query(x*2,l,t,fl,fr));else
if (fl>t) return (query(x*2+1,t+1,r,fl,fr));else
return query(x*2,l,t,fl,t)+query(x*2+1,t+1,r,t+1,fr);
}
void updata(int x,int l,int r,int fl,int fr,int y)
{
if (l==fl && r==fr)
{
d[x]+=y;
a[x]+=(LL) y*(r-l+1);
}else
{
int t=(l+r)>>1;
push_down(x,l,r,t);
if (fr<=t) updata(x*2,l,t,fl,fr,y);else
if (fl>t) updata(x*2+1,t+1,r,fl,fr,y);else
{
updata(x*2,l,t,fl,t,y);
updata(x*2+1,t+1,r,t+1,fr,y);
}
a[x]=a[x*2]+a[x*2+1];
}
}
void build(int x,int l,int r)
{
if (l==r) a[x]=val[d[l]];
else
{
int t=(l+r)>>1;
build (x<<1,l,t);
build (x<<1|1,t+1,r);
a[x]=a[x<<1]+a[x<<1|1];
}
}
int lis(int x,int y)
{
for (int i=19;i>=0;i--) if (anc[x][i]>=0 && L[anc[x][i]]>L[y]) x=anc[x][i];
return x;
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++) scanf("%d",&val[i]);
for(int i=1;i<n;i++)
{
int j,k;
scanf("%d%d",&j,&k);
G[j].push_back(k);G[k].push_back(j);
}
dfs(1,-1);
preprocess();
build(1,1,n);
memset(d,0,sizeof d);
r=1;
for(int i=1;i<=m;i++)
{
int op,u,v,x;
scanf("%d",&op);
if (op==1)
{
scanf("%d",&x);
r=x;
}else
if (op==2)
{
scanf("%d%d%d",&u,&v,&x);
int lca=LCA(u,v);
if (st[r]<st[lca] || en[lca]<st[r])
{
updata(1,1,n,st[lca],en[lca],x);
}else
{
int a=LCA(u,r),b=LCA(r,v);
t=L[a]>L[b]?a:b;
updata(1,1,n,1,n,x);
if (t!=r)
{
t=lis(r,t);
updata(1,1,n,st[t],en[t],-x);
}
}
}else
{
scanf("%d",&x);
if (st[r]<st[x] || en[x]<st[r]) printf("%I64d\n",cal(x));else
{
if (x==r) printf("%I64d\n",cal(1));else
printf("%I64d\n",cal(1)-cal(lis(r,x)));
}
}
}
}