#妙妙题 #二分 #dfn #lca
题意
- 给定一个无根树,每个结点有自己的颜色,处理两种操作
- 求颜色x的生成树的大小
- 将结点x的颜色改为y
思路
- 对于某一种颜色
- 如果只有一个点,生成树大小是0
- 如果有两个点,生成树大小是这两个点的距离
- 如果在此基础上再加入新的点,会有两种情况
- 一种是新点x加入的位置两旁有其它点u,v,这时候新点加入的贡献是
- 而另一种是新点x加入的位置位于最左侧或者最右侧,此时任取两个点计算代价都是一样的,代价仍然是
- 一种是新点x加入的位置两旁有其它点u,v,这时候新点加入的贡献是
- 最终发现这个公式在两个点的时候也成立,那么对每一个结点维护一个set,然后二分找最近的两个点,如果二分没有结果就用任意两个点,每次添加和修改都要对答案维护
代码
#include<bits/stdc++.h>
#define endl '\n'
using namespace std;
vector<int> G[101010];
set<int> st[101010];
int F[101010][20];
int n;
int a[101010];
int b[101010];
int dfn[101010],ans[101010],dep[101010];
int tim;
void dfs(int x,int fa){
dep[x]=dep[fa]+1;
b[++tim]=x;
dfn[x]=tim;
F[x][0]=fa;
for(int i=1;(1<<i)<=dep[x];i++)
F[x][i]=F[F[x][i-1]][i-1];
for(auto i:G[x]){
if(i==fa) continue;
dfs(i,x);
}
}
int LCA(int a,int b){
if(dep[a]>dep[b]) swap(a,b);
for(int i=dep[b]-dep[a],j=0;i>0;i>>=1,j++){
if(i&1)
b=F[b][j];
}
if(a==b) return a;
int k;
for(k=0;(1<<k)<=dep[a];k++);
for(;k>=0;k--){
if(F[a][k]==F[b][k]) continue;
else{
a=F[a][k];
b=F[b][k];
}
}
return F[a][0];
}
int dis(int u,int v){
u=b[u],v=b[v];
return dep[u]+dep[v]-2*dep[LCA(u,v)];
}
void add(int col,int idx){
if(st[col].size()==0){
st[col].insert(idx);
ans[col]=0;
return ;
}
auto it=st[col].lower_bound(idx);
if (it==st[col].begin()||it==st[col].end()){
auto y=st[col].begin();
auto z=st[col].rbegin();
ans[col]+=(dis(idx,*y)+dis(idx,*z)-dis(*y,*z))/2;
}else{
auto y=it,z=it;
y--;
ans[col]+=(dis(idx,*y)+dis(idx,*z)-dis(*y,*z))/2;
}
st[col].insert(idx);
}
void del(int col,int idx){
if(st[col].size()==1){
st[col].erase(idx);
ans[col]=-1;
return ;
}
st[col].erase(idx);
auto it=st[col].lower_bound(idx);
if (it==st[col].begin()||it==st[col].end()){
auto y=st[col].begin();
auto z=st[col].rbegin();
ans[col]-=(dis(idx,*y)+dis(idx,*z)-dis(*y,*z))/2;
}else{
auto y=it,z=it;
y--;
ans[col]-=(dis(idx,*y)+dis(idx,*z)-dis(*y,*z))/2;
}
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
cin >> n;
for(int i=0;i<n-1;i++){
int u,v;
cin >> u >> v;
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1,0);
memset(ans,-1,sizeof(ans));
for(int i=1;i<=n;i++){
cin >> a[i];
add(a[i],dfn[i]);
}
int m;
cin >> m;
for(int i=0;i<m;i++){
char op;
cin >> op;
if(op=='Q'){
int tmp;
cin >> tmp;
cout << ans[tmp] << endl;
}else{
int x,y;
cin >> x >> y;
del(a[x],dfn[x]);
a[x]=y;
add(a[x],dfn[x]);
}
}
return 0;
}