题目链接
题目描述
在一次 Prompt 工程实践中,我们把一段 token 序列抽象成一棵二叉树。树中每个结点都有一个整数权值(可正可负,也可能为 0)。请在这棵树中选出一棵“价值最大”的子树,并把这棵子树按“完全二叉树的层序数组”形式输出。
- 子树的价值:定义为它所包含的所有结点权值之和。
- 剪枝规则:允许对某个结点“剪掉”对总和贡献为负的整棵子树(即可以只要左子树、或只要右子树、或两者都要;被剪掉的位置在输出中以
null占位)。 - 输入/输出:输入是一棵“用层序数组表示的完全二叉树”,缺失位置用
null占位;输出也使用相同规则表示挑选出的那棵最优子树,并且去除末尾多余的尾部null。
输入描述
- 一行:用方括号包裹的一维数组,表示树的层序遍历;缺失结点用
null。例如:[5,-1,3,null,null,4,7]。 - 约定:数组下标从 0 开始;对下标 i,左孩子为
2*i+1,右孩子为2*i+2。
输出描述
- 一行:选择的“最大和子树”的层序数组表示(仍以
null占位),并删除末尾无意义的连续null。
解题思路
本题的核心是寻找一棵二叉树中“价值最大”的子树。这里的“子树”定义比较特殊:它允许我们从任意节点开始,向下延伸时可以选择性地剪掉那些总和为负的左子树或右子树,以最大化当前子树的总和。最终,我们需要找到全局的最大和,并输出形成这个最大和的树结构。
这个问题可以通过一次后序遍历和树形动态规划来解决。我们需要从下往上计算信息,为每个节点做出决策。
1. 状态定义与计算
对于树中的任意一个节点 node,我们需要计算以它为根的、经过“剪枝”后可能达到的最大价值。这个值就是 node 自身的权值,加上其左、右子树经过“剪枝”后提供的正贡献。
同时,我们需要一个全局变量
global_max_sum 来记录在整棵树中遇到的所有 max_sum_at_node 的最大值,因为最优子树的根不一定是整棵树的根。
2. 第一次遍历 (后序遍历 - 计算最大和)
我们可以定义一个递归函数,比如 calculate_max_sum(node),它执行后序遍历:
- 基本情况: 如果
node为空,返回 0。 - 递归: 分别递归计算左、右子树的最大价值
left_max_sum和right_max_sum。 - 计算当前节点的最大价值: 根据上述公式计算
current_max_sum。 - 更新全局最大值: 用
current_max_sum更新global_max_sum,并记录下是哪个节点(root_of_max_subtree)产生了全局最大值。 - 返回值: 函数返回
current_max_sum,供其父节点使用。
3. 第二次遍历 (前序遍历 - 构建输出树)
在第一次遍历之后,我们已经知道了全局最大和以及最优子树的根节点。现在我们需要根据这个根节点,构建出最终要输出的树结构。
我们可以定义另一个递归函数,比如 build_output_tree(original_node, output_node_index, output_array),它执行前序遍历:
- 从
root_of_max_subtree开始。 - 将当前
original_node的值放入output_array的output_node_index位置。 - 决策剪枝:
- 检查其左子树。在第一次遍历中,我们可以用一个哈希表记录下每个节点计算出的
max_sum。如果左子树的max_sum > 0,说明它有正贡献,我们保留它并递归构建。否则,我们“剪掉”它。 - 对右子树做同样的操作。
- 检查其左子树。在第一次遍历中,我们可以用一个哈希表记录下每个节点计算出的
- 这个过程会根据第一次遍历的计算结果,在
output_array中逐步构建出被剪枝后的树的层序表示。
4. 输入与输出处理
- 输入: 输入是一个表示完全二叉树的层序数组。我们需要先将这个数组转换成我们熟悉的链式节点表示的树结构,以便于进行递归遍历。
- 输出: 第二次遍历生成的
output_array就是最终的结果,但需要移除末尾连续的null。
代码
#include <iostream>
#include <vector>
#include <string>
#include <sstream>
#include <algorithm>
#include <map>
#include <climits>
using namespace std;
// 树节点
struct TreeNode {
int val;
TreeNode *left;
TreeNode *right;
TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
};
// 全局变量来跟踪最大和及其根节点
long long global_max_sum = LLONG_MIN;
TreeNode* root_of_max_subtree = nullptr;
map<TreeNode*, long long> node_sums; // 存储每个节点为根的最大子树和
// 从层序数组构建树
TreeNode* build_tree_from_level_order(const vector<string>& arr) {
if (arr.empty() || arr[0] == "null") {
return nullptr;
}
vector<TreeNode*> nodes;
for (const string& s : arr) {
if (s == "null") {
nodes.push_back(nullptr);
} else {
nodes.push_back(new TreeNode(stoi(s)));
}
}
for (int i = 0; i < nodes.size(); ++i) {
if (nodes[i]) {
int left_idx = 2 * i + 1;
int right_idx = 2 * i + 2;
if (left_idx < nodes.size()) nodes[i]->left = nodes[left_idx];
if (right_idx < nodes.size()) nodes[i]->right = nodes[right_idx];
}
}
return nodes[0];
}
// 后序遍历计算最大子树和
long long calculate_max_sum(TreeNode* node) {
if (!node) return 0;
long long left_sum = calculate_max_sum(node->left);
long long right_sum = calculate_max_sum(node->right);
long long current_max_sum = node->val + max(0LL, left_sum) + max(0LL, right_sum);
node_sums[node] = current_max_sum;
if (current_max_sum > global_max_sum) {
global_max_sum = current_max_sum;
root_of_max_subtree = node;
}
return current_max_sum;
}
// 前序遍历构建输出数组
void build_output(TreeNode* node, int index, vector<string>& output) {
if (!node) return;
if (index >= output.size()) output.resize(index + 1, "null");
output[index] = to_string(node->val);
if (node->left && node_sums[node->left] > 0) {
build_output(node->left, 2 * index + 1, output);
}
if (node->right && node_sums[node->right] > 0) {
build_output(node->right, 2 * index + 2, output);
}
}
int main() {
string line;
getline(cin, line);
line = line.substr(1, line.length() - 2); // 去掉方括号
replace(line.begin(), line.end(), ',', ' ');
stringstream ss(line);
string item;
vector<string> arr;
while (ss >> item) {
arr.push_back(item);
}
TreeNode* root = build_tree_from_level_order(arr);
if(!root) {
cout << "[]" << endl;
return 0;
}
calculate_max_sum(root);
vector<string> output;
if (root_of_max_subtree) {
build_output(root_of_max_subtree, 0, output);
}
// 删除末尾的 null
while (!output.empty() && output.back() == "null") {
output.pop_back();
}
cout << "[";
for (int i = 0; i < output.size(); ++i) {
cout << output[i] << (i == output.size() - 1 ? "" : ",");
}
cout << "]" << endl;
return 0;
}
import java.util.*;
class TreeNode {
int val;
TreeNode left;
TreeNode right;
TreeNode(int val) { this.val = val; }
}
public class Main {
static long globalMaxSum = Long.MIN_VALUE;
static TreeNode rootOfMaxSubtree = null;
static Map<TreeNode, Long> nodeSums = new HashMap<>();
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
String line = sc.nextLine();
line = line.substring(1, line.length() - 1);
if (line.isEmpty()) {
System.out.println("[]");
return;
}
String[] parts = line.split(",");
TreeNode root = buildTreeFromLevelOrder(parts);
calculateMaxSum(root);
List<String> output = new ArrayList<>();
if (rootOfMaxSubtree != null) {
buildOutput(rootOfMaxSubtree, 0, output);
}
while (!output.isEmpty() && output.get(output.size() - 1).equals("null")) {
output.remove(output.size() - 1);
}
System.out.println("[" + String.join(",", output) + "]");
}
private static TreeNode buildTreeFromLevelOrder(String[] arr) {
if (arr.length == 0 || arr[0].equals("null")) return null;
List<TreeNode> nodes = new ArrayList<>();
for (String s : arr) {
nodes.add(s.equals("null") ? null : new TreeNode(Integer.parseInt(s)));
}
for (int i = 0; i < nodes.size(); i++) {
if (nodes.get(i) != null) {
int leftIdx = 2 * i + 1;
int rightIdx = 2 * i + 2;
if (leftIdx < nodes.size()) nodes.get(i).left = nodes.get(leftIdx);
if (rightIdx < nodes.size()) nodes.get(i).right = nodes.get(rightIdx);
}
}
return nodes.get(0);
}
private static long calculateMaxSum(TreeNode node) {
if (node == null) return 0;
long leftSum = calculateMaxSum(node.left);
long rightSum = calculateMaxSum(node.right);
long currentMaxSum = node.val + Math.max(0, leftSum) + Math.max(0, rightSum);
nodeSums.put(node, currentMaxSum);
if (currentMaxSum > globalMaxSum) {
globalMaxSum = currentMaxSum;
rootOfMaxSubtree = node;
}
return currentMaxSum;
}
private static void buildOutput(TreeNode node, int index, List<String> output) {
if (node == null) return;
while (output.size() <= index) output.add("null");
output.set(index, String.valueOf(node.val));
if (node.left != null && nodeSums.get(node.left) > 0) {
buildOutput(node.left, 2 * index + 1, output);
}
if (node.right != null && nodeSums.get(node.right) > 0) {
buildOutput(node.right, 2 * index + 2, output);
}
}
}
import json
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
# 全局变量
global_max_sum = -float('inf')
root_of_max_subtree = None
node_sums = {}
def build_tree_from_level_order(arr):
if not arr or arr[0] is None:
return None
nodes = [TreeNode(val) if val is not None else None for val in arr]
for i, node in enumerate(nodes):
if node:
left_idx, right_idx = 2 * i + 1, 2 * i + 2
if left_idx < len(nodes):
node.left = nodes[left_idx]
if right_idx < len(nodes):
node.right = nodes[right_idx]
return nodes[0]
def calculate_max_sum(node):
global global_max_sum, root_of_max_subtree
if not node:
return 0
left_sum = calculate_max_sum(node.left)
right_sum = calculate_max_sum(node.right)
current_max_sum = node.val + max(0, left_sum) + max(0, right_sum)
node_sums[node] = current_max_sum
if current_max_sum > global_max_sum:
global_max_sum = current_max_sum
root_of_max_subtree = node
return current_max_sum
def build_output(node, index, output):
if not node:
return
if index >= len(output):
output.extend([None] * (index - len(output) + 1))
output[index] = node.val
if node.left and node_sums.get(node.left, 0) > 0:
build_output(node.left, 2 * index + 1, output)
if node.right and node_sums.get(node.right, 0) > 0:
build_output(node.right, 2 * index + 2, output)
def solve():
line = input()
# 使用 json 库解析带 null 的数组
# json.loads 可以直接处理 json 中的 null,会自动转为 python 的 None
arr = json.loads(line)
root = build_tree_from_level_order(arr)
if not root:
print("[]")
return
calculate_max_sum(root)
output = []
if root_of_max_subtree:
build_output(root_of_max_subtree, 0, output)
# 删除末尾的 null
while output and output[-1] is None:
output.pop()
# 格式化输出,json.dumps 会自动将 python 的 None 转为 json 的 null
print(json.dumps(output, separators=(',', ':')))
solve()
算法及复杂度
- 算法: 树形动态规划 + 树的遍历
- 时间复杂度:
- 其中
是树中的节点数。需要两次完整的树遍历。
- 空间复杂度:
- 需要
的空间来构建链式树、递归调用栈和存储输出结果的数组。

京公网安备 11010502036488号