题目链接:[国家集训队]Tree IIv


LCT维护树上链加法和乘法。

我们像线段树一样做一个加法标记,和一个乘法标记即可。转移和线段树一样的。

因为保证边是合法的,所以我们link和cut时就不必判断了。


AC代码:

#pragma GCC optimize(2)
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=3e5+10,mod=51061;
int n,q,a[N];
struct Link_Cut_Tree{
	int cnt,st[N];
	struct node{int ch[2],fa,val,re,lm,la,sz;}t[N];
	inline void push_up(int p){
		t[p].val=(t[t[p].ch[0]].val+t[t[p].ch[1]].val+a[p])%mod;
		t[p].sz=t[t[p].ch[0]].sz+t[t[p].ch[1]].sz+1;
	}
	inline void push_re(int p){swap(t[p].ch[0],t[p].ch[1]); t[p].re^=1;}
	inline void pushm(int p,int c){
		t[p].val=t[p].val*c%mod; a[p]=a[p]*c%mod;
		t[p].lm=(t[p].lm*c)%mod; t[p].la=(t[p].la*c)%mod;
	}
	inline void pusha(int p,int c){
		t[p].val=(t[p].val+c*t[p].sz)%mod; a[p]=(a[p]+c)%mod; t[p].la=(t[p].la+c)%mod;
	}
	inline void push_down(int p){
		if(t[p].lm!=1)	pushm(t[p].ch[0],t[p].lm),pushm(t[p].ch[1],t[p].lm),t[p].lm=1;
		if(t[p].la)	pusha(t[p].ch[0],t[p].la),pusha(t[p].ch[1],t[p].la),t[p].la=0;
		if(t[p].re){
			if(t[p].ch[0])	push_re(t[p].ch[0]);
			if(t[p].ch[1])	push_re(t[p].ch[1]); t[p].re^=1;
		}
	}
	inline bool isroot(int x){return t[t[x].fa].ch[0]!=x&&t[t[x].fa].ch[1]!=x;}
	inline void rotate(int x){
		int y=t[x].fa,z=t[y].fa,k=t[y].ch[1]==x,w=t[x].ch[!k];
		if(!isroot(y))	t[z].ch[t[z].ch[1]==y]=x; t[x].ch[!k]=y; t[y].ch[k]=w;
		if(w)	t[w].fa=y; t[y].fa=x; t[x].fa=z;	push_up(y);
	}
	inline void splay(int x){
		cnt=1;	st[cnt]=x; int y=x;
		while(!isroot(y))	st[++cnt]=y=t[y].fa;
		while(cnt)	push_down(st[cnt--]);
		while(!isroot(x)){
			int y=t[x].fa,z=t[y].fa;
			if(!isroot(y))	rotate((t[y].ch[0]==x)^(t[z].ch[0]==y)?x:y);	rotate(x);
		}push_up(x);
	}
	inline void access(int x){
		for(int y=0;x;x=t[y=x].fa) splay(x),t[x].ch[1]=y,push_up(x); 
	}
	inline void makeroot(int x){access(x); splay(x); push_re(x);}
	inline void split(int x,int y){makeroot(x); access(y); splay(y);}
	inline void link(int x,int y){makeroot(x); 	t[x].fa=y;}
	inline void cut(int x,int y){split(x,y);	t[x].fa=t[y].ch[0]=0;}
}tr;
signed main(){
	cin>>n>>q; char op[2];
	for(int i=1,a,b;i<n;i++)	scanf("%lld %lld",&a,&b),tr.link(a,b);
	for(int i=1;i<=n;i++)	a[i]=tr.t[i].sz=tr.t[i].lm=1;
	while(q--){
		scanf("%s",op);
		if(op[0]=='+'){
			int u,v,c;	scanf("%lld %lld %lld",&u,&v,&c); tr.split(u,v); tr.pusha(v,c);
		}else if(op[0]=='-'){
			int u1,v1,u2,v2;	scanf("%lld %lld %lld %lld",&u1,&v1,&u2,&v2);
			tr.cut(u1,v1);	tr.link(u2,v2);
		}else if(op[0]=='*'){
			int u,v,c;	scanf("%lld %lld %lld",&u,&v,&c); tr.split(u,v); tr.pushm(v,c);
		}else{
			int u,v;	scanf("%lld %lld",&u,&v);
			tr.split(u,v);	printf("%lld\n",tr.t[v].val);
		}
	}
	return 0;
}