小红的小红树

[题目链接](https://www.nowcoder.com/practice/e0c409d06d0e4fef94e93f34215a0df4)

思路

乍一看这道题,很多同学会以为是"树上最大匹配",但仔细读题会发现关键区别:操作只是把一对白色结点中的一个染红,另一个保持白色,可以反复参与后续操作。

朴素想法的陷阱

最容易犯的错误是认为"每个红色结点必须有一个永远保持白色的邻居"。但实际上,操作是有时序的:先把 A 染红(此时 B 是白色的帮手),之后再把 B 也染红(用 B 的另一个白色邻居 C 当帮手)。A 虽然失去了白色帮手 B,但 A 已经是红色了,不会"被撤销"。

连通分量视角

把问题抽象一下:设被染红的结点集合为 。在树上,如果两个红色结点之间的边权值和是质数,它们就属于同一个红色连通分量(通过质数边相连的红色结点链)。

关键观察:一个红色连通分量是合法的,当且仅当分量中至少有一个结点拥有一个不在 中的质数和邻居

为什么?对于这样的分量,我们可以从"深处"往"出口"的方向依次染色。最后一个染色的结点使用外部白色帮手。分量中其他结点在被染色时,它们的"下一个即将被染色的邻居"还是白色的,可以充当帮手。

树形 DP 设计

以结点 为根做树形 DP,为每个结点 定义三个状态:

状态 含义
保持白色,子树中所有红色连通分量均已被满足
被染红, 所在的连通分量已被满足,其他分量也已被满足
被染红, 所在的连通分量尚未被满足(需要由父结点一侧提供帮助)

转移方程

表示 是否为质数。

白色): 每个孩子独立选最优状态。如果孩子 选了状态 2(红色未满足),那 必须和白色的 之间有质数边, 就充当了外部帮手。

$$

红色,分量未满足): 所有和 有质数边的孩子,如果是红色就会合并进 的分量。要保持分量"未满足":

  • 有质数边的孩子必须也是红色且未满足(状态 2),否则要么白色帮手满足了 ,要么合并进已满足的子分量;
  • 无质数边的孩子照常取最优。

$$

若某个质数边孩子的 无效,则 也无效。

红色,分量已满足): 需要至少一个质数边孩子提供满足——要么是白色的(给 当外部帮手),要么是红色已满足的(合并后整个分量满足)。

先算出每个孩子取三个状态中最优值的总和 。如果恰好所有质数边孩子都更愿意选状态 2(未满足),就需要把代价最小的一个"掰回"状态 0 或 1。

答案

根结点没有父亲,不能用状态 2。答案为

复杂度

  • 时间,其中 是权值和的上界,质数判定为
  • 空间

代码

#include <bits/stdc++.h>
using namespace std;

