题目:
给一棵无根树。每个点有一个颜色。个操作。
:将结点的颜色修改为;
:问所有颜色为的结点的生成树大小。(若不存在颜色为的结点输出)
做法:
这题好神奇。可以用序来维护相同颜色结点的生成树。
考虑若颜色的只有个结点、,则其生成树大小就是两个结点在原树上的距离,记为。此时,如果加入第个结点,对生成树的贡献能用树上距离算出来:
①若在、之间:(有如下2种情况)
不在路径上和在路径上。
②若不在、之间:(有如下2种情况)
以上所有情况对生成树的贡献都是
所以我们可以对相同颜色点的集合按序排序维护(用)。每次加点或删点就取出左右相邻的个点计算贡献,若不存在左或右相邻,就用第②种情况,在一边取个点计算贡献,式子都是一样的。
至于的计算(应该算常识了吧)。以为根算出每个点的深度。。(最近公共祖先)
代码:
#include <bits/stdc++.h> #define IOS ios::sync_with_stdio(false), cin.tie(0) #define debug(a) cout << #a ": " << a << endl using namespace std; typedef long long ll; const int N = 1e5 + 7; int n, a[N], ans[N]; vector<int> g[N]; set<int> s[N]; int dep[N], fa[N][20], dfn[N], tag, in[N]; void dfs(int u, int p){ dfn[++tag] = u, in[u] = tag; dep[u] = dep[p] + 1; fa[u][0] = p; for (int i = 1; i < 20; ++i) fa[u][i] = fa[fa[u][i-1]][i-1]; for (auto &v : g[u]) if (v != p) dfs(v, u); } int lca(int u, int v){ if (dep[u] < dep[v]) swap(u, v); for (int i = 19; i >= 0; --i){ if (dep[fa[u][i]] >= dep[v]) u = fa[u][i]; } if (u == v) return u; for (int i = 19; i >= 0; --i){ if (fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i]; } return fa[u][0]; } int dis(int u, int v){ u = dfn[u], v = dfn[v]; return dep[u]+dep[v]-2*dep[lca(u, v)]; } void add(int x, int c){ if (s[c].size() == 0){ s[c].insert(x); ans[c] = 0; return; } auto it = s[c].lower_bound(x); if (it == s[c].begin() || it == s[c].end()){ auto y = s[c].begin(); auto z = s[c].rbegin(); ans[c] += (dis(x, *y)+dis(x, *z)-dis(*y, *z))/2; }else{ auto y = it, z = it; y--; ans[c] += (dis(x, *y)+dis(x, *z)-dis(*y, *z))/2; } s[c].insert(x); } void del(int x, int c){ if (s[c].size() == 1){ s[c].erase(x); ans[c] = -1; return; } s[c].erase(x); auto it = s[c].lower_bound(x); if (it == s[c].begin() || it == s[c].end()){ auto y = s[c].begin(); auto z = s[c].rbegin(); ans[c] -= (dis(x, *y)+dis(x, *z)-dis(*y, *z))/2; }else{ auto y = it, z = it; y--; ans[c] -= (dis(x, *y)+dis(x, *z)-dis(*y, *z))/2; } } int main(void){ scanf("%d", &n); for (int i = 0; i < n-1; ++i){ int u, v; scanf("%d%d", &u, &v); g[u].push_back(v); g[v].push_back(u); } dfs(1, 0); memset(ans, -1, sizeof ans); for (int i = 1; i <= n; ++i){ scanf("%d", &a[i]); add(in[i], a[i]); } int m; scanf("%d", &m); char op[5]; while (m--){ scanf("%s", op); if (op[0] == 'U'){ int u, c; scanf("%d%d", &u, &c); del(in[u], a[u]); a[u] = c; add(in[u], a[u]); }else{ int c; scanf("%d", &c); printf("%d\n", ans[c]); } } return 0; }