PEEK142 炮兵阵地

题目链接

PEEK142 炮兵阵地

题目描述

在一个 N x M 的网格地图上,'P' 代表平原(可部署),'H' 代表山地(不可部署)。部署一个炮兵后,它会攻击同一行左右各 2 格,以及同一列上下各 2 格。要求任意两个炮兵不能互相攻击,求最多能部署多少个炮兵。

解题思路

本题是又一个经典的棋盘类动态规划问题,解法仍然是状态压缩DP。与“互不侵犯的国王”问题相比,本题的核心区别在于炮兵的攻击范围更远,一个炮兵会影响到其上下两行,这使得 DP 的状态定义和转移变得更加复杂。

核心思想

由于第 i 行的决策会受到第 i-1 行和第 i-2 行的影响,因此我们的 DP 状态必须同时记录前两行的布局信息。

1. 状态表示

我们定义一个三维 DP 数组: dp[i][j][k]

其含义是:在处理完前 i 行后,第 i 行的布局为状态 j,第 i-1 行的布局为状态 k 时,能够部署的炮兵最大数量。

  • i: 当前处理的行号(0-indexed)。
  • j: 一个位掩码,表示第 i 行的布局。
  • k: 一个位掩码,表示第 i-1 行的布局。

2. 状态转移 dp[i][j][k] 的值,可以从所有与 jk 兼容的、合法的 dp[i-1][k][p] 状态转移而来。这里 p 代表了第 i-2 行的布局。转移方程为:

dp[i][j][k] = max(dp[i-1][k][p]) + count(j)

其中,count(j) 是状态 j 中炮兵的数量(即二进制中 1 的个数)。这个转移需要满足一系列的约束条件:

  • 地形约束:炮兵不能部署在山地上。如果 map_state[i] 是第 i 行山地的位掩码,则 (j & map_state[i]) == 0
  • 行内约束:同一行中,任意两个炮兵之间至少要间隔两个空格。位运算表示为 (j & (j << 1)) == 0(j & (j << 2)) == 0
  • 行间约束
    • i 行与第 i-1 行不冲突:(j & k) == 0
    • i 行与第 i-2 行不冲突:(j & p) == 0
    • (第 i-1 行与第 i-2 行不冲突 (k & p) == 0,这个条件在计算 dp[i-1][k][p] 时已经保证了)。

3. 优化与实现

  • 预处理地形:将输入的字符地图 'P''H' 转换为一个整数数组 map_state,其中每个整数是一个位掩码,1 代表山地。
  • 预处理合法布局:我们可以预先找出所有满足行内约束的布局状态,并计算好每个状态的炮兵数量,存入列表 statescounts。这可以显著减少 DP 过程中的状态空间。
  • DP 数组 Padding:为了方便处理边界情况(i=0i=1),我们可以给 DP 数组的行数增加一个 padding,例如定义 dp 表的大小为 (N+2) x |states| x |states|,其中 dp[i+2] 对应地图的第 i 行。这样 i=0 时,它会从 dp[1] 读取,而 dp[1]dp[0] 都是初始的 0,逻辑上是通顺的,无需特殊处理边界。

4. 最终答案

在填充完整个 DP 表格后,最终的答案是 dp[N+1](对应处理完第 N-1 行)中的最大值。

代码

#include <iostream>
#include <vector>
#include <string>
#include <algorithm>

using namespace std;

