小红的蛋糕切割

[题目链接](https://www.nowcoder.com/practice/ea72b571afdb499fa344e56f06b5c71c)

思路

给定 的矩阵,需要从中切出一个正方形子矩阵,使得切出部分的美味度之和 与剩余部分的美味度之和 的差的绝对值 最小。

转化问题

设矩阵所有元素之和为 ,则:

$$

因此问题转化为:枚举所有可能的正方形子矩阵,找到子矩阵和 使得 最小。

二维前缀和

用二维前缀和数组 表示矩阵左上角 的元素之和。建立前缀和后,任意以 为左上角、 为右下角的子矩阵之和可以 计算:

$$

枚举所有正方形

枚举正方形的边长 (从 ),再枚举正方形右下角的位置 ,即可遍历所有合法正方形。对每个正方形计算子矩阵和,更新答案。

样例演示

矩阵为:

1 2 3
2 3 4
3 2 1

总和 ,我们需要找 最接近 的正方形。取左下角 子矩阵(第 2-3 行,第 1-2 列):,此时 ,即为最优答案。

复杂度分析

  • 时间复杂度。外层枚举边长 ,内层枚举位置
  • 空间复杂度,存储前缀和数组。

代码

> 本题仅支持 C++ 和 Java 提交。Python / JavaScript 由于常数过大,在本题数据规模下会超时。

#include <bits/stdc++.h>
using namespace std;

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n, m;
    cin >> n >> m;
    vector<vector<long long>> pre(n+1, vector<long long>(m+1, 0));
    for(int i = 1; i <= n; i++){
        for(int j = 1; j <= m; j++){
            long long x;
            cin >> x;
            pre[i][j] = x + pre[i-1][j] + pre[i][j-1] - pre[i-1][j-1];
        }
    }
    long long total = pre[n][m];
    long long ans = LLONG_MAX;
    int maxk = min(n, m);
    for(int k = 1; k <= maxk; k++){
        for(int i = k; i <= n; i++){
            for(int j = k; j <= m; j++){
                long long s = pre[i][j] - pre[i-k][j] - pre[i][j-k] + pre[i-k][j-k];
                long long diff = abs(2*s - total);
                if(diff < ans) ans = diff;
            }
        }
    }
    cout << ans << endl;
    return 0;
}
import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt(), m = sc.nextInt();
        long[][] pre = new long[n + 1][m + 1];
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= m; j++) {
                long x = sc.nextLong();
                pre[i][j] = x + pre[i - 1][j] + pre[i][j - 1] - pre[i - 1][j - 1];
            }
        }
        long total = pre[n][m];
        long ans = Long.MAX_VALUE;
        int maxk = Math.min(n, m);
        for (int k = 1; k <= maxk; k++) {
            for (int i = k; i <= n; i++) {
                for (int j = k; j <= m; j++) {
                    long s = pre[i][j] - pre[i - k][j] - pre[i][j - k] + pre[i - k][j - k];
                    long diff = Math.abs(2 * s - total);
                    if (diff < ans) ans = diff;
                }
            }
        }
        System.out.println(ans);
    }
}
const readline = require('readline');
const rl = readline.createInterface({ input: process.stdin });
const lines = [];
rl.on('line', line => lines.push(line.trim()));
rl.on('close', () => {
    const [n, m] = lines[0].split(' ').map(Number);
    const pre = Array.from({length: n+1}, () => new Array(m+1).fill(0));
    for (let i = 1; i <= n; i++) {
        const row = lines[i].split(' ').map(Number);
        for (let j = 1; j <= m; j++) {
            pre[i][j] = row[j-1] + pre[i-1][j] + pre[i][j-1] - pre[i-1][j-1];
        }
    }
    const total = pre[n][m];
    let ans = Infinity;
    const maxk = Math.min(n, m);
    for (let k = 1; k <= maxk; k++) {
        for (let i = k; i <= n; i++) {
            for (let j = k; j <= m; j++) {
                const s = pre[i][j] - pre[i-k][j] - pre[i][j-k] + pre[i-k][j-k];
                const diff = Math.abs(2 * s - total);
                if (diff < ans) ans = diff;
            }
        }
    }
    console.log(ans);
});
import sys

def main():
    data = sys.stdin.buffer.read().split()
    idx = 0
    n = int(data[idx]); idx += 1
    m = int(data[idx]); idx += 1

    pre = [[0] * (m + 1) for _ in range(n + 1)]
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            pre[i][j] = int(data[idx]) + pre[i-1][j] + pre[i][j-1] - pre[i-1][j-1]
            idx += 1

    total = pre[n][m]
    ans = total
    maxk = min(n, m)
    for k in range(1, maxk + 1):
        for i in range(k, n + 1):
            pi = pre[i]
            pik = pre[i - k]
            for j in range(k, m + 1):
                s = pi[j] - pik[j] - pi[j - k] + pik[j - k]
                d = abs(2 * s - total)
                if d < ans:
                    ans = d
    print(ans)

main()