题目链接

小红的小红树

题目描述

给定一棵有 个节点的树,每个节点有自己的权值。初始时,所有节点都是白色的。

小红可以执行操作:选择两个相邻且都为白色的节点,如果它们的权值之和是一个质数,她就可以选择其中一个染成红色。

问:最多可以染红多少个节点?

解题思路

1. 分析操作的性质

  • 一次操作的核心条件是存在一条边 (u, v),使得节点 uv 都是白色的,且它们的权值和 weights[u] + weights[v] 是一个质数。我们称这样的边为“质数边”。
  • 操作执行后,uv 中的一个节点会被染成红色。
  • 关键的约束是:一个节点被染红后,就不能再参与任何操作

2. 将问题转化为计数问题

  • 考虑任意一条质数边 (u, v)。只要 uv 都是白色的,我们就可以在这条边上执行一次操作,将 uv 染红。
  • 一旦操作完成(例如,u 被染红),u 就不再是白色节点。这意味着,这条边 (u, v) 不能再被用于未来的任何操作,因为它的一端不再满足“都为白色”的条件。
  • 因此,每一条质数边最多只能贡献一次染红操作
  • 每一次染红操作,都会恰好产生一个红色的节点。
  • 为了最大化染红节点的数量,我们应该执行尽可能多的操作。由于每条质数边最多只能使用一次,最大操作次数就等于图中质数边的总数。
  • 所以,最多可以染红的节点数 = 树中质数边的总数量

3. 算法实现

  • 问题简化为:遍历树中的每一条边 (u, v),检查 weights[u] + weights[v] 是否为质数,并统计满足条件的边的数量。
  • 我们可以使用埃氏筛法(Sieve of Eratosthenes)预处理出一个质数表,以便快速查询一个数是否为质数。
  • 然后,通过深度优先搜索(DFS)或广度优先搜索(BFS)遍历整棵树。在遍历过程中,对于访问到的每一条边,都进行一次质数和的判断。
  • 最终的计数值就是答案。

代码

#include <iostream>
#include <vector>
#include <numeric>
#include <functional>
#include <algorithm>

using namespace std;

const int MAX_SUM = 200001;
vector<bool> is_prime(MAX_SUM, true);
vector<vector<int>> adj;
vector<int> weights;
int ans;

void sieve() {
    is_prime[0] = is_prime[1] = false;
    for (int p = 2; p * p < MAX_SUM; ++p) {
        if (is_prime[p]) {
            for (int i = p * p; i < MAX_SUM; i += p)
                is_prime[i] = false;
        }
    }
}

void dfs(int u, int p) {
    for (int v : adj[u]) {
        if (v != p) {
            if (weights[u] + weights[v] < MAX_SUM && is_prime[weights[u] + weights[v]]) {
                ans++;
            }
            dfs(v, u);
        }
    }
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    sieve();

    int n;
    cin >> n;
    weights.resize(n + 1);
    for (int i = 1; i <= n; ++i) {
        cin >> weights[i];
    }

    adj.resize(n + 1);
    vector<pair<int, int>> edges;
    for (int i = 0; i < n - 1; ++i) {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    
    ans = 0;
    dfs(1, 0);
    cout << ans << endl;

    return 0;
}
import java.util.*;

public class Main {
    static final int MAX_SUM = 200001;
    static boolean[] isPrime = new boolean[MAX_SUM];
    static List<List<Integer>> adj;
    static int[] weights;
    static int ans;

    private static void sieve() {
        Arrays.fill(isPrime, true);
        isPrime[0] = isPrime[1] = false;
        for (int p = 2; p * p < MAX_SUM; p++) {
            if (isPrime[p]) {
                for (int i = p * p; i < MAX_SUM; i += p)
                    isPrime[i] = false;
            }
        }
    }
    
    private static void dfs(int u, int p) {
        for (int v : adj.get(u)) {
            if (v != p) {
                if (weights[u] + weights[v] < MAX_SUM && isPrime[weights[u] + weights[v]]) {
                    ans++;
                }
                dfs(v, u);
            }
        }
    }

    public static void main(String[] args) {
        sieve();
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        
        weights = new int[n + 1];
        for (int i = 1; i <= n; i++) {
            weights[i] = sc.nextInt();
        }

        adj = new ArrayList<>();
        for (int i = 0; i <= n; i++) {
            adj.add(new ArrayList<>());
        }
        
        for (int i = 0; i < n - 1; i++) {
            int u = sc.nextInt();
            int v = sc.nextInt();
            adj.get(u).add(v);
            adj.get(v).add(u);
        }
        
        ans = 0;
        dfs(1, 0);
        
        System.out.println(ans);
    }
}
import sys

# Increase recursion limit for deep graphs
sys.setrecursionlimit(200005)

MAX_SUM = 200001
is_prime = [True] * MAX_SUM
adj = []
weights = []
ans = 0

def sieve():
    global is_prime
    is_prime[0] = is_prime[1] = False
    for p in range(2, int(MAX_SUM**0.5) + 1):
        if is_prime[p]:
            for i in range(p * p, MAX_SUM, p):
                is_prime[i] = False

def dfs(u, p):
    global ans
    for v in adj[u]:
        if v != p:
            if weights[u] + weights[v] < MAX_SUM and is_prime[weights[u] + weights[v]]:
                ans += 1
            dfs(v, u)

def solve():
    global adj, weights, ans
    try:
        n_str = sys.stdin.readline().strip()
        if not n_str: return
        n = int(n_str)
        
        weights = [0] + list(map(int, sys.stdin.readline().strip().split()))
        
        adj = [[] for _ in range(n + 1)]
        for _ in range(n - 1):
            u, v = map(int, sys.stdin.readline().strip().split())
            adj[u].append(v)
            adj[v].append(u)
            
    except (IOError, ValueError):
        return

    sieve()
    
    ans = 0
    dfs(1, 0)
    
    print(ans)

solve()

算法及复杂度

  • 算法:埃氏筛 + 深度优先搜索(DFS)

  • 时间复杂度,其中 是节点数, 是节点权值的最大值(本题中为 )。

    • 埃氏筛预处理质数的时间复杂度为
    • DFS 遍历树的每条边恰好一次,时间复杂度为
    • 总的时间复杂度由两者相加决定。
  • 空间复杂度

    • 用于存储树的邻接表和递归栈。
    • 用于存储质数表。