题3 - 树上路径

题目支持3种操作
1.将以u为根的子树内节点(包括u)的权值加val
2.将(u, v)路径上的节点权值加val
3.询问(u, v)路径上节点的权值两两相乘的和

思路:很明显,唯一有难度的就是操作3。
我们换位思考,要计算一个数组内元素两两相乘之和,其实就等于(元素之和的平方-元素平方之和)/2,即对于数组a,求 [(i=1nai)2i=1nai2]/2 [(\sum_{i=1}^n a_i)^2-\sum_{i=1}^n a_i^2]/2
因此我们不仅要维护区间和,还要维护区间内元素平方之和,特别推导一下修改时平方和的变化即可。

代码:

#include<bits/stdc++.h>
#define ll long long
#define L(i,j,k) for(ll i=(j);i<=(k);i++)
#define R(i,j,k) for(ll i=(j);i>=(k);i--)
#define inf 0x3f3f3f3f3f3f3f3f
#define fi first
#define se second
#define MS(i,j) memset(i,j,sizeof (i))
const ll N=1e6+10,M=5,mod=1e9+7,mmod=mod-1;
const double pi=acos(-1);
using namespace std;
ll gcd(ll x,ll y){if(y==0) return x;return gcd(y,x%y);}
ll fksm(ll a,ll b){ll r=1;if(b<0)b+=mod-1;for(a%=mod;b;b>>=1){if(b&1)r=r*a%mod;a=a*a%mod;}return r;}//a 分母; b MOD-2
ll lowbit(ll x){return x&(-x);}

ll m,n,t,x,y,z,l,r,k,p,pp,nx,ny,ansx,ansy,lim,num,sum,pos,tot,root,block,key,cnt,minn,maxx,ans;
ll a[N],head[N],dx[5]={0,0,-1,1},dy[5]={-1,1,0,0};
double dans;
bool vis,flag;
char mapp,zz[20];
struct qq{ll x,y;}q;
struct tree{ll l,r,tag,sum,mul;}trs[N*4];
struct Tree{ll fa,dep,dfn,siz,son,top,w;}tr[N];
struct Trp{ll l,r,fat,dep,n,w;}trp;
struct E{ll to,nxt,w;}eg[N*2];
struct matrix{ll n,m[M][M];};
struct complx{
	double r,i;
	complx(){}
	complx(double r,double i):r(r),i(i){}
	complx operator+(const complx& rhs)const{return complx (r+rhs.r,i+rhs.i);}
	complx operator-(const complx& rhs)const{return complx (r-rhs.r,i-rhs.i);}
	complx operator*(const complx& rhs)const{return complx (r*rhs.r-i*rhs.i,i*rhs.r+r*rhs.i);}
	void operator+=(const complx& rhs){r+=rhs.r,i+=rhs.i;}
	void operator*=(const complx& rhs){r=r*rhs.r-i*rhs.i,i=r*rhs.i+i*rhs.r;}
	void operator/=(const double& x){r/=x,i/=x;}
	complx conj(){return complx(r,-i);}
}; 

bool cmp(qq u,qq v){
    return u.x>v.x;
}
bool cmp1(qq u,qq v){
    return u.x<v.x;
}
bool cmpl(ll u,ll v){return u>v;}
struct cmps{bool operator()(ll u,ll v){
    return u>v;
}};//shun序

ll two=fksm(2,mod-2);
pair<ll,ll>pre;
vector<qq>v[N];//v.assign(m,vector<ll>(n));
//priority_queue<ll,vector<ll>,cmps>sp;
deque<qq>sq;
map<ll,ll>mp;
bitset<M>bi;

void add(ll u,ll v,ll w){
	eg[++cnt].to=v;
	eg[cnt].nxt=head[u];
	eg[cnt].w=w;
	head[u]=cnt;
}

void push_up(ll k){
	trs[k].sum=trs[k*2].sum+trs[k*2+1].sum;trs[k].sum%=mod;
	trs[k].mul=trs[k*2].mul+trs[k*2+1].mul;trs[k].mul%=mod;
}

void push_down(ll k){
	if(trs[k].tag){
		ll l=k*2,r=k*2+1;
		ll lenl=trs[l].r-trs[l].l+1,lenr=trs[r].r-trs[r].l+1;
		trs[l].tag+=trs[k].tag;trs[l].tag%=mod;
		trs[r].tag+=trs[k].tag;trs[r].tag%=mod;
		trs[l].mul+=2*trs[l].sum*trs[k].tag%mod+lenl*trs[k].tag%mod*trs[k].tag%mod;trs[l].mul%=mod;
		trs[r].mul+=2*trs[r].sum*trs[k].tag%mod+lenr*trs[k].tag%mod*trs[k].tag%mod;trs[r].mul%=mod;
		trs[l].sum+=lenl*trs[k].tag%mod;trs[l].sum%=mod;
		trs[r].sum+=lenr*trs[k].tag%mod;trs[r].sum%=mod;
		trs[k].tag=0; 
	}
}

