小红的小红树
[题目链接](https://www.nowcoder.com/practice/e0c409d06d0e4fef94e93f34215a0df4)
思路
乍一看这道题,很多同学会以为是"树上最大匹配",但仔细读题会发现关键区别:操作只是把一对白色结点中的一个染红,另一个保持白色,可以反复参与后续操作。
朴素想法的陷阱
最容易犯的错误是认为"每个红色结点必须有一个永远保持白色的邻居"。但实际上,操作是有时序的:先把 A 染红(此时 B 是白色的帮手),之后再把 B 也染红(用 B 的另一个白色邻居 C 当帮手)。A 虽然失去了白色帮手 B,但 A 已经是红色了,不会"被撤销"。
连通分量视角
把问题抽象一下:设被染红的结点集合为 。在树上,如果两个红色结点之间的边权值和是质数,它们就属于同一个红色连通分量(通过质数边相连的红色结点链)。
关键观察:一个红色连通分量是合法的,当且仅当分量中至少有一个结点拥有一个不在 中的质数和邻居。
为什么?对于这样的分量,我们可以从"深处"往"出口"的方向依次染色。最后一个染色的结点使用外部白色帮手。分量中其他结点在被染色时,它们的"下一个即将被染色的邻居"还是白色的,可以充当帮手。
树形 DP 设计
以结点 为根做树形 DP,为每个结点
定义三个状态:
| 状态 | 含义 |
|---|---|
转移方程
设 表示
是否为质数。
(
白色): 每个孩子独立选最优状态。如果孩子
选了状态 2(红色未满足),那
必须和白色的
之间有质数边,
就充当了外部帮手。
$$
(
红色,分量未满足): 所有和
有质数边的孩子,如果是红色就会合并进
的分量。要保持分量"未满足":
- 有质数边的孩子必须也是红色且未满足(状态 2),否则要么白色帮手满足了
,要么合并进已满足的子分量;
- 无质数边的孩子照常取最优。
$$
若某个质数边孩子的 无效,则
也无效。
(
红色,分量已满足): 需要至少一个质数边孩子提供满足——要么是白色的(给
当外部帮手),要么是红色已满足的(合并后整个分量满足)。
先算出每个孩子取三个状态中最优值的总和 。如果恰好所有质数边孩子都更愿意选状态 2(未满足),就需要把代价最小的一个"掰回"状态 0 或 1。
答案
根结点没有父亲,不能用状态 2。答案为 。
复杂度
- 时间:
,其中
是权值和的上界,质数判定为
。
- 空间:
。
代码
#include <bits/stdc++.h>
using namespace std;
bool isPrime(long long x) {
if (x < 2) return false;
if (x == 2) return true;
if (x % 2 == 0) return false;
for (long long i = 3; i * i <= x; i += 2)
if (x % i == 0) return false;
return true;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;
cin >> n;
vector<long long> w(n);
for (int i = 0; i < n; i++) cin >> w[i];
vector<vector<int>> adj(n);
for (int i = 0; i < n - 1; i++) {
int u, v;
cin >> u >> v;
u--; v--;
adj[u].push_back(v);
adj[v].push_back(u);
}
const int NEG = -1000000;
vector<array<int,3>> dp(n);
vector<bool> visited(n, false);
vector<int> par(n, -1);
vector<int> order;
stack<int> stk;
stk.push(0);
visited[0] = true;
while (!stk.empty()) {
int u = stk.top(); stk.pop();
order.push_back(u);
for (int v : adj[u]) {
if (!visited[v]) {
visited[v] = true;
par[v] = u;
stk.push(v);
}
}
}
for (int i = (int)order.size() - 1; i >= 0; i--) {
int u = order[i];
vector<int> children;
for (int v : adj[u]) {
if (v != par[u]) children.push_back(v);
}
if (children.empty()) {
dp[u] = {0, NEG, 1};
continue;
}
// dp[u][0]
int sum0 = 0;
for (int v : children) {
int best = max(dp[v][0], dp[v][1]);
if (isPrime(w[u] + w[v])) best = max(best, dp[v][2]);
sum0 += best;
}
dp[u][0] = sum0;
// dp[u][2]
int sum2 = 0;
bool valid2 = true;
for (int v : children) {
if (isPrime(w[u] + w[v])) {
if (dp[v][2] <= NEG) valid2 = false;
sum2 += dp[v][2];
} else {
sum2 += max(dp[v][0], dp[v][1]);
}
}
dp[u][2] = valid2 ? 1 + sum2 : NEG;
// dp[u][1]
bool hasPrimeChild = false;
int sumBase = 0;
bool allPrimeState2 = true;
int minCost = INT_MAX;
for (int v : children) {
if (isPrime(w[u] + w[v])) {
hasPrimeChild = true;
int best01 = max(dp[v][0], dp[v][1]);
int best012 = max(best01, dp[v][2]);
sumBase += best012;
if (best012 > best01 && dp[v][2] > NEG) {
minCost = min(minCost, dp[v][2] - best01);
} else {
allPrimeState2 = false;
}
} else {
sumBase += max(dp[v][0], dp[v][1]);
}
}
if (!hasPrimeChild) {
dp[u][1] = NEG;
} else if (!allPrimeState2) {
dp[u][1] = 1 + sumBase;
} else {
dp[u][1] = (minCost < INT_MAX) ? 1 + sumBase - minCost : NEG;
}
}
cout << max(dp[0][0], max(dp[0][1], 0)) << endl;
return 0;
}
import java.util.*;
import java.io.*;
public class Main {
static boolean isPrime(long x) {
if (x < 2) return false;
if (x == 2) return true;
if (x % 2 == 0) return false;
for (long i = 3; i * i <= x; i += 2)
if (x % i == 0) return false;
return true;
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
int n = Integer.parseInt(br.readLine().trim());
StringTokenizer st = new StringTokenizer(br.readLine().trim());
long[] w = new long[n];
for (int i = 0; i < n; i++) w[i] = Long.parseLong(st.nextToken());
List<List<Integer>> adj = new ArrayList<>();
for (int i = 0; i < n; i++) adj.add(new ArrayList<>());
for (int i = 0; i < n - 1; i++) {
st = new StringTokenizer(br.readLine().trim());
int u = Integer.parseInt(st.nextToken()) - 1;
int v = Integer.parseInt(st.nextToken()) - 1;
adj.get(u).add(v);
adj.get(v).add(u);
}
final int NEG = -1000000;
int[][] dp = new int[n][3];
boolean[] visited = new boolean[n];
int[] par = new int[n];
Arrays.fill(par, -1);
int[] order = new int[n];
int idx = 0;
Deque<Integer> stack = new ArrayDeque<>();
stack.push(0);
visited[0] = true;
while (!stack.isEmpty()) {
int u = stack.pop();
order[idx++] = u;
for (int v : adj.get(u)) {
if (!visited[v]) {
visited[v] = true;
par[v] = u;
stack.push(v);
}
}
}
for (int i = n - 1; i >= 0; i--) {
int u = order[i];
List<Integer> children = new ArrayList<>();
for (int v : adj.get(u)) {
if (v != par[u]) children.add(v);
}
if (children.isEmpty()) {
dp[u][0] = 0;
dp[u][1] = NEG;
dp[u][2] = 1;
continue;
}
// dp[u][0]
int sum0 = 0;
for (int v : children) {
int best = Math.max(dp[v][0], dp[v][1]);
if (isPrime(w[u] + w[v])) best = Math.max(best, dp[v][2]);
sum0 += best;
}
dp[u][0] = sum0;
// dp[u][2]
int sum2 = 0;
boolean valid2 = true;
for (int v : children) {
if (isPrime(w[u] + w[v])) {
if (dp[v][2] <= NEG) valid2 = false;
sum2 += dp[v][2];
} else {
sum2 += Math.max(dp[v][0], dp[v][1]);
}
}
dp[u][2] = valid2 ? 1 + sum2 : NEG;
// dp[u][1]
boolean hasPrimeChild = false;
int sumBase = 0;
boolean allPrimeState2 = true;
int minCost = Integer.MAX_VALUE;
for (int v : children) {
if (isPrime(w[u] + w[v])) {
hasPrimeChild = true;
int best01 = Math.max(dp[v][0], dp[v][1]);
int best012 = Math.max(best01, dp[v][2]);
sumBase += best012;
if (best012 > best01 && dp[v][2] > NEG) {
minCost = Math.min(minCost, dp[v][2] - best01);
} else {
allPrimeState2 = false;
}
} else {
sumBase += Math.max(dp[v][0], dp[v][1]);
}
}
if (!hasPrimeChild) {
dp[u][1] = NEG;
} else if (!allPrimeState2) {
dp[u][1] = 1 + sumBase;
} else {
dp[u][1] = (minCost < Integer.MAX_VALUE) ? 1 + sumBase - minCost : NEG;
}
}
System.out.println(Math.max(dp[0][0], Math.max(dp[0][1], 0)));
}
}

京公网安备 11010502036488号