题目链接
题目描述
给定两个长度为 的数组
和
。
需要从这两个数组中分别选出一个非空的子序列,我们称之为 和
。
要求满足条件: 的最大值不大于
的最小值。
求解总共有多少种满足条件的选取方法,答案需要对 取模。
解题思路
这是一个组合计数问题。直接枚举所有子序列对的数量级是 ,会严重超时。我们需要一种更高效的计数方法。
核心思想
我们可以转变计数的主体。与其枚举子序列,我们可以枚举第一个子序列 可能的最大值
,然后计算在这种情况下,有多少对
满足条件。
具体步骤
-
预处理:
- 对数组
和
进行升序排序。这使得我们能够快速地统计某个范围内元素的数量。
- 预计算
的幂次,因为一个包含
个元素的集合,其非空子序列的数量是
。
- 对数组
-
枚举与计算: 我们遍历排序后的数组
中的每一个元素
,并考虑当
的最大值为
时的情况。
-
第一步:计算
的数量 要使
的最大值恰好为
,需要满足两个条件:
中至少包含一个
。
中所有其他元素的值都小于或等于
。
这个数量可以通过容斥原理计算: (所有元素都
的非空子序列数量) - (所有元素都
的非空子序列数量)
设数组
中
的元素有
count_le
个,的元素有
count_lt
个。 则满足条件的数量为
。 由于数组
已经排序,
count_le
和count_lt
可以通过下标快速得到。 -
第二步:计算
的数量 当
的最大值为
时,条件变为
。这意味着
中的所有元素都必须大于或等于
。
我们需要计算数组
中有多少个元素
。设其数量为
count_ge
。 那么,由这些元素构成的任何非空子序列都满足条件。这样的数量为
。 由于数组
已经排序,
count_ge
可以通过二分查找(lower_bound
)高效地计算出来。 -
第三步:累加 将第一步和第二步的结果相乘,就得到了当
时的总方案数。我们遍历
中所有不重复的值,将每种情况的方案数累加,即可得到最终答案。
-
代码
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
const int MOD = 1e9 + 7;
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n;
cin >> n;
vector<long long> a(n), b(n);
for (int i = 0; i < n; ++i) cin >> a[i];
for (int i = 0; i < n; ++i) cin >> b[i];
sort(a.begin(), a.end());
sort(b.begin(), b.end());
vector<long long> pow2(n + 1);
pow2[0] = 1;
for (int i = 1; i <= n; ++i) {
pow2[i] = (pow2[i - 1] * 2) % MOD;
}
long long ans = 0;
int i = 0;
while (i < n) {
long long current_val = a[i];
int j = i;
while (j < n && a[j] == current_val) {
j++;
}
// 1. 计算 sub_a 的数量
// 小于 current_val 的元素数量是 i
// 小于等于 current_val 的元素数量是 j
long long lt_count = i;
long long le_count = j;
long long num_sub_a = (pow2[le_count] - pow2[lt_count] + MOD) % MOD;
// 2. 计算 sub_b 的数量
auto it = lower_bound(b.begin(), b.end(), current_val);
long long ge_count = distance(it, b.end());
if (ge_count > 0) {
long long num_sub_b = (pow2[ge_count] - 1 + MOD) % MOD;
// 3. 累加
ans = (ans + (num_sub_a * num_sub_b) % MOD) % MOD;
}
i = j;
}
cout << ans << endl;
return 0;
}
import java.util.Arrays;
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
long[] a = new long[n];
long[] b = new long[n];
for (int i = 0; i < n; i++) a[i] = sc.nextLong();
for (int i = 0; i < n; i++) b[i] = sc.nextLong();
Arrays.sort(a);
Arrays.sort(b);
long MOD = 1_000_000_007;
long[] pow2 = new long[n + 1];
pow2[0] = 1;
for (int i = 1; i <= n; i++) {
pow2[i] = (pow2[i - 1] * 2) % MOD;
}
long ans = 0;
int i = 0;
while (i < n) {
long currentVal = a[i];
int j = i;
while (j < n && a[j] == currentVal) {
j++;
}
// 1. 计算 sub_a 的数量
int ltCount = i;
int leCount = j;
long numSubA = (pow2[leCount] - pow2[ltCount] + MOD) % MOD;
// 2. 计算 sub_b 的数量
int geIndex = binarySearchLowerBound(b, currentVal);
int geCount = n - geIndex;
if (geCount > 0) {
long numSubB = (pow2[geCount] - 1 + MOD) % MOD;
// 3. 累加
ans = (ans + (numSubA * numSubB) % MOD) % MOD;
}
i = j;
}
System.out.println(ans);
}
// 手动实现 lower_bound
private static int binarySearchLowerBound(long[] arr, long key) {
int low = 0;
int high = arr.length;
while (low < high) {
int mid = low + (high - low) / 2;
if (arr[mid] >= key) {
high = mid;
} else {
low = mid + 1;
}
}
return low;
}
}
import bisect
def main():
n = int(input())
a = list(map(int, input().split()))
b = list(map(int, input().split()))
a.sort()
b.sort()
MOD = 10**9 + 7
pow2 = [1] * (n + 1)
for i in range(1, n + 1):
pow2[i] = (pow2[i - 1] * 2) % MOD
ans = 0
i = 0
while i < n:
current_val = a[i]
j = i
while j < n and a[j] == current_val:
j += 1
# 1. 计算 sub_a 的数量
# 小于 current_val 的元素数量是 i
# 小于等于 current_val 的元素数量是 j
lt_count = i
le_count = j
num_sub_a = (pow2[le_count] - pow2[lt_count] + MOD) % MOD
# 2. 计算 sub_b 的数量
ge_idx = bisect.bisect_left(b, current_val)
ge_count = n - ge_idx
if ge_count > 0:
num_sub_b = (pow2[ge_count] - 1 + MOD) % MOD
# 3. 累加
ans = (ans + num_sub_a * num_sub_b) % MOD
i = j
print(ans)
if __name__ == "__main__":
main()
算法及复杂度
- 算法:组合数学、排序、二分查找
- 时间复杂度:
- 排序两个数组需要
。
- 预计算
的幂次需要
。
- 遍历数组
的不重复元素,对于每个元素,在
中进行一次二分查找。总共是
。
- 综上,总时间复杂度为
。
- 排序两个数组需要
- 空间复杂度:需要存储预计算的幂次数组,空间复杂度为
。