题面:
题解:见注释
代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<queue>
#include<set>
#include<vector>
#define ll long long
#define llu unsigned ll
#define int ll
using namespace std;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const int maxn=200100;
int n,q;
int head[maxn],ver[maxn<<1],nt[maxn<<1];
int f[maxn],d[maxn],si[maxn],son[maxn],rk[maxn];
int top[maxn],id[maxn];
int s[maxn],vi[maxn];
int tot=1,cnt=0;
int ans1=0;
void add(int x,int y)
{
ver[++tot]=y,nt[tot]=head[x],head[x]=tot;
}
void dfs1(int x,int fa)
{
int max_son=0;
si[x]=1;
s[x]=vi[x];
for(int i=head[x];i;i=nt[i])
{
int y=ver[i];
if(y==fa) continue;
f[y]=x;
d[y]=d[x]+1;
dfs1(y,x);
si[x]+=si[y];
s[x]+=s[y];
if(si[y]>max_son)max_son=si[y],son[x]=y;
}
ans1+=s[x]*s[x];
}
void dfs2(int x,int t)
{
top[x]=t;
id[x]=++cnt;
rk[cnt]=x;
if(!son[x]) return ;
dfs2(son[x],t);
for(int i=head[x];i;i=nt[i])
{
int y=ver[i];
if(y!=son[x]&&y!=f[x])
dfs2(y,y);
}
}
struct node
{
int l,r;
ll sum;
ll laz;
}t[maxn<<2];
void pushup(int cnt)
{
t[cnt].sum=t[cnt<<1].sum+t[cnt<<1|1].sum;
}
void pushdown(int cnt)
{
if(t[cnt].laz)
{
t[cnt<<1].sum=(t[cnt<<1].sum+(t[cnt<<1].r-t[cnt<<1].l+1)*t[cnt].laz);
t[cnt<<1|1].sum=(t[cnt<<1|1].sum+(t[cnt<<1|1].r-t[cnt<<1|1].l+1)*t[cnt].laz);
t[cnt<<1].laz=(t[cnt<<1].laz+t[cnt].laz);
t[cnt<<1|1].laz=(t[cnt<<1|1].laz+t[cnt].laz);
t[cnt].laz=0;
}
}
void build(int l,int r,int cnt)
{
t[cnt].l=l,t[cnt].r=r;
t[cnt].laz=0;
if(l==r)
{
t[cnt].sum=s[rk[l]];
return ;
}
int mid=(l+r)>>1;
build(l,mid,cnt<<1);
build(mid+1,r,cnt<<1|1);
pushup(cnt);
}
void change(int l,int r,ll val,int cnt)
{
if(l<=t[cnt].l&&t[cnt].r<=r)
{
t[cnt].sum=(t[cnt].sum+(t[cnt].r-t[cnt].l+1)*val);
t[cnt].laz=(t[cnt].laz+val);
return ;
}
pushdown(cnt);
if(l<=t[cnt<<1].r) change(l,r,val,cnt<<1);
if(r>=t[cnt<<1|1].l) change(l,r,val,cnt<<1|1);
pushup(cnt);
}
ll ask(int l,int r,int cnt)
{
if(l<=t[cnt].l&&t[cnt].r<=r)
{
return t[cnt].sum;
}
pushdown(cnt);
ll ans=0;
if(l<=t[cnt<<1].r) ans+=ask(l,r,cnt<<1);
if(r>=t[cnt<<1|1].l) ans+=ask(l,r,cnt<<1|1);
return ans;
}
void changeroad(int x,int y,ll val)
{
while(top[x]!=top[y])
{
if(d[top[x]]<d[top[y]])
swap(x,y);
change(id[top[x]],id[x],val,1);
x=f[top[x]];
}
if(id[x]>id[y]) swap(x,y);
change(id[x],id[y],val,1);
}
int askroad(int x,int y)
{
ll ans=0;
while(top[x]!=top[y])
{
if(d[top[x]]<d[top[y]])
swap(x,y);
ans=(ans+ask(id[top[x]],id[x],1));
x=f[top[x]];
}
if(id[x]>id[y]) swap(x,y);
ans=(ans+ask(id[x],id[y],1));
return ans;
}
signed main(void)
{
scanf("%lld%lld",&n,&q);
int pos,x,y;
for(int i=1;i<n;i++)
{
scanf("%lld%lld",&x,&y);
add(x,y);
add(y,x);
}
for(int i=1;i<=n;i++)
scanf("%lld",&vi[i]);
dfs1(1,0);
dfs2(1,1);
build(1,cnt,1);
//这里的d[x]的编号是从0开始的,即d[1]=0
for(int i=1;i<=q;i++)
{
scanf("%lld",&pos);
//我们设s[i] 为 以i为根节点的子树的和。
//考虑修改,如果我们修改一个点x,那么影响的子树的和只是1--x这些点。
//1--x上某点被改变后 s[i]*s[i]-->(s[i]+a)*(s[i]+a)-->s[i]*s[i]+2*a*s[i]+a*a
//其中某点被改变后,新增的部分只有2*a*s[i]+a*a,1--x上的点全部被改变后,新增 2*a*sum of(s[i]) + (d[x]+1)*a*a
//其中d[x]为1--x上面的点的数量,所以我们只需要用书剖维护一个sum of s[i] 即可实现单点更改的要求
if(pos==1)
{
scanf("%lld%lld",&x,&y);
ll pm=askroad(1,x);
ll cm=y-vi[x];
ans1+=cm*cm*(d[x]+1)+2*cm*pm;
changeroad(1,x,cm);
vi[x]=y;
}
//考虑换根
//我们假设现在这条路为 1=x1-x2-x3----xk=x
//我们设以1为根x1-xk各点的子树的权值和为ai
//我们设以x为根x1-xk各点的子树的权值和为bi
//那么有ansx=ans1-sumof(ai*ai)+sumof(bi*bi)
//可以得到a[i+1]+b[i]=a[1]=b[k] --->都等于所有的点权和
//可以得到
//ansx=ans1-a1*a1-sum_of_(ai*ai)_from_2_to_k+bk*bk+sum_of((a1-a[i+1])*(a1-a[i+1))_from_1_to_k-1
//ansx=ans1-a1*a1-sum_of_(ai*ai)_from_2_to_k+bk*bk+sum_of((a1-ai)*(a1-ai))_from_2_to_k
//ansx=ans1-sum_of_(ai*ai)_from_2_to_k+sum_of((a1-ai)*(a1-ai))_from_2_to_k
//ansx=ans1+(k-1)*a1*a1-2*a1*sum_of(ai)_from_2_to_k
//ansx=ans1+(k-1)*a1*a1+2*a1*a1-2*a1*sum_of(ai)_from_2_to_k-2*a1*a1
//ansx=ans1+(k+1)*a1*a1-2*a1*sumof(ai)
//其中ai就是si
else
{
scanf("%lld",&x);
if(x==1) printf("%lld\n",ans1);
else
{
ll pm=ask(id[1],id[1],1);
ll cm=askroad(1,x);
ll ansu=ans1+(d[x]+2)*pm*pm-2*pm*cm;
printf("%lld\n",ansu);
}
}
}
return 0;
}