安保系统最大警戒值

题意

一座建筑的安保系统是一棵二叉树,每个节点是一个传感器,有一个警戒值。规则:如果激活某个传感器,它的所有相邻节点(父节点和子节点)都不能激活。求能获得的最大警戒值总和。

输入是二叉树的层序遍历数组,不存在的节点用 0 占位。

思路

经典的"树上打家劫舍"问题。关键约束是:选了一个节点,就不能选它的父/子节点——也就是说,树上相邻的两个节点不能同时选。

怎么做?自然想到树形 DP。

对每个节点,只有两种状态:选它或者不选它

  • 选当前节点:那它的两个孩子一定不能选,所以收益 = 当前节点值 + 左孩子"不选"的最优解 + 右孩子"不选"的最优解。
  • 不选当前节点:两个孩子可选可不选,取各自最优,所以收益 = max(左孩子选, 左孩子不选) + max(右孩子选, 右孩子不选)。

定义 为选节点 时以 为根的子树最大收益, 为不选节点 时的最大收益,转移方程:

$$

$$

答案就是

实现上有个小技巧: 题目给的是层序遍历数组,下标 的左孩子是 ,右孩子是 。这意味着我们不需要真的建树——直接在数组上从后往前扫一遍就行了,因为下标大的一定是下标小的后代,先处理下标大的节点就等于先处理了孩子。

拿样例验证一下:树的根是 78,左子树有 44、98、54 等,右子树有 73、51、87 等。选 98 + 53 + 87 + 73 + 40 + 40... 手算一下最优选法能得到 391,和答案一致。

复杂度

  • 时间:,每个节点只处理一次
  • 空间:,存储 DP 数组

代码

#include <bits/stdc++.h>
using namespace std;

int main(){
    int n;
    scanf("%d", &n);
    vector<int> arr(n);
    for(int i = 0; i < n; i++) scanf("%d", &arr[i]);

    vector<bool> exists(n, false);
    for(int i = 0; i < n; i++)
        if(arr[i] != 0) exists[i] = true;
    exists[0] = true;

    vector<int> L(n, -1), R(n, -1);
    for(int i = 0; i < n; i++){
        if(!exists[i]) continue;
        int l = 2*i+1, r = 2*i+2;
        if(l < n && exists[l]) L[i] = l;
        if(r < n && exists[r]) R[i] = r;
    }

    vector<long long> rob(n, 0), notRob(n, 0);
    for(int i = n-1; i >= 0; i--){
        if(!exists[i]) continue;
        long long lR = 0, lN = 0, rR = 0, rN = 0;
        if(L[i] != -1){ lR = rob[L[i]]; lN = notRob[L[i]]; }
        if(R[i] != -1){ rR = rob[R[i]]; rN = notRob[R[i]]; }
        rob[i] = arr[i] + lN + rN;
        notRob[i] = max(lR, lN) + max(rR, rN);
    }

    printf("%lld\n", max(rob[0], notRob[0]));
    return 0;
}
import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int[] arr = new int[n];
        for (int i = 0; i < n; i++) arr[i] = sc.nextInt();

        boolean[] exists = new boolean[n];
        for (int i = 0; i < n; i++)
            if (arr[i] != 0) exists[i] = true;
        exists[0] = true;

        int[] L = new int[n], R = new int[n];
        Arrays.fill(L, -1);
        Arrays.fill(R, -1);
        for (int i = 0; i < n; i++) {
            if (!exists[i]) continue;
            int l = 2*i+1, r = 2*i+2;
            if (l < n && exists[l]) L[i] = l;
            if (r < n && exists[r]) R[i] = r;
        }

        long[] rob = new long[n], notRob = new long[n];
        for (int i = n-1; i >= 0; i--) {
            if (!exists[i]) continue;
            long lR = 0, lN = 0, rR = 0, rN = 0;
            if (L[i] != -1) { lR = rob[L[i]]; lN = notRob[L[i]]; }
            if (R[i] != -1) { rR = rob[R[i]]; rN = notRob[R[i]]; }
            rob[i] = arr[i] + lN + rN;
            notRob[i] = Math.max(lR, lN) + Math.max(rR, rN);
        }

        System.out.println(Math.max(rob[0], notRob[0]));
    }
}
import sys
input = sys.stdin.readline

def main():
    n = int(input())
    arr = list(map(int, input().split()))

    exists = [False] * n
    for i in range(n):
        if arr[i] != 0:
            exists[i] = True
    exists[0] = True

    L = [-1] * n
    R = [-1] * n
    for i in range(n):
        if not exists[i]:
            continue
        l, r = 2*i+1, 2*i+2
        if l < n and exists[l]:
            L[i] = l
        if r < n and exists[r]:
            R[i] = r

    rob = [0] * n
    not_rob = [0] * n
    for i in range(n-1, -1, -1):
        if not exists[i]:
            continue
        lR = lN = rR = rN = 0
        if L[i] != -1:
            lR, lN = rob[L[i]], not_rob[L[i]]
        if R[i] != -1:
            rR, rN = rob[R[i]], not_rob[R[i]]
        rob[i] = arr[i] + lN + rN
        not_rob[i] = max(lR, lN) + max(rR, rN)

    print(max(rob[0], not_rob[0]))

main()
const readline = require('readline');
const rl = readline.createInterface({ input: process.stdin });
const lines = [];
rl.on('line', line => lines.push(line.trim()));
rl.on('close', () => {
    const n = parseInt(lines[0]);
    const arr = lines[1].split(' ').map(Number);

    const exists = new Array(n).fill(false);
    for (let i = 0; i < n; i++)
        if (arr[i] !== 0) exists[i] = true;
    exists[0] = true;

    const L = new Array(n).fill(-1);
    const R = new Array(n).fill(-1);
    for (let i = 0; i < n; i++) {
        if (!exists[i]) continue;
        const l = 2*i+1, r = 2*i+2;
        if (l < n && exists[l]) L[i] = l;
        if (r < n && exists[r]) R[i] = r;
    }

    const rob = new Array(n).fill(0);
    const notRob = new Array(n).fill(0);
    for (let i = n-1; i >= 0; i--) {
        if (!exists[i]) continue;
        let lR = 0, lN = 0, rR = 0, rN = 0;
        if (L[i] !== -1) { lR = rob[L[i]]; lN = notRob[L[i]]; }
        if (R[i] !== -1) { rR = rob[R[i]]; rN = notRob[R[i]]; }
        rob[i] = arr[i] + lN + rN;
        notRob[i] = Math.max(lR, lN) + Math.max(rR, rN);
    }

    console.log(Math.max(rob[0], notRob[0]));
});