矩阵计数

[题目链接](https://www.nowcoder.com/practice/3f97500b339d46d7ad969ae414cf87ad)

思路

给定一个 的字符矩阵,每个元素是 red 之一。定义矩阵权值为矩阵中三种字符出现次数的最小值。给定阈值 ,求有多少个子方阵的权值不小于

二维前缀和 + 二分查找

核心观察:固定子方阵的左上角 ,随着边长 的增大,方阵内每种字符的数量单调不减,因此权值(三种字符数量的最小值)也单调不减。这意味着可以对边长进行二分查找,找到使权值 的最小边长 ,那么从 到最大可能边长的所有方阵都满足条件。

快速查询子矩阵字符数:对 red 三种字符分别建立二维前缀和。设 为矩阵前 行、前 列中字符 的出现次数,则左上角 、右下角 的子矩阵中字符 的数量为:

$$

算法流程

  1. 构建三种字符的二维前缀和数组。
  2. 枚举每个位置 作为子方阵的左上角。
  3. 二分查找满足条件的最小边长 :在 上二分,用前缀和 查询子方阵中每种字符的数量,取最小值与 比较。
  4. 存在,则贡献 个合法子方阵。

样例演示

样例 1["red","red","red"], myval=2

矩阵为:

r e d
r e d
r e d
  • 所有 方阵:只含一种字符,权值为 0(另外两种字符计数为 0),不满足
  • 所有 方阵:至多包含 种字符各 个,第 种字符为 ,权值为
  • 方阵:r、e、d 各出现 次,权值为 ,合法。

答案为

样例 2["red","edr","dre"], myval=1

矩阵为:

r e d
e d r
d r e

方阵权值为 ;但 方阵中许多都满足三种字符各至少出现 次。答案为

复杂度分析

  • 时间复杂度:。枚举 个左上角,每个做 的二分查找。
  • 空间复杂度:。存储三个 的前缀和数组。

代码

class Solution {
public:
    int matrixCount(vector<string>& a, int myval) {
        int n = a.size();
        vector<vector<vector<int>>> pre(3, vector<vector<int>>(n + 1, vector<int>(n + 1, 0)));
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= n; j++) {
                char ch = a[i - 1][j - 1];
                int idx = (ch == 'r') ? 0 : (ch == 'e') ? 1 : 2;
                for (int c = 0; c < 3; c++) {
                    pre[c][i][j] = pre[c][i - 1][j] + pre[c][i][j - 1]
                                 - pre[c][i - 1][j - 1] + (c == idx ? 1 : 0);
                }
            }
        }
        int ans = 0;
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= n; j++) {
                int maxK = min(n - i + 1, n - j + 1);
                int lo = 1, hi = maxK, best = maxK + 1;
                while (lo <= hi) {
                    int mid = (lo + hi) / 2;
                    int r2 = i + mid - 1, c2 = j + mid - 1;
                    int minCnt = INT_MAX;
                    for (int c = 0; c < 3; c++) {
                        int cnt = pre[c][r2][c2] - pre[c][i - 1][c2]
                                - pre[c][r2][j - 1] + pre[c][i - 1][j - 1];
                        minCnt = min(minCnt, cnt);
                    }
                    if (minCnt >= myval) {
                        best = mid;
                        hi = mid - 1;
                    } else {
                        lo = mid + 1;
                    }
                }
                if (best <= maxK) {
                    ans += maxK - best + 1;
                }
            }
        }
        return ans;
    }
};
import java.util.*;

public class Solution {
    public int matrixCount(String[] a, int myval) {
        int n = a.length;
        int[][][] pre = new int[3][n + 1][n + 1];
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= n; j++) {
                char ch = a[i - 1].charAt(j - 1);
                int idx = (ch == 'r') ? 0 : (ch == 'e') ? 1 : 2;
                for (int c = 0; c < 3; c++) {
                    pre[c][i][j] = pre[c][i - 1][j] + pre[c][i][j - 1]
                                 - pre[c][i - 1][j - 1] + (c == idx ? 1 : 0);
                }
            }
        }
        int ans = 0;
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= n; j++) {
                int maxK = Math.min(n - i + 1, n - j + 1);
                int lo = 1, hi = maxK, best = maxK + 1;
                while (lo <= hi) {
                    int mid = (lo + hi) / 2;
                    int r2 = i + mid - 1, c2 = j + mid - 1;
                    int minCnt = Integer.MAX_VALUE;
                    for (int c = 0; c < 3; c++) {
                        int cnt = pre[c][r2][c2] - pre[c][i - 1][c2]
                                - pre[c][r2][j - 1] + pre[c][i - 1][j - 1];
                        minCnt = Math.min(minCnt, cnt);
                    }
                    if (minCnt >= myval) {
                        best = mid;
                        hi = mid - 1;
                    } else {
                        lo = mid + 1;
                    }
                }
                if (best <= maxK) {
                    ans += maxK - best + 1;
                }
            }
        }
        return ans;
    }
}