题目链接
题目描述
给定 组数据。每组数据给出一个由非负整数构成的
数字矩阵。需要从矩阵中选出若干个数,使得任意两个被选中的位置在八连通意义下不相邻(即上下左右与四个对角方向均不能相邻)。求所有被选数字之和的最大值。
输入:
- 第一行一个正整数
- 对于每组数据:第一行两个正整数
、
;接下来
行,每行
个非负整数
输出:
- 对于每组数据,输出一个非负整数,表示最大和
解题思路
将问题转化为“逐行状态压缩 DP”:
- 用长度为
的二进制掩码
表示一行中被选的位置;
合法需满足同一行内不相邻:
。
- 跨行时,上一行掩码
与当前行掩码
应满足八连通不相邻:
- 垂直/重叠:
- 左上/右下对角:
- 右上/左下对角:
- 垂直/重叠:
- 预处理每行在掩码
下的取值和
。状态转移:
- 初始为第
行各合法掩码的行和,答案为最后一行所有
的最大值。
实现要点:
- 预枚举所有合法掩码列表以及它们之间的兼容关系列表,转移更快。
- 使用 64 位整型存储答案(行内和与最终和可能较大)。
代码
#include <bits/stdc++.h>
using namespace std;
using int64 = long long;
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int T;
if (!(cin >> T)) return 0;
while (T--) {
int n, m;
cin >> n >> m;
vector<vector<int64>> a(n, vector<int64>(m));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < m; ++j) cin >> a[i][j];
}
// 枚举合法行掩码
vector<int> masks;
int full = 1 << m;
for (int s = 0; s < full; ++s) {
if (s & (s << 1)) continue; // 行内相邻
masks.push_back(s);
}
int S = (int)masks.size();
// 预处理兼容列表
vector<vector<int>> compat(S);
for (int i = 0; i < S; ++i) {
int p = masks[i];
for (int j = 0; j < S; ++j) {
int s = masks[j];
if ((s & p) == 0) {
int lp = (p << 1) & (full - 1);
int rp = (p >> 1);
if ((s & lp) == 0 && (s & rp) == 0) compat[i].push_back(j);
}
}
}
// 预处理每行在掩码 s 下的行和
vector<vector<int64>> rowSum(n, vector<int64>(S, 0));
for (int r = 0; r < n; ++r) {
for (int j = 0; j < S; ++j) {
int s = masks[j];
int64 sum = 0;
for (int c = 0; c < m; ++c) {
if (s & (1 << c)) sum += a[r][c];
}
rowSum[r][j] = sum;
}
}
const int64 NEG = (int64)-4e18;
vector<int64> dpPrev(S, NEG), dpCurr(S, NEG);
// 第一行
for (int j = 0; j < S; ++j) dpPrev[j] = rowSum[0][j];
// 后续各行
for (int r = 1; r < n; ++r) {
fill(dpCurr.begin(), dpCurr.end(), NEG);
for (int i = 0; i < S; ++i) {
if (dpPrev[i] == NEG) continue;
for (int j : compat[i]) {
dpCurr[j] = max(dpCurr[j], dpPrev[i] + rowSum[r][j]);
}
}
dpPrev.swap(dpCurr);
}
int64 ans = 0;
for (int j = 0; j < S; ++j) ans = max(ans, dpPrev[j]);
cout << ans << "\n";
}
return 0;
}
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int T = sc.nextInt();
while (T-- > 0) {
int n = sc.nextInt();
int m = sc.nextInt();
long[][] a = new long[n][m];
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) a[i][j] = sc.nextLong();
}
List<Integer> masks = new ArrayList<>();
int full = 1 << m;
for (int s = 0; s < full; s++) {
if ((s & (s << 1)) == 0) masks.add(s);
}
int S = masks.size();
// compat[i] 存 masks 索引 j,使得 masks[i] 与 masks[j] 兼容(上一行->当前行)
List<List<Integer>> compat = new ArrayList<>();
for (int i = 0; i < S; i++) compat.add(new ArrayList<>());
for (int i = 0; i < S; i++) {
int p = masks.get(i);
int lp = (p << 1) & (full - 1);
int rp = (p >> 1);
for (int j = 0; j < S; j++) {
int s = masks.get(j);
if ((s & p) == 0 && (s & lp) == 0 && (s & rp) == 0) compat.get(i).add(j);
}
}
long[][] rowSum = new long[n][S];
for (int r = 0; r < n; r++) {
for (int j = 0; j < S; j++) {
int s = masks.get(j);
long sum = 0;
for (int c = 0; c < m; c++) if ((s & (1 << c)) != 0) sum += a[r][c];
rowSum[r][j] = sum;
}
}
long NEG = (long)-4e18;
long[] dpPrev = new long[S];
long[] dpCurr = new long[S];
Arrays.fill(dpPrev, NEG);
Arrays.fill(dpCurr, NEG);
for (int j = 0; j < S; j++) dpPrev[j] = rowSum[0][j];
for (int r = 1; r < n; r++) {
Arrays.fill(dpCurr, NEG);
for (int i = 0; i < S; i++) {
if (dpPrev[i] == NEG) continue;
for (int j : compat.get(i)) {
long v = dpPrev[i] + rowSum[r][j];
if (v > dpCurr[j]) dpCurr[j] = v;
}
}
long[] tmp = dpPrev; dpPrev = dpCurr; dpCurr = tmp;
}
long ans = 0;
for (int j = 0; j < S; j++) ans = Math.max(ans, dpPrev[j]);
System.out.println(ans);
}
}
}
T = int(input().strip())
for _ in range(T):
n, m = map(int, input().split())
a = [list(map(int, input().split())) for __ in range(n)]
masks = [s for s in range(1 << m) if (s & (s << 1)) == 0]
S = len(masks)
full = (1 << m) - 1
# 兼容列表(上一行索引 i -> 当前行索引列表)
compat = [[] for _ in range(S)]
for i, p in enumerate(masks):
lp = (p << 1) & full
rp = (p >> 1)
for j, s in enumerate(masks):
if (s & p) == 0 and (s & lp) == 0 and (s & rp) == 0:
compat[i].append(j)
row_sum = [[0] * S for _ in range(n)]
for r in range(n):
for j, s in enumerate(masks):
tot = 0
c = s
col = 0
while c:
if c & 1:
tot += a[r][col]
c >>= 1
col += 1
# 若上面用位扫描,可保证 O(ones);也可直接 for col in range(m)
row_sum[r][j] = tot
NEG = -10**18
dp_prev = [NEG] * S
dp_curr = [NEG] * S
for j in range(S):
dp_prev[j] = row_sum[0][j]
for r in range(1, n):
for j in range(S):
dp_curr[j] = NEG
for i in range(S):
if dp_prev[i] == NEG:
continue
val = dp_prev[i]
for j in compat[i]:
v = val + row_sum[r][j]
if v > dp_curr[j]:
dp_curr[j] = v
dp_prev, dp_curr = dp_curr, dp_prev
ans = 0
for j in range(S):
if dp_prev[j] > ans:
ans = dp_prev[j]
print(ans)
算法及复杂度
- 算法:状态压缩 DP(逐行转移,八连通不相邻约束)
- 时间复杂度:
,其中
为合法掩码数量(
,实际远小于
;若预先为每个
存兼容的
列表,常数更小)
- 空间复杂度:
(滚动数组)