Description
给出一棵树,每个节点都有颜色。对树有两种操作:
U x c 更新操作,将x节点的颜色变成c
Q c 查询所有颜色为c的节点形成的子图有多少条边
Solution
完全没思路。。。看了其他聚聚的题解
通过找规律可以看出加入每条边x的贡献是
因此可以用一个set来维护每个颜色的序列,然后每次修改时在set里取点计算贡献。
关于距离 可以用 来得到。
太菜了,5555555555555
Code
#include <bits/stdc++.h> using namespace std; typedef long long ll; const int N = 1e5 + 5; int a[N], dep[N], fa[N][25], ans[N]; vector<int> G[N]; set<int> st[N]; int n, tot, dfn[N], in[N], out[N]; void dfs(int u, int par){ dfn[++tot] = u, in[u] = tot; dep[u] = dep[par] + 1; fa[u][0] = par; for (int i = 1; i < 20; ++i) fa[u][i] = fa[fa[u][i-1]][i-1]; for (auto &v : G[u]) { if (v == par) continue; dfs(v, u); } } int lca(int u, int v){ if (dep[u] < dep[v]) swap(u, v); for (int i = 19; i >= 0; --i){ if (dep[fa[u][i]] >= dep[v]) u = fa[u][i]; } if (u == v) return u; for (int i = 19; i >= 0; --i){ if (fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i]; } return fa[u][0]; } int dis(int u, int v){ u = dfn[u], v = dfn[v]; return dep[u] + dep[v] - 2 * dep[lca(u, v)]; } void add(int x, int c){ if (st[c].size() == 0){ st[c].insert(x); ans[c] = 0; return; } auto it = st[c].lower_bound(x); if (it == st[c].begin() || it == st[c].end()){ auto y = st[c].begin(); auto z = st[c].rbegin(); ans[c] += (dis(x, *y)+dis(x, *z)-dis(*y, *z))/2; }else{ auto y = it, z = it; y--; ans[c] += (dis(x, *y)+dis(x, *z)-dis(*y, *z))/2; } st[c].insert(x); } void del(int x, int c){ if (st[c].size() == 1){ st[c].erase(x); ans[c] = -1; return; } st[c].erase(x); auto it = st[c].lower_bound(x); if (it == st[c].begin() || it == st[c].end()){ auto y = st[c].begin(); auto z = st[c].rbegin(); ans[c] -= (dis(x, *y)+dis(x, *z)-dis(*y, *z))/2; }else{ auto y = it, z = it; y--; ans[c] -= (dis(x, *y)+dis(x, *z)-dis(*y, *z))/2; } } int main() { ios::sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr); int n, m; cin >> n; for(int i = 0; i < n - 1; i++) { int u, v; cin >> u >> v; G[u].push_back(v), G[v].push_back(u); } dfs(1, 0); memset(ans, -1, sizeof(ans)); for(int i = 1; i <= n; i++) { cin >> a[i]; add(in[i], a[i]); } cin >> m; while(m--) { char op[5]; cin >> op; if(op[0] == 'U') { int x, y; cin >> x >> y; del(in[x], a[x]); a[x] = y; add(in[x], a[x]); } else { int x; cin >> x; cout << ans[x] << '\n'; } } return 0; }