题目大意

给定一棵大小为n的树,有m次操作,分为三种:

  • (1 x w) 位置x上的权值+w,同时所有位置的权值加上w-dis(x,y);(dis(x,y)为从x到y的边数)
  • (2 x) 如果x位置的权值>0,那么设为0。
  • (3 x) 输出x位置的权值

解题思路

这题有很多种做法,我再这里用的是魔改版的树链剖分(刚学会)。

操作1

  • 我们在画图分析后,可以得到这样一个简单的式子,来转化操作1:
    (LCA:最近公共祖先,应该都了解怎么求,的深度)

  • 对于任意的来说,都应该是个常数,所以我们可以专门记录。
    接下来,我们设第次的操作节点,修改的节点为
    可以得到,我们每执行一次操作1,就可以在每个结点的权值加上(这里的是操作1的次数)
    那么,我们的式子就是

  • 但是求得这个就是较为棘手的了,需要开脑洞:
    对于每次操作1,我们可以从根到x这条路径上所有的结点权值+2,然后在询问时,求就是求根到y这条路径上结点的权值之和。

操作2

对于这个归零的操作,我们用一个额外的dt数组记录。
如果当前x的权值大于0,需要修改,那么我们就使dt[x]减去x的权值。

操作3:计算后输出即可

AC代码

#include<bits/stdc++.h>
using namespace std;
const int N=5e4+10;
struct node
{
    int l,r,s;
    long long x,y;
} tr[N<<2];
struct link
{
    int x,y;
} e[N<<1];
int a[N],d[N],f[N],h[N],id[N],s[N],tot[N],top[N],z[N],n,t,r,q,mod;
long long dt[N],sum,num;
/////////////////////////////////////////
void pushup(int x)
{
    tr[x].y=tr[x<<1].y+tr[x<<1|1].y;
}
void pushdown(int x)
{
    if(!tr[x].x) return;
    tr[x<<1].y=tr[x<<1].y+tr[x].x*(1ll*tr[x<<1].s);
    tr[x<<1|1].y=tr[x<<1|1].y+tr[x].x*(1ll*tr[x<<1|1].s);
    tr[x<<1].x=tr[x<<1].x+tr[x].x;
    tr[x<<1|1].x=tr[x<<1|1].x+tr[x].x;
    tr[x].x=0;
}
void build(int x,int l,int r)
{
    tr[x].l=l,tr[x].r=r,tr[x].s=r-l+1;
    if(l==r)
    {
        tr[x].y=1ll*a[l];
        return;
    }
    int m=(l+r)/2;
    build(x<<1,l,m);
    build(x<<1|1,m+1,r);
    pushup(x);
}
/////////////////////////////////////////
void add(int x,int l,int r,int y)
{
    if(l<=tr[x].l && tr[x].r<=r)
    {
        tr[x].y+=(1ll*tr[x].s)*(1ll*y);
        tr[x].x=(tr[x].x+1ll*y);
        return;
    }
    pushdown(x);
    if(tr[x<<1].r>=l) add(x<<1,l,r,y);
    if(tr[x<<1|1].l<=r) add(x<<1|1,l,r,y);
    pushup(x);
}
void add1(int x,int y)
{
    e[++t].x=y,e[t].y=h[x],h[x]=t;
}
void add2(int x,int y,int z)
{
    while(top[x]!=top[y])
    {
        if(d[top[x]]<d[top[y]]) swap(x,y);
        add(1,id[top[x]],id[x],z);
        x=f[top[x]];
    }
    if(d[x]>d[y]) swap(x,y);
    add(1,id[x],id[y],z);
}
/////////////////////////////////////////
long long query(int x,int l,int r)
{
    if(l<=tr[x].l && tr[x].r<=r) return tr[x].y;
    pushdown(x);
    long long ans=0;
    if(tr[x<<1].r>=l) ans=(ans+query(x<<1,l,r));
    if(tr[x<<1|1].l<=r) ans=(ans+query(x<<1|1,l,r));
    return ans;
}
int dfs(int x,int y,int z)
{
    d[x]=z,f[x]=y,tot[x]=1;
    int u,v=-1,i;
    for(i=h[x];i;i=e[i].y)
    {
        u=e[i].x;
        if(u==y) continue;
        tot[x]+=dfs(u,x,z+1);
        if(tot[u]>v) v=tot[u],s[x]=u;
    }
    return tot[x];
}
void ddfs(int x,int y)
{
    id[x]=++q,a[q]=z[x],top[x]=y;
    if(!s[x]) return;
    ddfs(s[x],y);
    int z,i;
    for(i=h[x];i;i=e[i].y)
    {
        z=e[i].x;
        if(z==f[x]||z==s[x]) continue;
        ddfs(z,z);
    }
}
long long find(int x,int y)
{
    long long ans=0;
    while(top[x]!=top[y])
    {
        if(d[top[x]]<d[top[y]]) swap(x,y);
        ans=(ans+query(1,id[top[x]],id[x]));
        x=f[top[x]];
    }
    if(d[x]>d[y]) swap(x,y);
    ans=(ans+query(1,id[x],id[y]));
    return ans;
}
/////////////////////////////////////////
int main()
{
    int T,m,x,y,z,ch,i;
    long long ans;
    scanf("%d",&T);
    while(T--)
    {
        memset(dt,0,sizeof(dt));
        memset(h,0,sizeof(h));
        memset(s,0,sizeof(s));
        memset(tr,0,sizeof(tr));
        t=q=sum=num=0,r=1;
        scanf("%d%d",&n,&m);
        for(i=1;i<n;i++)
        {
            scanf("%d%d",&x,&y);
            add1(x,y);add1(y,x);
        }
        dfs(r,0,1);ddfs(r,r);build(1,1,n);
        while(m--)
        {
            scanf("%d",&ch);
            if(ch==1)
            {
                scanf("%d%d",&x,&y);
                sum+=1ll*y-1ll*d[x],num++;
                add2(r,x,2);
            }
            else if(ch==2)
            {
                scanf("%d",&x);
                ans=sum-1ll*d[x]*num+find(r,x)+dt[x];
                if(ans>=0) dt[x]-=ans;
            }
            else
            {
                scanf("%d",&x);
                ans=sum-1ll*d[x]*num+find(r,x)+dt[x];
                printf("%lld\n",ans);
            }
        }
    }
}