Prompt上下文信息精简

题目分析

给定一棵用层序数组表示的二叉树,每个节点有整数权值(可正、可负、可零)。可以"剪枝"任意子树(将其置为 null),目标是使剩余节点的权值之和最大。输出剪枝后的层序数组(去掉末尾的 null)。

思路

树形DP + 贪心剪枝

核心思想:对每个节点,自底向上计算其子树在最优剪枝下能获得的最大权值和。

定义 为以节点 为根、经过最优剪枝后的子树权值和:

$$

即:如果某个子树的最优和为负,直接剪掉(贡献 0);否则保留。

选择最优根:遍历所有节点,找 最大的节点作为输出子树的根。这样处理了根节点本身为负、但某棵子树更优的情况。

构建输出:从最优根出发 BFS,对每个节点只保留 的子节点。由于输出要求是层序数组格式,需要将子树重新映射到从 0 开始的完全二叉树索引上(根为 0,左孩子为 ,右孩子为 ),中间的空位填 null,末尾 null 去掉。

以示例 [1,-2,3,-4,-5,6,7] 为例:

  • (两个子树都剪掉)
  • (剪掉左子树)
  • 最优根为节点 0,输出 [1,null,3,null,null,6,7]

代码

from collections import deque

def solve():
    line = input().strip()
    inner = line[1:-1].strip()
    if not inner:
        print("[]")
        return

    tokens = [t.strip() for t in inner.split(',')]
    nodes = []
    for t in tokens:
        if t == 'null':
            nodes.append(None)
        else:
            nodes.append(int(t))

    n = len(nodes)
    if n == 0:
        print("[]")
        return

    # 自底向上计算每个节点的最优子树和
    max_sum = [0] * n
    for i in range(n - 1, -1, -1):
        if nodes[i] is None:
            continue
        val = nodes[i]
        left = 2 * i + 1
        right = 2 * i + 2
        ls = max(0, max_sum[left]) if left < n and nodes[left] is not None else 0
        rs = max(0, max_sum[right]) if right < n and nodes[right] is not None else 0
        max_sum[i] = val + ls + rs

    # 找最优根
    best_val = float('-inf')
    best_root = 0
    for i in range(n):
        if nodes[i] is not None and max_sum[i] > best_val:
            best_val = max_sum[i]
            best_root = i

    # BFS 构建输出树,映射到新的层序索引
    tree = {}
    queue = deque()
    queue.append((best_root, 0))

    while queue:
        orig_idx, new_idx = queue.popleft()
        tree[new_idx] = nodes[orig_idx]

        left = 2 * orig_idx + 1
        right = 2 * orig_idx + 2

        if left < n and nodes[left] is not None and max_sum[left] > 0:
            queue.append((left, 2 * new_idx + 1))
        if right < n and nodes[right] is not None and max_sum[right] > 0:
            queue.append((right, 2 * new_idx + 2))

    if not tree:
        print("[]")
        return

    max_idx = max(tree.keys())
    result = []
    for i in range(max_idx + 1):
        if i in tree:
            result.append(str(tree[i]))
        else:
            result.append("null")

    while result and result[-1] == "null":
        result.pop()

    print("[" + ",".join(result) + "]")

solve()

复杂度分析

  • 时间复杂度,其中 为数组长度。自底向上遍历一次计算 ,再遍历一次找最优根,最后 BFS 构建输出,均为线性。
  • 空间复杂度,存储 数组和输出树的字典。