题目链接

取数游戏

题目描述

给定 组数据。每组数据给出一个由非负整数构成的 数字矩阵。需要从矩阵中选出若干个数,使得任意两个被选中的位置在八连通意义下不相邻(即上下左右与四个对角方向均不能相邻)。求所有被选数字之和的最大值。

输入:

  • 第一行一个正整数
  • 对于每组数据:第一行两个正整数 ;接下来 行,每行 个非负整数

输出:

  • 对于每组数据,输出一个非负整数,表示最大和

解题思路

将问题转化为“逐行状态压缩 DP”:

  • 用长度为 的二进制掩码 表示一行中被选的位置; 合法需满足同一行内不相邻:
  • 跨行时,上一行掩码 与当前行掩码 应满足八连通不相邻:
    • 垂直/重叠:
    • 左上/右下对角:
    • 右上/左下对角:
  • 预处理每行在掩码 下的取值和 。状态转移:
  • 初始为第 行各合法掩码的行和,答案为最后一行所有 的最大值。

实现要点:

  • 预枚举所有合法掩码列表以及它们之间的兼容关系列表,转移更快。
  • 使用 64 位整型存储答案(行内和与最终和可能较大)。

代码

#include <bits/stdc++.h>
using namespace std;
using int64 = long long;

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int T; 
    if (!(cin >> T)) return 0;
    while (T--) {
        int n, m; 
        cin >> n >> m;
        vector<vector<int64>> a(n, vector<int64>(m));
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < m; ++j) cin >> a[i][j];
        }

        // 枚举合法行掩码
        vector<int> masks;
        int full = 1 << m;
        for (int s = 0; s < full; ++s) {
            if (s & (s << 1)) continue; // 行内相邻
            masks.push_back(s);
        }

        int S = (int)masks.size();
        // 预处理兼容列表
        vector<vector<int>> compat(S);
        for (int i = 0; i < S; ++i) {
            int p = masks[i];
            for (int j = 0; j < S; ++j) {
                int s = masks[j];
                if ((s & p) == 0) {
                    int lp = (p << 1) & (full - 1);
                    int rp = (p >> 1);
                    if ((s & lp) == 0 && (s & rp) == 0) compat[i].push_back(j);
                }
            }
        }

        // 预处理每行在掩码 s 下的行和
        vector<vector<int64>> rowSum(n, vector<int64>(S, 0));
        for (int r = 0; r < n; ++r) {
            for (int j = 0; j < S; ++j) {
                int s = masks[j];
                int64 sum = 0;
                for (int c = 0; c < m; ++c) {
                    if (s & (1 << c)) sum += a[r][c];
                }
                rowSum[r][j] = sum;
            }
        }

        const int64 NEG = (int64)-4e18;
        vector<int64> dpPrev(S, NEG), dpCurr(S, NEG);
        // 第一行
        for (int j = 0; j < S; ++j) dpPrev[j] = rowSum[0][j];
        // 后续各行
        for (int r = 1; r < n; ++r) {
            fill(dpCurr.begin(), dpCurr.end(), NEG);
            for (int i = 0; i < S; ++i) {
                if (dpPrev[i] == NEG) continue;
                for (int j : compat[i]) {
                    dpCurr[j] = max(dpCurr[j], dpPrev[i] + rowSum[r][j]);
                }
            }
            dpPrev.swap(dpCurr);
        }

        int64 ans = 0;
        for (int j = 0; j < S; ++j) ans = max(ans, dpPrev[j]);
        cout << ans << "\n";
    }
    return 0;
}
import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int T = sc.nextInt();
        while (T-- > 0) {
            int n = sc.nextInt();
            int m = sc.nextInt();
            long[][] a = new long[n][m];
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < m; j++) a[i][j] = sc.nextLong();
            }

            List<Integer> masks = new ArrayList<>();
            int full = 1 << m;
            for (int s = 0; s < full; s++) {
                if ((s & (s << 1)) == 0) masks.add(s);
            }
            int S = masks.size();

            // compat[i] 存 masks 索引 j,使得 masks[i] 与 masks[j] 兼容(上一行->当前行)
            List<List<Integer>> compat = new ArrayList<>();
            for (int i = 0; i < S; i++) compat.add(new ArrayList<>());
            for (int i = 0; i < S; i++) {
                int p = masks.get(i);
                int lp = (p << 1) & (full - 1);
                int rp = (p >> 1);
                for (int j = 0; j < S; j++) {
                    int s = masks.get(j);
                    if ((s & p) == 0 && (s & lp) == 0 && (s & rp) == 0) compat.get(i).add(j);
                }
            }

            long[][] rowSum = new long[n][S];
            for (int r = 0; r < n; r++) {
                for (int j = 0; j < S; j++) {
                    int s = masks.get(j);
                    long sum = 0;
                    for (int c = 0; c < m; c++) if ((s & (1 << c)) != 0) sum += a[r][c];
                    rowSum[r][j] = sum;
                }
            }

            long NEG = (long)-4e18;
            long[] dpPrev = new long[S];
            long[] dpCurr = new long[S];
            Arrays.fill(dpPrev, NEG);
            Arrays.fill(dpCurr, NEG);

            for (int j = 0; j < S; j++) dpPrev[j] = rowSum[0][j];
            for (int r = 1; r < n; r++) {
                Arrays.fill(dpCurr, NEG);
                for (int i = 0; i < S; i++) {
                    if (dpPrev[i] == NEG) continue;
                    for (int j : compat.get(i)) {
                        long v = dpPrev[i] + rowSum[r][j];
                        if (v > dpCurr[j]) dpCurr[j] = v;
                    }
                }
                long[] tmp = dpPrev; dpPrev = dpCurr; dpCurr = tmp;
            }

            long ans = 0;
            for (int j = 0; j < S; j++) ans = Math.max(ans, dpPrev[j]);
            System.out.println(ans);
        }
    }
}
T = int(input().strip())
for _ in range(T):
    n, m = map(int, input().split())
    a = [list(map(int, input().split())) for __ in range(n)]
    masks = [s for s in range(1 << m) if (s & (s << 1)) == 0]
    S = len(masks)
    full = (1 << m) - 1

    # 兼容列表(上一行索引 i -> 当前行索引列表)
    compat = [[] for _ in range(S)]
    for i, p in enumerate(masks):
        lp = (p << 1) & full
        rp = (p >> 1)
        for j, s in enumerate(masks):
            if (s & p) == 0 and (s & lp) == 0 and (s & rp) == 0:
                compat[i].append(j)

    row_sum = [[0] * S for _ in range(n)]
    for r in range(n):
        for j, s in enumerate(masks):
            tot = 0
            c = s
            col = 0
            while c:
                if c & 1:
                    tot += a[r][col]
                c >>= 1
                col += 1
            # 若上面用位扫描,可保证 O(ones);也可直接 for col in range(m)
            row_sum[r][j] = tot

    NEG = -10**18
    dp_prev = [NEG] * S
    dp_curr = [NEG] * S
    for j in range(S):
        dp_prev[j] = row_sum[0][j]
    for r in range(1, n):
        for j in range(S):
            dp_curr[j] = NEG
        for i in range(S):
            if dp_prev[i] == NEG:
                continue
            val = dp_prev[i]
            for j in compat[i]:
                v = val + row_sum[r][j]
                if v > dp_curr[j]:
                    dp_curr[j] = v
        dp_prev, dp_curr = dp_curr, dp_prev

    ans = 0
    for j in range(S):
        if dp_prev[j] > ans:
            ans = dp_prev[j]
    print(ans)

算法及复杂度

  • 算法:状态压缩 DP(逐行转移,八连通不相邻约束)
  • 时间复杂度:,其中 为合法掩码数量(,实际远小于 ;若预先为每个 存兼容的 列表,常数更小)
  • 空间复杂度:(滚动数组)