题目链接
题目描述
小红定义一棵树的权值为:在所有节点字符构成回文串的简单路径中,最长路径的长度(节点数)。
给定一棵 个节点的树,以及 'a' 到 'z' 每种字母的可用数量(总和恰好为
)。你需要将每个字母填入一个树节点,使得该树的权值最大。输出这个最大权值。
思路分析
1. 问题转换
这个问题的核心是,在给定的树结构和字符集双重约束下,能构造出的最长回文路径有多长。
2. 约束分析
最大权值(即最长回文路径长度 )受到两个独立因素的限制:
- 树结构限制:任何路径的长度都不可能超过树的直径。树的直径是树中任意两点间的最长简单路径。如果直径的长度(以节点数计)为
,那么必然有
。
- 字符数量限制:要构造一个长度为
的回文串,路径上的字符排布必须是中心对称的。这意味着我们需要:
对相同的字符(例如,路径
a-b-c-b-a
需要两对字符:一对 'a' 和一对 'b')。个单独的字符(如果路径长度
是奇数,就需要一个中心字符,如
a-b-c-b-a
中的 'c')。
3. 整合策略
既然我们可以自由地将字符分配到节点上,为了让路径尽可能长,我们应该选择树上最长的路径——即直径——作为我们构造回文串的“骨架”。
问题就转化为:我们手头的字符资源,最多能支持在一条长度不超过直径 的路径上构造多长的回文串?
首先,我们盘点字符资源:
- 可用“对”数:对于每种出现
次的字母,它可以提供
对。总可用对数
。
- 可用“单”数:凑对之后剩下的单个字符数量。总可用单数
。
现在,我们要寻找一个最大的长度 ,它必须同时满足所有约束:
(结构约束)
(需要足够多的字符对)
(如果
是奇数,需要至少一个单字符作中心)
4. 求解方法
我们可以注意到,对于一个候选长度 ,如果它能被构造出来,那么任何比它短且奇偶性相同的长度(如
)也肯定能被构造出来。这种单调性非常适合使用二分查找来高效求解。
算法步骤:
- 统计字符资源:根据输入的26个字母个数,计算出总的
num_pairs
和num_singles
。 - 求树的直径:
- 根据输入的边构建树的邻接表。
- 使用两次广度优先搜索(BFS)或深度优先搜索(DFS)来找到树的直径
(以节点数计)。 a. 从任意节点(如1号节点)出发,找到离它最远的点
。 b. 从
出发,找到离它最远的点
。
和
之间的距离(边数)加1就是直径的节点数
。
- 二分查找答案:在
的范围内二分查找最大可行长度
。对于二分过程中的每一个候选长度
mid
:- 检查
mid
是否满足字符数量限制:mid // 2 <= num_pairs
并且mid % 2 <= num_singles
。 - 如果满足,说明长度
mid
是可行的,我们可以尝试更长的路径,即low = mid + 1
。 - 如果不满足,说明
mid
太长了,必须缩短,即high = mid - 1
。
- 检查
- 二分查找结束时,记录下的最大可行长度就是最终答案。
代码
#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
#include <queue>
using namespace std;
// BFS函数,返回从start_node出发能到达的最远节点和距离(边数)
pair<int, int> bfs(int start_node, int n, const vector<vector<int>>& adj) {
vector<int> dist(n + 1, -1);
queue<int> q;
dist[start_node] = 0;
q.push(start_node);
int farthest_node = start_node;
int max_dist = 0;
while (!q.empty()) {
int u = q.front();
q.pop();
if (dist[u] > max_dist) {
max_dist = dist[u];
farthest_node = u;
}
for (int v : adj[u]) {
if (dist[v] == -1) {
dist[v] = dist[u] + 1;
q.push(v);
}
}
}
return {farthest_node, max_dist};
}
bool check(int len, int num_pairs, int num_singles) {
int pairs_needed = len / 2;
int singles_needed = len % 2;
return pairs_needed <= num_pairs && singles_needed <= num_singles;
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int num_pairs = 0;
int num_singles = 0;
for (int i = 0; i < 26; ++i) {
int count;
cin >> count;
num_pairs += count / 2;
num_singles += count % 2;
}
int n;
cin >> n;
if (n <= 1) {
cout << (check(1, num_pairs, num_singles) ? 1 : 0) << endl;
return 0;
}
vector<vector<int>> adj(n + 1);
for (int i = 0; i < n - 1; ++i) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
pair<int, int> farthest1 = bfs(1, n, adj);
pair<int, int> diameter_info = bfs(farthest1.first, n, adj);
int diameter = diameter_info.second + 1;
int low = 1, high = diameter, ans = 0;
while (low <= high) {
int mid = low + (high - low) / 2;
if (check(mid, num_pairs, num_singles)) {
ans = mid;
low = mid + 1;
} else {
high = mid - 1;
}
}
cout << ans << endl;
return 0;
}
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int numPairs = 0;
int numSingles = 0;
for (int i = 0; i < 26; i++) {
int count = sc.nextInt();
numPairs += count / 2;
numSingles += count % 2;
}
int n = sc.nextInt();
if (n <= 1) {
System.out.println(check(1, numPairs, numSingles) ? 1 : 0);
return;
}
List<List<Integer>> adj = new ArrayList<>();
for (int i = 0; i <= n; i++) {
adj.add(new ArrayList<>());
}
for (int i = 0; i < n - 1; i++) {
int u = sc.nextInt();
int v = sc.nextInt();
adj.get(u).add(v);
adj.get(v).add(u);
}
int[] farthest1 = bfs(1, n, adj);
int[] diameterInfo = bfs(farthest1[0], n, adj);
int diameter = diameterInfo[1] + 1;
int low = 1, high = diameter, ans = 0;
while (low <= high) {
int mid = low + (high - low) / 2;
if (check(mid, numPairs, numSingles)) {
ans = mid;
low = mid + 1;
} else {
high = mid - 1;
}
}
System.out.println(ans);
}
private static int[] bfs(int startNode, int n, List<List<Integer>> adj) {
int[] dist = new int[n + 1];
Arrays.fill(dist, -1);
Queue<Integer> q = new LinkedList<>();
dist[startNode] = 0;
q.add(startNode);
int farthestNode = startNode;
int maxDist = 0;
while (!q.isEmpty()) {
int u = q.poll();
if (dist[u] > maxDist) {
maxDist = dist[u];
farthestNode = u;
}
for (int v : adj.get(u)) {
if (dist[v] == -1) {
dist[v] = dist[u] + 1;
q.add(v);
}
}
}
return new int[]{farthestNode, maxDist};
}
private static boolean check(int len, int numPairs, int numSingles) {
int pairsNeeded = len / 2;
int singlesNeeded = len % 2;
return pairsNeeded <= numPairs && singlesNeeded <= numSingles;
}
}
import sys
from collections import deque
# 设置递归深度以防万一(虽然BFS不需要)
sys.setrecursionlimit(100000)
def bfs(start_node, n, adj):
dist = [-1] * (n + 1)
q = deque([(start_node, 0)])
dist[start_node] = 0
farthest_node = start_node
max_dist = 0
while q:
u, d = q.popleft()
if d > max_dist:
max_dist = d
farthest_node = u
for v in adj[u]:
if dist[v] == -1:
dist[v] = d + 1
q.append((v, d + 1))
return farthest_node, max_dist
def check(length, num_pairs, num_singles):
pairs_needed = length // 2
singles_needed = length % 2
return pairs_needed <= num_pairs and singles_needed <= num_singles
def solve():
counts = list(map(int, sys.stdin.readline().split()))
num_pairs = sum(c // 2 for c in counts)
num_singles = sum(c % 2 for c in counts)
n_str = sys.stdin.readline()
if not n_str: return
n = int(n_str)
if n == 0:
print(0)
return
if n == 1:
print(1 if check(1, num_pairs, num_singles) else 0)
return
adj = [[] for _ in range(n + 1)]
for _ in range(n - 1):
line = sys.stdin.readline()
if not line: break
u, v = map(int, line.split())
adj[u].append(v)
adj[v].append(u)
farthest_node_1, _ = bfs(1, n, adj)
_, diameter_edges = bfs(farthest_node_1, n, adj)
diameter_nodes = diameter_edges + 1
low, high = 1, diameter_nodes
ans = 0
while low <= high:
mid = (low + high) // 2
if check(mid, num_pairs, num_singles):
ans = mid
low = mid + 1
else:
high = mid - 1
print(ans)
solve()
算法及复杂度
- 算法:两次BFS求树的直径 + 二分查找
- 时间复杂度:
。构建邻接表需要
,两次BFS求直径是
(因为树的边数是
),二分查找的范围是
,最多进行
次,每次检查是
的。因此,总时间复杂度由建图和BFS主导,为
。
- 空间复杂度:
,主要用于存储树的邻接表以及BFS中使用的距离数组和队列。