题目链接

牛牛的糖果树

题目描述

给定一棵 个节点的树,根节点为 1。每个节点上都有一颗带颜色的糖果。

我们可以选择任意一个节点 ,吃掉其为根的整个子树中的糖果。但在吃之前,需要遵循一个规则:找出子树中出现次数最多的颜色(可能有多种),并将所有这些颜色的糖果全部扔掉。

我们的目标是计算,在扔掉糖果后,一次能吃到的所有剩余糖果的颜色值的异或和最大是多少。注意,如果一种颜色的糖果有多个,在计算异或和时也要计算多次。

解题思路

题目要求我们对树上的每一个子树都进行一次查询,并找出最优解。一个朴素的想法是,遍历每一个节点 ,然后对以 为根的子树进行一次完整的遍历,统计颜色频率、计算异或和,最后求出答案。这种方法的总时间复杂度为 ,对于 的数据规模来说太慢了。

为了优化这个过程,我们需要一种更高效的算法来处理子树查询,这就是 DSU on Tree (在国内也被称为“树上启发式合并”或 Sack 算法)。

DSU on Tree 核心思想

DSU on Tree 是一种通过优化暴力来处理树上子树查询问题的算法。其核心思想是:当计算完一个节点 的所有子树后,我们将其中一棵子树的计算结果(例如颜色频率等信息)直接“继承”过来,然后只将其他子树(以及节点 本身)的信息暴力合并进去。

为了让暴力合并的总代价最小,我们每次选择继承重儿子(subtree size 最大的儿子)的信息,并将所有轻儿子(非重儿子的其他儿子)的信息暴力合并。可以证明,通过这种方式,每个节点被暴力合并的次数不会超过 次,从而将总时间复杂度优化到

算法步骤

  1. 预处理 (DFS Pass 1):

    • 我们需要一次 DFS 来建立父子关系,并计算每个节点的 subtree_size
    • 在计算完一个节点所有儿子的 subtree_size 后,我们可以确定它的重儿子(size 最大的那个儿子)。
  2. 计算答案 (DFS Pass 2):

    • 这是 DSU on Tree 的核心。我们进行另一次 DFS,函数 solve(u, keep)keep 是一个布尔值,表示在处理完节点 后是否保留其子树的统计信息。
    • 处理轻儿子: 对 的所有轻儿子 v,递归调用 solve(v, false)false 表示处理完 v 的子树后,清空其统计信息。
    • 处理重儿子: 对 的重儿子 hc,递归调用 solve(hc, true)true 表示处理完 hc 的子树后,保留其统计信息。这样,当前全局的统计信息就正好是 hc 子树的信息。
    • 合并轻儿子和当前节点: 现在,我们将节点 以及它所有轻儿子子树的信息暴力合并到当前的全局统计信息中。
    • 计算 的答案: 合并完成后,全局统计信息就代表了整个 子树的信息。此时我们根据这些信息计算出选择子树 时的答案,并更新全局最大值。
    • 清理: 如果 keepfalse(即 是一个轻儿子),我们需要清空刚刚为 子树计算出的所有统计信息,以确保不影响其兄弟节点的计算。

维护统计信息

为了在合并和计算时效率更高,我们需要维护以下几个关键信息:

  • freq[c]: 颜色 c 的出现次数。
  • count_of_counts[k]: 出现次数为 k 的颜色有多少种。
  • max_freq: 当前出现次数的最大值。
  • total_xor_sum: 当前子树中所有颜色(未剔除前)的总异或和。
  • xor_sum_by_freq[k]: 所有出现次数为 k 的颜色的异或和。

通过维护这些信息,我们可以在 的时间内添加或删除一个节点,并快速计算出当前子树的答案:

  • 要移除的颜色的异或和 xor_to_remove:如果 max_freq 是奇数,则为 xor_sum_by_freq[max_freq];如果 max_freq 是偶数,则为 0(因为 c ^ c ^ ... ^ c (偶数次) = 0)。
  • 最终答案为 total_xor_sum ^ xor_to_remove

代码

#include <iostream>
#include <vector>
#include <map>
#include <algorithm>

using namespace std;

vector<int> adj[100005];
int color[100005];
int sz[100005];
int heavy_child[100005];
long long global_max_ans = 0;

// Global state for DSU on Tree
map<int, int> freq;
map<int, int> count_of_counts;
map<int, long long> xor_sum_by_freq;
int max_freq = 0;
long long total_xor_sum = 0;

