题目大意:给出一棵树以及每个结点的权值,初始时根为1。有3种操作:

一是将根换为x;

二是给出两个节点u,v,把包含这两个点的最小子树中每个节点权值加上x;

三是查询以u为根的权值和。


题解:不考虑换根时,需要的算法是裸的模板线段树+dfs序+LCA,考虑到换根,我们不必真的换,只需要对操作2和3进行分类讨论。

对于第一种操作,直接将根换一下即可。

对于第二种操作,设t为(u,v)的LCA,分为两大种情况,一是根不在t的子树内,二是根在t的子树内。

第一种情况,比较简单,直接用线段树区间修改下t的子树的权值。

                                    

                                                                        根不在t的子树内

第二种情况,又分为两类,一是当Root为根时,(u,v)的LCA为Root,如图

                                                   

                                                                        根在t的子树内

那么显然此时需要修改的就是整棵树了。

二是当Root为根时,(u,v)的LCA为不为Root,而是为a,如图

                                               

此时我们需要修改的部分是,图中红色线以上部分,于是我们可以先将整体全部修改,再将b的子树反向修改(加x变为减x),从而达到不修改b的子树的目的。

问题来了,怎么确定a和b呢?先说a,通过观察可以看出,a无非是LCA(u,Root)或LCA(v,Root),到底是哪个,我们由深度判断,更深的点就是a。图中可以看出,a的深度是大于t((Root,v)的LCA)的。

在说b,我们从Root开始,不断向上找父亲就可以,这里可以用到log(n)的优化,看程序不难理解。


最后是第三种操作,相似的也分为两大种情况。一是根不在x的子树内,二是根在x的子树内。

不在时,很简单,输出x的子树权值和即可。

在时,同操作二的第二种情况

                                                    

答案为总共的权值减去b的子树的权值,b的求法和上面b的求法一致。

注意一种特殊情况,x==Root时,不需要再减,输出总共的权值即可。


代码:

#include<bits/stdc++.h>
#define N 100010
#define INF 1e9
#define LL long long
#define cal(x) query(1,1,n,st[x],en[x])
using namespace std;
vector<int>G[N];
int st[N],en[N],L[N],fa[N],anc[N][20],val[N];
LL d[N<<2],a[N<<2];
int n,m,t,r;

void preprocess()
{
    for (int i=1;i<=n;i++)
    {
        anc[i][0]=fa[i];
        for (int j=1;(1<<j)<=n;j++) anc[i][j]=-1;
    }
    for (int j=1;(1<<j)<=n;j++)
        for (int i=1;i<=n;i++)
        if (anc[i][j-1]!=-1)
        {
            int x=anc[i][j-1];
            anc[i][j]=anc[x][j-1];
        }
}

int LCA(int p,int q)
{
    int log,i;
    if (L[p]<L[q]) swap(p,q);
    for (log=1;(1<<log)<=L[p];log++);log--;
    int ans=-INF;
    for (int i=log;i>=0;i--)
        if (L[p]-(1<<i)>=L[q])p=anc[p][i];
    if (p==q)return p;
    for (int i=log;i>=0;i--)
        if (anc[p][i]!=-1 && anc[p][i]!=anc[q][i])
    {
        p=anc[p][i];
        q=anc[q][i];
    }
    return fa[p];//LCA为fa[p]
}

void dfs(int x,int p)
{
    st[x]=++t; d[t]=x;
    for (int i=0;i<G[x].size();i++)
    {
        int v=G[x][i];
        if (v!=p)
        {
            fa[v]=x;
            L[v]=L[x]+1;
            dfs(v,x);
        }
    }
    en[x]=t;
}

void push_down(int x,int l,int r,int t)
{
    if (d[x]!=0)
	{
		d[x*2]+=d[x];a[x*2]+=d[x]*(t-l+1);
		d[x*2+1]+=d[x];a[x*2+1]+=d[x]*(r-t);
		d[x]=0;
	}
}

LL query(int x,int l,int r,int fl,int fr)
{
	if (l==fl && r==fr)	return (a[x]);
	int t=(l+r)>>1;
	push_down(x,l,r,t);
	if (fr<=t)return (query(x*2,l,t,fl,fr));else
	if (fl>t) return (query(x*2+1,t+1,r,fl,fr));else
	return query(x*2,l,t,fl,t)+query(x*2+1,t+1,r,t+1,fr);
}

void updata(int x,int l,int r,int fl,int fr,int y)
{
	if (l==fl && r==fr)
	{
		d[x]+=y;
		a[x]+=(LL) y*(r-l+1);
	}else
	{
		int t=(l+r)>>1;
		push_down(x,l,r,t);
		if (fr<=t)	updata(x*2,l,t,fl,fr,y);else
		if (fl>t)	updata(x*2+1,t+1,r,fl,fr,y);else
		{
			updata(x*2,l,t,fl,t,y);
			updata(x*2+1,t+1,r,t+1,fr,y);
		}
		a[x]=a[x*2]+a[x*2+1];
	}
}

void build(int x,int l,int r)
{
    if (l==r) a[x]=val[d[l]];
    else
    {
        int t=(l+r)>>1;
        build (x<<1,l,t);
        build (x<<1|1,t+1,r);
        a[x]=a[x<<1]+a[x<<1|1];
    }
}

int lis(int x,int y)
{
    for (int i=19;i>=0;i--) if (anc[x][i]>=0 && L[anc[x][i]]>L[y]) x=anc[x][i];
    return x;
}

int main()
{
    scanf("%d%d",&n,&m);
    for (int i=1;i<=n;i++) scanf("%d",&val[i]);
    for(int i=1;i<n;i++)
    {
        int j,k;
        scanf("%d%d",&j,&k);
        G[j].push_back(k);G[k].push_back(j);
    }
    dfs(1,-1);
    preprocess();
    build(1,1,n);
    memset(d,0,sizeof d);
    r=1;
    for(int i=1;i<=m;i++)
    {
        int op,u,v,x;
        scanf("%d",&op);
        if (op==1)
        {
            scanf("%d",&x);
            r=x;
        }else
        if (op==2)
        {
            scanf("%d%d%d",&u,&v,&x);
            int lca=LCA(u,v);
            if (st[r]<st[lca] || en[lca]<st[r])
            {
                updata(1,1,n,st[lca],en[lca],x);
            }else
            {
                int a=LCA(u,r),b=LCA(r,v);
                t=L[a]>L[b]?a:b;
                updata(1,1,n,1,n,x);
                if (t!=r)
                {
                    t=lis(r,t);
                    updata(1,1,n,st[t],en[t],-x);
                }
            }
        }else
        {
            scanf("%d",&x);
            if (st[r]<st[x] || en[x]<st[r]) printf("%I64d\n",cal(x));else
            {
                if (x==r) printf("%I64d\n",cal(1));else
                printf("%I64d\n",cal(1)-cal(lis(r,x)));
            }
        }
    }
}