题目链接

Prompt上下文信息精简

题目描述

在一次 Prompt 工程实践中,我们把一段 token 序列抽象成一棵二叉树。树中每个结点都有一个整数权值(可正可负,也可能为 0)。请在这棵树中选出一棵“价值最大”的子树,并把这棵子树按“完全二叉树的层序数组”形式输出。

  • 子树的价值:定义为它所包含的所有结点权值之和。
  • 剪枝规则:允许对某个结点“剪掉”对总和贡献为负的整棵子树(即可以只要左子树、或只要右子树、或两者都要;被剪掉的位置在输出中以 null 占位)。
  • 输入/输出:输入是一棵“用层序数组表示的完全二叉树”,缺失位置用 null 占位;输出也使用相同规则表示挑选出的那棵最优子树,并且去除末尾多余的尾部 null

输入描述

  • 一行:用方括号包裹的一维数组,表示树的层序遍历;缺失结点用 null。例如:[5,-1,3,null,null,4,7]
  • 约定:数组下标从 0 开始;对下标 i,左孩子为 2*i+1,右孩子为 2*i+2

输出描述

  • 一行:选择的“最大和子树”的层序数组表示(仍以 null 占位),并删除末尾无意义的连续 null

解题思路

本题的核心是寻找一棵二叉树中“价值最大”的子树。这里的“子树”定义比较特殊:它允许我们从任意节点开始,向下延伸时可以选择性地剪掉那些总和为负的左子树或右子树,以最大化当前子树的总和。最终,我们需要找到全局的最大和,并输出形成这个最大和的树结构。

这个问题可以通过一次后序遍历树形动态规划来解决。我们需要从下往上计算信息,为每个节点做出决策。

1. 状态定义与计算

对于树中的任意一个节点 node,我们需要计算以它为根的、经过“剪枝”后可能达到的最大价值。这个值就是 node 自身的权值,加上其左、右子树经过“剪枝”后提供的正贡献 同时,我们需要一个全局变量 global_max_sum 来记录在整棵树中遇到的所有 max_sum_at_node 的最大值,因为最优子树的根不一定是整棵树的根。

2. 第一次遍历 (后序遍历 - 计算最大和)

我们可以定义一个递归函数,比如 calculate_max_sum(node),它执行后序遍历:

  1. 基本情况: 如果 node 为空,返回 0。
  2. 递归: 分别递归计算左、右子树的最大价值 left_max_sumright_max_sum
  3. 计算当前节点的最大价值: 根据上述公式计算 current_max_sum
  4. 更新全局最大值: 用 current_max_sum 更新 global_max_sum,并记录下是哪个节点(root_of_max_subtree)产生了全局最大值。
  5. 返回值: 函数返回 current_max_sum,供其父节点使用。

3. 第二次遍历 (前序遍历 - 构建输出树)

在第一次遍历之后,我们已经知道了全局最大和以及最优子树的根节点。现在我们需要根据这个根节点,构建出最终要输出的树结构。

我们可以定义另一个递归函数,比如 build_output_tree(original_node, output_node_index, output_array),它执行前序遍历:

  1. root_of_max_subtree 开始。
  2. 将当前 original_node 的值放入 output_arrayoutput_node_index 位置。
  3. 决策剪枝:
    • 检查其左子树。在第一次遍历中,我们可以用一个哈希表记录下每个节点计算出的 max_sum。如果左子树的 max_sum > 0,说明它有正贡献,我们保留它并递归构建。否则,我们“剪掉”它。
    • 对右子树做同样的操作。
  4. 这个过程会根据第一次遍历的计算结果,在 output_array 中逐步构建出被剪枝后的树的层序表示。

4. 输入与输出处理

  • 输入: 输入是一个表示完全二叉树的层序数组。我们需要先将这个数组转换成我们熟悉的链式节点表示的树结构,以便于进行递归遍历。
  • 输出: 第二次遍历生成的 output_array 就是最终的结果,但需要移除末尾连续的 null

代码

#include <iostream>
#include <vector>
#include <string>
#include <sstream>
#include <algorithm>
#include <map>
#include <climits>

using namespace std;

