题目链接

题面:

题解:见注释

代码:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<queue>
#include<set>
#include<vector>
#define ll long long
#define llu unsigned ll
#define int ll
using namespace std;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const int maxn=200100;
int n,q;
int head[maxn],ver[maxn<<1],nt[maxn<<1];
int f[maxn],d[maxn],si[maxn],son[maxn],rk[maxn];
int top[maxn],id[maxn];
int s[maxn],vi[maxn];
int tot=1,cnt=0;
int ans1=0;

void add(int x,int y)
{
    ver[++tot]=y,nt[tot]=head[x],head[x]=tot;
}

void dfs1(int x,int fa)
{
    int max_son=0;
    si[x]=1;
    s[x]=vi[x];
    for(int i=head[x];i;i=nt[i])
    {
        int y=ver[i];
        if(y==fa) continue;
        f[y]=x;
        d[y]=d[x]+1;

        dfs1(y,x);
        si[x]+=si[y];
        s[x]+=s[y];
        if(si[y]>max_son)max_son=si[y],son[x]=y;
    }
    ans1+=s[x]*s[x];
}

void dfs2(int x,int t)
{
    top[x]=t;
    id[x]=++cnt;
    rk[cnt]=x;

    if(!son[x]) return ;
    dfs2(son[x],t);

    for(int i=head[x];i;i=nt[i])
    {
        int y=ver[i];
        if(y!=son[x]&&y!=f[x])
            dfs2(y,y);
    }
}

struct node
{
    int l,r;
    ll sum;
    ll laz;
}t[maxn<<2];

void pushup(int cnt)
{
    t[cnt].sum=t[cnt<<1].sum+t[cnt<<1|1].sum;
}

void pushdown(int cnt)
{
    if(t[cnt].laz)
    {
        t[cnt<<1].sum=(t[cnt<<1].sum+(t[cnt<<1].r-t[cnt<<1].l+1)*t[cnt].laz);
        t[cnt<<1|1].sum=(t[cnt<<1|1].sum+(t[cnt<<1|1].r-t[cnt<<1|1].l+1)*t[cnt].laz);
        t[cnt<<1].laz=(t[cnt<<1].laz+t[cnt].laz);
        t[cnt<<1|1].laz=(t[cnt<<1|1].laz+t[cnt].laz);
        t[cnt].laz=0;
    }
}


void build(int l,int r,int cnt)
{
    t[cnt].l=l,t[cnt].r=r;
    t[cnt].laz=0;
    if(l==r)
    {
        t[cnt].sum=s[rk[l]];
        return ;
    }

    int mid=(l+r)>>1;
    build(l,mid,cnt<<1);
    build(mid+1,r,cnt<<1|1);
    pushup(cnt);
}

void change(int l,int r,ll val,int cnt)
{
    if(l<=t[cnt].l&&t[cnt].r<=r)
    {
        t[cnt].sum=(t[cnt].sum+(t[cnt].r-t[cnt].l+1)*val);
        t[cnt].laz=(t[cnt].laz+val);
        return ;
    }

    pushdown(cnt);
    if(l<=t[cnt<<1].r) change(l,r,val,cnt<<1);
    if(r>=t[cnt<<1|1].l) change(l,r,val,cnt<<1|1);
    pushup(cnt);
}


ll ask(int l,int r,int cnt)
{
    if(l<=t[cnt].l&&t[cnt].r<=r)
    {
        return t[cnt].sum;
    }
    pushdown(cnt);
    ll ans=0;
    if(l<=t[cnt<<1].r) ans+=ask(l,r,cnt<<1);
    if(r>=t[cnt<<1|1].l) ans+=ask(l,r,cnt<<1|1);
    return ans;
}


void changeroad(int x,int y,ll val)
{
    while(top[x]!=top[y])
    {
        if(d[top[x]]<d[top[y]])
            swap(x,y);
        change(id[top[x]],id[x],val,1);
        x=f[top[x]];
    }
    if(id[x]>id[y]) swap(x,y);
    change(id[x],id[y],val,1);
}

int askroad(int x,int y)
{
    ll ans=0;
    while(top[x]!=top[y])
    {
        if(d[top[x]]<d[top[y]])
            swap(x,y);
        ans=(ans+ask(id[top[x]],id[x],1));
        x=f[top[x]];
    }
    if(id[x]>id[y]) swap(x,y);
    ans=(ans+ask(id[x],id[y],1));
    return ans;
}

signed main(void)
{
    scanf("%lld%lld",&n,&q);
    int pos,x,y;
    for(int i=1;i<n;i++)
    {
        scanf("%lld%lld",&x,&y);
        add(x,y);
        add(y,x);
    }
    for(int i=1;i<=n;i++)
        scanf("%lld",&vi[i]);
    dfs1(1,0);
    dfs2(1,1);
    build(1,cnt,1);
    //这里的d[x]的编号是从0开始的,即d[1]=0
    for(int i=1;i<=q;i++)
    {
        scanf("%lld",&pos);
        //我们设s[i] 为 以i为根节点的子树的和。
        //考虑修改,如果我们修改一个点x,那么影响的子树的和只是1--x这些点。
        //1--x上某点被改变后 s[i]*s[i]-->(s[i]+a)*(s[i]+a)-->s[i]*s[i]+2*a*s[i]+a*a
        //其中某点被改变后,新增的部分只有2*a*s[i]+a*a,1--x上的点全部被改变后,新增 2*a*sum of(s[i]) + (d[x]+1)*a*a
        //其中d[x]为1--x上面的点的数量,所以我们只需要用书剖维护一个sum of s[i] 即可实现单点更改的要求
        if(pos==1)
        {
            scanf("%lld%lld",&x,&y);
            ll pm=askroad(1,x);
            ll cm=y-vi[x];
            ans1+=cm*cm*(d[x]+1)+2*cm*pm;
            changeroad(1,x,cm);
            vi[x]=y;
        }
        //考虑换根
        //我们假设现在这条路为 1=x1-x2-x3----xk=x
        //我们设以1为根x1-xk各点的子树的权值和为ai
        //我们设以x为根x1-xk各点的子树的权值和为bi
        //那么有ansx=ans1-sumof(ai*ai)+sumof(bi*bi)
        //可以得到a[i+1]+b[i]=a[1]=b[k] --->都等于所有的点权和
        //可以得到
        //ansx=ans1-a1*a1-sum_of_(ai*ai)_from_2_to_k+bk*bk+sum_of((a1-a[i+1])*(a1-a[i+1))_from_1_to_k-1
        //ansx=ans1-a1*a1-sum_of_(ai*ai)_from_2_to_k+bk*bk+sum_of((a1-ai)*(a1-ai))_from_2_to_k
        //ansx=ans1-sum_of_(ai*ai)_from_2_to_k+sum_of((a1-ai)*(a1-ai))_from_2_to_k
        //ansx=ans1+(k-1)*a1*a1-2*a1*sum_of(ai)_from_2_to_k
        //ansx=ans1+(k-1)*a1*a1+2*a1*a1-2*a1*sum_of(ai)_from_2_to_k-2*a1*a1
        //ansx=ans1+(k+1)*a1*a1-2*a1*sumof(ai)
        //其中ai就是si
        else
        {
            scanf("%lld",&x);
            if(x==1) printf("%lld\n",ans1);
            else
            {
                ll pm=ask(id[1],id[1],1);
                ll cm=askroad(1,x);
                ll ansu=ans1+(d[x]+2)*pm*pm-2*pm*cm;
                printf("%lld\n",ansu);
            }
        }

    }
    return 0;
}