题目链接

取数游戏

题目描述

给定一个 的由非负整数构成的数字矩阵,你需要在其中取出若干个数字,使得取出的任意两个数字不相邻(若一个数字在另外一个数字相邻的8个格子中的一个,即认为这两个数字相邻),求取出数字和的最大值。

解题思路

这是一个在网格中选取不相邻元素以获取最大和的经典问题。由于矩阵的维度非常小(),这提示我们可以使用状态压缩动态规划(Bitmask DP)来解决。

1. 核心思想与状态定义

我们逐行处理矩阵,计算每一行在不同选取方案下的最大和。

我们定义一个二维DP数组:

  • :代表当前处理到第 行(从0开始)。

  • :一个整数,它的二进制表示法代表了第 行的选取状态。如果 的第 位是 ,表示我们选取了第 行第 列的数字;如果是 ,则不选取。

  • 的值:表示处理完前 行(即第0到i行),并且第 行的选取状态恰好为 时,所能获得的最大数字和。

2. 状态转移

为了计算 ,我们需要考虑它能从上一行(第 行)的哪些状态转移而来。

这里的 是第 行按照 状态选取的数字之和。

操作则需要遍历上一行所有可能的选取状态 ,但前提是 必须是兼容的。

3. 兼容性判断

兼容性包含两个层面:

A. 行内兼容性

一个 本身必须是合法的。根据题意,同一行内不能选取相邻的数字。

这意味着 的二进制表示中,不能有两个相邻的

用位运算可以高效地检查:

B. 行间兼容性

行的状态 必须与第 行的状态 兼容。

如果我们在第 行选取了第 列的数字(即 的第 位为1),那么在第 行,第 , , 列的数字都不能被选取。

这可以用三个位运算条件来概括:

  1. (正上方不相邻)

  2. (左上方不相邻)

  3. (右上方不相邻)

只有同时满足这两个层面兼容性的 才能进行状态转移。

4. 算法流程

  1. 初始化 (Base Case): 处理第 行。对于所有满足行内兼容性的 ,直接计算 ,其中 是第 行按 选取数字的和。

  2. 递推: 从第 行开始,遍历到第 行: 对于每个 : 对于每个满足行内兼容性的 对于每个满足行内兼容性的 : 如果 满足行间兼容性:

  3. 最终答案: 遍历 表的最后一行,即 ,其中的最大值就是最终答案。不要忘记,如果不选取任何数,和为0,所以最终答案至少为0。

代码

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

void solve() {
    int R, C;
    cin >> R >> C;
    vector<vector<int>> grid(R, vector<int>(C));
    for (int i = 0; i < R; ++i) {
        for (int j = 0; j < C; ++j) {
            cin >> grid[i][j];
        }
    }

    int max_mask = 1 << C;
    vector<vector<int>> dp(R, vector<int>(max_mask, 0));

    // Base case: 第 0 行
    for (int mask = 0; mask < max_mask; ++mask) {
        if ((mask & (mask << 1)) == 0) { // 行内兼容
            int current_sum = 0;
            for (int j = 0; j < C; ++j) {
                if ((mask >> j) & 1) {
                    current_sum += grid[0][j];
                }
            }
            dp[0][mask] = current_sum;
        }
    }

    // DP 递推
    for (int i = 1; i < R; ++i) {
        for (int mask_i = 0; mask_i < max_mask; ++mask_i) {
            if ((mask_i & (mask_i << 1)) == 0) { // 当前行 mask_i 必须合法
                int current_sum = 0;
                for (int j = 0; j < C; ++j) {
                    if ((mask_i >> j) & 1) {
                        current_sum += grid[i][j];
                    }
                }

                int max_prev_sum = 0;
                for (int mask_prev = 0; mask_prev < max_mask; ++mask_prev) {
                    // 检查行间兼容性
                    if ((mask_i & mask_prev) == 0 &&
                        (mask_i & (mask_prev << 1)) == 0 &&
                        (mask_i & (mask_prev >> 1)) == 0) {
                        max_prev_sum = max(max_prev_sum, dp[i-1][mask_prev]);
                    }
                }
                dp[i][mask_i] = current_sum + max_prev_sum;
            }
        }
    }

    int ans = 0;
    for (int mask = 0; mask < max_mask; ++mask) {
        ans = max(ans, dp[R - 1][mask]);
    }
    cout << ans << endl;
}