// 树节点
struct TreeNode {
    int val;
    TreeNode *left;
    TreeNode *right;
    TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
};

// 全局变量来跟踪最大和及其根节点
long long global_max_sum = LLONG_MIN;
TreeNode* root_of_max_subtree = nullptr;
map<TreeNode*, long long> node_sums; // 存储每个节点为根的最大子树和

// 从层序数组构建树
TreeNode* build_tree_from_level_order(const vector<string>& arr) {
    if (arr.empty() || arr[0] == "null") {
        return nullptr;
    }
    vector<TreeNode*> nodes;
    for (const string& s : arr) {
        if (s == "null") {
            nodes.push_back(nullptr);
        } else {
            nodes.push_back(new TreeNode(stoi(s)));
        }
    }
    for (int i = 0; i < nodes.size(); ++i) {
        if (nodes[i]) {
            int left_idx = 2 * i + 1;
            int right_idx = 2 * i + 2;
            if (left_idx < nodes.size()) nodes[i]->left = nodes[left_idx];
            if (right_idx < nodes.size()) nodes[i]->right = nodes[right_idx];
        }
    }
    return nodes[0];
}

// 后序遍历计算最大子树和
long long calculate_max_sum(TreeNode* node) {
    if (!node) return 0;

    long long left_sum = calculate_max_sum(node->left);
    long long right_sum = calculate_max_sum(node->right);

    long long current_max_sum = node->val + max(0LL, left_sum) + max(0LL, right_sum);
    
    node_sums[node] = current_max_sum;

    if (current_max_sum > global_max_sum) {
        global_max_sum = current_max_sum;
        root_of_max_subtree = node;
    }
    return current_max_sum;
}

// 前序遍历构建输出数组
void build_output(TreeNode* node, int index, vector<string>& output) {
    if (!node) return;
    if (index >= output.size()) output.resize(index + 1, "null");
    output[index] = to_string(node->val);
    
    if (node->left && node_sums[node->left] > 0) {
        build_output(node->left, 2 * index + 1, output);
    }
    if (node->right && node_sums[node->right] > 0) {
        build_output(node->right, 2 * index + 2, output);
    }
}

int main() {
    string line;
    getline(cin, line);
    line = line.substr(1, line.length() - 2); // 去掉方括号
    replace(line.begin(), line.end(), ',', ' ');
    
    stringstream ss(line);
    string item;
    vector<string> arr;
    while (ss >> item) {
        arr.push_back(item);
    }

    TreeNode* root = build_tree_from_level_order(arr);
    
    if(!root) {
        cout << "[]" << endl;
        return 0;
    }

    calculate_max_sum(root);

    vector<string> output;
    if (root_of_max_subtree) {
        build_output(root_of_max_subtree, 0, output);
    }
    
    // 删除末尾的 null
    while (!output.empty() && output.back() == "null") {
        output.pop_back();
    }

    cout << "[";
    for (int i = 0; i < output.size(); ++i) {
        cout << output[i] << (i == output.size() - 1 ? "" : ",");
    }
    cout << "]" << endl;

    return 0;
}
import java.util.*;

class TreeNode {
    int val;
    TreeNode left;
    TreeNode right;
    TreeNode(int val) { this.val = val; }
}

public class Main {
    static long globalMaxSum = Long.MIN_VALUE;
    static TreeNode rootOfMaxSubtree = null;
    static Map<TreeNode, Long> nodeSums = new HashMap<>();

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        String line = sc.nextLine();
        line = line.substring(1, line.length() - 1);
        if (line.isEmpty()) {
            System.out.println("[]");
            return;
        }
        String[] parts = line.split(",");
        
        TreeNode root = buildTreeFromLevelOrder(parts);

        calculateMaxSum(root);

        List<String> output = new ArrayList<>();
        if (rootOfMaxSubtree != null) {
            buildOutput(rootOfMaxSubtree, 0, output);
        }

        while (!output.isEmpty() && output.get(output.size() - 1).equals("null")) {
            output.remove(output.size() - 1);
        }

