题目链接
题目描述
小红有两个长度都为 的数组
和
,它们仅包含0和1。现在小红生成一个
的二维矩阵
,满足
(
是异或操作)。
请计算出矩阵 的所有子矩阵的数值之和,结果对
取模。
思路分析
1. 问题转化:从子矩阵和到单点贡献
直接枚举所有子矩阵(数量级为 )来求和是不可行的。一个经典的优化思路是转换视角,计算矩阵中每个元素
对总和的贡献。
一个元素 的总贡献等于
的值乘以包含它的子矩阵的数量。
一个子矩阵由其左上角 和右下角
决定。要使一个子矩阵包含元素
(0-indexed),必须满足:
对于行:
- 左上角的行
有
种选择(从
到
)。
- 右下角的行
有
种选择(从
到
)。 对于列:
- 左上角的列
有
种选择(从
到
)。
- 右下角的列
有
种选择(从
到
)。
因此,包含元素 的子矩阵数量为
。
所有子矩阵的数值之和 就可以表示为:
代入 :
这个公式的计算复杂度是 ,对于
依然会超时。我们需要进一步优化。
2. 线性优化:分离变量
我们观察到求和式中的项可以分离。令权重函数 。则公式变为:
我们可以将内层和式提出来:
现在,我们分析内层的和 。这个和的值取决于
:
- 如果
:内层和变为
。
- 如果
:内层和变为
。
可以看到,内层的和只依赖于两个可以预先算出的值:
(只对
的项求和)
(对所有
求和)
于是,总和 可以被重写为:
令 和
。
则最终公式为:
其中 也可以表示为
。
所有这些求和都可以在 时间内完成,因此总时间复杂度降至
。
3. 算法步骤
- 定义模数
。
- 初始化
为 0。
- 遍历
从
到
: a. 计算
。 b. 根据
的值,将
累加到
或
。 c. 根据
的值(如果
),将
累加到
。
- 计算
。
- 根据公式
计算最终结果。注意处理减法时的取模,防止出现负数。
代码
#include <iostream>
#include <vector>
#include <numeric>
using namespace std;
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n;
cin >> n;
vector<long long> a(n), b(n);
for (int i = 0; i < n; ++i) cin >> a[i];
for (int i = 0; i < n; ++i) cin >> b[i];
long long M = 1e9 + 7;
long long s_aw0 = 0, s_aw1 = 0;
long long s_bw = 0;
for (long long i = 0; i < n; ++i) {
long long w_i = ((i + 1) * (n - i)) % M;
if (a[i] == 0) {
s_aw0 = (s_aw0 + w_i) % M;
} else {
s_aw1 = (s_aw1 + w_i) % M;
}
if (b[i] == 1) {
s_bw = (s_bw + w_i) % M;
}
}
long long s_w = (s_aw0 + s_aw1) % M;
long long term1 = (s_aw0 * s_bw) % M;
long long s_w_minus_s_bw = (s_w - s_bw + M) % M;
long long term2 = (s_aw1 * s_w_minus_s_bw) % M;
long long total_sum = (term1 + term2) % M;
cout << total_sum << endl;
return 0;
}
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
long[] a = new long[n];
long[] b = new long[n];
for (int i = 0; i < n; i++) {
a[i] = sc.nextLong();
}
for (int i = 0; i < n; i++) {
b[i] = sc.nextLong();
}
long M = 1_000_000_007;
long sAw0 = 0, sAw1 = 0;
long sBw = 0;
for (long i = 0; i < n; i++) {
long wI = ((i + 1) * (n - i)) % M;
if (a[(int)i] == 0) {
sAw0 = (sAw0 + wI) % M;
} else {
sAw1 = (sAw1 + wI) % M;
}
if (b[(int)i] == 1) {
sBw = (sBw + wI) % M;
}
}
long sW = (sAw0 + sAw1) % M;
long term1 = (sAw0 * sBw) % M;
long sWMinusSBw = (sW - sBw + M) % M;
long term2 = (sAw1 * sWMinusSBw) % M;
long totalSum = (term1 + term2) % M;
System.out.println(totalSum);
}
}
import sys
def solve():
n = int(sys.stdin.readline())
a = list(map(int, sys.stdin.readline().split()))
b = list(map(int, sys.stdin.readline().split()))
M = 10**9 + 7
s_aw0 = 0
s_aw1 = 0
s_bw = 0
# 在循环中累加时可以不取模,减少取模运算次数
# 只要中间结果不超过Python整数上限即可
for i in range(n):
w_i = (i + 1) * (n - i)
if a[i] == 0:
s_aw0 += w_i
else:
s_aw1 += w_i
if b[i] == 1:
s_bw += w_i
# 在最后计算前统一取模
s_aw0 %= M
s_aw1 %= M
s_bw %= M
s_w = (s_aw0 + s_aw1) % M
term1 = (s_aw0 * s_bw) % M
s_w_minus_s_bw = (s_w - s_bw + M) % M
term2 = (s_aw1 * s_w_minus_s_bw) % M
total_sum = (term1 + term2) % M
print(total_sum)
solve()
算法及复杂度
- 算法:数学推导 + 贡献法
- 时间复杂度:
。我们只需要遍历数组一次来计算所需的各个和。
- 空间复杂度:
,用于存储输入的两个数组
和
。