小红的最大中位数

[题目链接](https://www.nowcoder.com/practice/daaf3691ed1c4fa49b1bf38821b1348d)

思路

给定一个长度为 的数组,要求选出一个非空子序列,使其中位数尽可能大,并统计能取到最大中位数的子序列数量。奇数长度子序列的中位数为排序后正中间的数,偶数长度子序列的中位数为排序后中间两个数的平均值。答案对 取模。

最大中位数是什么?

设数组的最大值为 。任何只包含 的子序列,中位数显然就是 。而任何子序列的中位数都不可能超过 (因为所有元素 ),所以最大中位数一定等于数组的最大值

哪些子序列的中位数等于

设子序列长度为 ,其中包含 个值为 的元素和 个值小于 的元素。将子序列排序后,前 个位置是小于 的元素,后 个位置都是

  • 奇数长度 :中位数在第 个位置,需要该位置是 ,即 ,等价于
  • 偶数长度 :中位数是第 和第 个位置的平均值,要等于 则两个位置都必须是 ,即 ,同样等价于

因此条件统一为:选出的 的个数严格大于非 的个数,即

计数

设数组中 出现 次,其余元素有 个。答案为:

$$

,则答案变为

分两段计算:

  1. :此时 ,逐步递推 。每次 ,增量计算即可。
  2. (若 ):此时 ,贡献为 。后者可用 计算。

样例演示

数组 ,最大值

的范围 方案数
1
2

合计

复杂度分析

  • 时间复杂度:,预处理阶乘和逆元后,枚举 的循环总量为
  • 空间复杂度:,存储阶乘数组。

代码

#include <bits/stdc++.h>
using namespace std;

const int MOD = 1e9 + 7;

long long power(long long a, long long b, long long mod) {
    long long res = 1;
    a %= mod;
    while (b > 0) {
        if (b & 1) res = res * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return res;
}

int main() {
    int n;
    scanf("%d", &n);
    vector<int> a(n);
    for (int i = 0; i < n; i++) scanf("%d", &a[i]);

    int M = *max_element(a.begin(), a.end());
    int cnt = 0;
    for (int x : a) if (x == M) cnt++;
    int rest = n - cnt;

    int maxn = n + 1;
    vector<long long> fac(maxn), inv_fac(maxn);
    fac[0] = 1;
    for (int i = 1; i < maxn; i++) fac[i] = fac[i - 1] * i % MOD;
    inv_fac[maxn - 1] = power(fac[maxn - 1], MOD - 2, MOD);
    for (int i = maxn - 2; i >= 0; i--) inv_fac[i] = inv_fac[i + 1] * (i + 1) % MOD;

    auto C = [&](int n, int k) -> long long {
        if (k < 0 || k > n) return 0;
        return fac[n] % MOD * inv_fac[k] % MOD * inv_fac[n - k] % MOD;
    };

    long long ans = 0;

    // 第一段:j = 1 到 min(rest, cnt),递推前缀和 S(j-1)
    long long prefS = 1; // S(0) = C(rest, 0) = 1
    for (int j = 1; j <= min(rest, cnt); j++) {
        ans = (ans + C(cnt, j) * prefS) % MOD;
        prefS = (prefS + C(rest, j)) % MOD;
    }

    // 第二段:j = rest+1 到 cnt,内层和 = 2^rest
    if (cnt > rest) {
        long long pow2rest = power(2, rest, MOD);
        long long sumC = 0;
        for (int j = 0; j <= rest; j++) {
            sumC = (sumC + C(cnt, j)) % MOD;
        }
        long long total = (power(2, cnt, MOD) - sumC + MOD) % MOD;
        ans = (ans + total * pow2rest) % MOD;
    }

    printf("%lld\n", ans);
    return 0;
}
import java.util.*;

public class Main {
    static final long MOD = 1_000_000_007;

    static long power(long a, long b, long mod) {
        long res = 1;
        a %= mod;
        while (b > 0) {
            if ((b & 1) == 1) res = res * a % mod;
            a = a * a % mod;
            b >>= 1;
        }
        return res;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int[] a = new int[n];
        int M = 0;
        for (int i = 0; i < n; i++) {
            a[i] = sc.nextInt();
            M = Math.max(M, a[i]);
        }
        int cnt = 0;
        for (int x : a) if (x == M) cnt++;
        int rest = n - cnt;

        int maxn = n + 1;
        long[] fac = new long[maxn];
        long[] invFac = new long[maxn];
        fac[0] = 1;
        for (int i = 1; i < maxn; i++) fac[i] = fac[i - 1] * i % MOD;
        invFac[maxn - 1] = power(fac[maxn - 1], MOD - 2, MOD);
        for (int i = maxn - 2; i >= 0; i--) invFac[i] = invFac[i + 1] * (i + 1) % MOD;

        long ans = 0;

        long prefS = 1;
        int limit = Math.min(rest, cnt);
        for (int j = 1; j <= limit; j++) {
            long cntJ = fac[cnt] % MOD * invFac[j] % MOD * invFac[cnt - j] % MOD;
            ans = (ans + cntJ * prefS) % MOD;
            long cRestJ = fac[rest] % MOD * invFac[j] % MOD * invFac[rest - j] % MOD;
            prefS = (prefS + cRestJ) % MOD;
        }

        if (cnt > rest) {
            long pow2rest = power(2, rest, MOD);
            long sumC = 0;
            for (int j = 0; j <= rest; j++) {
                long c = fac[cnt] % MOD * invFac[j] % MOD * invFac[cnt - j] % MOD;
                sumC = (sumC + c) % MOD;
            }
            long total = (power(2, cnt, MOD) - sumC + MOD) % MOD;
            ans = (ans + total * pow2rest) % MOD;
        }

        System.out.println(ans);
    }
}