题目链接

REAL744 树的最大权值

题目描述

小红定义一棵树的权值为:在所有节点字符构成回文串的简单路径中,最长路径的长度(节点数)。

给定一棵 个节点的树,以及 'a' 到 'z' 每种字母的可用数量(总和恰好为 )。你需要将每个字母填入一个树节点,使得该树的权值最大。输出这个最大权值。

思路分析

1. 问题转换

这个问题的核心是,在给定的树结构字符集双重约束下,能构造出的最长回文路径有多长。

2. 约束分析

最大权值(即最长回文路径长度 )受到两个独立因素的限制:

  1. 树结构限制:任何路径的长度都不可能超过树的直径。树的直径是树中任意两点间的最长简单路径。如果直径的长度(以节点数计)为 ,那么必然有
  2. 字符数量限制:要构造一个长度为 的回文串,路径上的字符排布必须是中心对称的。这意味着我们需要:
    • 对相同的字符(例如,路径 a-b-c-b-a 需要两对字符:一对 'a' 和一对 'b')。
    • 个单独的字符(如果路径长度 是奇数,就需要一个中心字符,如 a-b-c-b-a 中的 'c')。

3. 整合策略

既然我们可以自由地将字符分配到节点上,为了让路径尽可能长,我们应该选择树上最长的路径——即直径——作为我们构造回文串的“骨架”。

问题就转化为:我们手头的字符资源,最多能支持在一条长度不超过直径 的路径上构造多长的回文串?

首先,我们盘点字符资源:

  • 可用“对”数:对于每种出现 次的字母,它可以提供 对。总可用对数
  • 可用“单”数:凑对之后剩下的单个字符数量。总可用单数

现在,我们要寻找一个最大的长度 ,它必须同时满足所有约束:

  1. (结构约束)
  2. (需要足够多的字符对)
  3. (如果 是奇数,需要至少一个单字符作中心)

4. 求解方法

我们可以注意到,对于一个候选长度 ,如果它能被构造出来,那么任何比它短且奇偶性相同的长度(如 )也肯定能被构造出来。这种单调性非常适合使用二分查找来高效求解。

算法步骤

  1. 统计字符资源:根据输入的26个字母个数,计算出总的 num_pairsnum_singles
  2. 求树的直径
    • 根据输入的边构建树的邻接表。
    • 使用两次广度优先搜索(BFS)或深度优先搜索(DFS)来找到树的直径 (以节点数计)。 a. 从任意节点(如1号节点)出发,找到离它最远的点 。 b. 从 出发,找到离它最远的点 之间的距离(边数)加1就是直径的节点数
  3. 二分查找答案:在 的范围内二分查找最大可行长度 。对于二分过程中的每一个候选长度 mid
    • 检查 mid 是否满足字符数量限制:mid // 2 <= num_pairs 并且 mid % 2 <= num_singles
    • 如果满足,说明长度 mid 是可行的,我们可以尝试更长的路径,即 low = mid + 1
    • 如果不满足,说明 mid 太长了,必须缩短,即 high = mid - 1
  4. 二分查找结束时,记录下的最大可行长度就是最终答案。

代码

#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
#include <queue>

using namespace std;

// BFS函数,返回从start_node出发能到达的最远节点和距离(边数)
pair<int, int> bfs(int start_node, int n, const vector<vector<int>>& adj) {
    vector<int> dist(n + 1, -1);
    queue<int> q;

    dist[start_node] = 0;
    q.push(start_node);

    int farthest_node = start_node;
    int max_dist = 0;

    while (!q.empty()) {
        int u = q.front();
        q.pop();

        if (dist[u] > max_dist) {
            max_dist = dist[u];
            farthest_node = u;
        }

        for (int v : adj[u]) {
            if (dist[v] == -1) {
                dist[v] = dist[u] + 1;
                q.push(v);
            }
        }
    }
    return {farthest_node, max_dist};
}

