题意

给定一个 的整数矩阵,要求计算出所有非空子矩阵的元素按位与结果,并将这些结果从小到大排序,求出中位数(即第 个数,其中 为子矩阵总数)。

题解

矩阵元素最大值不超过 ,因此所有子矩阵的按位与结果必然落在 的范围内。我们可以直接统计这 种结果各自出现的总次数,最后通过累加频次寻找出中位数。

先枚举子矩阵的上边界 r1 和下边界 r2,并将这若干行压缩成一个一维数组 b。原问题转化为了求一维数组的所有连续子数组按位与的值的频次。从左到右扫描数组,维护以当前位置结尾的所有不同按位与值及其出现次数。由于与运算的结果具有单调不降的特性,从当前位置向左延伸的子数组,最多只会产生约 种不同的值。合并相同的值后,再把当前新产生的各个状态以及频次累加到全局结果中即可。

考虑到 ,为了进一步优化常数,可以做一个小特判:当 时,先将整个矩阵进行转置。这能确保外层 的暴力枚举始终发生在较小的维度上。

复杂度

  • 时间复杂度:
  • 空间复杂度:

参考代码与链接

#include <bits/stdc++.h>
using namespace std;

using ll = long long;
using pil = pair<int, ll>;

void solve() {
    int n, m;
    cin >> n >> m;
    
    vector<vector<int>> a(n, vector<int>(m));
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < m; ++j) {
            cin >> a[i][j];
        }
    }
    
    if (n > m) {
        vector<vector<int>> a_T(m, vector<int>(n));
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < m; ++j) {
                a_T[j][i] = a[i][j];
            }
        }
        swap(n, m);
        a = std::move(a_T);
    }
    
    vector<ll> cnt(1024, 0);
    vector<int> b(m);
    
    for (int r1 = 0; r1 < n; ++r1) {
        for (int j = 0; j < m; ++j) b[j] = a[r1][j];
        
        for (int r2 = r1; r2 < n; ++r2) {
            if (r2 > r1) {
                for (int j = 0; j < m; ++j) b[j] &= a[r2][j];
            }
            
            vector<pil> prev_ands;
            for (int j = 0; j < m; ++j) {
                vector<pil> next_ands;
                next_ands.push_back({b[j], 1});
                
                for (auto& p : prev_ands) {
                    int new_val = p.first & b[j];
                    if (next_ands.back().first == new_val) {
                        next_ands.back().second += p.second;
                    } else {
                        next_ands.push_back({new_val, p.second});
                    }
                }
                
                for (auto& p : next_ands) {
                    cnt[p.first] += p.second;
                }
                prev_ands = std::move(next_ands);
            }
        }
    }
    
    ll total = 0;
    for (int i = 0; i < 1024; ++i) total += cnt[i];
    
    ll target = (total + 1) / 2;
    ll curr = 0;
    
    for (int i = 0; i < 1024; ++i) {
        curr += cnt[i];
        if (curr >= target) {
            cout << i << '\n';
            return;
        }
    }
}

int main() {
    ios::sync_with_stdio(false), cin.tie(nullptr);
    int T; cin >> T;
    while (T--) solve();
}