Description

给出一棵树,每个节点都有颜色。对树有两种操作:
U x c 更新操作,将x节点的颜色变成c
Q c 查询所有颜色为c的节点形成的子图有多少条边

Solution

完全没思路。。。看了其他聚聚的题解
通过找规律可以看出加入每条边x的贡献是
因此可以用一个set来维护每个颜色的序列,然后每次修改时在set里取点计算贡献。
关于距离 可以用 来得到。
太菜了,5555555555555

Code

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5 + 5;
int a[N], dep[N], fa[N][25], ans[N];
vector<int> G[N];
set<int> st[N];
int n, tot, dfn[N], in[N], out[N];
void dfs(int u, int par){
  dfn[++tot] = u, in[u] = tot;
  dep[u] = dep[par] + 1;
  fa[u][0] = par;
  for (int i = 1; i < 20; ++i) fa[u][i] = fa[fa[u][i-1]][i-1];
  for (auto &v : G[u]) {
    if (v == par) continue;
    dfs(v, u);
  }
}
int lca(int u, int v){
    if (dep[u] < dep[v]) swap(u, v);
    for (int i = 19; i >= 0; --i){
      if (dep[fa[u][i]] >= dep[v])
        u = fa[u][i];
    }
    if (u == v)
      return u;
    for (int i = 19; i >= 0; --i){
      if (fa[u][i] != fa[v][i])
        u = fa[u][i], v = fa[v][i];
    }
    return fa[u][0];
}
int dis(int u, int v){
    u = dfn[u], v = dfn[v];
    return dep[u] + dep[v] - 2 * dep[lca(u, v)];
}
void add(int x, int c){
    if (st[c].size() == 0){
        st[c].insert(x); ans[c] = 0;
        return;
    }
    auto it = st[c].lower_bound(x);
    if (it == st[c].begin() || it == st[c].end()){
        auto y = st[c].begin();
        auto z = st[c].rbegin();
        ans[c] += (dis(x, *y)+dis(x, *z)-dis(*y, *z))/2;
    }else{
        auto y = it, z = it; y--;
        ans[c] += (dis(x, *y)+dis(x, *z)-dis(*y, *z))/2;
    }
    st[c].insert(x);
}
void del(int x, int c){
    if (st[c].size() == 1){
        st[c].erase(x); ans[c] = -1;
        return;
    }
    st[c].erase(x);
    auto it = st[c].lower_bound(x);
    if (it == st[c].begin() || it == st[c].end()){
        auto y = st[c].begin();
        auto z = st[c].rbegin();
        ans[c] -= (dis(x, *y)+dis(x, *z)-dis(*y, *z))/2;
    }else{
        auto y = it, z = it; y--;
        ans[c] -= (dis(x, *y)+dis(x, *z)-dis(*y, *z))/2;
    }
}
int main() {
  ios::sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr);
  int n, m; cin >> n;
  for(int i = 0; i < n - 1; i++) {
    int u, v; cin >> u >> v;
    G[u].push_back(v), G[v].push_back(u);
  }
  dfs(1, 0);
  memset(ans, -1, sizeof(ans));
  for(int i = 1; i <= n; i++) {
    cin >> a[i];
    add(in[i], a[i]);
  }
  cin >> m;
  while(m--) {
    char op[5]; cin >> op;
    if(op[0] == 'U') {
      int x, y; cin >> x >> y;
      del(in[x], a[x]); a[x] = y;
      add(in[x], a[x]);
    } else {
      int x; cin >> x;
      cout << ans[x] << '\n';
    }
  }
  return 0;
}