题目:
给一棵无根树。每个点有一个颜色。
个操作。
:将结点
的颜色修改为
;
:问所有颜色为
的结点的生成树大小。(若不存在颜色为
的结点输出
)
做法:
这题好神奇。可以用序来维护相同颜色结点的生成树。
考虑若颜色的只有
个结点
、
,则其生成树大小就是两个结点在原树上的距离,记为
。此时,如果加入第
个结点
,
对生成树的贡献能用树上距离算出来:
①若在
、
之间:(有如下2种情况)
不在
路径上和
在
路径上。
②若不在
、
之间:(有如下2种情况)
以上所有情况对生成树的贡献都是
所以我们可以对相同颜色点的集合按序排序维护(用
)。每次加点或删点就取出左右相邻的
个点计算贡献,若不存在左或右相邻,就用第②种情况,在一边取
个点计算贡献,式子都是一样的。
至于的计算(
应该算常识了吧)。以为根算出每个点的深度
。
。(
最近公共祖先)
代码:
#include <bits/stdc++.h>
#define IOS ios::sync_with_stdio(false), cin.tie(0)
#define debug(a) cout << #a ": " << a << endl
using namespace std;
typedef long long ll;
const int N = 1e5 + 7;
int n, a[N], ans[N];
vector<int> g[N];
set<int> s[N];
int dep[N], fa[N][20], dfn[N], tag, in[N];
void dfs(int u, int p){
dfn[++tag] = u, in[u] = tag;
dep[u] = dep[p] + 1;
fa[u][0] = p;
for (int i = 1; i < 20; ++i) fa[u][i] = fa[fa[u][i-1]][i-1];
for (auto &v : g[u]) if (v != p) dfs(v, u);
}
int lca(int u, int v){
if (dep[u] < dep[v]) swap(u, v);
for (int i = 19; i >= 0; --i){
if (dep[fa[u][i]] >= dep[v]) u = fa[u][i];
}
if (u == v) return u;
for (int i = 19; i >= 0; --i){
if (fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i];
}
return fa[u][0];
}
int dis(int u, int v){
u = dfn[u], v = dfn[v];
return dep[u]+dep[v]-2*dep[lca(u, v)];
}
void add(int x, int c){
if (s[c].size() == 0){
s[c].insert(x); ans[c] = 0;
return;
}
auto it = s[c].lower_bound(x);
if (it == s[c].begin() || it == s[c].end()){
auto y = s[c].begin();
auto z = s[c].rbegin();
ans[c] += (dis(x, *y)+dis(x, *z)-dis(*y, *z))/2;
}else{
auto y = it, z = it; y--;
ans[c] += (dis(x, *y)+dis(x, *z)-dis(*y, *z))/2;
}
s[c].insert(x);
}
void del(int x, int c){
if (s[c].size() == 1){
s[c].erase(x); ans[c] = -1;
return;
}
s[c].erase(x);
auto it = s[c].lower_bound(x);
if (it == s[c].begin() || it == s[c].end()){
auto y = s[c].begin();
auto z = s[c].rbegin();
ans[c] -= (dis(x, *y)+dis(x, *z)-dis(*y, *z))/2;
}else{
auto y = it, z = it; y--;
ans[c] -= (dis(x, *y)+dis(x, *z)-dis(*y, *z))/2;
}
}
int main(void){
scanf("%d", &n);
for (int i = 0; i < n-1; ++i){
int u, v; scanf("%d%d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1, 0);
memset(ans, -1, sizeof ans);
for (int i = 1; i <= n; ++i){
scanf("%d", &a[i]);
add(in[i], a[i]);
}
int m; scanf("%d", &m);
char op[5];
while (m--){
scanf("%s", op);
if (op[0] == 'U'){
int u, c; scanf("%d%d", &u, &c);
del(in[u], a[u]); a[u] = c;
add(in[u], a[u]);
}else{
int c; scanf("%d", &c);
printf("%d\n", ans[c]);
}
}
return 0;
}
京公网安备 11010502036488号