与或和
题目描述见链接 .
正解部分
子任务1: 求有多少全 1矩阵 .
子任务2: 求有多少存在 1的矩阵 → 总矩阵数目 − 全 0矩阵数目 .
其中 全 0矩阵数目 等于 ∑i=1Ni∑j=1nj=(2N(N+1))2 个.
于是只需解决 子任务1 即可, 怎么解决呢 ? 我们使用 单调栈 .
设 (i,j) 及其上方连续 1 的个数是 Up[i,j], 最后答案为 res,
枚举矩阵的下边界, O(N), 从左向右维护一个 Up[i,j] 单调递增 的 单调栈,
设当前 单调栈 中 高度不同 的 左上端点 数量为 tmp, 那么在新加入一个元素 Up[i,j] 时, 分 2 类情况讨论,
- 满足单调性, 则 tmp 个 不同高度 的 左上端点 各可以获得 相应高度的 右端点,
而这 Up[i,j] 个 1 又可以作为 Up[i,j] 个 不同高度 的 左上端点, 且 可以 以自己为 右端点 .
于是 tmp+=Up[i,j], res+=tmp . - 不满足单调性, 此时栈中 tmp 个 … 存在不能以这 Up[i,j] 个 1 为 右端点 的 左端点,
且这些 左端点 对应的矩形都已经在前面计算过了, 已无用, 需要丢掉,
设 Up[i,stk[k]]≤Up[i,j], 不断弹出栈顶直到 stk[top]=stk[k] 即可,
在弹出一个栈顶时, 有 t=(stk[top]−stk[top−1])∗(Up[i,stk[top]−Up[i,j]) 个 左端点 消失了,
因此每次弹出时, tmp−=t, 弹完后, 栈 中 左端点 就全部可以正常使用了,
同上 tmp+=Up[i,j], res+=tmp .
最后得到 res .
实现部分
上面已经说得很清楚了 .
#include<bits/stdc++.h>
#define reg register
typedef long long ll;
int read(){
char c;
int s = 0, flag = 1;
while((c=getchar()) && !isdigit(c))
if(c == '-'){ flag = -1, c = getchar(); break ; }
while(isdigit(c)) s = s*10 + c-'0', c = getchar();
return s * flag;
}
const int maxn = 1002;
const int mod = 998244353;
int N;
int top;
int Ans;
int Ans_1;
int Max_v;
int Ld[maxn];
int Rd[maxn];
int stk[maxn];
int A[maxn][maxn];
int Up[maxn][maxn];
int sum_1[maxn][maxn][2];
//int sum_2[maxn][maxn][2];
ll pw[maxn];
ll Calc(int p){
ll res = 0;
for(reg int j = 1; j <= N; j ++)
for(reg int i = 1; i <= N; i ++){
Up[i][j] = 0;
if(A[i][j] & pw[p]) Up[i][j] = Up[i-1][j] + 1;
}
for(reg int i = 1; i <= N; i ++){
top = 1; ll tmp = 0;
for(reg int j = 1; j <= N; j ++){
while(top != 1 && Up[i][stk[top]] > Up[i][j])
tmp -= (stk[top]-stk[top-1]) * (Up[i][stk[top]] - Up[i][j]), top --;
tmp += Up[i][j], stk[++ top] = j;
res = (res + tmp) % mod;
}
}
return res % mod;
}
int main(){
freopen("mob.in", "r", stdin);
freopen("mob.out", "w", stdout);
N = read();
for(reg int i = 1; i <= N; i ++)
for(reg int j = 1; j <= N; j ++) A[i][j] = read(), Max_v = std::max(Max_v, A[i][j]);
pw[0] = 1;
for(reg int i = 1; i <= 100; i ++) pw[i] = pw[i-1]<<1;
ll tot = N*(N+1)/2 % mod; tot = tot*tot % mod;
for(reg int p = 0; pw[p] <= Max_v; p ++){
Ans_1 = (Ans_1 + (1ll*pw[p]*Calc(p))%mod) % mod;
for(reg int i = 1; i <= N; i ++)
for(reg int j = 1; j <= N; j ++) A[i][j] ^= pw[p];
int num = (tot - Calc(p) + mod) % mod;
Ans = (Ans + (1ll*pw[p]*num)%mod) % mod;
}
printf("%d\n", Ans_1);
printf("%d\n", Ans);
return 0;
}