bool check(int len, int num_pairs, int num_singles) {
    int pairs_needed = len / 2;
    int singles_needed = len % 2;
    return pairs_needed <= num_pairs && singles_needed <= num_singles;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int num_pairs = 0;
    int num_singles = 0;
    for (int i = 0; i < 26; ++i) {
        int count;
        cin >> count;
        num_pairs += count / 2;
        num_singles += count % 2;
    }

    int n;
    cin >> n;
    if (n <= 1) {
        cout << (check(1, num_pairs, num_singles) ? 1 : 0) << endl;
        return 0;
    }
    vector<vector<int>> adj(n + 1);
    for (int i = 0; i < n - 1; ++i) {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    pair<int, int> farthest1 = bfs(1, n, adj);
    pair<int, int> diameter_info = bfs(farthest1.first, n, adj);
    int diameter = diameter_info.second + 1;

    int low = 1, high = diameter, ans = 0;
    while (low <= high) {
        int mid = low + (high - low) / 2;
        if (check(mid, num_pairs, num_singles)) {
            ans = mid;
            low = mid + 1;
        } else {
            high = mid - 1;
        }
    }

    cout << ans << endl;

    return 0;
}
import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int numPairs = 0;
        int numSingles = 0;
        for (int i = 0; i < 26; i++) {
            int count = sc.nextInt();
            numPairs += count / 2;
            numSingles += count % 2;
        }

        int n = sc.nextInt();
        if (n <= 1) {
            System.out.println(check(1, numPairs, numSingles) ? 1 : 0);
            return;
        }

        List<List<Integer>> adj = new ArrayList<>();
        for (int i = 0; i <= n; i++) {
            adj.add(new ArrayList<>());
        }
        for (int i = 0; i < n - 1; i++) {
            int u = sc.nextInt();
            int v = sc.nextInt();
            adj.get(u).add(v);
            adj.get(v).add(u);
        }

        int[] farthest1 = bfs(1, n, adj);
        int[] diameterInfo = bfs(farthest1[0], n, adj);
        int diameter = diameterInfo[1] + 1;

        int low = 1, high = diameter, ans = 0;
        while (low <= high) {
            int mid = low + (high - low) / 2;
            if (check(mid, numPairs, numSingles)) {
                ans = mid;
                low = mid + 1;
            } else {
                high = mid - 1;
            }
        }
        System.out.println(ans);
    }

    private static int[] bfs(int startNode, int n, List<List<Integer>> adj) {
        int[] dist = new int[n + 1];
        Arrays.fill(dist, -1);
        Queue<Integer> q = new LinkedList<>();

        dist[startNode] = 0;
        q.add(startNode);

        int farthestNode = startNode;
        int maxDist = 0;

        while (!q.isEmpty()) {
            int u = q.poll();
            if (dist[u] > maxDist) {
                maxDist = dist[u];
                farthestNode = u;
            }
            for (int v : adj.get(u)) {
                if (dist[v] == -1) {
                    dist[v] = dist[u] + 1;
                    q.add(v);
                }
            }
        }
        return new int[]{farthestNode, maxDist};
    }

    private static boolean check(int len, int numPairs, int numSingles) {
        int pairsNeeded = len / 2;
        int singlesNeeded = len % 2;
        return pairsNeeded <= numPairs && singlesNeeded <= numSingles;
    }
}
import sys
from collections import deque

# 设置递归深度以防万一(虽然BFS不需要)
sys.setrecursionlimit(100000)

def bfs(start_node, n, adj):
    dist = [-1] * (n + 1)
    q = deque([(start_node, 0)])
    dist[start_node] = 0
    
    farthest_node = start_node
    max_dist = 0
    
    while q:
        u, d = q.popleft()
        if d > max_dist:
            max_dist = d
            farthest_node = u
            
        for v in adj[u]:
            if dist[v] == -1:
                dist[v] = d + 1
                q.append((v, d + 1))
                
    return farthest_node, max_dist

def check(length, num_pairs, num_singles):
    pairs_needed = length // 2
    singles_needed = length % 2
    return pairs_needed <= num_pairs and singles_needed <= num_singles

def solve():
    counts = list(map(int, sys.stdin.readline().split()))
    num_pairs = sum(c // 2 for c in counts)
    num_singles = sum(c % 2 for c in counts)
    
    n_str = sys.stdin.readline()
    if not n_str: return
    n = int(n_str)

    if n == 0:
        print(0)
        return
    if n == 1:
        print(1 if check(1, num_pairs, num_singles) else 0)
        return
        
    adj = [[] for _ in range(n + 1)]
    for _ in range(n - 1):
        line = sys.stdin.readline()
        if not line: break
        u, v = map(int, line.split())
        adj[u].append(v)
        adj[v].append(u)
        
    farthest_node_1, _ = bfs(1, n, adj)
    _, diameter_edges = bfs(farthest_node_1, n, adj)
    diameter_nodes = diameter_edges + 1
    
    low, high = 1, diameter_nodes
    ans = 0
    while low <= high:
        mid = (low + high) // 2
        if check(mid, num_pairs, num_singles):
            ans = mid
            low = mid + 1
        else:
            high = mid - 1
            
    print(ans)

solve()

算法及复杂度

  • 算法:两次BFS求树的直径 + 二分查找
  • 时间复杂度。构建邻接表需要 ,两次BFS求直径是 (因为树的边数是 ),二分查找的范围是 ,最多进行 次,每次检查是 的。因此,总时间复杂度由建图和BFS主导,为
  • 空间复杂度,主要用于存储树的邻接表以及BFS中使用的距离数组和队列。