题意
你有一颗树,每个点有一个颜色,有次操作, 将 节点改为颜色 。询问所有颜色为 的点的生成树大小。
分析
我们可以用序来维护每种颜色的生成树大小。
考虑若颜色为的节点只有个、,则其生成树大小就是两个结点在树上的距离,记为。此时,如果加入第个结点,对生成树的贡献能用树上距离算出来:
我们分类讨论:
1.如果在、之间生成树的大小为,加号后面的是加入节点的贡献。
在分别考虑一下其他几种情况可以归纳出,你加入节点的贡献就是。
所以我们可以对相同颜色点的集合按序排序用维护。每次加点或删点就取出左右相邻的个点计算贡献。
代码
#include<bits/stdc++.h> #define ll long long const int N=1e5+5,INF=0x3f3f3f3f,mod=998244353; using namespace std; int n,tot,cnt; int a[N],ans[N]; set<int > s[N]; int dep[N],fa[N][20],dfn[N],in[N],head[N]; struct node { int nxt,to; }e[N<<1]; inline int read() { register int x=0,f=1;char c=getchar(); while(c<'0'||c>'9'){if(c=='-') f=-1;c=getchar();} while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar(); return x*f; } int qpow(int a,int b) { int ans=1; while(b){if(b&1) ans=ans*a%mod;a=a*a%mod;b>>=1;} return ans; } void adde(int u,int v) { e[++cnt].nxt=head[u]; e[cnt].to=v; head[u]=cnt; } void dfs(int u, int fath) { dfn[++tot]=u,in[u]=tot; dep[u]=dep[fath]+1; fa[u][0]=fath; for(int i=1;i<20;i++) fa[u][i]=fa[fa[u][i-1]][i-1]; for(int i=head[u];i;i=e[i].nxt) if(e[i].to!=fath) dfs(e[i].to,u); } int lca(int u,int v) { if(dep[u]<dep[v]) swap(u,v); for(int i=19;i>=0;i--) if(dep[fa[u][i]]>=dep[v]) u=fa[u][i]; if(u==v) return u; for(int i=19;i>=0;i--) if(fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i]; return fa[u][0]; } int dis(int u, int v) { u=dfn[u],v=dfn[v]; return dep[u]+dep[v]-2*dep[lca(u,v)]; } void add(int x,int c) { if(s[c].size()==0) { s[c].insert(x);ans[c]=0; return; } auto it=s[c].lower_bound(x); if(it==s[c].begin()||it==s[c].end()) { auto y=s[c].begin(); auto z=s[c].rbegin(); ans[c]+=(dis(x,*y)+dis(x,*z)-dis(*y,*z))/2; } else { auto y=it,z=it;y--; ans[c]+=(dis(x,*y)+dis(x,*z)-dis(*y,*z))/2; } s[c].insert(x); } void del(int x,int c) { if (s[c].size()==1) { s[c].erase(x); ans[c]=-1; return; } s[c].erase(x); auto it=s[c].lower_bound(x); if(it==s[c].begin()||it==s[c].end()) { auto y=s[c].begin(); auto z=s[c].rbegin(); ans[c]-=(dis(x,*y)+dis(x,*z)-dis(*y,*z))/2; } else { auto y=it,z=it; y--; ans[c]-=(dis(x,*y)+dis(x,*z)-dis(*y,*z))/2; } } int main() { n=read(); for (int i=1;i<n;i++) { int u=read(),v=read(); adde(u,v);adde(v,u); } dfs(1,0); memset(ans,-1,sizeof(ans)); for(int i=1;i<=n;++i) { a[i]=read(); add(in[i],a[i]); } int m=read(); char op[5]; while (m--) { scanf("%s",op); if(op[0]=='U') { int u=read(),c=read(); del(in[u],a[u]); a[u]=c; add(in[u],a[u]); } else { int c=read(); printf("%d\n",ans[c]); } } return 0; }