小红的最大中位数
[题目链接](https://www.nowcoder.com/practice/daaf3691ed1c4fa49b1bf38821b1348d)
思路
给定一个长度为 的数组,要求选出一个非空子序列,使其中位数尽可能大,并统计能取到最大中位数的子序列数量。奇数长度子序列的中位数为排序后正中间的数,偶数长度子序列的中位数为排序后中间两个数的平均值。答案对
取模。
最大中位数是什么?
设数组的最大值为 。任何只包含
的子序列,中位数显然就是
。而任何子序列的中位数都不可能超过
(因为所有元素
),所以最大中位数一定等于数组的最大值
。
哪些子序列的中位数等于
?
设子序列长度为 ,其中包含
个值为
的元素和
个值小于
的元素。将子序列排序后,前
个位置是小于
的元素,后
个位置都是
。
- 奇数长度
:中位数在第
个位置,需要该位置是
,即
,等价于
。
- 偶数长度
:中位数是第
和第
个位置的平均值,要等于
则两个位置都必须是
,即
,同样等价于
。
因此条件统一为:选出的 的个数严格大于非
的个数,即
。
计数
设数组中 出现
次,其余元素有
个。答案为:
$$
记 ,则答案变为
。
分两段计算:
到
:此时
,逐步递推
。每次
,增量计算即可。
到
(若
):此时
,贡献为
。后者可用
计算。
样例演示
数组 ,最大值
,
,
。
| 方案数 | ||
|---|---|---|
| 1 | ||
| 2 |
合计 。
复杂度分析
- 时间复杂度:
,预处理阶乘和逆元后,枚举
的循环总量为
。
- 空间复杂度:
,存储阶乘数组。
代码
#include <bits/stdc++.h>
using namespace std;
const int MOD = 1e9 + 7;
long long power(long long a, long long b, long long mod) {
long long res = 1;
a %= mod;
while (b > 0) {
if (b & 1) res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
int main() {
int n;
scanf("%d", &n);
vector<int> a(n);
for (int i = 0; i < n; i++) scanf("%d", &a[i]);
int M = *max_element(a.begin(), a.end());
int cnt = 0;
for (int x : a) if (x == M) cnt++;
int rest = n - cnt;
int maxn = n + 1;
vector<long long> fac(maxn), inv_fac(maxn);
fac[0] = 1;
for (int i = 1; i < maxn; i++) fac[i] = fac[i - 1] * i % MOD;
inv_fac[maxn - 1] = power(fac[maxn - 1], MOD - 2, MOD);
for (int i = maxn - 2; i >= 0; i--) inv_fac[i] = inv_fac[i + 1] * (i + 1) % MOD;
auto C = [&](int n, int k) -> long long {
if (k < 0 || k > n) return 0;
return fac[n] % MOD * inv_fac[k] % MOD * inv_fac[n - k] % MOD;
};
long long ans = 0;
// 第一段:j = 1 到 min(rest, cnt),递推前缀和 S(j-1)
long long prefS = 1; // S(0) = C(rest, 0) = 1
for (int j = 1; j <= min(rest, cnt); j++) {
ans = (ans + C(cnt, j) * prefS) % MOD;
prefS = (prefS + C(rest, j)) % MOD;
}
// 第二段:j = rest+1 到 cnt,内层和 = 2^rest
if (cnt > rest) {
long long pow2rest = power(2, rest, MOD);
long long sumC = 0;
for (int j = 0; j <= rest; j++) {
sumC = (sumC + C(cnt, j)) % MOD;
}
long long total = (power(2, cnt, MOD) - sumC + MOD) % MOD;
ans = (ans + total * pow2rest) % MOD;
}
printf("%lld\n", ans);
return 0;
}
import java.util.*;
public class Main {
static final long MOD = 1_000_000_007;
static long power(long a, long b, long mod) {
long res = 1;
a %= mod;
while (b > 0) {
if ((b & 1) == 1) res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int[] a = new int[n];
int M = 0;
for (int i = 0; i < n; i++) {
a[i] = sc.nextInt();
M = Math.max(M, a[i]);
}
int cnt = 0;
for (int x : a) if (x == M) cnt++;
int rest = n - cnt;
int maxn = n + 1;
long[] fac = new long[maxn];
long[] invFac = new long[maxn];
fac[0] = 1;
for (int i = 1; i < maxn; i++) fac[i] = fac[i - 1] * i % MOD;
invFac[maxn - 1] = power(fac[maxn - 1], MOD - 2, MOD);
for (int i = maxn - 2; i >= 0; i--) invFac[i] = invFac[i + 1] * (i + 1) % MOD;
long ans = 0;
long prefS = 1;
int limit = Math.min(rest, cnt);
for (int j = 1; j <= limit; j++) {
long cntJ = fac[cnt] % MOD * invFac[j] % MOD * invFac[cnt - j] % MOD;
ans = (ans + cntJ * prefS) % MOD;
long cRestJ = fac[rest] % MOD * invFac[j] % MOD * invFac[rest - j] % MOD;
prefS = (prefS + cRestJ) % MOD;
}
if (cnt > rest) {
long pow2rest = power(2, rest, MOD);
long sumC = 0;
for (int j = 0; j <= rest; j++) {
long c = fac[cnt] % MOD * invFac[j] % MOD * invFac[cnt - j] % MOD;
sumC = (sumC + c) % MOD;
}
long total = (power(2, cnt, MOD) - sumC + MOD) % MOD;
ans = (ans + total * pow2rest) % MOD;
}
System.out.println(ans);
}
}

京公网安备 11010502036488号