题目链接

安保系统最大警戒值

题目描述

安保系统的拓扑结构是一个二叉树,每个传感器(节点)都有一个警戒值。你需要制定一个激活方案,使得总警戒值最大。

规则:如果一个传感器被激活,那么与它直接相连的父节点和所有子节点都必须保持关闭状态。

输入:一个表示二叉树层序遍历的数组,其中 0 代表 null 节点。

解题思路

这是一个经典的树形动态规划问题,也被称为“打家劫舍 III”(House Robber III)。核心思想是对于树中的每一个节点,我们都面临一个选择:激活不激活,并且这个选择会影响其相邻节点的选择。

问题与挑战:栈溢出

一个直接的想法是先根据层序遍历数组构建出真实的树结构,然后用递归的深度优先搜索(DFS)来解决。然而,题目的数据范围 N 高达 10^6。如果输入的树是一个深度很大的链状结构,递归深度也会同样巨大,这将导致栈溢出 (Stack Overflow),从而引发段错误。

优化方案:基于数组的迭代式动态规划

为了避免递归带来的栈溢出问题,我们可以利用输入是层序遍历数组这一特性,直接在数组上进行动态规划,而无需显式地构建树。

  1. 层序数组的性质:在一个从 0 开始索引的层序数组中,索引为 i 的节点的左右子节点分别位于 2*i + 12*i + 2
  2. 自底向上的计算顺序:如果我们从数组的末尾 n-1 向前遍历到 0,就可以保证在计算节点 i 的状态时,其子节点的状态已经被计算出来。这等价于对树进行了一次后序遍历。
  3. DP 状态定义:我们创建一个 DP 数组 dpdp[i] 存储一个包含两个值的状态:
    • dp[i].first (或 dp[i][0]): 不激活索引 i 对应的节点时,其子树能获得的最大警戒值。
    • dp[i].second (或 dp[i][1]): 激活索引 i 对应的节点时,其子树能获得的最大警戒值。
  4. 状态转移方程(在 in-10 的循环中):
    • left_res = dp[2*i + 1], right_res = dp[2*i + 2] (注意处理越界情况)
    • dp[i].second = nums[i] + left_res.first + right_res.first (激活 i,则子节点不能激活)
    • dp[i].first = max(left_res.first, left_res.second) + max(right_res.first, right_res.second) (不激活 i,则子节点可选最优状态)
  5. 最终答案:遍历结束后,max(dp[0].first, dp[0].second) 就是最终答案。

代码

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

using ll = long long;
using pll = pair<ll, ll>;

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    cin >> n;
    if (n == 0) {
        cout << 0 << endl;
        return 0;
    }
    vector<ll> nums(n);
    for (int i = 0; i < n; ++i) {
        cin >> nums[i];
    }

    vector<pll> dp(n, {0, 0});

    for (int i = n - 1; i >= 0; --i) {
        if (nums[i] == 0) {
            continue;
        }

        pll left_res = {0, 0};
        int left_child_idx = 2 * i + 1;
        if (left_child_idx < n) {
            left_res = dp[left_child_idx];
        }

        pll right_res = {0, 0};
        int right_child_idx = 2 * i + 2;
        if (right_child_idx < n) {
            right_res = dp[right_child_idx];
        }

        // 不激活当前节点
        ll not_activated_max = max(left_res.first, left_res.second) + max(right_res.first, right_res.second);
        
        // 激活当前节点
        ll activated_max = nums[i] + left_res.first + right_res.first;

        dp[i] = {not_activated_max, activated_max};
    }

    cout << max(dp[0].first, dp[0].second) << endl;

    return 0;
}
import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        if (n == 0) {
            System.out.println(0);
            return;
        }
        long[] nums = new long[n];
        for (int i = 0; i < n; i++) {
            nums[i] = sc.nextLong();
        }

        // dp[i][0]:不激活, dp[i][1]:激活
        long[][] dp = new long[n][2];

        for (int i = n - 1; i >= 0; i--) {
            if (nums[i] == 0) {
                continue;
            }

            long[] leftRes = {0, 0};
            int leftChildIdx = 2 * i + 1;
            if (leftChildIdx < n) {
                leftRes = dp[leftChildIdx];
            }

            long[] rightRes = {0, 0};
            int rightChildIdx = 2 * i + 2;
            if (rightChildIdx < n) {
                rightRes = dp[rightChildIdx];
            }

            // 不激活当前节点
            dp[i][0] = Math.max(leftRes[0], leftRes[1]) + Math.max(rightRes[0], rightRes[1]);
            
            // 激活当前节点
            dp[i][1] = nums[i] + leftRes[0] + rightRes[0];
        }

        System.out.println(Math.max(dp[0][0], dp[0][1]));
    }
}
def main():
    n = int(input())
    if n == 0:
        print(0)
        return
    
    nums = list(map(int, input().split()))
    
    # dp[i] = (不激活的最大值, 激活的最大值)
    dp = [(0, 0)] * n
    
    for i in range(n - 1, -1, -1):
        if nums[i] == 0:
            continue
            
        left_res = (0, 0)
        left_child_idx = 2 * i + 1
        if left_child_idx < n:
            left_res = dp[left_child_idx]
            
        right_res = (0, 0)
        right_child_idx = 2 * i + 2
        if right_child_idx < n:
            right_res = dp[right_child_idx]
            
        # 不激活当前节点
        not_activated_max = max(left_res) + max(right_res)
        
        # 激活当前节点
        activated_max = nums[i] + left_res[0] + right_res[0]
        
        dp[i] = (not_activated_max, activated_max)
        
    print(max(dp[0]))

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:树形动态规划 (Tree DP) - 迭代实现
  • 时间复杂度:,其中 是传感器的数量。我们只需要对输入数组进行一次反向遍历来填充 DP 数组。
  • 空间复杂度:,主要空间开销来自于存储 DP 数组。