小O的子序列最值(二)
[题目链接](https://www.nowcoder.com/practice/ccde5aaa096f418b81ee6ed9dc253504)
题意
给定两个长度为 的数组
和
,从
和
中各选一个非空子序列,要求
所选子序列的最大值不大于
所选子序列的最小值。求满足条件的选法总数,答案对
取模。
思路
枚举最大值分组
将两个数组分别排序。
对于从 中选取的子序列,其最大值是确定的。我们可以枚举
中每个不同的值
作为所选子序列的最大值,然后对每个
:
- 统计从
中选出最大值恰好为
的非空子序列数。
- 统计从
中选出最小值
的非空子序列数。
两者相乘即为以 为界的贡献,所有
的贡献求和即为答案。
计数公式
设排序后的 中,值
首次出现位置为
,末次出现位置为
(0-indexed)。
A 中最大值恰好为 的非空子序列数:
- 必须从
(所有等于
的元素)中至少选 1 个:共
种。
中所有元素均小于
,可以任意选取:共
种。
- 合计:
B 中最小值 的非空子序列数:
- 设
中
的元素个数为
(用二分查找得到)。
- 从这
个元素中选至少 1 个:
。
复杂度
排序 ,枚举每个不同值并二分查找
,预处理 2 的幂次
。总体
。
示例验证
输入:
2
1 2
3 4
排序后 ,
。
(
):countA =
;
中
的有 2 个,countB =
;贡献
。
(
):countA =
;
中
的有 2 个,countB =
;贡献
。
总答案 = ,与期望一致。
代码
C++
#include <bits/stdc++.h>
using namespace std;
const int MOD = 1e9 + 7;
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n;
cin >> n;
vector<int> a(n), b(n);
for (int i = 0; i < n; i++) cin >> a[i];
for (int i = 0; i < n; i++) cin >> b[i];
sort(a.begin(), a.end());
sort(b.begin(), b.end());
vector<long long> pw(n + 1);
pw[0] = 1;
for (int i = 1; i <= n; i++) pw[i] = pw[i-1] * 2 % MOD;
long long ans = 0;
int i = 0;
while (i < n) {
int j = i;
while (j < n && a[j] == a[i]) j++;
int v = a[i];
int lo = i, hi = j - 1;
long long countA = pw[lo] * ((pw[hi - lo + 1] - 1 + MOD) % MOD) % MOD;
int cb = n - (int)(lower_bound(b.begin(), b.end(), v) - b.begin());
long long countB = (pw[cb] - 1 + MOD) % MOD;
ans = (ans + countA * countB) % MOD;
i = j;
}
cout << ans << endl;
return 0;
}
Java
import java.util.*;
import java.io.*;
public class Main {
static final long MOD = 1_000_000_007L;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
int n = Integer.parseInt(br.readLine().trim());
int[] a = new int[n], b = new int[n];
StringTokenizer st = new StringTokenizer(br.readLine());
for (int i = 0; i < n; i++) a[i] = Integer.parseInt(st.nextToken());
st = new StringTokenizer(br.readLine());
for (int i = 0; i < n; i++) b[i] = Integer.parseInt(st.nextToken());
Arrays.sort(a);
Arrays.sort(b);
long[] pw = new long[n + 1];
pw[0] = 1;
for (int i = 1; i <= n; i++) pw[i] = pw[i-1] * 2 % MOD;
long ans = 0;
int i = 0;
while (i < n) {
int j = i;
while (j < n && a[j] == a[i]) j++;
int v = a[i];
int lo = i, hi = j - 1;
long countA = pw[lo] * ((pw[hi - lo + 1] - 1 + MOD) % MOD) % MOD;
int lb = lowerBound(b, v);
int cb = n - lb;
long countB = (pw[cb] - 1 + MOD) % MOD;
ans = (ans + countA * countB) % MOD;
i = j;
}
System.out.println(ans);
}
static int lowerBound(int[] arr, int val) {
int lo = 0, hi = arr.length;
while (lo < hi) {
int mid = (lo + hi) >>> 1;
if (arr[mid] < val) lo = mid + 1;
else hi = mid;
}
return lo;
}
}

京公网安备 11010502036488号