题意
你有一颗树,每个点有一个颜色,有
次操作,
将
节点改为颜色
。
询问所有颜色为
的点的生成树大小。
分析
我们可以用序来维护每种颜色的生成树大小。
考虑若颜色为的节点只有
个
、
,则其生成树大小就是两个结点在树上的距离,记为
。此时,如果加入第
个结点
,
对生成树的贡献能用树上距离算出来:
我们分类讨论:
1.如果在
、
之间生成树的大小为
,加号后面的是加入
节点的贡献。
在分别考虑一下其他几种情况可以归纳出,你加入节点的贡献就是
。
所以我们可以对相同颜色点的集合按序排序用
维护。每次加点或删点就取出左右相邻的
个点计算贡献。
代码
#include<bits/stdc++.h>
#define ll long long
const int N=1e5+5,INF=0x3f3f3f3f,mod=998244353;
using namespace std;
int n,tot,cnt;
int a[N],ans[N];
set<int > s[N];
int dep[N],fa[N][20],dfn[N],in[N],head[N];
struct node
{
int nxt,to;
}e[N<<1];
inline int read()
{
register int x=0,f=1;char c=getchar();
while(c<'0'||c>'9'){if(c=='-') f=-1;c=getchar();}
while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
return x*f;
}
int qpow(int a,int b)
{
int ans=1;
while(b){if(b&1) ans=ans*a%mod;a=a*a%mod;b>>=1;}
return ans;
}
void adde(int u,int v)
{
e[++cnt].nxt=head[u];
e[cnt].to=v;
head[u]=cnt;
}
void dfs(int u, int fath)
{
dfn[++tot]=u,in[u]=tot;
dep[u]=dep[fath]+1;
fa[u][0]=fath;
for(int i=1;i<20;i++) fa[u][i]=fa[fa[u][i-1]][i-1];
for(int i=head[u];i;i=e[i].nxt) if(e[i].to!=fath) dfs(e[i].to,u);
}
int lca(int u,int v)
{
if(dep[u]<dep[v]) swap(u,v);
for(int i=19;i>=0;i--)
if(dep[fa[u][i]]>=dep[v]) u=fa[u][i];
if(u==v) return u;
for(int i=19;i>=0;i--)
if(fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i];
return fa[u][0];
}
int dis(int u, int v)
{
u=dfn[u],v=dfn[v];
return dep[u]+dep[v]-2*dep[lca(u,v)];
}
void add(int x,int c)
{
if(s[c].size()==0)
{
s[c].insert(x);ans[c]=0;
return;
}
auto it=s[c].lower_bound(x);
if(it==s[c].begin()||it==s[c].end())
{
auto y=s[c].begin();
auto z=s[c].rbegin();
ans[c]+=(dis(x,*y)+dis(x,*z)-dis(*y,*z))/2;
}
else
{
auto y=it,z=it;y--;
ans[c]+=(dis(x,*y)+dis(x,*z)-dis(*y,*z))/2;
}
s[c].insert(x);
}
void del(int x,int c)
{
if (s[c].size()==1)
{
s[c].erase(x);
ans[c]=-1;
return;
}
s[c].erase(x);
auto it=s[c].lower_bound(x);
if(it==s[c].begin()||it==s[c].end())
{
auto y=s[c].begin();
auto z=s[c].rbegin();
ans[c]-=(dis(x,*y)+dis(x,*z)-dis(*y,*z))/2;
}
else
{
auto y=it,z=it;
y--;
ans[c]-=(dis(x,*y)+dis(x,*z)-dis(*y,*z))/2;
}
}
int main()
{
n=read();
for (int i=1;i<n;i++)
{
int u=read(),v=read();
adde(u,v);adde(v,u);
}
dfs(1,0);
memset(ans,-1,sizeof(ans));
for(int i=1;i<=n;++i)
{
a[i]=read();
add(in[i],a[i]);
}
int m=read();
char op[5];
while (m--)
{
scanf("%s",op);
if(op[0]=='U')
{
int u=read(),c=read();
del(in[u],a[u]);
a[u]=c;
add(in[u],a[u]);
}
else
{
int c=read();
printf("%d\n",ans[c]);
}
}
return 0;
}
京公网安备 11010502036488号