会写就行了,我们先求出在root情况下的lca,发现实际上是所有lca中深度最大的那个,于是我们可以分情况大力讨论一下lca和root的关系,同理我们query的时候也是讨论一下当前x和root的关系
剩下就只需要一个求k级祖先的过程,这个我们可以长链剖分,但是作者直接写了个倍增的log求法
剩下树剖就完了!
代码:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
#include<queue>
#include<cmath>
#include<cstdlib>
using namespace std;
#define LL long long
#define LD long double
#define DB double
LL read(){
    char ch=getchar();LL x=0,fl=1;
    for(;!isdigit(ch);ch=getchar())if(ch=='-')fl=-1;
    for(;isdigit(ch);ch=getchar())x=(x<<3)+(x<<1)+(ch-'0');
    return x*fl;
}
const int NN=100000+17;
void open(){
    freopen("a.in","r",stdin);
    freopen("a.out","w",stdout);
}


int n,m,root;
int fa[NN],dep[NN],siz[NN],son[NN],top[NN],dfn[NN],rev[NN];
int up[NN][21];
int tim;
int len[NN<<2];
LL a[NN],sum[NN<<2],tag[NN<<2];
vector<int> e[NN];
void set_tag(int rt,LL x){
    sum[rt]+=1LL*len[rt]*x;
    tag[rt]+=x;
}
void psd(int rt){
    if(tag[rt]){
        set_tag(rt<<1,tag[rt]);
        set_tag(rt<<1|1,tag[rt]);
        tag[rt]=0LL;
    }
}
void build(int rt,int l,int r){
    len[rt]=r-l+1;
    if(l==r){
        sum[rt]=a[rev[l]];
        return;
    }
    int mid=(l+r)>>1;
    build(rt<<1,l,mid);
    build(rt<<1|1,mid+1,r);
    sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}
void modify(int rt,int l,int r,int ll,int rr,LL x){
    if(ll<=l&&r<=rr){
        set_tag(rt,x);
        return; 
    }
    psd(rt);
    int mid=(l+r)>>1;
    if(ll<=mid)modify(rt<<1,l,mid,ll,rr,x);
    if(rr>mid)modify(rt<<1|1,mid+1,r,ll,rr,x);
    sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}
LL query(int rt,int l,int r,int ll,int rr){
    if(ll<=l&&r<=rr)return sum[rt];
    psd(rt);
    int mid=(l+r)>>1;
    LL res=0LL;
    if(ll<=mid)res+=query(rt<<1,l,mid,ll,rr);
    if(rr>mid)res+=query(rt<<1|1,mid+1,r,ll,rr);
    return res;
}

void dfs(int x,int ff){
    fa[x]=up[x][0]=ff;
    dep[x]=dep[ff]+1;
    siz[x]=1;
    for(int i=1;i<=20;i++)up[x][i]=up[up[x][i-1]][i-1];
    for(int i=0,top=e[x].size();i<top;i++){
        int y=e[x][i];
        if(y!=ff){
            dfs(y,x);
            siz[x]+=siz[y];
            if(siz[y]>siz[son[x]])son[x]=y;
        }
    }
}
void get_top(int x,int now_top){
    top[x]=now_top;
    dfn[x]=++tim;
    rev[tim]=x;
    if(son[x])get_top(son[x],now_top);
    for(int i=0,top=e[x].size();i<top;i++){
        int y=e[x][i];
        if(y!=fa[x]&&y!=son[x])get_top(y,y);
    }
}
int lca(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])swap(x,y);
        x=fa[top[x]];
    }
    return (dep[x]<dep[y])?x:y;
}
int get_kth(int x,int k){
    for(int i=0;i<=20;i++){
        if(k&(1<<i))x=up[x][i];
    }
    return x;
}
int get_max(int x,int y){
    return (dep[x]>dep[y])?x:y;
}
int chk_in(int x,int y){
    return dfn[x]<=dfn[y]&&dfn[y]<=dfn[x]+siz[x]-1;
}
void add(int l,int r,LL val){
    if(l<=r)modify(1,1,n,l,r,val);
}
LL ask(int l,int r){
    if(l<=r)return query(1,1,n,l,r); 
    return 0LL;
}

int main(){
    //open();
    n=read();
    m=read();
    for(int i=1;i<=n;i++)a[i]=read();
    for(int i=1;i<n;i++){
        int x=read(),y=read();
        e[x].push_back(y);
        e[y].push_back(x);
    }
    root=1;
    dfs(1,0);
    get_top(1,1);
    build(1,1,n);
    while(m--){
        int opt=read();
        if(opt==1){
            int x=read();
            root=x;
        }
        else if(opt==2){
            int x=read(),y=read();
            LL val=read();
            int pos=get_max(lca(x,y),get_max(lca(x,root),lca(y,root)));
            if(pos==root||x==root||y==root){
                add(1,n,val);
                continue;
            }
            if(chk_in(pos,root)){
                add(1,n,val);
                pos=get_kth(root,dep[root]-dep[pos]-1);
                add(dfn[pos],dfn[pos]+siz[pos]-1,-val);
            }
            else{
                add(dfn[pos],dfn[pos]+siz[pos]-1,val);
            }
        }
        else{
            int x=read();
            if(x==root){
                printf("%lld\n",ask(1,n));
            }
            else if(chk_in(x,root)){
                int pos=get_kth(root,dep[root]-dep[x]-1);
                printf("%lld\n",ask(1,n)-ask(dfn[pos],dfn[pos]+siz[pos]-1));
            }
            else{
                printf("%lld\n",ask(dfn[x],dfn[x]+siz[x]-1));
            }
        }
    }
    return 0;
}