题目链接
题目描述
给定一个 的由非负整数构成的数字矩阵,你需要在其中取出若干个数字,使得取出的任意两个数字不相邻(若一个数字在另外一个数字相邻的8个格子中的一个,即认为这两个数字相邻),求取出数字和的最大值。
解题思路
这是一个在网格中选取不相邻元素以获取最大和的经典问题。由于矩阵的维度非常小(),这提示我们可以使用状态压缩动态规划(Bitmask DP)来解决。
1. 核心思想与状态定义
我们逐行处理矩阵,计算每一行在不同选取方案下的最大和。
我们定义一个二维DP数组:。
-
:代表当前处理到第
行(从0开始)。
-
:一个整数,它的二进制表示法代表了第
行的选取状态。如果
的第
位是
,表示我们选取了第
行第
列的数字;如果是
,则不选取。
-
的值:表示处理完前
行(即第0到i行),并且第
行的选取状态恰好为
时,所能获得的最大数字和。
2. 状态转移
为了计算 ,我们需要考虑它能从上一行(第
行)的哪些状态转移而来。
这里的 是第
行按照
状态选取的数字之和。
操作则需要遍历上一行所有可能的选取状态
,但前提是
和
必须是兼容的。
3. 兼容性判断
兼容性包含两个层面:
A. 行内兼容性
一个 本身必须是合法的。根据题意,同一行内不能选取相邻的数字。
这意味着 的二进制表示中,不能有两个相邻的
。
用位运算可以高效地检查:。
B. 行间兼容性
第 行的状态
必须与第
行的状态
兼容。
如果我们在第 行选取了第
列的数字(即
的第
位为1),那么在第
行,第
,
,
列的数字都不能被选取。
这可以用三个位运算条件来概括:
-
(正上方不相邻)
-
(左上方不相邻)
-
(右上方不相邻)
只有同时满足这两个层面兼容性的 和
才能进行状态转移。
4. 算法流程
-
初始化 (Base Case): 处理第
行。对于所有满足行内兼容性的
,直接计算
,其中
是第
行按
选取数字的和。
-
递推: 从第
行开始,遍历到第
行: 对于每个
从
到
: 对于每个满足行内兼容性的
:
对于每个满足行内兼容性的
: 如果
和
满足行间兼容性:
-
最终答案: 遍历
表的最后一行,即
,其中的最大值就是最终答案。不要忘记,如果不选取任何数,和为0,所以最终答案至少为0。
代码
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
void solve() {
int R, C;
cin >> R >> C;
vector<vector<int>> grid(R, vector<int>(C));
for (int i = 0; i < R; ++i) {
for (int j = 0; j < C; ++j) {
cin >> grid[i][j];
}
}
int max_mask = 1 << C;
vector<vector<int>> dp(R, vector<int>(max_mask, 0));
// Base case: 第 0 行
for (int mask = 0; mask < max_mask; ++mask) {
if ((mask & (mask << 1)) == 0) { // 行内兼容
int current_sum = 0;
for (int j = 0; j < C; ++j) {
if ((mask >> j) & 1) {
current_sum += grid[0][j];
}
}
dp[0][mask] = current_sum;
}
}
// DP 递推
for (int i = 1; i < R; ++i) {
for (int mask_i = 0; mask_i < max_mask; ++mask_i) {
if ((mask_i & (mask_i << 1)) == 0) { // 当前行 mask_i 必须合法
int current_sum = 0;
for (int j = 0; j < C; ++j) {
if ((mask_i >> j) & 1) {
current_sum += grid[i][j];
}
}
int max_prev_sum = 0;
for (int mask_prev = 0; mask_prev < max_mask; ++mask_prev) {
// 检查行间兼容性
if ((mask_i & mask_prev) == 0 &&
(mask_i & (mask_prev << 1)) == 0 &&
(mask_i & (mask_prev >> 1)) == 0) {
max_prev_sum = max(max_prev_sum, dp[i-1][mask_prev]);
}
}
dp[i][mask_i] = current_sum + max_prev_sum;
}
}
}
int ans = 0;
for (int mask = 0; mask < max_mask; ++mask) {
ans = max(ans, dp[R - 1][mask]);
}
cout << ans << endl;
}
int main() {
int T;
cin >> T;
while (T--) {
solve();
}
return 0;
}
import java.util.Scanner;
import java.util.Arrays;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int T = sc.nextInt();
while (T-- > 0) {
solve(sc);
}
}
private static void solve(Scanner sc) {
int R = sc.nextInt();
int C = sc.nextInt();
int[][] grid = new int[R][C];
for (int i = 0; i < R; i++) {
for (int j = 0; j < C; j++) {
grid[i][j] = sc.nextInt();
}
}
int maxMask = 1 << C;
int[][] dp = new int[R][maxMask];
// Base case: 第 0 行
for (int mask = 0; mask < maxMask; mask++) {
if ((mask & (mask << 1)) == 0) { // 行内兼容
int currentSum = 0;
for (int j = 0; j < C; j++) {
if (((mask >> j) & 1) == 1) {
currentSum += grid[0][j];
}
}
dp[0][mask] = currentSum;
}
}
// DP 递推
for (int i = 1; i < R; i++) {
for (int maskI = 0; maskI < maxMask; maskI++) {
if ((maskI & (maskI << 1)) == 0) { // 当前行 maskI 必须合法
int currentSum = 0;
for (int j = 0; j < C; j++) {
if (((maskI >> j) & 1) == 1) {
currentSum += grid[i][j];
}
}
int maxPrevSum = 0;
for (int maskPrev = 0; maskPrev < maxMask; maskPrev++) {
// 检查行间兼容性
if ((maskI & maskPrev) == 0 &&
(maskI & (maskPrev << 1)) == 0 &&
(maskI & (maskPrev >> 1)) == 0) {
maxPrevSum = Math.max(maxPrevSum, dp[i - 1][maskPrev]);
}
}
dp[i][maskI] = currentSum + maxPrevSum;
}
}
}
int ans = 0;
for (int mask = 0; mask < maxMask; mask++) {
ans = Math.max(ans, dp[R - 1][mask]);
}
System.out.println(ans);
}
}
def solve():
R, C = map(int, input().split())
grid = []
for _ in range(R):
grid.append(list(map(int, input().split())))
max_mask = 1 << C
dp = [[0] * max_mask for _ in range(R)]
# Base case: 第 0 行
for mask in range(max_mask):
if (mask & (mask << 1)) == 0: # 行内兼容
current_sum = 0
for j in range(C):
if (mask >> j) & 1:
current_sum += grid[0][j]
dp[0][mask] = current_sum
# DP 递推
for i in range(1, R):
for mask_i in range(max_mask):
if (mask_i & (mask_i << 1)) == 0: # 当前行 mask_i 必须合法
current_sum = 0
for j in range(C):
if (mask_i >> j) & 1:
current_sum += grid[i][j]
max_prev_sum = 0
for mask_prev in range(max_mask):
# 检查行间兼容性
if (mask_i & mask_prev) == 0 and \
(mask_i & (mask_prev << 1)) == 0 and \
(mask_i & (mask_prev >> 1)) == 0:
max_prev_sum = max(max_prev_sum, dp[i-1][mask_prev])
dp[i][mask_i] = current_sum + max_prev_sum
ans = 0
if R > 0:
ans = max(dp[R - 1])
print(ans)
T = int(input())
for _ in range(T):
solve()
算法及复杂度
-
算法:状态压缩动态规划 (Bitmask DP)
-
时间复杂度:
。
状态共有
个。计算每个状态需要遍历上一行的所有
个状态。因此总复杂度为
。由于
,这是完全可以接受的。
-
空间复杂度:
。用于存储
表。可以优化到
,因为计算第
行的状态只需要第
行的信息。