Lomsat gelral

之前没有记录过dsu on tree,挑了一个板子题记录一下

DSU on tree(Disjoint Set Union,树上启发式合并)

  • 思想:利用每个节点到根节点路径上的轻边数复杂度是 l o g log log级别的,同时只有当每次遍历轻边时才会将整棵轻边连接的子树额外暴力遍历一遍,所以每个节点被暴力遍历的次数就是 l o g log log级别的,这样就保证了整体复杂度在 l o g log log级别啦!(默认点修改为O(1))
  • 三种 d f s dfs dfs
    • dfs0:处理出重儿子
    • dfs1:对子树进行暴力操作,分为两种状态,一种在增添信息时使用,一种在删除信息时使用
    • dfs2:dsu on tree主体
  • 讲讲dfs2,它也分为两种状态,一种是保留信息的,一种是不保留信息的。大致框架:
    • 对每个节点,先处理所有非重儿子节点,并且不保留信息
    • 然后处理重儿子,并且保留信息
    • 再暴力增添所有非重儿子信息,同时加上当前节点本身的信息,这样就得到当前节点答案啦!
    • 最后若此dfs2是“不保留型”,则用dfs1暴力删除刚刚所有的信息

题意:

给了一个带颜色的树,对每个节点,求出占领它的颜色的权值和。占领的意思就是在这个节点的子树中某种颜色出现次数最多(可以有多个)。

思路:标准dsu on tree板子题

全局维护每种颜色数量,某种数量颜色的权值和,以及最大数量颜色的数量是多少(有点点绕)即可。

附:一篇dsu on tree入门文章

代码

#include "bits/stdc++.h"
#define hhh printf("hhh\n")
#define see(x) (cerr<<(#x)<<'='<<(x)<<endl)
using namespace std;
typedef long long ll;
typedef pair<int,int> pr;
inline int read() {int x=0,f=1;char c=getchar();while(c!='-'&&(c<'0'||c>'9'))c=getchar();if(c=='-')f=-1,c=getchar();while(c>='0'&&c<='9')x=x*10+c-'0',c=getchar();return f*x;}

const int maxn = 3e5+7;
const int inf = 0x3f3f3f3f;
const int mod = 1e9+7;

int n;
int c[maxn], num[maxn], mx;
int head[maxn], to[maxn], nxt[maxn], tot;
int son[maxn], sz[maxn];
ll sum[maxn], ans[maxn];

inline void add_edge(int u, int v) {
    ++tot; to[tot]=v; nxt[tot]=head[u]; head[u]=tot;
    ++tot; to[tot]=u; nxt[tot]=head[v]; head[v]=tot;
}

void dfs0(int u, int f) {
    sz[u]=1;
    for(int i=head[u]; i; i=nxt[i]) {
        int v=to[i]; if(v==f) continue;
        dfs0(v,u); sz[u]+=sz[v];
        if(sz[v]>sz[son[u]]) son[u]=v;
    }
}

void dfs1(int u, int f, int ff) {
    if(ff) {
        num[c[u]]++;
        sum[num[c[u]]-1]-=c[u], sum[num[c[u]]]+=c[u];
        if(sum[mx+1]) mx++;
    }
    else {
        num[c[u]]--;
        sum[num[c[u]]+1]-=c[u], sum[num[c[u]]]+=c[u];
        if(!sum[mx]) mx--;
    }
    for(int i=head[u]; i; i=nxt[i]) {
        int v=to[i]; if(v==f) continue;
        dfs1(v,u,ff);
    }
}

void dfs2(int u, int f, int keep) {
    for(int i=head[u]; i; i=nxt[i]) {
        int v=to[i]; if(v==f||v==son[u]) continue;
        dfs2(v,u,0);
    }
    if(son[u]) dfs2(son[u],u,1);
    num[c[u]]++;
    sum[num[c[u]]-1]-=c[u], sum[num[c[u]]]+=c[u];
    if(sum[mx+1]) mx++;
    for(int i=head[u]; i; i=nxt[i]) {
        int v=to[i]; if(v==f||v==son[u]) continue;
        dfs1(v,u,1);
    }
    ans[u]=sum[mx];
    if(!keep) {
        num[c[u]]--;
        sum[num[c[u]]+1]-=c[u], sum[num[c[u]]]+=c[u];
        if(!sum[mx]) mx--;
        for(int i=head[u]; i; i=nxt[i]) {
            int v=to[i]; if(v==f) continue;
            dfs1(v,u,0);
        }
    }
}

int main() {
    n=read();
    for(int i=1; i<=n; ++i) c[i]=read();
    for(int i=1; i<n; ++i) add_edge(read(),read());
    dfs0(1,0);
    dfs2(1,0,1);
    for(int i=1; i<=n; ++i) printf("%lld%c", ans[i], " \n"[i==n]);
}