bool isPrime(long long x) {
    if (x < 2) return false;
    if (x == 2) return true;
    if (x % 2 == 0) return false;
    for (long long i = 3; i * i <= x; i += 2)
        if (x % i == 0) return false;
    return true;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n;
    cin >> n;
    vector<long long> w(n);
    for (int i = 0; i < n; i++) cin >> w[i];

    vector<vector<int>> adj(n);
    for (int i = 0; i < n - 1; i++) {
        int u, v;
        cin >> u >> v;
        u--; v--;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    const int NEG = -1000000;
    vector<array<int,3>> dp(n);
    vector<bool> visited(n, false);
    vector<int> par(n, -1);

    vector<int> order;
    stack<int> stk;
    stk.push(0);
    visited[0] = true;
    while (!stk.empty()) {
        int u = stk.top(); stk.pop();
        order.push_back(u);
        for (int v : adj[u]) {
            if (!visited[v]) {
                visited[v] = true;
                par[v] = u;
                stk.push(v);
            }
        }
    }

    for (int i = (int)order.size() - 1; i >= 0; i--) {
        int u = order[i];
        vector<int> children;
        for (int v : adj[u]) {
            if (v != par[u]) children.push_back(v);
        }

        if (children.empty()) {
            dp[u] = {0, NEG, 1};
            continue;
        }

        // dp[u][0]
        int sum0 = 0;
        for (int v : children) {
            int best = max(dp[v][0], dp[v][1]);
            if (isPrime(w[u] + w[v])) best = max(best, dp[v][2]);
            sum0 += best;
        }
        dp[u][0] = sum0;

        // dp[u][2]
        int sum2 = 0;
        bool valid2 = true;
        for (int v : children) {
            if (isPrime(w[u] + w[v])) {
                if (dp[v][2] <= NEG) valid2 = false;
                sum2 += dp[v][2];
            } else {
                sum2 += max(dp[v][0], dp[v][1]);
            }
        }
        dp[u][2] = valid2 ? 1 + sum2 : NEG;

        // dp[u][1]
        bool hasPrimeChild = false;
        int sumBase = 0;
        bool allPrimeState2 = true;
        int minCost = INT_MAX;

        for (int v : children) {
            if (isPrime(w[u] + w[v])) {
                hasPrimeChild = true;
                int best01 = max(dp[v][0], dp[v][1]);
                int best012 = max(best01, dp[v][2]);
                sumBase += best012;
                if (best012 > best01 && dp[v][2] > NEG) {
                    minCost = min(minCost, dp[v][2] - best01);
                } else {
                    allPrimeState2 = false;
                }
            } else {
                sumBase += max(dp[v][0], dp[v][1]);
            }
        }

        if (!hasPrimeChild) {
            dp[u][1] = NEG;
        } else if (!allPrimeState2) {
            dp[u][1] = 1 + sumBase;
        } else {
            dp[u][1] = (minCost < INT_MAX) ? 1 + sumBase - minCost : NEG;
        }
    }

    cout << max(dp[0][0], max(dp[0][1], 0)) << endl;
    return 0;
}
import java.util.*;
import java.io.*;

public class Main {
    static boolean isPrime(long x) {
        if (x < 2) return false;
        if (x == 2) return true;
        if (x % 2 == 0) return false;
        for (long i = 3; i * i <= x; i += 2)
            if (x % i == 0) return false;
        return true;
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(br.readLine().trim());
        StringTokenizer st = new StringTokenizer(br.readLine().trim());
        long[] w = new long[n];
        for (int i = 0; i < n; i++) w[i] = Long.parseLong(st.nextToken());

        List<List<Integer>> adj = new ArrayList<>();
        for (int i = 0; i < n; i++) adj.add(new ArrayList<>());
        for (int i = 0; i < n - 1; i++) {
            st = new StringTokenizer(br.readLine().trim());
            int u = Integer.parseInt(st.nextToken()) - 1;
            int v = Integer.parseInt(st.nextToken()) - 1;
            adj.get(u).add(v);
            adj.get(v).add(u);
        }

        final int NEG = -1000000;
        int[][] dp = new int[n][3];
        boolean[] visited = new boolean[n];
        int[] par = new int[n];
        Arrays.fill(par, -1);

        int[] order = new int[n];
        int idx = 0;
        Deque<Integer> stack = new ArrayDeque<>();
        stack.push(0);
        visited[0] = true;
        while (!stack.isEmpty()) {
            int u = stack.pop();
            order[idx++] = u;
            for (int v : adj.get(u)) {
                if (!visited[v]) {
                    visited[v] = true;
                    par[v] = u;
                    stack.push(v);
                }
            }
        }

        for (int i = n - 1; i >= 0; i--) {
            int u = order[i];
            List<Integer> children = new ArrayList<>();
            for (int v : adj.get(u)) {
                if (v != par[u]) children.add(v);
            }

            if (children.isEmpty()) {
                dp[u][0] = 0;
                dp[u][1] = NEG;
                dp[u][2] = 1;
                continue;
            }

            // dp[u][0]
            int sum0 = 0;
            for (int v : children) {
                int best = Math.max(dp[v][0], dp[v][1]);
                if (isPrime(w[u] + w[v])) best = Math.max(best, dp[v][2]);
                sum0 += best;
            }
            dp[u][0] = sum0;

            // dp[u][2]
            int sum2 = 0;
            boolean valid2 = true;
            for (int v : children) {
                if (isPrime(w[u] + w[v])) {
                    if (dp[v][2] <= NEG) valid2 = false;
                    sum2 += dp[v][2];
                } else {
                    sum2 += Math.max(dp[v][0], dp[v][1]);
                }
            }
            dp[u][2] = valid2 ? 1 + sum2 : NEG;

            // dp[u][1]
            boolean hasPrimeChild = false;
            int sumBase = 0;
            boolean allPrimeState2 = true;
            int minCost = Integer.MAX_VALUE;

            for (int v : children) {
                if (isPrime(w[u] + w[v])) {
                    hasPrimeChild = true;
                    int best01 = Math.max(dp[v][0], dp[v][1]);
                    int best012 = Math.max(best01, dp[v][2]);
                    sumBase += best012;
                    if (best012 > best01 && dp[v][2] > NEG) {
                        minCost = Math.min(minCost, dp[v][2] - best01);
                    } else {
                        allPrimeState2 = false;
                    }
                } else {
                    sumBase += Math.max(dp[v][0], dp[v][1]);
                }
            }

            if (!hasPrimeChild) {
                dp[u][1] = NEG;
            } else if (!allPrimeState2) {
                dp[u][1] = 1 + sumBase;
            } else {
                dp[u][1] = (minCost < Integer.MAX_VALUE) ? 1 + sumBase - minCost : NEG;
            }
        }

        System.out.println(Math.max(dp[0][0], Math.max(dp[0][1], 0)));
    }
}