树上启发式合并
这个是啥?
先说一下并查集的启发式合并:
并查集的启发式合并就是把集合小的并到集合大的上去。(按秩合并是把低的并到高的上面去)
于是树上的也差不多: 就这样一个优化的思路。把大的并到小的上去
树上:
什么是大的:重儿子的那个子树。
小的: 轻儿子的那几个子树。
就是算答案的时候把轻儿子的贡献并到重儿子的贡献上去于是就有了当前根节点的答案。
说说例题吧:
cf600E
题目大意:
给出一棵树,根节点为1;每个节点都有一个颜色、求以每个子树上颜色最多的颜色代号和。就是 如当前子树上1号颜色最多,答案就是1 。如果1号、2号颜色的节点一样多的话,答案是1+2 = 3;
思路:
col[maxn]数组表示 i 节点的颜色。
cnt[maxn]数组表示当前编号为 i 颜色的数量
于是就dfs计算节点的答案。
计算答案的时候当前根节点的答案就是把轻节点的子树合并到重儿子的子树上。也就是大的并到小的上面。
启发式合并板子一般为
void dfs2(int x,int fa,int f)
{
for (int i = 0; i < vv[x].size(); i ++ )
{
int v = vv[x][i];
if(v == fa || v == son[x])
continue;
dfs2(v,x,1);
}
if(son[x])
{
dfs2(son[x],x,0);
}
solve(x,fa,1);
if(f)
{
del();
}
}
这个板子需要说一下。
参数里 f 的作用: 标记 当前节点是否为轻节点,因为只有一个cnt数组记录col的数量,所以得先把轻节点的子树的贡献删去。f就是用来标记是否需要删去的。
为什么先递归轻儿子?因为重儿子的贡献不用删去。
代码:
下面代码里把大的并到小的上面,然后顺便计算答案的函数为count(x,fa,val);
val 的作用:需要合并,合并的时候就是再遍历一边轻儿子的子树,然后颜色数量加1,因为有的时候要删去贡献,也就是删去cnt数组里的值。删除的时候儿子的节点的颜色数量-1.此时val为-1。
flag的作用: 因为合并的时候不用遍历这个根节点的重儿子。flag就是用来标记他的重儿子的。为什么不直接写成 v != son[x] ?因为他的轻儿子的重儿子还是要遍历的。记得!!!
好了不废话了,上代码:
#include<stdio.h>
#include<algorithm>
#include<vector>
using namespace std;
typedef long long ll;
const int maxn = 1e5+5;
int col[maxn];
int cnt[maxn];
int num[maxn];
vector<int> vv[maxn];
int son[maxn];
ll ans[maxn];
void dfs(int x,int fa)
{
num[x] = 1;
for (int i = 0; i < vv[x].size(); i ++ )
{
int v = vv[x][i];
if(v == fa)
continue;
dfs(v,x);
num[x] += num[v];
if(num[v] > num[son[x]])
son[x] = v;
}
}
ll sum = 0;
int maxc = 0;
int flag = 0;
void solve(int x,int fa,int val)
{
cnt[col[x]] += val;
if(cnt[col[x]] > maxc)
{
maxc = cnt[col[x]];
sum = col[x];
}
else if(cnt[col[x]] == maxc)
{
sum += 1ll * col[x];
}
for (int i =0 ; i < vv[x].size(); i ++ )
{
int v = vv[x][i];
if(v == fa || v == flag)
continue;
solve(v,x,val);
}
}
void dfs2(int x,int fa,int f)
{
for (int i = 0; i < vv[x].size(); i ++ )
{
int v = vv[x][i];
if(v == fa || v == son[x])
continue;
dfs2(v,x,1);
}
if(son[x])
{
dfs2(son[x],x,0);
flag = son[x];
}
solve(x,fa,1);
flag = 0;
ans[x] = sum;
if(f)
{
solve(x,fa,-1);
sum = maxc = 0;
}
}
int main()
{
int n;
scanf("%d",&n);
for (int i = 1; i <= n; i ++ )
{
scanf("%d",&col[i]);
}
for (int i = 1; i < n; i ++ )
{
int x,y;
scanf("%d%d",&x,&y);
vv[x].push_back(y);
vv[y].push_back(x);
}
dfs(1,0);
dfs2(1,0,0);
for(int i = 1; i <= n; i ++ )
{
printf("%lld ",ans[i]);
}
printf("\n");
}
看的视频链接:膜大佬