题目链接
题目描述
给定一棵 个节点的树,根节点为 1。每个节点上都有一颗带颜色的糖果。
我们可以选择任意一个节点 ,吃掉其为根的整个子树中的糖果。但在吃之前,需要遵循一个规则:找出子树中出现次数最多的颜色(可能有多种),并将所有这些颜色的糖果全部扔掉。
我们的目标是计算,在扔掉糖果后,一次能吃到的所有剩余糖果的颜色值的异或和最大是多少。注意,如果一种颜色的糖果有多个,在计算异或和时也要计算多次。
解题思路
题目要求我们对树上的每一个子树都进行一次查询,并找出最优解。一个朴素的想法是,遍历每一个节点 ,然后对以
为根的子树进行一次完整的遍历,统计颜色频率、计算异或和,最后求出答案。这种方法的总时间复杂度为
,对于
的数据规模来说太慢了。
为了优化这个过程,我们需要一种更高效的算法来处理子树查询,这就是 DSU on Tree (在国内也被称为“树上启发式合并”或 Sack 算法)。
DSU on Tree 核心思想
DSU on Tree 是一种通过优化暴力来处理树上子树查询问题的算法。其核心思想是:当计算完一个节点 的所有子树后,我们将其中一棵子树的计算结果(例如颜色频率等信息)直接“继承”过来,然后只将其他子树(以及节点
本身)的信息暴力合并进去。
为了让暴力合并的总代价最小,我们每次选择继承重儿子(subtree size 最大的儿子)的信息,并将所有轻儿子(非重儿子的其他儿子)的信息暴力合并。可以证明,通过这种方式,每个节点被暴力合并的次数不会超过 次,从而将总时间复杂度优化到
。
算法步骤
-
预处理 (DFS Pass 1):
- 我们需要一次 DFS 来建立父子关系,并计算每个节点的
subtree_size
。 - 在计算完一个节点所有儿子的
subtree_size
后,我们可以确定它的重儿子(size 最大的那个儿子)。
- 我们需要一次 DFS 来建立父子关系,并计算每个节点的
-
计算答案 (DFS Pass 2):
- 这是 DSU on Tree 的核心。我们进行另一次 DFS,函数
solve(u, keep)
中keep
是一个布尔值,表示在处理完节点后是否保留其子树的统计信息。
- 处理轻儿子: 对
的所有轻儿子
v
,递归调用solve(v, false)
。false
表示处理完v
的子树后,清空其统计信息。 - 处理重儿子: 对
的重儿子
hc
,递归调用solve(hc, true)
。true
表示处理完hc
的子树后,保留其统计信息。这样,当前全局的统计信息就正好是hc
子树的信息。 - 合并轻儿子和当前节点: 现在,我们将节点
以及它所有轻儿子子树的信息暴力合并到当前的全局统计信息中。
- 计算
的答案: 合并完成后,全局统计信息就代表了整个
子树的信息。此时我们根据这些信息计算出选择子树
时的答案,并更新全局最大值。
- 清理: 如果
keep
为false
(即是一个轻儿子),我们需要清空刚刚为
子树计算出的所有统计信息,以确保不影响其兄弟节点的计算。
- 这是 DSU on Tree 的核心。我们进行另一次 DFS,函数
维护统计信息
为了在合并和计算时效率更高,我们需要维护以下几个关键信息:
freq[c]
: 颜色c
的出现次数。count_of_counts[k]
: 出现次数为k
的颜色有多少种。max_freq
: 当前出现次数的最大值。total_xor_sum
: 当前子树中所有颜色(未剔除前)的总异或和。xor_sum_by_freq[k]
: 所有出现次数为k
的颜色的异或和。
通过维护这些信息,我们可以在 的时间内添加或删除一个节点,并快速计算出当前子树的答案:
- 要移除的颜色的异或和
xor_to_remove
:如果max_freq
是奇数,则为xor_sum_by_freq[max_freq]
;如果max_freq
是偶数,则为 0(因为c ^ c ^ ... ^ c
(偶数次) = 0)。 - 最终答案为
total_xor_sum ^ xor_to_remove
。
代码
#include <iostream>
#include <vector>
#include <map>
#include <algorithm>
using namespace std;
vector<int> adj[100005];
int color[100005];
int sz[100005];
int heavy_child[100005];
long long global_max_ans = 0;
// Global state for DSU on Tree
map<int, int> freq;
map<int, int> count_of_counts;
map<int, long long> xor_sum_by_freq;
int max_freq = 0;
long long total_xor_sum = 0;
void dfs_size(int u, int p) {
sz[u] = 1;
int max_sz = 0;
heavy_child[u] = -1;
for (int v : adj[u]) {
if (v == p) continue;
dfs_size(v, u);
sz[u] += sz[v];
if (sz[v] > max_sz) {
max_sz = sz[v];
heavy_child[u] = v;
}
}
}
void update_node(int c, int op) {
total_xor_sum ^= c;
int old_f = freq[c];
if (old_f > 0) {
count_of_counts[old_f]--;
xor_sum_by_freq[old_f] ^= c;
if (count_of_counts[old_f] == 0 && old_f == max_freq) {
max_freq--;
}
}
int new_f = old_f + op;
freq[c] = new_f;
if (new_f > 0) {
count_of_counts[new_f]++;
xor_sum_by_freq[new_f] ^= c;
if (new_f > max_freq) {
max_freq = new_f;
}
}
}
void update_subtree(int u, int p, int op) {
update_node(color[u], op);
for (int v : adj[u]) {
if (v != p) {
update_subtree(v, u, op);
}
}
}
void solve(int u, int p, bool keep) {
for (int v : adj[u]) {
if (v != p && v != heavy_child[u]) {
solve(v, u, false);
}
}
if (heavy_child[u] != -1) {
solve(heavy_child[u], u, true);
}
update_node(color[u], 1);
for (int v : adj[u]) {
if (v != p && v != heavy_child[u]) {
update_subtree(v, u, 1);
}
}
long long xor_to_remove = 0;
if (max_freq > 0 && (max_freq % 2) != 0) {
xor_to_remove = xor_sum_by_freq[max_freq];
}
global_max_ans = max(global_max_ans, total_xor_sum ^ xor_to_remove);
if (!keep) {
update_subtree(u, p, -1);
}
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n;
cin >> n;
for (int i = 1; i <= n; ++i) {
cin >> color[i];
}
for (int i = 0; i < n - 1; ++i) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
dfs_size(1, 0);
solve(1, 0, false);
cout << global_max_ans << endl;
return 0;
}
import java.util.*;
public class Main {
static List<Integer>[] adj;
static int[] color;
static int[] sz;
static int[] heavyChild;
static long globalMaxAns = 0;
// Global state for DSU on Tree
static Map<Integer, Integer> freq = new HashMap<>();
static Map<Integer, Integer> countOfCounts = new HashMap<>();
static Map<Integer, Long> xorSumByFreq = new HashMap<>();
static int maxFreq = 0;
static long totalXorSum = 0;
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
color = new int[n + 1];
adj = new ArrayList[n + 1];
sz = new int[n + 1];
heavyChild = new int[n + 1];
for (int i = 1; i <= n; i++) {
color[i] = sc.nextInt();
adj[i] = new ArrayList<>();
}
for (int i = 0; i < n - 1; i++) {
int u = sc.nextInt();
int v = sc.nextInt();
adj[u].add(v);
adj[v].add(u);
}
dfsSize(1, 0);
solve(1, 0, false);
System.out.println(globalMaxAns);
}
static void dfsSize(int u, int p) {
sz[u] = 1;
int maxSz = 0;
heavyChild[u] = -1;
for (int v : adj[u]) {
if (v == p) continue;
// Hack to remove parent from adjacency list to make it a directed tree
adj[v].remove(Integer.valueOf(u));
dfsSize(v, u);
sz[u] += sz[v];
if (sz[v] > maxSz) {
maxSz = sz[v];
heavyChild[u] = v;
}
}
}
static void updateNode(int c, int op) {
totalXorSum ^= c;
int oldF = freq.getOrDefault(c, 0);
if (oldF > 0) {
countOfCounts.put(oldF, countOfCounts.get(oldF) - 1);
xorSumByFreq.put(oldF, xorSumByFreq.getOrDefault(oldF, 0L) ^ c);
if (countOfCounts.get(oldF) == 0 && oldF == maxFreq) {
maxFreq--;
}
}
int newF = oldF + op;
freq.put(c, newF);
if (newF > 0) {
countOfCounts.put(newF, countOfCounts.getOrDefault(newF, 0) + 1);
xorSumByFreq.put(newF, xorSumByFreq.getOrDefault(newF, 0L) ^ c);
if (newF > maxFreq) {
maxFreq = newF;
}
}
}
static void updateSubtree(int u, int op) {
updateNode(color[u], op);
for (int v : adj[u]) {
updateSubtree(v, op);
}
}
static void solve(int u, int p, boolean keep) {
for (int v : adj[u]) {
if (v != heavyChild[u]) {
solve(v, u, false);
}
}
if (heavyChild[u] != -1) {
solve(heavyChild[u], u, true);
}
updateNode(color[u], 1);
for (int v : adj[u]) {
if (v != heavyChild[u]) {
updateSubtree(v, 1);
}
}
long xorToRemove = 0;
if (maxFreq > 0 && (maxFreq % 2) != 0) {
xorToRemove = xorSumByFreq.getOrDefault(maxFreq, 0L);
}
globalMaxAns = Math.max(globalMaxAns, totalXorSum ^ xorToRemove);
if (!keep) {
updateSubtree(u, -1);
}
}
}
import sys
from collections import defaultdict
# It's recommended to increase recursion limit for deep trees in Python
sys.setrecursionlimit(200005)
def solve():
n_str = sys.stdin.readline()
if not n_str: return
n = int(n_str)
colors_input = list(map(int, sys.stdin.readline().split()))
color = [0] * (n + 1)
for i in range(n):
color[i+1] = colors_input[i]
adj = [[] for _ in range(n + 1)]
for _ in range(n - 1):
u, v = map(int, sys.stdin.readline().split())
adj[u].append(v)
adj[v].append(u)
sz = [0] * (n + 1)
heavy_child = [-1] * (n + 1)
parent = [0] * (n + 1)
# Global state
global_max_ans = 0
freq = defaultdict(int)
count_of_counts = defaultdict(int)
xor_sum_by_freq = defaultdict(int)
max_freq = 0
total_xor_sum = 0
def dfs_size(u, p):
sz[u] = 1
parent[u] = p
max_sz = 0
heavy_child[u] = -1
children_to_remove = []
for i, v in enumerate(adj[u]):
if v == p:
children_to_remove.append(i)
continue
dfs_size(v, u)
sz[u] += sz[v]
if sz[v] > max_sz:
max_sz = sz[v]
heavy_child[u] = v
# Remove parent from adj list to make it a directed tree
for i in sorted(children_to_remove, reverse=True):
adj[u].pop(i)
def update_node(c, op):
nonlocal max_freq, total_xor_sum
total_xor_sum ^= c
old_f = freq[c]
if old_f > 0:
count_of_counts[old_f] -= 1
xor_sum_by_freq[old_f] ^= c
if count_of_counts[old_f] == 0 and old_f == max_freq:
max_freq -= 1
new_f = old_f + op
freq[c] = new_f
if new_f > 0:
count_of_counts[new_f] += 1
xor_sum_by_freq[new_f] ^= c
if new_f > max_freq:
max_freq = new_f
def update_subtree(u, op):
update_node(color[u], op)
for v in adj[u]:
update_subtree(v, op)
def dsu_solve(u, keep):
nonlocal global_max_ans
for v in adj[u]:
if v != heavy_child[u]:
dsu_solve(v, False)
if heavy_child[u] != -1:
dsu_solve(heavy_child[u], True)
update_node(color[u], 1)
for v in adj[u]:
if v != heavy_child[u]:
update_subtree(v, 1)
xor_to_remove = 0
if max_freq > 0 and (max_freq % 2) != 0:
xor_to_remove = xor_sum_by_freq[max_freq]
current_ans = total_xor_sum ^ xor_to_remove
global_max_ans = max(global_max_ans, current_ans)
if not keep:
update_subtree(u, -1)
dfs_size(1, 0)
dsu_solve(1, False)
print(global_max_ans)
solve()
算法及复杂度
- 算法:DSU on Tree (树上启发式合并 / Sack)
- 时间复杂度:
。预处理 DFS 是
。在核心的
solve
函数中,由于重链剖分的性质,每个节点最多位于条轻边路径上,因此每个节点被作为轻子树暴力合并的次数是
次。每次合并操作(添加/删除节点)是
(
map
操作平均为或
,其中 K 是 map 大小)。总复杂度为
。
- 空间复杂度:
。用于存储树、颜色、子树大小以及 DSU on Tree 算法中所需的各种
map
。在最坏情况下(所有颜色都不同),map
的大小可能达到。