int main() {
    int T;
    cin >> T;
    while (T--) {
        solve();
    }
    return 0;
}
import java.util.Scanner;
import java.util.Arrays;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int T = sc.nextInt();
        while (T-- > 0) {
            solve(sc);
        }
    }

    private static void solve(Scanner sc) {
        int R = sc.nextInt();
        int C = sc.nextInt();
        int[][] grid = new int[R][C];
        for (int i = 0; i < R; i++) {
            for (int j = 0; j < C; j++) {
                grid[i][j] = sc.nextInt();
            }
        }

        int maxMask = 1 << C;
        int[][] dp = new int[R][maxMask];

        // Base case: 第 0 行
        for (int mask = 0; mask < maxMask; mask++) {
            if ((mask & (mask << 1)) == 0) { // 行内兼容
                int currentSum = 0;
                for (int j = 0; j < C; j++) {
                    if (((mask >> j) & 1) == 1) {
                        currentSum += grid[0][j];
                    }
                }
                dp[0][mask] = currentSum;
            }
        }

        // DP 递推
        for (int i = 1; i < R; i++) {
            for (int maskI = 0; maskI < maxMask; maskI++) {
                if ((maskI & (maskI << 1)) == 0) { // 当前行 maskI 必须合法
                    int currentSum = 0;
                    for (int j = 0; j < C; j++) {
                        if (((maskI >> j) & 1) == 1) {
                            currentSum += grid[i][j];
                        }
                    }

                    int maxPrevSum = 0;
                    for (int maskPrev = 0; maskPrev < maxMask; maskPrev++) {
                        // 检查行间兼容性
                        if ((maskI & maskPrev) == 0 &&
                            (maskI & (maskPrev << 1)) == 0 &&
                            (maskI & (maskPrev >> 1)) == 0) {
                            maxPrevSum = Math.max(maxPrevSum, dp[i - 1][maskPrev]);
                        }
                    }
                    dp[i][maskI] = currentSum + maxPrevSum;
                }
            }
        }

        int ans = 0;
        for (int mask = 0; mask < maxMask; mask++) {
            ans = Math.max(ans, dp[R - 1][mask]);
        }
        System.out.println(ans);
    }
}
def solve():
    R, C = map(int, input().split())
    grid = []
    for _ in range(R):
        grid.append(list(map(int, input().split())))

    max_mask = 1 << C
    dp = [[0] * max_mask for _ in range(R)]

    # Base case: 第 0 行
    for mask in range(max_mask):
        if (mask & (mask << 1)) == 0:  # 行内兼容
            current_sum = 0
            for j in range(C):
                if (mask >> j) & 1:
                    current_sum += grid[0][j]
            dp[0][mask] = current_sum

    # DP 递推
    for i in range(1, R):
        for mask_i in range(max_mask):
            if (mask_i & (mask_i << 1)) == 0:  # 当前行 mask_i 必须合法
                current_sum = 0
                for j in range(C):
                    if (mask_i >> j) & 1:
                        current_sum += grid[i][j]

                max_prev_sum = 0
                for mask_prev in range(max_mask):
                    # 检查行间兼容性
                    if (mask_i & mask_prev) == 0 and \
                       (mask_i & (mask_prev << 1)) == 0 and \
                       (mask_i & (mask_prev >> 1)) == 0:
                        max_prev_sum = max(max_prev_sum, dp[i-1][mask_prev])
                
                dp[i][mask_i] = current_sum + max_prev_sum
    
    ans = 0
    if R > 0:
        ans = max(dp[R - 1])
    
    print(ans)


T = int(input())
for _ in range(T):
    solve()

算法及复杂度

  • 算法:状态压缩动态规划 (Bitmask DP)

  • 时间复杂度状态共有 个。计算每个状态需要遍历上一行的所有 个状态。因此总复杂度为 。由于 ,这是完全可以接受的。

  • 空间复杂度。用于存储表。可以优化到 ,因为计算第 行的状态只需要第 行的信息。