void dfs_size(int u, int p) {
    sz[u] = 1;
    int max_sz = 0;
    heavy_child[u] = -1;
    for (int v : adj[u]) {
        if (v == p) continue;
        dfs_size(v, u);
        sz[u] += sz[v];
        if (sz[v] > max_sz) {
            max_sz = sz[v];
            heavy_child[u] = v;
        }
    }
}

void update_node(int c, int op) {
    total_xor_sum ^= c;

    int old_f = freq[c];
    if (old_f > 0) {
        count_of_counts[old_f]--;
        xor_sum_by_freq[old_f] ^= c;
        if (count_of_counts[old_f] == 0 && old_f == max_freq) {
            max_freq--;
        }
    }

    int new_f = old_f + op;
    freq[c] = new_f;
    if (new_f > 0) {
        count_of_counts[new_f]++;
        xor_sum_by_freq[new_f] ^= c;
        if (new_f > max_freq) {
            max_freq = new_f;
        }
    }
}

void update_subtree(int u, int p, int op) {
    update_node(color[u], op);
    for (int v : adj[u]) {
        if (v != p) {
            update_subtree(v, u, op);
        }
    }
}

void solve(int u, int p, bool keep) {
    for (int v : adj[u]) {
        if (v != p && v != heavy_child[u]) {
            solve(v, u, false);
        }
    }

    if (heavy_child[u] != -1) {
        solve(heavy_child[u], u, true);
    }

    update_node(color[u], 1);
    for (int v : adj[u]) {
        if (v != p && v != heavy_child[u]) {
            update_subtree(v, u, 1);
        }
    }

    long long xor_to_remove = 0;
    if (max_freq > 0 && (max_freq % 2) != 0) {
        xor_to_remove = xor_sum_by_freq[max_freq];
    }
    global_max_ans = max(global_max_ans, total_xor_sum ^ xor_to_remove);

    if (!keep) {
        update_subtree(u, p, -1);
    }
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    cin >> n;
    for (int i = 1; i <= n; ++i) {
        cin >> color[i];
    }
    for (int i = 0; i < n - 1; ++i) {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    dfs_size(1, 0);
    solve(1, 0, false);

    cout << global_max_ans << endl;

    return 0;
}
import java.util.*;

public class Main {
    static List<Integer>[] adj;
    static int[] color;
    static int[] sz;
    static int[] heavyChild;
    static long globalMaxAns = 0;

    // Global state for DSU on Tree
    static Map<Integer, Integer> freq = new HashMap<>();
    static Map<Integer, Integer> countOfCounts = new HashMap<>();
    static Map<Integer, Long> xorSumByFreq = new HashMap<>();
    static int maxFreq = 0;
    static long totalXorSum = 0;

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();

        color = new int[n + 1];
        adj = new ArrayList[n + 1];
        sz = new int[n + 1];
        heavyChild = new int[n + 1];

        for (int i = 1; i <= n; i++) {
            color[i] = sc.nextInt();
            adj[i] = new ArrayList<>();
        }

        for (int i = 0; i < n - 1; i++) {
            int u = sc.nextInt();
            int v = sc.nextInt();
            adj[u].add(v);
            adj[v].add(u);
        }

        dfsSize(1, 0);
        solve(1, 0, false);

        System.out.println(globalMaxAns);
    }

    static void dfsSize(int u, int p) {
        sz[u] = 1;
        int maxSz = 0;
        heavyChild[u] = -1;
        for (int v : adj[u]) {
            if (v == p) continue;
            // Hack to remove parent from adjacency list to make it a directed tree
            adj[v].remove(Integer.valueOf(u));
            dfsSize(v, u);
            sz[u] += sz[v];
            if (sz[v] > maxSz) {
                maxSz = sz[v];
                heavyChild[u] = v;
            }
        }
    }
    
    static void updateNode(int c, int op) {
        totalXorSum ^= c;

        int oldF = freq.getOrDefault(c, 0);
        if (oldF > 0) {
            countOfCounts.put(oldF, countOfCounts.get(oldF) - 1);
            xorSumByFreq.put(oldF, xorSumByFreq.getOrDefault(oldF, 0L) ^ c);
            if (countOfCounts.get(oldF) == 0 && oldF == maxFreq) {
                maxFreq--;
            }
        }

        int newF = oldF + op;
        freq.put(c, newF);
        if (newF > 0) {
            countOfCounts.put(newF, countOfCounts.getOrDefault(newF, 0) + 1);
            xorSumByFreq.put(newF, xorSumByFreq.getOrDefault(newF, 0L) ^ c);
            if (newF > maxFreq) {
                maxFreq = newF;
            }
        }
    }

    static void updateSubtree(int u, int op) {
        updateNode(color[u], op);
        for (int v : adj[u]) {
            updateSubtree(v, op);
        }
    }

    static void solve(int u, int p, boolean keep) {
        for (int v : adj[u]) {
            if (v != heavyChild[u]) {
                solve(v, u, false);
            }
        }

        if (heavyChild[u] != -1) {
            solve(heavyChild[u], u, true);
        }

        updateNode(color[u], 1);
        for (int v : adj[u]) {
            if (v != heavyChild[u]) {
                updateSubtree(v, 1);
            }
        }

        long xorToRemove = 0;
        if (maxFreq > 0 && (maxFreq % 2) != 0) {
            xorToRemove = xorSumByFreq.getOrDefault(maxFreq, 0L);
        }
        globalMaxAns = Math.max(globalMaxAns, totalXorSum ^ xorToRemove);

        if (!keep) {
            updateSubtree(u, -1);
        }
    }
}
import sys
from collections import defaultdict

