小O的子序列最值(二)

[题目链接](https://www.nowcoder.com/practice/ccde5aaa096f418b81ee6ed9dc253504)

题意

给定两个长度为 的数组 ,从 中各选一个非空子序列,要求 所选子序列的最大值不大于 所选子序列的最小值。求满足条件的选法总数,答案对 取模。

思路

枚举最大值分组

将两个数组分别排序。

对于从 中选取的子序列,其最大值是确定的。我们可以枚举 中每个不同的值 作为所选子序列的最大值,然后对每个

  1. 统计从 中选出最大值恰好为 的非空子序列数。
  2. 统计从 中选出最小值 的非空子序列数。

两者相乘即为以 为界的贡献,所有 的贡献求和即为答案。

计数公式

设排序后的 中,值 首次出现位置为 ,末次出现位置为 (0-indexed)。

A 中最大值恰好为 的非空子序列数:

  • 必须从 (所有等于 的元素)中至少选 1 个:共 种。
  • 中所有元素均小于 ,可以任意选取:共 种。
  • 合计:

B 中最小值 的非空子序列数:

  • 的元素个数为 (用二分查找得到)。
  • 从这 个元素中选至少 1 个:

复杂度

排序 ,枚举每个不同值并二分查找 ,预处理 2 的幂次 。总体

示例验证

输入:

2
1 2
3 4

排序后

  • ):countA = 的有 2 个,countB = ;贡献
  • ):countA = 的有 2 个,countB = ;贡献

总答案 = ,与期望一致。

代码

C++

#include <bits/stdc++.h>
using namespace std;

const int MOD = 1e9 + 7;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    cin >> n;

    vector<int> a(n), b(n);
    for (int i = 0; i < n; i++) cin >> a[i];
    for (int i = 0; i < n; i++) cin >> b[i];

    sort(a.begin(), a.end());
    sort(b.begin(), b.end());

    vector<long long> pw(n + 1);
    pw[0] = 1;
    for (int i = 1; i <= n; i++) pw[i] = pw[i-1] * 2 % MOD;

    long long ans = 0;
    int i = 0;
    while (i < n) {
        int j = i;
        while (j < n && a[j] == a[i]) j++;
        int v = a[i];
        int lo = i, hi = j - 1;
        long long countA = pw[lo] * ((pw[hi - lo + 1] - 1 + MOD) % MOD) % MOD;
        int cb = n - (int)(lower_bound(b.begin(), b.end(), v) - b.begin());
        long long countB = (pw[cb] - 1 + MOD) % MOD;
        ans = (ans + countA * countB) % MOD;
        i = j;
    }

    cout << ans << endl;
    return 0;
}

Java

import java.util.*;
import java.io.*;

public class Main {
    static final long MOD = 1_000_000_007L;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(br.readLine().trim());

        int[] a = new int[n], b = new int[n];
        StringTokenizer st = new StringTokenizer(br.readLine());
        for (int i = 0; i < n; i++) a[i] = Integer.parseInt(st.nextToken());
        st = new StringTokenizer(br.readLine());
        for (int i = 0; i < n; i++) b[i] = Integer.parseInt(st.nextToken());

        Arrays.sort(a);
        Arrays.sort(b);

        long[] pw = new long[n + 1];
        pw[0] = 1;
        for (int i = 1; i <= n; i++) pw[i] = pw[i-1] * 2 % MOD;

        long ans = 0;
        int i = 0;
        while (i < n) {
            int j = i;
            while (j < n && a[j] == a[i]) j++;
            int v = a[i];
            int lo = i, hi = j - 1;
            long countA = pw[lo] * ((pw[hi - lo + 1] - 1 + MOD) % MOD) % MOD;
            int lb = lowerBound(b, v);
            int cb = n - lb;
            long countB = (pw[cb] - 1 + MOD) % MOD;
            ans = (ans + countA * countB) % MOD;
            i = j;
        }

        System.out.println(ans);
    }

    static int lowerBound(int[] arr, int val) {
        int lo = 0, hi = arr.length;
        while (lo < hi) {
            int mid = (lo + hi) >>> 1;
            if (arr[mid] < val) lo = mid + 1;
            else hi = mid;
        }
        return lo;
    }
}