题目:

给一棵无根树。每个点有一个颜色个操作。
:将结点的颜色修改为
:问所有颜色为的结点的生成树大小。(若不存在颜色为的结点输出)


做法:

这题好神奇。可以用序来维护相同颜色结点的生成树。
考虑若颜色的只有个结点,则其生成树大小就是两个结点在原树上的距离,记为。此时,如果加入第个结点对生成树的贡献能用树上距离算出来:
①若之间:(有如下2种情况)
不在路径上和路径上。
图片说明
②若不在之间:(有如下2种情况)
图片说明
以上所有情况对生成树的贡献都是
所以我们可以对相同颜色点的集合按序排序维护(用)。每次加点或删点就取出左右相邻的个点计算贡献,若不存在左或右相邻,就用第②种情况,在一边取个点计算贡献,式子都是一样的。
至于的计算(应该算常识了吧)。以为根算出每个点的深度。(最近公共祖先)


代码:

#include <bits/stdc++.h>
#define IOS ios::sync_with_stdio(false), cin.tie(0)
#define debug(a) cout << #a ": " << a << endl
using namespace std;
typedef long long ll;
const int N = 1e5 + 7;
int n, a[N], ans[N];
vector<int> g[N];
set<int> s[N];
int dep[N], fa[N][20], dfn[N], tag, in[N];
void dfs(int u, int p){
    dfn[++tag] = u, in[u] = tag;
    dep[u] = dep[p] + 1;
    fa[u][0] = p;
    for (int i = 1; i < 20; ++i) fa[u][i] = fa[fa[u][i-1]][i-1];
    for (auto &v : g[u]) if (v != p) dfs(v, u);
}
int lca(int u, int v){
    if (dep[u] < dep[v]) swap(u, v);
    for (int i = 19; i >= 0; --i){
        if (dep[fa[u][i]] >= dep[v]) u = fa[u][i];
    }
    if (u == v) return u;
    for (int i = 19; i >= 0; --i){
        if (fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i];
    }
    return fa[u][0];
}
int dis(int u, int v){
    u = dfn[u], v = dfn[v];
    return dep[u]+dep[v]-2*dep[lca(u, v)];
}
void add(int x, int c){
    if (s[c].size() == 0){
        s[c].insert(x); ans[c] = 0;
        return;
    }
    auto it = s[c].lower_bound(x);
    if (it == s[c].begin() || it == s[c].end()){
        auto y = s[c].begin(); 
        auto z = s[c].rbegin();
        ans[c] += (dis(x, *y)+dis(x, *z)-dis(*y, *z))/2;
    }else{
        auto y = it, z = it; y--;
        ans[c] += (dis(x, *y)+dis(x, *z)-dis(*y, *z))/2;
    }
    s[c].insert(x);
}
void del(int x, int c){
    if (s[c].size() == 1){
        s[c].erase(x); ans[c] = -1;
        return;
    }
    s[c].erase(x);
    auto it = s[c].lower_bound(x);
    if (it == s[c].begin() || it == s[c].end()){
        auto y = s[c].begin(); 
        auto z = s[c].rbegin();
        ans[c] -= (dis(x, *y)+dis(x, *z)-dis(*y, *z))/2;
    }else{
        auto y = it, z = it; y--;
        ans[c] -= (dis(x, *y)+dis(x, *z)-dis(*y, *z))/2;
    }
}
int main(void){
    scanf("%d", &n);
    for (int i = 0; i < n-1; ++i){
        int u, v; scanf("%d%d", &u, &v);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs(1, 0);
    memset(ans, -1, sizeof ans);
    for (int i = 1; i <= n; ++i){
        scanf("%d", &a[i]);
        add(in[i], a[i]);
    } 
    int m; scanf("%d", &m);
    char op[5];
    while (m--){
        scanf("%s", op);
        if (op[0] == 'U'){
            int u, c; scanf("%d%d", &u, &c);
            del(in[u], a[u]); a[u] = c;
            add(in[u], a[u]);
        }else{
            int c; scanf("%d", &c);
            printf("%d\n", ans[c]);
        }
    }
    return 0;
}