题意:

给你一个n*m(1e3)的矩阵,让你找出元素全部相同的子矩阵的个数。

思路:

可以预处理向左和向上的最大相同长度,然后对于每列用rmq维护一个区间最小值,

这个值表示向左延伸的长度,然后对于当前的元素,二分查找距离他最近的值小于他的上一个位置,

然后当前位置的贡献就是向左延伸的长度*纵坐标之差+1(这块矩阵完全相同,直接边长相乘)再加上上一个位置的贡献。

总复杂度就是n^2log(n)的

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <set>
#include <algorithm>
using namespace std;
typedef long long LL;
typedef pair<int,int>pii;
const int N = 1e3+5;
const int mod = 1e9+7;
int T,n,m,a[N][N],u[N][N],l[N][N],t[N];
int dp[N][10],mm[N];
void initrmq(int x){
  for(int i=1;i<=n;++i)
    dp[i][0] = l[i][x];
  for(int j = 1;j<=mm[n];++j)
    for(int i=1;i+(1<<j)-1<=n;++i)
      dp[i][j] = min(dp[i][j-1],dp[i+(1<<(j-1))][j-1]);
}
int rmq(int x,int y){
  int k = mm[y-x+1];
  return min(dp[x][k],dp[y-(1<<k)+1][k]);
}
int main(){
   scanf("%d",&T);
   while(T--){
     mm[0] = -1;
     scanf("%d%d",&n,&m);
     for(int i=1;i<=n;++i)
        mm[i] = ((i&(i-1))==0)?mm[i-1]+1:mm[i-1];
     for(int i=1;i<=n;++i)
       for(int j=1;j<=m;++j)
          scanf("%d",&a[i][j]);
      for(int i=1;i<=n;++i){
         l[i][1] = 1;
         for(int j=2;j<=m;++j)
           if(a[i][j]==a[i][j-1])l[i][j] = l[i][j-1]+1;
           else l[i][j] = 1;
      }
      for(int j=1;j<=m;++j){
         u[1][j] = 1;
         for(int i=2;i<=n;++i){
            if(a[i][j]==a[i-1][j])u[i][j] = u[i-1][j]+1;
            else u[i][j] = 1;
         }
      }
      LL ret = 0;
      for(int j=1;j<=m;++j){
         t[1] = l[1][j];
         initrmq(j);
         for(int i=2;i<=n;++i){
            int x = i ,y = i-u[i][j]+1;
            int tmp = rmq(y,x);
            if(tmp>=l[i][j]){
              t[i] = l[i][j]*u[i][j];
              continue;
            }
            int ans;
            while(x>=y){
               int mid = x+y>>1;
               tmp = rmq(mid,i);
               if(tmp<l[i][j])y = mid+1;
               else ans = mid,x = mid-1;
            }
            t[i] = l[i][j]*(i-ans+1);
            if(i-u[i][j]+1<=ans-1)t[i]+=t[ans-1];
         }
         for(int i=1;i<=n;++i)ret+=t[i];
      }
      printf("%I64d\n",ret);
   }
   return 0;
}