# It's recommended to increase recursion limit for deep trees in Python
sys.setrecursionlimit(200005)

def solve():
    n_str = sys.stdin.readline()
    if not n_str: return
    n = int(n_str)
    
    colors_input = list(map(int, sys.stdin.readline().split()))
    color = [0] * (n + 1)
    for i in range(n):
        color[i+1] = colors_input[i]

    adj = [[] for _ in range(n + 1)]
    for _ in range(n - 1):
        u, v = map(int, sys.stdin.readline().split())
        adj[u].append(v)
        adj[v].append(u)

    sz = [0] * (n + 1)
    heavy_child = [-1] * (n + 1)
    parent = [0] * (n + 1)
    
    # Global state
    global_max_ans = 0
    freq = defaultdict(int)
    count_of_counts = defaultdict(int)
    xor_sum_by_freq = defaultdict(int)
    max_freq = 0
    total_xor_sum = 0
    
    def dfs_size(u, p):
        sz[u] = 1
        parent[u] = p
        max_sz = 0
        heavy_child[u] = -1
        
        children_to_remove = []
        for i, v in enumerate(adj[u]):
            if v == p:
                children_to_remove.append(i)
                continue
            dfs_size(v, u)
            sz[u] += sz[v]
            if sz[v] > max_sz:
                max_sz = sz[v]
                heavy_child[u] = v
        # Remove parent from adj list to make it a directed tree
        for i in sorted(children_to_remove, reverse=True):
             adj[u].pop(i)

    def update_node(c, op):
        nonlocal max_freq, total_xor_sum
        total_xor_sum ^= c
        
        old_f = freq[c]
        if old_f > 0:
            count_of_counts[old_f] -= 1
            xor_sum_by_freq[old_f] ^= c
            if count_of_counts[old_f] == 0 and old_f == max_freq:
                max_freq -= 1
        
        new_f = old_f + op
        freq[c] = new_f
        if new_f > 0:
            count_of_counts[new_f] += 1
            xor_sum_by_freq[new_f] ^= c
            if new_f > max_freq:
                max_freq = new_f

    def update_subtree(u, op):
        update_node(color[u], op)
        for v in adj[u]:
            update_subtree(v, op)

    def dsu_solve(u, keep):
        nonlocal global_max_ans
        
        for v in adj[u]:
            if v != heavy_child[u]:
                dsu_solve(v, False)

        if heavy_child[u] != -1:
            dsu_solve(heavy_child[u], True)

        update_node(color[u], 1)
        for v in adj[u]:
            if v != heavy_child[u]:
                update_subtree(v, 1)

        xor_to_remove = 0
        if max_freq > 0 and (max_freq % 2) != 0:
            xor_to_remove = xor_sum_by_freq[max_freq]
        
        current_ans = total_xor_sum ^ xor_to_remove
        global_max_ans = max(global_max_ans, current_ans)

        if not keep:
            update_subtree(u, -1)

    dfs_size(1, 0)
    dsu_solve(1, False)
    print(global_max_ans)

solve()

算法及复杂度

  • 算法:DSU on Tree (树上启发式合并 / Sack)
  • 时间复杂度。预处理 DFS 是 。在核心的 solve 函数中,由于重链剖分的性质,每个节点最多位于 条轻边路径上,因此每个节点被作为轻子树暴力合并的次数是 次。每次合并操作(添加/删除节点)是 map 操作平均为 ,其中 K 是 map 大小)。总复杂度为
  • 空间复杂度。用于存储树、颜色、子树大小以及 DSU on Tree 算法中所需的各种 map。在最坏情况下(所有颜色都不同),map 的大小可能达到