void bd_tree(ll k,ll l,ll r){
	trs[k].tag=0;
	trs[k].l=l,trs[k].r=r;
	if(l==r){
		trs[k].sum=a[l]%mod;
		trs[k].mul=a[l]*a[l]%mod;
		return;
	}
	ll mid=(l+r)/2;
	bd_tree(k*2,l,mid);
	bd_tree(k*2+1,mid+1,r);
	push_up(k);
}

qq query(ll k,ll pl,ll pr){
	qq ml={0,0},mr={0,0};
	if(trs[k].l>=pl&&trs[k].r<=pr){
		return {trs[k].sum,trs[k].mul};
	}
	push_down(k);
	ll mid=(trs[k].l+trs[k].r)/2;
	if(mid>=pl)ml=query(k*2,pl,pr);
	if(mid+1<=pr)mr=query(k*2+1,pl,pr);
	return {(ml.x+mr.x)%mod,(ml.y+mr.y)%mod};
}

void modify(ll k,ll pl,ll pr,ll val){//[pl,pr]改为val 
	if(trs[k].l>=pl&&trs[k].r<=pr){
		trs[k].mul+=2*trs[k].sum*val%mod+(trs[k].r-trs[k].l+1)*val%mod*val%mod;trs[k].mul%=mod;
		trs[k].sum+=(trs[k].r-trs[k].l+1)*val%mod;trs[k].sum%=mod;
		trs[k].tag+=val;trs[k].tag%=mod;
		return;
	}
	push_down(k);
	ll mid=(trs[k].l+trs[k].r)/2;
	if(mid>=pl)modify(k*2,pl,pr,val);
	if(mid+1<=pr)modify(k*2+1,pl,pr,val);
	push_up(k);
}

void dfs1(ll x,ll ac){
	tr[x].fa=ac;
	tr[x].dep=tr[tr[x].fa].dep+1;
	tr[x].siz=1;
	ll k=head[x];
	while(k){
		ll y=eg[k].to;
		if(y!=ac){
			dfs1(y,x);
			tr[x].siz+=tr[y].siz;
			if(!tr[x].son||tr[y].siz>tr[tr[x].son].siz)tr[x].son=y;
		}
		k=eg[k].nxt;
	}
}

void dfs2(ll x,ll pos){
	tr[x].dfn=++tot;
	tr[x].top=pos;
	a[tot]=tr[x].w;
	if(!tr[x].son)return;
	dfs2(tr[x].son,pos);
	ll k=head[x];
	while(k){
		ll y=eg[k].to;
		if(y!=tr[x].fa&&y!=tr[x].son)dfs2(y,y);
		k=eg[k].nxt;
	} 
}

void mchain(ll x,ll y,ll val){
	while(tr[x].top!=tr[y].top){
		if(tr[tr[x].top].dep<tr[tr[y].top].dep)modify(1,tr[tr[y].top].dfn,tr[y].dfn,val),y=tr[tr[y].top].fa;
		else modify(1,tr[tr[x].top].dfn,tr[x].dfn,val),x=tr[tr[x].top].fa;
	}
	if(tr[x].dep>tr[y].dep)swap(x,y);
	modify(1,tr[x].dfn,tr[y].dfn,val);
}

qq qchain(ll x,ll y){
	qq res={0,0},tmp;
	while(tr[x].top!=tr[y].top){
		if(tr[tr[x].top].dep<tr[tr[y].top].dep)tmp=query(1,tr[tr[y].top].dfn,tr[y].dfn),y=tr[tr[y].top].fa;
		else tmp=query(1,tr[tr[x].top].dfn,tr[x].dfn),x=tr[tr[x].top].fa;
		res={(res.x+tmp.x)%mod,(res.y+tmp.y)%mod};
	}
	if(tr[x].dep>tr[y].dep)swap(x,y);
	tmp=query(1,tr[x].dfn,tr[y].dfn);
	res={(res.x+tmp.x)%mod,(res.y+tmp.y)%mod};
	return res;
}

int main(){
	scanf("%lld%lld",&n,&m);
	L(i,1,n)scanf("%lld",&tr[i].w);
	cnt=0;
	L(i,2,n){
		scanf("%lld%lld",&x,&y);
		add(x,y,0);
		add(y,x,0);
	}
	tot=0;
	dfs1(1,0);
	dfs2(1,1);
	bd_tree(1,1,n);
	//L(i,1,n)printf("%lld ",tr[i].dfn);printf("\n");
	
	L(i,1,m){
		scanf("%lld",&k);
		if(k==1){
			scanf("%lld%lld",&x,&y);
			modify(1,tr[x].dfn,tr[x].dfn+tr[x].siz-1,y);
		}
		else if(k==2){
			scanf("%lld%lld%lld",&x,&y,&z);
			mchain(x,y,z);
		}
		else{
			scanf("%lld%lld",&x,&y);
			qq tmp=qchain(x,y);
			ans=(tmp.x*tmp.x%mod-tmp.y+mod)%mod*two%mod;
			printf("%lld\n",ans);
		}
	}
}