题目链接
题目描述
给定一个长度为 的二进制数组(元素为0或1)。对于该数组所有长度恰好为
的子序列(
为奇数),求出它们的中位数之和,结果对
取模。
- 子序列: 从原数组中删除任意数量(可以为0)的元素,剩下的元素保持原相对顺序组成的序列。
- 中位数: 长度为奇数
的数组,排序后的第
个元素。
输入:
- 第一行一个整数
表示测试用例数量。
- 每个测试用例第一行是
。第二行是
个0或1。
输出:
- 对每个测试用例,输出中位数之和的模。
解题思路
直接枚举所有长度为 的子序列是不可行的,数量级会非常巨大。我们需要转换思路。
问题的核心突破口在于:
- 数组是二进制的。
- 我们要求的是中位数之和。
因为数组中只有0和1,所以任何子序列的中位数也只可能是0或1。因此,中位数之和就等于中位数为1的子序列的数量。
问题转化为:在所有长度为 的子序列中,有多少个的中位数是1?
设一个长度为 的子序列,当它被排序后,会呈现出
[0, 0, ..., 1, 1, ...]
的形式。
中位数是第 个元素。
要使中位数为1,当且仅当排序后第
个元素是1。这也就意味着,这个子序列中 1的数量必须大于等于
。
所以,我们的最终目标是:统计有多少长度为 的子序列,其中包含至少
个1。
这变成了一个组合计数问题。
- 首先,遍历原始数组,统计出其中
1
的总数,记为count_one
,和0
的总数count_zero
。 - 我们枚举子序列中
1
的数量,设为i
。根据中位数为1的条件,的取值范围是
。
- 对于一个固定的
i
,我们要构造一个包含i
个1和m-i
个0的子序列。- 从
count_one
个1中选出i
个,方案数为。
- 从
count_zero
个0中选出m-i
个,方案数为。
- 根据乘法原理,构成这样一个子序列的方案数是
。
- 从
- 根据加法原理,将所有可能的
i
(从到
)的方案数累加起来,就是最终答案。 总方案数 =
为了高效计算组合数 ,我们可以预先计算出阶乘及其模逆元。由于所有测试用例的
之和不超过
,我们可以一次性预处理到
。
代码
#include <iostream>
#include <vector>
#include <numeric>
using namespace std;
using LL = long long;
const int MOD = 1e9 + 7;
const int MAXN = 200000;
LL fact[MAXN + 1];
LL invFact[MAXN + 1];
LL power(LL base, LL exp) {
LL res = 1;
base %= MOD;
while (exp > 0) {
if (exp % 2 == 1) res = (res * base) % MOD;
base = (base * base) % MOD;
exp /= 2;
}
return res;
}
void precompute() {
fact[0] = 1;
for (int i = 1; i <= MAXN; i++) {
fact[i] = (fact[i - 1] * i) % MOD;
}
invFact[MAXN] = power(fact[MAXN], MOD - 2);
for (int i = MAXN - 1; i >= 0; i--) {
invFact[i] = (invFact[i + 1] * (i + 1)) % MOD;
}
}
LL combinations(int n, int k) {
if (k < 0 || k > n) {
return 0;
}
return (((fact[n] * invFact[k]) % MOD) * invFact[n - k]) % MOD;
}
void solve() {
int n, m;
cin >> n >> m;
int count_one = 0;
for (int i = 0; i < n; ++i) {
int x;
cin >> x;
if (x == 1) count_one++;
}
int count_zero = n - count_one;
int k = (m + 1) / 2;
LL ans = 0;
for (int i = k; i <= m; ++i) {
LL term = (combinations(count_one, i) * combinations(count_zero, m - i)) % MOD;
ans = (ans + term) % MOD;
}
cout << ans << '\n';
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
precompute();
int t;
cin >> t;
while (t--) {
solve();
}
return 0;
}
import java.util.Scanner;
public class Main {
static final int MOD = 1_000_000_007;
static final int MAXN = 200000;
static long[] fact = new long[MAXN + 1];
static long[] invFact = new long[MAXN + 1];
static long power(long base, long exp) {
long res = 1;
base %= MOD;
while (exp > 0) {
if (exp % 2 == 1) res = (res * base) % MOD;
base = (base * base) % MOD;
exp /= 2;
}
return res;
}
static void precompute() {
fact[0] = 1;
for (int i = 1; i <= MAXN; i++) {
fact[i] = (fact[i - 1] * i) % MOD;
}
invFact[MAXN] = power(fact[MAXN], MOD - 2);
for (int i = MAXN - 1; i >= 0; i--) {
invFact[i] = (invFact[i + 1] * (i + 1)) % MOD;
}
}
static long combinations(int n, int k) {
if (k < 0 || k > n) return 0;
return (((fact[n] * invFact[k]) % MOD) * invFact[n - k]) % MOD;
}
static void solve(Scanner sc) {
int n = sc.nextInt();
int m = sc.nextInt();
int count_one = 0;
for (int i = 0; i < n; i++) {
if (sc.nextInt() == 1) count_one++;
}
int count_zero = n - count_one;
int k = (m + 1) / 2;
long ans = 0;
for (int i = k; i <= m; i++) {
long term = (combinations(count_one, i) * combinations(count_zero, m - i)) % MOD;
ans = (ans + term) % MOD;
}
System.out.println(ans);
}
public static void main(String[] args) {
precompute();
Scanner sc = new Scanner(System.in);
int t = sc.nextInt();
while (t-- > 0) {
solve(sc);
}
}
}
MOD = 1_000_000_007
MAXN = 200000
# 全局预处理阶乘和逆元
fact = [1] * (MAXN + 1)
invFact = [1] * (MAXN + 1)
def power(base, exp):
res = 1
base %= MOD
while exp > 0:
if exp % 2 == 1:
res = (res * base) % MOD
base = (base * base) % MOD
exp //= 2
return res
def precompute():
fact[0] = 1
for i in range(1, MAXN + 1):
fact[i] = (fact[i - 1] * i) % MOD
# 计算 fact[MAXN] 的逆元
invFact[MAXN] = power(fact[MAXN], MOD - 2)
# 从后向前推导所有阶乘的逆元
for i in range(MAXN - 1, -1, -1):
invFact[i] = (invFact[i + 1] * (i + 1)) % MOD
def combinations(n, k):
if k < 0 or k > n:
return 0
# C(n, k) = n! / (k! * (n-k)!)
# C(n, k) = fact[n] * invFact[k] * invFact[n-k]
numerator = fact[n]
denominator = (invFact[k] * invFact[n - k]) % MOD
return (numerator * denominator) % MOD
def solve():
n, m = map(int, input().split())
# 读取数组并计算1的数量
arr = list(map(int, input().split()))
count_one = sum(arr)
count_zero = n - count_one
# 中位数为1,则子序列中1的数量至少为 k
k = (m + 1) // 2
ans = 0
# 枚举子序列中1的数量 i (从 k到m)
for i in range(k, m + 1):
# 从 count_one 个1中选i个, 从 count_zero 个0中选 m-i 个
# 组合数 C(count_one, i) * C(count_zero, m - i)
if i > count_one or m - i > count_zero:
continue # 如果选的数量超过总数,则跳过
term = (combinations(count_one, i) * combinations(count_zero, m - i)) % MOD
ans = (ans + term) % MOD
print(ans)
def main():
precompute()
t = int(input())
for _ in range(t):
solve()
if __name__ == "__main__":
main()
算法及复杂度
- 算法:组合数学、计数原理、费马小定理、快速幂
- 时间复杂度:预处理为
。对于每个测试用例,复杂度为
,其中
用于读入和计数,
用于循环累加组合数。由于所有
的总和以及
,总时间复杂度为
,符合题目要求。
- 空间复杂度:
,用于存储预计算的阶乘和逆元数组。