int countSetBits(int n) {
    int count = 0;
    while (n > 0) {
        n &= (n - 1);
        count++;
    }
    return count;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n, m;
    cin >> n >> m;

    vector<int> map_state(n);
    for (int i = 0; i < n; ++i) {
        string row_str;
        cin >> row_str;
        for (int j = 0; j < m; ++j) {
            if (row_str[j] == 'H') {
                map_state[i] |= (1 << j);
            }
        }
    }

    vector<int> states;
    vector<int> counts;
    for (int i = 0; i < (1 << m); ++i) {
        if (!(i & (i << 1)) && !(i & (i << 2))) {
            states.push_back(i);
            counts.push_back(countSetBits(i));
        }
    }

    int num_states = states.size();
    // dp[i][j][k]: max cannons up to row i-1, row i-1 is states[j], row i-2 is states[k]
    vector<vector<vector<int>>> dp(n + 2, vector<vector<int>>(num_states, vector<int>(num_states, 0)));

    for (int i = 0; i < n; ++i) { // current row i
        for (int j = 0; j < num_states; ++j) { // state for row i
            int s_i = states[j];
            if (s_i & map_state[i]) continue;
            int num_i = counts[j];

            for (int k = 0; k < num_states; ++k) { // state for row i-1
                int s_prev = states[k];
                if (s_i & s_prev) continue;

                for (int p = 0; p < num_states; ++p) { // state for row i-2
                    int s_prev2 = states[p];
                    if (s_i & s_prev2) continue;
                    if (s_prev & s_prev2) continue;
                    
                    dp[i + 2][j][k] = max(dp[i + 2][j][k], dp[i + 1][k][p] + num_i);
                }
            }
        }
    }

    int max_cannons = 0;
    for (int j = 0; j < num_states; ++j) {
        for (int k = 0; k < num_states; ++k) {
            max_cannons = max(max_cannons, dp[n + 1][j][k]);
        }
    }

    cout << max_cannons << endl;

    return 0;
}
import java.util.Scanner;
import java.util.ArrayList;
import java.util.List;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();

        int[] mapState = new int[n];
        for (int i = 0; i < n; i++) {
            String rowStr = sc.next();
            for (int j = 0; j < m; j++) {
                if (rowStr.charAt(j) == 'H') {
                    mapState[i] |= (1 << j);
                }
            }
        }

        List<Integer> states = new ArrayList<>();
        List<Integer> counts = new ArrayList<>();
        for (int i = 0; i < (1 << m); i++) {
            if ((i & (i << 1)) == 0 && (i & (i << 2)) == 0) {
                states.add(i);
                counts.add(Integer.bitCount(i));
            }
        }

        int numStates = states.size();
        int[][][] dp = new int[n + 2][numStates][numStates];

        for (int i = 0; i < n; i++) { // current row i
            for (int j = 0; j < numStates; j++) { // state for row i
                int s_i = states.get(j);
                if ((s_i & mapState[i]) != 0) continue;
                int num_i = counts.get(j);

                for (int k = 0; k < numStates; k++) { // state for row i-1
                    int s_prev = states.get(k);
                    if ((s_i & s_prev) != 0) continue;

                    for (int p = 0; p < numStates; p++) { // state for row i-2
                        int s_prev2 = states.get(p);
                        if ((s_i & s_prev2) != 0) continue;
                        if ((s_prev & s_prev2) != 0) continue;
                        
                        dp[i + 2][j][k] = Math.max(dp[i + 2][j][k], dp[i + 1][k][p] + num_i);
                    }
                }
            }
        }

        int maxCannons = 0;
        for (int j = 0; j < numStates; j++) {
            for (int k = 0; k < numStates; k++) {
                maxCannons = Math.max(maxCannons, dp[n + 1][j][k]);
            }
        }
        System.out.println(maxCannons);
    }
}
import sys

def solve():
    try:
        n, m = map(int, sys.stdin.readline().split())
        map_state = [0] * n
        for i in range(n):
            row_str = sys.stdin.readline().strip()
            for j in range(m):
                if row_str[j] == 'H':
                    map_state[i] |= (1 << j)
    except (IOError, ValueError):
        return

    states = []
    counts = []
    for i in range(1 << m):
        if not (i & (i << 1)) and not (i & (i << 2)):
            states.append(i)
            counts.append(bin(i).count('1'))

    num_states = len(states)
    dp = [[[0] * num_states for _ in range(num_states)] for _ in range(n + 2)]
    
    for i in range(n):  # current row i
        for j in range(num_states):  # state for row i
            s_i = states[j]
            if s_i & map_state[i]:
                continue
            num_i = counts[j]

            for k in range(num_states):  # state for row i-1
                s_prev = states[k]
                if s_i & s_prev:
                    continue

                for p in range(num_states):  # state for row i-2
                    s_prev2 = states[p]
                    if s_i & s_prev2 or s_prev & s_prev2:
                        continue
                    
                    dp[i + 2][j][k] = max(dp[i + 2][j][k], dp[i + 1][k][p] + num_i)

    max_cannons = 0
    for j in range(num_states):
        for k in range(num_states):
            max_cannons = max(max_cannons, dp[n + 1][j][k])
            
    print(max_cannons)

solve()

算法及复杂度

  • 算法:状态压缩动态规划 (State Compression DP)。
  • 时间复杂度,其中 N 是行数,|S| 是满足行内约束的合法状态数量。对于 M=10|S| 的数量为 60,这个复杂度是可以接受的。
  • 空间复杂度,用于存储 DP 状态数组。可以使用滚动数组将空间优化到