题意

你有一颗树,每个点有一个颜色,有次操作, 节点改为颜色 询问所有颜色为 的点的生成树大小。

分析

我们可以用序来维护每种颜色的生成树大小。
考虑若颜色为的节点只有,则其生成树大小就是两个结点在树上的距离,记为。此时,如果加入第个结点对生成树的贡献能用树上距离算出来:

我们分类讨论:
1.如果之间生成树的大小为,加号后面的是加入节点的贡献。
在分别考虑一下其他几种情况可以归纳出,你加入节点的贡献就是
所以我们可以对相同颜色点的集合按序排序用维护。每次加点或删点就取出左右相邻的个点计算贡献。

代码

#include<bits/stdc++.h>
#define ll long long
const int N=1e5+5,INF=0x3f3f3f3f,mod=998244353;
using namespace std;

int n,tot,cnt;
int a[N],ans[N];
set<int > s[N];
int dep[N],fa[N][20],dfn[N],in[N],head[N];
struct node
{
    int nxt,to;
}e[N<<1];

inline int read()
{
    register int x=0,f=1;char c=getchar();
    while(c<'0'||c>'9'){if(c=='-') f=-1;c=getchar();}
    while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
    return x*f;
}

int qpow(int a,int b)
{
    int ans=1;
    while(b){if(b&1) ans=ans*a%mod;a=a*a%mod;b>>=1;}
    return ans;
}

void adde(int u,int v)
{
    e[++cnt].nxt=head[u];
    e[cnt].to=v;
    head[u]=cnt;
}

void dfs(int u, int fath)
{
    dfn[++tot]=u,in[u]=tot;
    dep[u]=dep[fath]+1;
    fa[u][0]=fath;
    for(int i=1;i<20;i++) fa[u][i]=fa[fa[u][i-1]][i-1];
    for(int i=head[u];i;i=e[i].nxt) if(e[i].to!=fath) dfs(e[i].to,u);
}

int lca(int u,int v)
{
    if(dep[u]<dep[v]) swap(u,v);
    for(int i=19;i>=0;i--)
        if(dep[fa[u][i]]>=dep[v]) u=fa[u][i];
    if(u==v) return u;
    for(int i=19;i>=0;i--)
        if(fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i];
    return fa[u][0];
}

int dis(int u, int v)
{
    u=dfn[u],v=dfn[v];
    return dep[u]+dep[v]-2*dep[lca(u,v)];
}

void add(int x,int c)
{
    if(s[c].size()==0)
    {
        s[c].insert(x);ans[c]=0;
        return;
    }
    auto it=s[c].lower_bound(x);
    if(it==s[c].begin()||it==s[c].end())
    {
        auto y=s[c].begin(); 
        auto z=s[c].rbegin();
        ans[c]+=(dis(x,*y)+dis(x,*z)-dis(*y,*z))/2;
    }
    else
    {
        auto y=it,z=it;y--;
        ans[c]+=(dis(x,*y)+dis(x,*z)-dis(*y,*z))/2;
    }
    s[c].insert(x);
}

void del(int x,int c)
{
    if (s[c].size()==1)
    {
        s[c].erase(x); 
        ans[c]=-1;
        return;
    }
    s[c].erase(x);
    auto it=s[c].lower_bound(x);
    if(it==s[c].begin()||it==s[c].end())
    {
        auto y=s[c].begin(); 
        auto z=s[c].rbegin();
        ans[c]-=(dis(x,*y)+dis(x,*z)-dis(*y,*z))/2;
    }
    else 
    {
        auto y=it,z=it;
        y--;
        ans[c]-=(dis(x,*y)+dis(x,*z)-dis(*y,*z))/2;
    }
}
int main()
{
    n=read();
    for (int i=1;i<n;i++)
    {
        int u=read(),v=read();
        adde(u,v);adde(v,u);
    }
    dfs(1,0);
    memset(ans,-1,sizeof(ans));
    for(int i=1;i<=n;++i)
    {
        a[i]=read();
        add(in[i],a[i]);
    } 
    int m=read();
    char op[5];
    while (m--)
    {
        scanf("%s",op);
        if(op[0]=='U')
        {
            int u=read(),c=read();
            del(in[u],a[u]);
            a[u]=c;
            add(in[u],a[u]);
        }
        else
        {
            int c=read(); 
            printf("%d\n",ans[c]);
        }
    }
    return 0;
}