会写就行了,我们先求出在root情况下的lca,发现实际上是所有lca中深度最大的那个,于是我们可以分情况大力讨论一下lca和root的关系,同理我们query的时候也是讨论一下当前x和root的关系
剩下就只需要一个求k级祖先的过程,这个我们可以长链剖分,但是作者直接写了个倍增的log求法
剩下树剖就完了!
代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
#include<queue>
#include<cmath>
#include<cstdlib>
using namespace std;
#define LL long long
#define LD long double
#define DB double
LL read(){
char ch=getchar();LL x=0,fl=1;
for(;!isdigit(ch);ch=getchar())if(ch=='-')fl=-1;
for(;isdigit(ch);ch=getchar())x=(x<<3)+(x<<1)+(ch-'0');
return x*fl;
}
const int NN=100000+17;
void open(){
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
}
int n,m,root;
int fa[NN],dep[NN],siz[NN],son[NN],top[NN],dfn[NN],rev[NN];
int up[NN][21];
int tim;
int len[NN<<2];
LL a[NN],sum[NN<<2],tag[NN<<2];
vector<int> e[NN];
void set_tag(int rt,LL x){
sum[rt]+=1LL*len[rt]*x;
tag[rt]+=x;
}
void psd(int rt){
if(tag[rt]){
set_tag(rt<<1,tag[rt]);
set_tag(rt<<1|1,tag[rt]);
tag[rt]=0LL;
}
}
void build(int rt,int l,int r){
len[rt]=r-l+1;
if(l==r){
sum[rt]=a[rev[l]];
return;
}
int mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}
void modify(int rt,int l,int r,int ll,int rr,LL x){
if(ll<=l&&r<=rr){
set_tag(rt,x);
return;
}
psd(rt);
int mid=(l+r)>>1;
if(ll<=mid)modify(rt<<1,l,mid,ll,rr,x);
if(rr>mid)modify(rt<<1|1,mid+1,r,ll,rr,x);
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
}
LL query(int rt,int l,int r,int ll,int rr){
if(ll<=l&&r<=rr)return sum[rt];
psd(rt);
int mid=(l+r)>>1;
LL res=0LL;
if(ll<=mid)res+=query(rt<<1,l,mid,ll,rr);
if(rr>mid)res+=query(rt<<1|1,mid+1,r,ll,rr);
return res;
}
void dfs(int x,int ff){
fa[x]=up[x][0]=ff;
dep[x]=dep[ff]+1;
siz[x]=1;
for(int i=1;i<=20;i++)up[x][i]=up[up[x][i-1]][i-1];
for(int i=0,top=e[x].size();i<top;i++){
int y=e[x][i];
if(y!=ff){
dfs(y,x);
siz[x]+=siz[y];
if(siz[y]>siz[son[x]])son[x]=y;
}
}
}
void get_top(int x,int now_top){
top[x]=now_top;
dfn[x]=++tim;
rev[tim]=x;
if(son[x])get_top(son[x],now_top);
for(int i=0,top=e[x].size();i<top;i++){
int y=e[x][i];
if(y!=fa[x]&&y!=son[x])get_top(y,y);
}
}
int lca(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
x=fa[top[x]];
}
return (dep[x]<dep[y])?x:y;
}
int get_kth(int x,int k){
for(int i=0;i<=20;i++){
if(k&(1<<i))x=up[x][i];
}
return x;
}
int get_max(int x,int y){
return (dep[x]>dep[y])?x:y;
}
int chk_in(int x,int y){
return dfn[x]<=dfn[y]&&dfn[y]<=dfn[x]+siz[x]-1;
}
void add(int l,int r,LL val){
if(l<=r)modify(1,1,n,l,r,val);
}
LL ask(int l,int r){
if(l<=r)return query(1,1,n,l,r);
return 0LL;
}
int main(){
//open();
n=read();
m=read();
for(int i=1;i<=n;i++)a[i]=read();
for(int i=1;i<n;i++){
int x=read(),y=read();
e[x].push_back(y);
e[y].push_back(x);
}
root=1;
dfs(1,0);
get_top(1,1);
build(1,1,n);
while(m--){
int opt=read();
if(opt==1){
int x=read();
root=x;
}
else if(opt==2){
int x=read(),y=read();
LL val=read();
int pos=get_max(lca(x,y),get_max(lca(x,root),lca(y,root)));
if(pos==root||x==root||y==root){
add(1,n,val);
continue;
}
if(chk_in(pos,root)){
add(1,n,val);
pos=get_kth(root,dep[root]-dep[pos]-1);
add(dfn[pos],dfn[pos]+siz[pos]-1,-val);
}
else{
add(dfn[pos],dfn[pos]+siz[pos]-1,val);
}
}
else{
int x=read();
if(x==root){
printf("%lld\n",ask(1,n));
}
else if(chk_in(x,root)){
int pos=get_kth(root,dep[root]-dep[x]-1);
printf("%lld\n",ask(1,n)-ask(dfn[pos],dfn[pos]+siz[pos]-1));
}
else{
printf("%lld\n",ask(dfn[x],dfn[x]+siz[x]-1));
}
}
}
return 0;
}

京公网安备 11010502036488号