题意
有一颗 个节点的树,以 为根节点,每个点有一个颜色 。设子树 中颜色出现次数最多的颜色集合为 ,记 。现在要求 。
其中,。
分析
这种题叫做 ,也就是树上启发式合并。
让我们先考虑暴力做法。
就是以每个节点,对子树进行 ,然后开一个桶记录颜色出现次数,最后把颜色出现次数最多的颜色加起来。这样子做复杂度是 的。
这复杂度显然是不可接受的嘛!暴力差就差在,它计算了很多重复的东西!如果我们能让重复的东西尽量减少计算,复杂度就能够得到提升了!
看到这道题,有的同学可能一下子想的是树上莫队。
确实,莫队算法就是用来优化这些有重复计算的东西的。
不过,更优秀的算法是用启发式合并,复杂度可以做到 。
一句话解释这个算法,就是保留重儿子的结果,暴力迭代轻儿子。
重儿子是什么?
如果你学过树链剖分,就能一下子知道了。不过没学过也没关系。重儿子就是这个节点所有儿子中 最大的点。如图:
的 最大,所以 是 的重儿子。
记 为颜色 的出现次数。
我们要让重复求的东西尽量少,但是子树之间又互相独立,于是我们只能钦点一个子树来保留 的值。既然重儿子如此牛逼,那我们就钦点重儿子吧!
假设我们现在要求 ,我们已经保留了 子树的 数组。
那我们从 开始遍历一遍子树,如果遇到 ,就继续往下 ,求出 。如果遇到 ,那么就可以 了,因为之前求过 了。再求一遍不是智障了吗??
这样子,单次求 的复杂度是 的。
写成代码的话长这样:
关于复杂度
不妨考虑每个点会被访问多少次。
如果一个节点到根节点有 条轻边,那么这个节点会被访问 次。
由于一个节点到根节点的轻边数量不超过 条。
于是总的复杂度为
代码如下
#include <bits/stdc++.h> #define N 100005 using namespace std; typedef long long LL; typedef unsigned long long uLL; LL z = 1; int read(){ int x, f = 1; char ch; while(ch = getchar(), ch < '0' || ch > '9') if(ch == '-') f = -1; x = ch - '0'; while(ch = getchar(), ch >= '0' && ch <= '9') x = x * 10 + ch - 48; return x * f; } struct node{ int a, b, n; }d[N * 2]; int fa[N], siz[N], son[N], h[N], v[N], cnt; int tot[N], Son, maxn; LL ans[N], sum; void cr(int a, int b){ d[++cnt].a = a; d[cnt].b = b; d[cnt].n = h[a]; h[a] = cnt; } void dfs1(int a){ int i, b; siz[a] = 1; for(i = h[a]; i; i = d[i].n){ b = d[i].b; if(b == fa[a]) continue; fa[b] = a; dfs1(b); siz[a] += siz[b]; if(siz[b] >= siz[son[a]]) son[a] = b;//找到重儿子 } } void add(int a, int c){//遍历 a 的子树,求出 ans[a] int i, b; tot[v[a]] += c;//更新 tot 数组 if(maxn < tot[v[a]]) maxn = tot[v[a]], sum = v[a]; else if(maxn == tot[v[a]]) sum += v[a];//这一步是在更新 sum 和 maxn for(i = h[a]; i; i = d[i].n){ b = d[i].b; if(b == fa[a] || b == Son) continue;//遇到重儿子就return,所以只遍历轻儿子 add(b, c); } } void dsu(int a, int flag){ int i, b; for(i = h[a]; i; i = d[i].n){ b = d[i].b; if(b != fa[a] && b != son[a]) dsu(b, 1);//先求轻儿子 } if(son[a]) dsu(son[a], 0), Son = son[a]; //再求重儿子 add(a, 1); Son = 0; ans[a] = sum;//求出 ans[a],同时把重儿子标记去除(没去除的话无法清空 tot 数组 if(flag) add(a, -1), sum = 0, maxn = 0;//如果当前节点是轻儿子,就清空 tot 数组并且重置 sum 和 maxn } int main(){ int i, j, n, m, a, b; n = read(); for(i = 1; i <= n; i++) v[i] = read(); for(i = 1; i < n; i++){ a = read(); b = read(); cr(a, b); cr(b, a); } dfs1(1); dsu(1, 0); for(i = 1; i <= n; i++) printf("%lld ", ans[i]); return 0; }