题目链接
题目描述
给定一个长度为 的二进制数组
(每个元素为
或
)。记
为奇数。对于数组
的所有长度恰为
的子序列,求它们中位数之和,并对
取模。
名词解释:
- 子序列:如果数组
可以从
中删除几个(可能是零)元素得到,那么
就是
的子序列。
- 中位数:长度为奇数
的数组的中位数是排序后的第
个元素。
解题思路
这是一个组合计数问题。直接枚举所有长度为 的子序列是不可行的,因为其数量可能非常巨大。我们需要找到一种更高效的数学方法。
1. 问题转化
由于数组是二进制的(只包含 和
),任何子序列排序后都会形成一段连续的
和一段连续的
(其中一段可能为空)。
一个长度为 的子序列的中位数,只可能是
或
。
因此,所有子序列的中位数之和,就等于 (中位数为0的子序列数量 * 0) + (中位数为1的子序列数量 * 1)
,这恰好等于中位数为1的子序列的数量。
所以,原问题被转化为了:计算有多少个长度为 的子序列,其排序后的中位数为
?
2. 中位数为 1 的条件
对于一个长度为 (
为奇数)的二进制子序列,它的中位数是排序后的第
mid_pos = (k+1)/2
个元素。
要使这个元素为 ,那么在这个子序列中,
的数量必须至少为
mid_pos
个。
- 如果
的数量小于
mid_pos
,那么的数量就大于或等于
k - mid_pos + 1 = k - (k+1)/2 + 1 = (k-1)/2 + 1 = (k+1)/2
。排序后,第mid_pos
个位置必然是。
- 如果
的数量大于或等于
mid_pos
,那么的数量就小于或等于
k - mid_pos = (k-1)/2
。排序后,前(k-1)/2
个位置最多只能被填满,第
mid_pos
个位置必然是。
结论:子序列的中位数为 ,当且仅当该子序列中
的个数不少于
。
3. 组合计数
现在问题变成了:从原数组 中,选出一个长度为
的子序列,要求其中
的个数大于等于
,求这样的子序列有多少个。
我们可以分类讨论子序列中 的个数。
-
首先,我们统计原数组
中
和
的总数,记为
zeros
和ones
。 -
我们枚举子序列中
的个数
,根据中位数为
的条件,
的取值范围是
。
-
对于一个固定的
,我们要在子序列中包含
个
和
个
。
- 从
ones
个中选出
个的方法数是
。
- 从
zeros
个中选出
个的方法数是
。
根据乘法原理,包含
个
和
个
的子序列数量为
。
- 从
-
根据加法原理,将所有可能的
的情况相加,就是最终的答案:
4. 实现
由于 的范围较大,我们需要预处理阶乘和阶乘的逆元,以便在
的时间内查询组合数。
- 预处理:计算
到
的值及其模
的逆元。
- 查询:对于每个测试用例,先统计
和
的个数,然后循环
从
到
,累加组合数乘积即可。
代码
#include <iostream>
#include <vector>
#include <numeric>
using namespace std;
const int MOD = 1000000007;
const int MAX_N = 200001;
vector<long long> fact(MAX_N);
vector<long long> invFact(MAX_N);
long long power(long long base, long long exp) {
long long res = 1;
base %= MOD;
while (exp > 0) {
if (exp % 2 == 1) res = (res * base) % MOD;
base = (base * base) % MOD;
exp /= 2;
}
return res;
}
void precompute() {
fact[0] = 1;
for (int i = 1; i < MAX_N; i++) {
fact[i] = (fact[i - 1] * i) % MOD;
}
invFact[MAX_N - 1] = power(fact[MAX_N - 1], MOD - 2);
for (int i = MAX_N - 2; i >= 0; i--) {
invFact[i] = (invFact[i + 1] * (i + 1)) % MOD;
}
}
long long nCr_mod_p(int n, int r) {
if (r < 0 || r > n) {
return 0;
}
return (((fact[n] * invFact[r]) % MOD) * invFact[n - r]) % MOD;
}
void solve() {
int n, k;
cin >> n >> k;
int ones = 0;
for (int i = 0; i < n; ++i) {
int val;
cin >> val;
if (val == 1) {
ones++;
}
}
int zeros = n - ones;
long long total_sum = 0;
int mid_pos = (k + 1) / 2;
for (int i = mid_pos; i <= k; ++i) {
long long comb_ones = nCr_mod_p(ones, i);
long long comb_zeros = nCr_mod_p(zeros, k - i);
long long term = (comb_ones * comb_zeros) % MOD;
total_sum = (total_sum + term) % MOD;
}
cout << total_sum << endl;
}
int main() {
precompute();
int T;
cin >> T;
while (T--) {
solve();
}
return 0;
}
import java.util.Scanner;
public class Main {
static final int MOD = 1000000007;
static final int MAX_N = 200001;
static long[] fact = new long[MAX_N];
static long[] invFact = new long[MAX_N];
public static long power(long base, long exp) {
long res = 1;
base %= MOD;
while (exp > 0) {
if (exp % 2 == 1) res = (res * base) % MOD;
base = (base * base) % MOD;
exp /= 2;
}
return res;
}
public static void precompute() {
fact[0] = 1;
for (int i = 1; i < MAX_N; i++) {
fact[i] = (fact[i - 1] * i) % MOD;
}
invFact[MAX_N - 1] = power(fact[MAX_N - 1], MOD - 2);
for (int i = MAX_N - 2; i >= 0; i--) {
invFact[i] = (invFact[i + 1] * (i + 1)) % MOD;
}
}
public static long nCr_mod_p(int n, int r) {
if (r < 0 || r > n) {
return 0;
}
return (((fact[n] * invFact[r]) % MOD) * invFact[n - r]) % MOD;
}
public static void main(String[] args) {
precompute();
Scanner sc = new Scanner(System.in);
int T = sc.nextInt();
while (T-- > 0) {
int n = sc.nextInt();
int k = sc.nextInt();
int ones = 0;
for (int i = 0; i < n; i++) {
if (sc.nextInt() == 1) {
ones++;
}
}
int zeros = n - ones;
long total_sum = 0;
int mid_pos = (k + 1) / 2;
for (int i = mid_pos; i <= k; i++) {
long comb_ones = nCr_mod_p(ones, i);
long comb_zeros = nCr_mod_p(zeros, k - i);
long term = (comb_ones * comb_zeros) % MOD;
total_sum = (total_sum + term) % MOD;
}
System.out.println(total_sum);
}
}
}
MOD = 1000000007
MAX_N = 200001
fact = [1] * MAX_N
invFact = [1] * MAX_N
for i in range(1, MAX_N):
fact[i] = (fact[i - 1] * i) % MOD
invFact[MAX_N - 1] = pow(fact[MAX_N - 1], MOD - 2, MOD)
for i in range(MAX_N - 2, -1, -1):
invFact[i] = (invFact[i + 1] * (i + 1)) % MOD
def nCr_mod_p(n, r):
if r < 0 or r > n:
return 0
numerator = fact[n]
denominator = (invFact[r] * invFact[n - r]) % MOD
return (numerator * denominator) % MOD
def solve():
n, k = map(int, input().split())
a = list(map(int, input().split()))
ones = a.count(1)
zeros = n - ones
total_sum = 0
mid_pos = (k + 1) // 2
for i in range(mid_pos, k + 1):
comb_ones = nCr_mod_p(ones, i)
comb_zeros = nCr_mod_p(zeros, k - i)
term = (comb_ones * comb_zeros) % MOD
total_sum = (total_sum + term) % MOD
print(total_sum)
T = int(input())
for _ in range(T):
solve()
算法及复杂度
-
算法:组合数学、数论(模逆元)、预处理
-
时间复杂度:预处理阶乘和逆元的时间复杂度为
,其中
是
的最大值(
)。对于每个测试用例,我们需要统计
和
的个数,时间复杂度为
,然后进行一个长度约为
的循环,每次循环内部是
的组合数查询。因此,每个测试用例的处理时间是
。由于所有测试用例的
之和不超过
,总时间复杂度为
,可以接受。
-
空间复杂度:需要两个数组来存储阶乘和阶乘的逆元,大小为
。因此,总空间复杂度为
。