题目链接:[国家集训队]Tree IIv
LCT维护树上链加法和乘法。
我们像线段树一样做一个加法标记,和一个乘法标记即可。转移和线段树一样的。
因为保证边是合法的,所以我们link和cut时就不必判断了。
AC代码:
#pragma GCC optimize(2)
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=3e5+10,mod=51061;
int n,q,a[N];
struct Link_Cut_Tree{
int cnt,st[N];
struct node{int ch[2],fa,val,re,lm,la,sz;}t[N];
inline void push_up(int p){
t[p].val=(t[t[p].ch[0]].val+t[t[p].ch[1]].val+a[p])%mod;
t[p].sz=t[t[p].ch[0]].sz+t[t[p].ch[1]].sz+1;
}
inline void push_re(int p){swap(t[p].ch[0],t[p].ch[1]); t[p].re^=1;}
inline void pushm(int p,int c){
t[p].val=t[p].val*c%mod; a[p]=a[p]*c%mod;
t[p].lm=(t[p].lm*c)%mod; t[p].la=(t[p].la*c)%mod;
}
inline void pusha(int p,int c){
t[p].val=(t[p].val+c*t[p].sz)%mod; a[p]=(a[p]+c)%mod; t[p].la=(t[p].la+c)%mod;
}
inline void push_down(int p){
if(t[p].lm!=1) pushm(t[p].ch[0],t[p].lm),pushm(t[p].ch[1],t[p].lm),t[p].lm=1;
if(t[p].la) pusha(t[p].ch[0],t[p].la),pusha(t[p].ch[1],t[p].la),t[p].la=0;
if(t[p].re){
if(t[p].ch[0]) push_re(t[p].ch[0]);
if(t[p].ch[1]) push_re(t[p].ch[1]); t[p].re^=1;
}
}
inline bool isroot(int x){return t[t[x].fa].ch[0]!=x&&t[t[x].fa].ch[1]!=x;}
inline void rotate(int x){
int y=t[x].fa,z=t[y].fa,k=t[y].ch[1]==x,w=t[x].ch[!k];
if(!isroot(y)) t[z].ch[t[z].ch[1]==y]=x; t[x].ch[!k]=y; t[y].ch[k]=w;
if(w) t[w].fa=y; t[y].fa=x; t[x].fa=z; push_up(y);
}
inline void splay(int x){
cnt=1; st[cnt]=x; int y=x;
while(!isroot(y)) st[++cnt]=y=t[y].fa;
while(cnt) push_down(st[cnt--]);
while(!isroot(x)){
int y=t[x].fa,z=t[y].fa;
if(!isroot(y)) rotate((t[y].ch[0]==x)^(t[z].ch[0]==y)?x:y); rotate(x);
}push_up(x);
}
inline void access(int x){
for(int y=0;x;x=t[y=x].fa) splay(x),t[x].ch[1]=y,push_up(x);
}
inline void makeroot(int x){access(x); splay(x); push_re(x);}
inline void split(int x,int y){makeroot(x); access(y); splay(y);}
inline void link(int x,int y){makeroot(x); t[x].fa=y;}
inline void cut(int x,int y){split(x,y); t[x].fa=t[y].ch[0]=0;}
}tr;
signed main(){
cin>>n>>q; char op[2];
for(int i=1,a,b;i<n;i++) scanf("%lld %lld",&a,&b),tr.link(a,b);
for(int i=1;i<=n;i++) a[i]=tr.t[i].sz=tr.t[i].lm=1;
while(q--){
scanf("%s",op);
if(op[0]=='+'){
int u,v,c; scanf("%lld %lld %lld",&u,&v,&c); tr.split(u,v); tr.pusha(v,c);
}else if(op[0]=='-'){
int u1,v1,u2,v2; scanf("%lld %lld %lld %lld",&u1,&v1,&u2,&v2);
tr.cut(u1,v1); tr.link(u2,v2);
}else if(op[0]=='*'){
int u,v,c; scanf("%lld %lld %lld",&u,&v,&c); tr.split(u,v); tr.pushm(v,c);
}else{
int u,v; scanf("%lld %lld",&u,&v);
tr.split(u,v); printf("%lld\n",tr.t[v].val);
}
}
return 0;
}