牛牛的糖果树

题意

一棵 个节点的有根树(根为 ),每个节点有一个颜色 。牛牛选择一棵子树,先扔掉其中出现次数最多的颜色的所有糖果(若有多个颜色并列最多则全部扔掉),然后吃剩下的糖果。吃到的糖果颜色依次异或(同色多个糖果也逐个异或),求所有子树选择中异或和的最大值。

思路

DSU on tree(树上启发式合并)

先抓住一个关键性质:颜色 出现 次时,异或贡献为 次),等于 为奇)或 为偶)。

因此,对于一棵子树:

  • = 所有节点颜色的异或和
  • = 出现次数最多的颜色的频率
  • 为偶数,则被扔掉的颜色异或贡献为 ,答案就是
  • 为奇数,则被扔掉的颜色异或贡献为「所有频率恰好等于 的不同颜色的异或」,从 中异或掉即可

所以对每棵子树我们需要维护三样东西:

  1. cnt[c]:颜色 的出现次数
  2. color_xor_at_freq[f]:频率恰为 的所有不同颜色的异或
  3. num_at_freq[f]:频率恰为 的不同颜色个数(用于维护

每次增删一个节点时,更新上述三个结构, 即可。

暴力枚举每棵子树统计这些信息需要 。使用 DSU on tree 优化:按重儿子保留、轻儿子暴力加回的策略,总复杂度降为

代码

#include <bits/stdc++.h>
using namespace std;
int main(){
    int n;
    scanf("%d", &n);
    vector<int> col(n+1);
    for(int i = 1; i <= n; i++) scanf("%d", &col[i]);
    vector<vector<int>> adj(n+1);
    for(int i = 0; i < n-1; i++){
        int u, v;
        scanf("%d%d", &u, &v);
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    vector<int> sz(n+1,0), heavy(n+1,-1), par(n+1,0);
    vector<int> order;
    order.reserve(n);
    vector<bool> vis(n+1, false);
    queue<int> q;
    q.push(1); vis[1] = true;
    while(!q.empty()){
        int u = q.front(); q.pop();
        order.push_back(u);
        for(int v : adj[u]) if(!vis[v]){
            vis[v] = true; par[v] = u; q.push(v);
        }
    }
    for(int i = (int)order.size()-1; i >= 0; i--){
        int u = order[i]; sz[u] = 1;
        int mx = 0;
        for(int v : adj[u]) if(v != par[u]){
            sz[u] += sz[v];
            if(sz[v] > mx){ mx = sz[v]; heavy[u] = v; }
        }
    }
    unordered_map<int,int> cnt, cxf, naf;
    int txor = 0, mf = 0, ans = 0;
    auto add = [&](int u){
        int c = col[u], of = cnt[c]++, nf = of+1;
        txor ^= c;
        if(of > 0){ cxf[of] ^= c; naf[of]--; }
        cxf[nf] ^= c; naf[nf]++;
        if(nf > mf) mf = nf;
    };
    auto rem = [&](int u){
        int c = col[u], of = cnt[c]--, nf = of-1;
        txor ^= c;
        cxf[of] ^= c; naf[of]--;
        if(nf > 0){ cxf[nf] ^= c; naf[nf]++; }
        while(mf > 0 && naf[mf] == 0) mf--;
    };
    auto getSub = [&](int root, auto& self) -> vector<int> {
        vector<int> res;
        stack<int> stk; stk.push(root);
        while(!stk.empty()){
            int u = stk.top(); stk.pop();
            res.push_back(u);
            for(int v : adj[u]) if(v != par[u]) stk.push(v);
        }
        return res;
    };
    function<void(int,bool)> dfs = [&](int u, bool keep){
        for(int v : adj[u]) if(v != par[u] && v != heavy[u]) dfs(v, false);
        if(heavy[u] != -1) dfs(heavy[u], true);
        add(u);
        for(int v : adj[u]) if(v != par[u] && v != heavy[u])
            for(int x : getSub(v, getSub)) add(x);
        int rx = (mf % 2 == 1) ? cxf[mf] : 0;
        ans = max(ans, txor ^ rx);
        if(!keep) for(int x : getSub(u, getSub)) rem(x);
    };
    dfs(1, false);
    printf("%d\n", ans);
}
import java.util.*;
import java.io.*;

public class Main {
    static int[] col, par, sz, heavy;
    static List<List<Integer>> adj;
    static HashMap<Integer,Integer> cnt = new HashMap<>(), cxf = new HashMap<>(), naf = new HashMap<>();
    static int txor = 0, mf = 0, ans = 0;

    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(br.readLine().trim());
        StringTokenizer st = new StringTokenizer(br.readLine().trim());
        col = new int[n+1];
        for (int i = 1; i <= n; i++) col[i] = Integer.parseInt(st.nextToken());
        adj = new ArrayList<>();
        for (int i = 0; i <= n; i++) adj.add(new ArrayList<>());
        for (int i = 0; i < n-1; i++) {
            st = new StringTokenizer(br.readLine().trim());
            int u = Integer.parseInt(st.nextToken()), v = Integer.parseInt(st.nextToken());
            adj.get(u).add(v); adj.get(v).add(u);
        }
        par = new int[n+1]; sz = new int[n+1]; heavy = new int[n+1];
        Arrays.fill(heavy, -1);
        int[] order = new int[n];
        boolean[] vis = new boolean[n+1];
        int head = 0, tail = 0;
        order[tail++] = 1; vis[1] = true;
        while (head < tail) {
            int u = order[head++];
            for (int v : adj.get(u)) if (!vis[v]) { vis[v] = true; par[v] = u; order[tail++] = v; }
        }
        for (int i = n-1; i >= 0; i--) {
            int u = order[i]; sz[u] = 1; int mx = 0;
            for (int v : adj.get(u)) if (v != par[u]) { sz[u] += sz[v]; if (sz[v] > mx) { mx = sz[v]; heavy[u] = v; } }
        }
        dfs(1, false);
        System.out.println(ans);
    }

    static void add(int u) {
        int c = col[u], of = cnt.getOrDefault(c,0), nf = of+1;
        cnt.put(c, nf); txor ^= c;
        if (of > 0) { cxf.merge(of, c, (a,b)->a^b); naf.merge(of, -1, Integer::sum); }
        cxf.merge(nf, c, (a,b)->a^b); naf.merge(nf, 1, Integer::sum);
        if (nf > mf) mf = nf;
    }

    static void rem(int u) {
        int c = col[u], of = cnt.get(c), nf = of-1;
        cnt.put(c, nf); txor ^= c;
        cxf.merge(of, c, (a,b)->a^b); naf.merge(of, -1, Integer::sum);
        if (nf > 0) { cxf.merge(nf, c, (a,b)->a^b); naf.merge(nf, 1, Integer::sum); }
        while (mf > 0 && naf.getOrDefault(mf,0) == 0) mf--;
    }

    static List<Integer> getSub(int root) {
        List<Integer> res = new ArrayList<>();
        Deque<Integer> stk = new ArrayDeque<>(); stk.push(root);
        while (!stk.isEmpty()) {
            int u = stk.pop(); res.add(u);
            for (int v : adj.get(u)) if (v != par[u]) stk.push(v);
        }
        return res;
    }

    static void dfs(int u, boolean keep) {
        for (int v : adj.get(u)) if (v != par[u] && v != heavy[u]) dfs(v, false);
        if (heavy[u] != -1) dfs(heavy[u], true);
        add(u);
        for (int v : adj.get(u)) if (v != par[u] && v != heavy[u]) for (int x : getSub(v)) add(x);
        int rx = (mf % 2 == 1) ? cxf.getOrDefault(mf,0) : 0;
        ans = Math.max(ans, txor ^ rx);
        if (!keep) for (int x : getSub(u)) rem(x);
    }
}

复杂度

  • 时间复杂度。DSU on tree 保证每个节点被加入/删除 次,每次操作
  • 空间复杂度,存储树结构、颜色频率哈希表等。