        System.out.println("[" + String.join(",", output) + "]");
    }

    private static TreeNode buildTreeFromLevelOrder(String[] arr) {
        if (arr.length == 0 || arr[0].equals("null")) return null;
        
        List<TreeNode> nodes = new ArrayList<>();
        for (String s : arr) {
            nodes.add(s.equals("null") ? null : new TreeNode(Integer.parseInt(s)));
        }

        for (int i = 0; i < nodes.size(); i++) {
            if (nodes.get(i) != null) {
                int leftIdx = 2 * i + 1;
                int rightIdx = 2 * i + 2;
                if (leftIdx < nodes.size()) nodes.get(i).left = nodes.get(leftIdx);
                if (rightIdx < nodes.size()) nodes.get(i).right = nodes.get(rightIdx);
            }
        }
        return nodes.get(0);
    }

    private static long calculateMaxSum(TreeNode node) {
        if (node == null) return 0;

        long leftSum = calculateMaxSum(node.left);
        long rightSum = calculateMaxSum(node.right);

        long currentMaxSum = node.val + Math.max(0, leftSum) + Math.max(0, rightSum);
        nodeSums.put(node, currentMaxSum);

        if (currentMaxSum > globalMaxSum) {
            globalMaxSum = currentMaxSum;
            rootOfMaxSubtree = node;
        }
        return currentMaxSum;
    }

    private static void buildOutput(TreeNode node, int index, List<String> output) {
        if (node == null) return;
        while (output.size() <= index) output.add("null");
        output.set(index, String.valueOf(node.val));

        if (node.left != null && nodeSums.get(node.left) > 0) {
            buildOutput(node.left, 2 * index + 1, output);
        }
        if (node.right != null && nodeSums.get(node.right) > 0) {
            buildOutput(node.right, 2 * index + 2, output);
        }
    }
}
import json

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

# 全局变量
global_max_sum = -float('inf')
root_of_max_subtree = None
node_sums = {}

def build_tree_from_level_order(arr):
    if not arr or arr[0] is None:
        return None
    
    nodes = [TreeNode(val) if val is not None else None for val in arr]
    for i, node in enumerate(nodes):
        if node:
            left_idx, right_idx = 2 * i + 1, 2 * i + 2
            if left_idx < len(nodes):
                node.left = nodes[left_idx]
            if right_idx < len(nodes):
                node.right = nodes[right_idx]
    return nodes[0]

def calculate_max_sum(node):
    global global_max_sum, root_of_max_subtree
    if not node:
        return 0
    
    left_sum = calculate_max_sum(node.left)
    right_sum = calculate_max_sum(node.right)
    
    current_max_sum = node.val + max(0, left_sum) + max(0, right_sum)
    node_sums[node] = current_max_sum
    
    if current_max_sum > global_max_sum:
        global_max_sum = current_max_sum
        root_of_max_subtree = node
        
    return current_max_sum

def build_output(node, index, output):
    if not node:
        return
    
    if index >= len(output):
        output.extend([None] * (index - len(output) + 1))
    output[index] = node.val
    
    if node.left and node_sums.get(node.left, 0) > 0:
        build_output(node.left, 2 * index + 1, output)
    if node.right and node_sums.get(node.right, 0) > 0:
        build_output(node.right, 2 * index + 2, output)

def solve():
    line = input()
    # 使用 json 库解析带 null 的数组
    # json.loads 可以直接处理 json 中的 null,会自动转为 python 的 None
    arr = json.loads(line)
    
    root = build_tree_from_level_order(arr)
    
    if not root:
        print("[]")
        return

    calculate_max_sum(root)
    
    output = []
    if root_of_max_subtree:
        build_output(root_of_max_subtree, 0, output)
        
    # 删除末尾的 null
    while output and output[-1] is None:
        output.pop()
        
    # 格式化输出,json.dumps 会自动将 python 的 None 转为 json 的 null
    print(json.dumps(output, separators=(',', ':')))

solve()

算法及复杂度

  • 算法: 树形动态规划 + 树的遍历
  • 时间复杂度: - 其中 是树中的节点数。需要两次完整的树遍历。
  • 空间复杂度: - 需要 的空间来构建链式树、递归调用栈和存储输出结果的数组。