题目链接

小红的蛋糕切割

题目描述

小红有一个 的矩形蛋糕,每个区域都有一个美味度。

她希望切割出一个正方形的小蛋糕给自己吃,剩下的部分给小紫吃。

她希望两人吃的部分的美味度之和尽可能接近。设小红吃的美味度之和为 ,小紫吃的为 ,请你输出 的最小值。

解题思路

这是一个在矩阵中寻找最优子结构的问题。我们可以通过问题转化和二维前缀和来高效解决。

1. 问题转化

设整个蛋糕的总美味度为

小红切走一块美味度为 的正方形,则小紫吃剩下的部分,其美味度

我们需要最小化

代入,得到:

因此,原问题转化为:遍历所有可能的正方形区域,计算其美味度之和 ,并找到使 最小的值。

2. 使用二维前缀和优化

暴力枚举所有正方形并每次都重新计算其和,效率很低。为了快速计算任意子矩阵的和,我们可以使用二维前缀和

  • 预计算

    我们创建一个 prefix_sum 数组,其中 prefix_sum[i][j] 存储以 (0,0) 为左上角、(i-1, j-1) 为右下角的矩形区域的美味度之和。这个数组可以在 的时间内构建完成。

  • 查询

    构建完成后,任意一个以 (r1, c1) 为左上角、(r2, c2) 为右下角的子矩阵的和都可以通过 prefix_sum 数组在 时间内算出。

3. 算法步骤

  1. 计算总和

    遍历一遍原始蛋糕矩阵,计算出总美味度

  2. 构建前缀和数组

    根据递推公式 prefix_sum[i][j] = cake[i-1][j-1] + prefix_sum[i-1][j] + prefix_sum[i][j-1] - prefix_sum[i-1][j-1],构建二维前缀和数组。

  3. 枚举所有正方形

    • 用三层循环遍历所有可能的正方形:左上角坐标 (i, j) 和边长 k

    • 对于每个正方形,利用前缀和数组在 内计算出其和

    • 计算 并更新全局最小差值。

  4. 输出结果

    遍历结束后,输出记录的最小差值。

代码

#include <bits/stdc++.h>

using namespace std;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n, m;
    cin >> n >> m;

    vector<vector<long long>> cake(n, vector<long long>(m));
    long long total_sum = 0;
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < m; ++j) {
            cin >> cake[i][j];
            total_sum += cake[i][j];
        }
    }

    vector<vector<long long>> prefix_sum(n + 1, vector<long long>(m + 1, 0));
    for (int i = 1; i <= n; ++i) {
        for (int j = 1; j <= m; ++j) {
            prefix_sum[i][j] = cake[i - 1][j - 1] + prefix_sum[i - 1][j] + prefix_sum[i][j - 1] - prefix_sum[i - 1][j - 1];
        }
    }

    long long min_diff = -1;

    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < m; ++j) {
            for (int k = 1; i + k <= n && j + k <= m; ++k) {
                int r1 = i, c1 = j;
                int r2 = i + k - 1, c2 = j + k - 1;
                
                long long s_red = prefix_sum[r2 + 1][c2 + 1] - prefix_sum[r1][c2 + 1] - prefix_sum[r2 + 1][c1] + prefix_sum[r1][c1];
                long long current_diff = abs(2 * s_red - total_sum);
                
                if (min_diff == -1 || current_diff < min_diff) {
                    min_diff = current_diff;
                }
            }
        }
    }

    cout << min_diff << endl;

    return 0;
}
import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();

        long[][] cake = new long[n][m];
        long totalSum = 0;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                cake[i][j] = sc.nextLong();
                totalSum += cake[i][j];
            }
        }

        long[][] prefixSum = new long[n + 1][m + 1];
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= m; j++) {
                prefixSum[i][j] = cake[i - 1][j - 1] + prefixSum[i - 1][j] + prefixSum[i][j - 1] - prefixSum[i - 1][j - 1];
            }
        }

        long minDiff = -1;

        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                for (int k = 1; i + k <= n && j + k <= m; k++) {
                    int r1 = i, c1 = j;
                    int r2 = i + k - 1, c2 = j + k - 1;

                    long sRed = prefixSum[r2 + 1][c2 + 1] - prefixSum[r1][c2 + 1] - prefixSum[r2 + 1][c1] + prefixSum[r1][c1];
                    long currentDiff = Math.abs(2 * sRed - totalSum);
                    
                    if (minDiff == -1 || currentDiff < minDiff) {
                        minDiff = currentDiff;
                    }
                }
            }
        }
        System.out.println(minDiff);
    }
}
import sys

def solve():
    n, m = map(int, sys.stdin.readline().split())
    cake = []
    total_sum = 0
    for _ in range(n):
        row = list(map(int, sys.stdin.readline().split()))
        total_sum += sum(row)
        cake.append(row)

    prefix_sum = [[0] * (m + 1) for _ in range(n + 1)]
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            prefix_sum[i][j] = cake[i-1][j-1] + prefix_sum[i-1][j] + prefix_sum[i][j-1] - prefix_sum[i-1][j-1]

    min_diff = -1

    for i in range(n):
        for j in range(m):
            for k in range(1, min(n - i, m - j) + 1):
                r1, c1 = i, j
                r2, c2 = i + k - 1, j + k - 1
                
                s_red = prefix_sum[r2+1][c2+1] - prefix_sum[r1][c2+1] - prefix_sum[r2+1][c1] + prefix_sum[r1][c1]
                current_diff = abs(2 * s_red - total_sum)
                
                if min_diff == -1 or current_diff < min_diff:
                    min_diff = current_diff

    print(min_diff)

solve()

算法及复杂度

  • 算法:二维前缀和 + 枚举

  • 时间复杂度: 。构建前缀和数组需要 。之后三层循环枚举所有正方形,复杂度为

  • 空间复杂度: ,用于存储二维前缀和数组。