牛牛的糖果树
题意
一棵 个节点的有根树(根为
),每个节点有一个颜色
。牛牛选择一棵子树,先扔掉其中出现次数最多的颜色的所有糖果(若有多个颜色并列最多则全部扔掉),然后吃剩下的糖果。吃到的糖果颜色依次异或(同色多个糖果也逐个异或),求所有子树选择中异或和的最大值。
思路
DSU on tree(树上启发式合并)
先抓住一个关键性质:颜色 出现
次时,异或贡献为
(
次),等于
(
为奇)或
(
为偶)。
因此,对于一棵子树:
- 设
= 所有节点颜色的异或和
- 设
= 出现次数最多的颜色的频率
- 若
为偶数,则被扔掉的颜色异或贡献为
,答案就是
- 若
为奇数,则被扔掉的颜色异或贡献为「所有频率恰好等于
的不同颜色的异或」,从
中异或掉即可
所以对每棵子树我们需要维护三样东西:
cnt[c]:颜色的出现次数
color_xor_at_freq[f]:频率恰为的所有不同颜色的异或
num_at_freq[f]:频率恰为的不同颜色个数(用于维护
)
每次增删一个节点时,更新上述三个结构, 即可。
暴力枚举每棵子树统计这些信息需要 。使用 DSU on tree 优化:按重儿子保留、轻儿子暴力加回的策略,总复杂度降为
。
代码
#include <bits/stdc++.h>
using namespace std;
int main(){
int n;
scanf("%d", &n);
vector<int> col(n+1);
for(int i = 1; i <= n; i++) scanf("%d", &col[i]);
vector<vector<int>> adj(n+1);
for(int i = 0; i < n-1; i++){
int u, v;
scanf("%d%d", &u, &v);
adj[u].push_back(v);
adj[v].push_back(u);
}
vector<int> sz(n+1,0), heavy(n+1,-1), par(n+1,0);
vector<int> order;
order.reserve(n);
vector<bool> vis(n+1, false);
queue<int> q;
q.push(1); vis[1] = true;
while(!q.empty()){
int u = q.front(); q.pop();
order.push_back(u);
for(int v : adj[u]) if(!vis[v]){
vis[v] = true; par[v] = u; q.push(v);
}
}
for(int i = (int)order.size()-1; i >= 0; i--){
int u = order[i]; sz[u] = 1;
int mx = 0;
for(int v : adj[u]) if(v != par[u]){
sz[u] += sz[v];
if(sz[v] > mx){ mx = sz[v]; heavy[u] = v; }
}
}
unordered_map<int,int> cnt, cxf, naf;
int txor = 0, mf = 0, ans = 0;
auto add = [&](int u){
int c = col[u], of = cnt[c]++, nf = of+1;
txor ^= c;
if(of > 0){ cxf[of] ^= c; naf[of]--; }
cxf[nf] ^= c; naf[nf]++;
if(nf > mf) mf = nf;
};
auto rem = [&](int u){
int c = col[u], of = cnt[c]--, nf = of-1;
txor ^= c;
cxf[of] ^= c; naf[of]--;
if(nf > 0){ cxf[nf] ^= c; naf[nf]++; }
while(mf > 0 && naf[mf] == 0) mf--;
};
auto getSub = [&](int root, auto& self) -> vector<int> {
vector<int> res;
stack<int> stk; stk.push(root);
while(!stk.empty()){
int u = stk.top(); stk.pop();
res.push_back(u);
for(int v : adj[u]) if(v != par[u]) stk.push(v);
}
return res;
};
function<void(int,bool)> dfs = [&](int u, bool keep){
for(int v : adj[u]) if(v != par[u] && v != heavy[u]) dfs(v, false);
if(heavy[u] != -1) dfs(heavy[u], true);
add(u);
for(int v : adj[u]) if(v != par[u] && v != heavy[u])
for(int x : getSub(v, getSub)) add(x);
int rx = (mf % 2 == 1) ? cxf[mf] : 0;
ans = max(ans, txor ^ rx);
if(!keep) for(int x : getSub(u, getSub)) rem(x);
};
dfs(1, false);
printf("%d\n", ans);
}
import java.util.*;
import java.io.*;
public class Main {
static int[] col, par, sz, heavy;
static List<List<Integer>> adj;
static HashMap<Integer,Integer> cnt = new HashMap<>(), cxf = new HashMap<>(), naf = new HashMap<>();
static int txor = 0, mf = 0, ans = 0;
public static void main(String[] args) throws Exception {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
int n = Integer.parseInt(br.readLine().trim());
StringTokenizer st = new StringTokenizer(br.readLine().trim());
col = new int[n+1];
for (int i = 1; i <= n; i++) col[i] = Integer.parseInt(st.nextToken());
adj = new ArrayList<>();
for (int i = 0; i <= n; i++) adj.add(new ArrayList<>());
for (int i = 0; i < n-1; i++) {
st = new StringTokenizer(br.readLine().trim());
int u = Integer.parseInt(st.nextToken()), v = Integer.parseInt(st.nextToken());
adj.get(u).add(v); adj.get(v).add(u);
}
par = new int[n+1]; sz = new int[n+1]; heavy = new int[n+1];
Arrays.fill(heavy, -1);
int[] order = new int[n];
boolean[] vis = new boolean[n+1];
int head = 0, tail = 0;
order[tail++] = 1; vis[1] = true;
while (head < tail) {
int u = order[head++];
for (int v : adj.get(u)) if (!vis[v]) { vis[v] = true; par[v] = u; order[tail++] = v; }
}
for (int i = n-1; i >= 0; i--) {
int u = order[i]; sz[u] = 1; int mx = 0;
for (int v : adj.get(u)) if (v != par[u]) { sz[u] += sz[v]; if (sz[v] > mx) { mx = sz[v]; heavy[u] = v; } }
}
dfs(1, false);
System.out.println(ans);
}
static void add(int u) {
int c = col[u], of = cnt.getOrDefault(c,0), nf = of+1;
cnt.put(c, nf); txor ^= c;
if (of > 0) { cxf.merge(of, c, (a,b)->a^b); naf.merge(of, -1, Integer::sum); }
cxf.merge(nf, c, (a,b)->a^b); naf.merge(nf, 1, Integer::sum);
if (nf > mf) mf = nf;
}
static void rem(int u) {
int c = col[u], of = cnt.get(c), nf = of-1;
cnt.put(c, nf); txor ^= c;
cxf.merge(of, c, (a,b)->a^b); naf.merge(of, -1, Integer::sum);
if (nf > 0) { cxf.merge(nf, c, (a,b)->a^b); naf.merge(nf, 1, Integer::sum); }
while (mf > 0 && naf.getOrDefault(mf,0) == 0) mf--;
}
static List<Integer> getSub(int root) {
List<Integer> res = new ArrayList<>();
Deque<Integer> stk = new ArrayDeque<>(); stk.push(root);
while (!stk.isEmpty()) {
int u = stk.pop(); res.add(u);
for (int v : adj.get(u)) if (v != par[u]) stk.push(v);
}
return res;
}
static void dfs(int u, boolean keep) {
for (int v : adj.get(u)) if (v != par[u] && v != heavy[u]) dfs(v, false);
if (heavy[u] != -1) dfs(heavy[u], true);
add(u);
for (int v : adj.get(u)) if (v != par[u] && v != heavy[u]) for (int x : getSub(v)) add(x);
int rx = (mf % 2 == 1) ? cxf.getOrDefault(mf,0) : 0;
ans = Math.max(ans, txor ^ rx);
if (!keep) for (int x : getSub(u)) rem(x);
}
}
复杂度
- 时间复杂度:
。DSU on tree 保证每个节点被加入/删除
次,每次操作
。
- 空间复杂度:
,存储树结构、颜色频率哈希表等。

京公网安备 11010502036488号