题目链接

小O的子序列最值(二)

题目描述

给定两个长度为 的数组

需要从这两个数组中分别选出一个非空的子序列,我们称之为

要求满足条件: 的最大值不大于 的最小值。

求解总共有多少种满足条件的选取方法,答案需要对 取模。

解题思路

这是一个组合计数问题。直接枚举所有子序列对的数量级是 ,会严重超时。我们需要一种更高效的计数方法。

核心思想

我们可以转变计数的主体。与其枚举子序列,我们可以枚举第一个子序列 可能的最大值 ,然后计算在这种情况下,有多少对 满足条件。

具体步骤

  1. 预处理

    • 对数组 进行升序排序。这使得我们能够快速地统计某个范围内元素的数量。
    • 预计算 的幂次,因为一个包含 个元素的集合,其非空子序列的数量是
  2. 枚举与计算: 我们遍历排序后的数组 中的每一个元素 ,并考虑当 的最大值为 时的情况。

    • 第一步:计算 的数量 要使 的最大值恰好为 ,需要满足两个条件:

      1. 中至少包含一个
      2. 中所有其他元素的值都小于或等于

      这个数量可以通过容斥原理计算: (所有元素都 的非空子序列数量) - (所有元素都 的非空子序列数量)

      设数组 的元素有 count_le 个, 的元素有 count_lt 个。 则满足条件的 数量为 。 由于数组 已经排序,count_lecount_lt 可以通过下标快速得到。

    • 第二步:计算 的数量 的最大值为 时,条件变为 。这意味着 中的所有元素都必须大于或等于

      我们需要计算数组 中有多少个元素 。设其数量为 count_ge。 那么,由这些元素构成的任何非空子序列都满足条件。这样的 数量为 。 由于数组 已经排序,count_ge 可以通过二分查找(lower_bound)高效地计算出来。

    • 第三步:累加 将第一步和第二步的结果相乘,就得到了当 时的总方案数。我们遍历 中所有不重复的值,将每种情况的方案数累加,即可得到最终答案。

代码

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

const int MOD = 1e9 + 7;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    cin >> n;

    vector<long long> a(n), b(n);
    for (int i = 0; i < n; ++i) cin >> a[i];
    for (int i = 0; i < n; ++i) cin >> b[i];

    sort(a.begin(), a.end());
    sort(b.begin(), b.end());

    vector<long long> pow2(n + 1);
    pow2[0] = 1;
    for (int i = 1; i <= n; ++i) {
        pow2[i] = (pow2[i - 1] * 2) % MOD;
    }

    long long ans = 0;
    int i = 0;
    while (i < n) {
        long long current_val = a[i];
        
        int j = i;
        while (j < n && a[j] == current_val) {
            j++;
        }
        
        // 1. 计算 sub_a 的数量
        // 小于 current_val 的元素数量是 i
        // 小于等于 current_val 的元素数量是 j
        long long lt_count = i;
        long long le_count = j;
        long long num_sub_a = (pow2[le_count] - pow2[lt_count] + MOD) % MOD;

        // 2. 计算 sub_b 的数量
        auto it = lower_bound(b.begin(), b.end(), current_val);
        long long ge_count = distance(it, b.end());
        
        if (ge_count > 0) {
            long long num_sub_b = (pow2[ge_count] - 1 + MOD) % MOD;
            // 3. 累加
            ans = (ans + (num_sub_a * num_sub_b) % MOD) % MOD;
        }
        
        i = j;
    }

    cout << ans << endl;

    return 0;
}
import java.util.Arrays;
import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        long[] a = new long[n];
        long[] b = new long[n];
        for (int i = 0; i < n; i++) a[i] = sc.nextLong();
        for (int i = 0; i < n; i++) b[i] = sc.nextLong();

        Arrays.sort(a);
        Arrays.sort(b);

        long MOD = 1_000_000_007;

        long[] pow2 = new long[n + 1];
        pow2[0] = 1;
        for (int i = 1; i <= n; i++) {
            pow2[i] = (pow2[i - 1] * 2) % MOD;
        }

        long ans = 0;
        int i = 0;
        while (i < n) {
            long currentVal = a[i];
            
            int j = i;
            while (j < n && a[j] == currentVal) {
                j++;
            }
            
            // 1. 计算 sub_a 的数量
            int ltCount = i;
            int leCount = j;
            long numSubA = (pow2[leCount] - pow2[ltCount] + MOD) % MOD;
            
            // 2. 计算 sub_b 的数量
            int geIndex = binarySearchLowerBound(b, currentVal);
            int geCount = n - geIndex;
            
            if (geCount > 0) {
                long numSubB = (pow2[geCount] - 1 + MOD) % MOD;
                // 3. 累加
                ans = (ans + (numSubA * numSubB) % MOD) % MOD;
            }
            
            i = j;
        }

        System.out.println(ans);
    }

    // 手动实现 lower_bound
    private static int binarySearchLowerBound(long[] arr, long key) {
        int low = 0;
        int high = arr.length;
        while (low < high) {
            int mid = low + (high - low) / 2;
            if (arr[mid] >= key) {
                high = mid;
            } else {
                low = mid + 1;
            }
        }
        return low;
    }
}
import bisect

def main():
    n = int(input())
    a = list(map(int, input().split()))
    b = list(map(int, input().split()))

    a.sort()
    b.sort()

    MOD = 10**9 + 7

    pow2 = [1] * (n + 1)
    for i in range(1, n + 1):
        pow2[i] = (pow2[i - 1] * 2) % MOD

    ans = 0
    i = 0
    while i < n:
        current_val = a[i]
        
        j = i
        while j < n and a[j] == current_val:
            j += 1
        
        # 1. 计算 sub_a 的数量
        # 小于 current_val 的元素数量是 i
        # 小于等于 current_val 的元素数量是 j
        lt_count = i
        le_count = j
        num_sub_a = (pow2[le_count] - pow2[lt_count] + MOD) % MOD
        
        # 2. 计算 sub_b 的数量
        ge_idx = bisect.bisect_left(b, current_val)
        ge_count = n - ge_idx
        
        if ge_count > 0:
            num_sub_b = (pow2[ge_count] - 1 + MOD) % MOD
            # 3. 累加
            ans = (ans + num_sub_a * num_sub_b) % MOD
            
        i = j

    print(ans)

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:组合数学、排序、二分查找
  • 时间复杂度:
    • 排序两个数组需要
    • 预计算 的幂次需要
    • 遍历数组 的不重复元素,对于每个元素,在 中进行一次二分查找。总共是
    • 综上,总时间复杂度为
  • 空间复杂度:需要存储预计算的幂次数组,空间复杂度为