题目链接
很早就想学的模板,由于懒拖到现在(其实是菜)
树链剖分其实是将树形结构处理成线性序列然后用数据结构来维护树的一个东西。可以解决很多树上问题。
具体重要的三个函数为:
一.dfs1
第一遍dfs遍历这颗树,处理出每个点的子树大小(包含自己)为 siz,节点的父亲节点为 fa,节点的深度为 de,还有每个节点的重儿子(siz值最大的儿子)
void dfs1(int now,int pre,int d){
siz[now]=1;
fa[now]=pre;
de[now]=d;
int cnt=-1;
for(auto k:v[now]){
if(k==pre)continue;
dfs1(k,now,d+1);
siz[now]+=siz[k];
if(siz[k]>cnt){
cnt=siz[k];
son[now]=k;
}
}
}
二.dfs2
处理出每个节点所在链的顶节点记为 to,和该节点所在dfs序列的标号,和dfs序列的每个元素的值,每个节点先处理重儿子节点.
void dfs2(int now,int pre){
to[now]=pre;
a[now]=++cnt;
t[cnt]=x[now];
t[cnt]%=p;
if(!son[now])return ;
dfs2(son[now],pre);
for(auto k:v[now]){
if(k==fa[now]||k==son[now])continue;
dfs2(k,k);
}
}
三.找最近公共祖先
对两个点,当他们不在一条链上时,我们对他们所在链顶端深度较大的那个点就行跳跃,跳到顶的父亲节点,不断重复。
然后返回深度较小的点即为lca。
int find_lca(int l,int r){
int res=0;
while(to[l]!=to[r]){
if(de[to[l]]<de[to[r]])swap(l,r);
l=fa[to[l]];
}
if(de[l]>de[r])swap(l,r);
return l;
}
对洛谷那题来说就是加上了线段树来维护区间的一些信息,重要的还是树链剖分的部分
下面是ac的代码
#include<bits/stdc++.h>
#define LL long long
#define fi first
#define se second
#define mp make_pair
#define pb push_back
using namespace std;
LL gcd(LL a,LL b){return b?gcd(b,a%b):a;}
LL lcm(LL a,LL b){return a/gcd(a,b)*b;}
LL powmod(LL a,LL b,LL MOD){LL ans=1;while(b){if(b%2)ans=ans*a%MOD;a=a*a%MOD;b/=2;}return ans;}
const int N = 1e5 +11;
int siz[N],fa[N],son[N],a[N],cnt,to[N],t[N],de[N];
int T[N<<2];
int laz[N<<2];
vector<int>v[N];
int n,m,p,r,x[N];
void dfs1(int now,int pre,int d){
siz[now]=1;
fa[now]=pre;
de[now]=d;
int cnt=-1;
for(auto k:v[now]){
if(k==pre)continue;
dfs1(k,now,d+1);
siz[now]+=siz[k];
if(siz[k]>cnt){
cnt=siz[k];
son[now]=k;
}
}
}
void dfs2(int now,int pre){
to[now]=pre;
a[now]=++cnt;
t[cnt]=x[now];
t[cnt]%=p;
if(!son[now])return ;
dfs2(son[now],pre);
for(auto k:v[now]){
if(k==fa[now]||k==son[now])continue;
dfs2(k,k);
}
}
void build(int now,int l,int r){
if(l==r){
T[now]=t[l];
return ;
}
int mid=l+r>>1;
build(now<<1,l,mid);
build(now<<1|1,mid+1,r);
T[now]=T[now<<1]+T[now<<1|1];
T[now]%=p;
}
void pd(int now,int l,int r){
if(laz[now]){
int mid=l+r>>1;
laz[now<<1]+=laz[now];
laz[now<<1|1]+=laz[now];
T[now<<1]+=(1ll*laz[now]*(mid-l+1))%p;
T[now<<1|1]+=(1ll*laz[now]*(r-mid))%p;
T[now<<1]%=p;
T[now<<1|1]%=p;
laz[now<<1]%=p;
laz[now<<1|1]%=p;
laz[now]=0;
}
}
int get(int now,int l,int r,int x,int y){
if(l>=x&&r<=y){
return (T[now])%p;
}
pd(now,l,r);
int mid=l+r>>1;
int res=0;
if(x<=mid)res+=get(now<<1,l,mid,x,y);
if(y>mid)res+=get(now<<1|1,mid+1,r,x,y);
res%=p;
return res;
}
void ud(int now,int l,int r,int x,int y,int d){
if(l>=x&&r<=y){
laz[now]+=d;
laz[now]%=p;
T[now]+=(1ll*d*(r-l+1)%p);
T[now]%=p;
return;
}
pd(now,l,r);
int mid=l+r>>1;
if(x<=mid)ud(now<<1,l,mid,x,y,d);
if(y>mid)ud(now<<1|1,mid+1,r,x,y,d);
T[now]=T[now<<1]+T[now<<1|1];
T[now]%=p;
}
int gao(int l,int r){
int res=0;
while(to[l]!=to[r]){
if(de[to[l]]<de[to[r]])swap(l,r);
res+=get(1,1,cnt,a[to[l]],a[l]);
res%=p;
l=fa[to[l]];
}
if(de[l]>de[r])swap(l,r);
res+=get(1,1,cnt,a[l],a[r]);
return res%p;
}
void gan(int l,int r,LL k){
k%=p;
while(to[l]!=to[r]){
if(de[to[l]]<de[to[r]])swap(l,r);
ud(1,1,cnt,a[to[l]],a[l],k);
l=fa[to[l]];
}
if(de[l]>de[r])swap(l,r);
ud(1,1,cnt,a[l],a[r],k);
return ;
}
int main(){
ios::sync_with_stdio(false);
cin>>n>>m>>r>>p;
for(int i=1;i<=n;i++)cin>>x[i];
for(int i=1;i<n;i++){
int s,t;
cin>>s>>t;
v[s].pb(t);
v[t].pb(s);
}
dfs1(r,0,0);
dfs2(r,r);
cout<<endl;
build(1,1,cnt);
for(int i=1;i<=m;i++){
int ope;
int X,Y,Z;
cin>>ope;
if(ope==1){
cin>>X>>Y>>Z;
gan(X,Y,Z);
}else if(ope==2){
cin>>X>>Y;
cout<<gao(X,Y)<<endl;
}else if(ope==3){
cin>>X>>Z;
ud(1,1,cnt,a[X],a[X]+siz[X]-1,Z);
}else{
cin>>X;
cout<<get(1,1,cnt,a[X],a[X]+siz[X]-1)<<endl;
}
}
return 0;
}