题目链接

中位数之和

题目描述

给定一个长度为 的二进制数组 (每个元素为 )。记 为奇数。对于数组 的所有长度恰为 的子序列,求它们中位数之和,并对 取模。

名词解释:

  • 子序列:如果数组 可以从 中删除几个(可能是零)元素得到,那么 就是 的子序列。
  • 中位数:长度为奇数 的数组的中位数是排序后的第 个元素。

解题思路

这是一个组合计数问题。直接枚举所有长度为 的子序列是不可行的,因为其数量可能非常巨大。我们需要找到一种更高效的数学方法。

1. 问题转化

由于数组是二进制的(只包含 ),任何子序列排序后都会形成一段连续的 和一段连续的 (其中一段可能为空)。

一个长度为 的子序列的中位数,只可能是

因此,所有子序列的中位数之和,就等于 (中位数为0的子序列数量 * 0) + (中位数为1的子序列数量 * 1),这恰好等于中位数为1的子序列的数量

所以,原问题被转化为了:计算有多少个长度为 的子序列,其排序后的中位数为

2. 中位数为 1 的条件

对于一个长度为 为奇数)的二进制子序列,它的中位数是排序后的第 mid_pos = (k+1)/2 个元素。

要使这个元素为 ,那么在这个子序列中, 的数量必须至少mid_pos 个。

  • 如果 的数量小于 mid_pos,那么 的数量就大于或等于 k - mid_pos + 1 = k - (k+1)/2 + 1 = (k-1)/2 + 1 = (k+1)/2。排序后,第 mid_pos 个位置必然是
  • 如果 的数量大于或等于 mid_pos,那么 的数量就小于或等于 k - mid_pos = (k-1)/2。排序后,前 (k-1)/2 个位置最多只能被 填满,第 mid_pos 个位置必然是

结论:子序列的中位数为 ,当且仅当该子序列中 的个数不少于

3. 组合计数

现在问题变成了:从原数组 中,选出一个长度为 的子序列,要求其中 的个数大于等于 ,求这样的子序列有多少个。

我们可以分类讨论子序列中 的个数。

  1. 首先,我们统计原数组 的总数,记为 zerosones

  2. 我们枚举子序列中 的个数 ,根据中位数为 的条件, 的取值范围是

  3. 对于一个固定的 ,我们要在子序列中包含

    • ones 中选出 个的方法数是
    • zeros 中选出 个的方法数是

    根据乘法原理,包含 的子序列数量为

  4. 根据加法原理,将所有可能的 的情况相加,就是最终的答案:

4. 实现

由于 的范围较大,我们需要预处理阶乘和阶乘的逆元,以便在 的时间内查询组合数。

  • 预处理:计算 的值及其模 的逆元。
  • 查询:对于每个测试用例,先统计 的个数,然后循环 ,累加组合数乘积即可。

代码

#include <iostream>
#include <vector>
#include <numeric>

using namespace std;

const int MOD = 1000000007;
const int MAX_N = 200001;

vector<long long> fact(MAX_N);
vector<long long> invFact(MAX_N);

long long power(long long base, long long exp) {
    long long res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % MOD;
        base = (base * base) % MOD;
        exp /= 2;
    }
    return res;
}

void precompute() {
    fact[0] = 1;
    for (int i = 1; i < MAX_N; i++) {
        fact[i] = (fact[i - 1] * i) % MOD;
    }
    invFact[MAX_N - 1] = power(fact[MAX_N - 1], MOD - 2);
    for (int i = MAX_N - 2; i >= 0; i--) {
        invFact[i] = (invFact[i + 1] * (i + 1)) % MOD;
    }
}

long long nCr_mod_p(int n, int r) {
    if (r < 0 || r > n) {
        return 0;
    }
    return (((fact[n] * invFact[r]) % MOD) * invFact[n - r]) % MOD;
}

void solve() {
    int n, k;
    cin >> n >> k;
    int ones = 0;
    for (int i = 0; i < n; ++i) {
        int val;
        cin >> val;
        if (val == 1) {
            ones++;
        }
    }
    int zeros = n - ones;

    long long total_sum = 0;
    int mid_pos = (k + 1) / 2;

    for (int i = mid_pos; i <= k; ++i) {
        long long comb_ones = nCr_mod_p(ones, i);
        long long comb_zeros = nCr_mod_p(zeros, k - i);
        long long term = (comb_ones * comb_zeros) % MOD;
        total_sum = (total_sum + term) % MOD;
    }
    cout << total_sum << endl;
}

int main() {
    precompute();
    int T;
    cin >> T;
    while (T--) {
        solve();
    }
    return 0;
}
import java.util.Scanner;

public class Main {
    static final int MOD = 1000000007;
    static final int MAX_N = 200001;
    static long[] fact = new long[MAX_N];
    static long[] invFact = new long[MAX_N];

    public static long power(long base, long exp) {
        long res = 1;
        base %= MOD;
        while (exp > 0) {
            if (exp % 2 == 1) res = (res * base) % MOD;
            base = (base * base) % MOD;
            exp /= 2;
        }
        return res;
    }

    public static void precompute() {
        fact[0] = 1;
        for (int i = 1; i < MAX_N; i++) {
            fact[i] = (fact[i - 1] * i) % MOD;
        }
        invFact[MAX_N - 1] = power(fact[MAX_N - 1], MOD - 2);
        for (int i = MAX_N - 2; i >= 0; i--) {
            invFact[i] = (invFact[i + 1] * (i + 1)) % MOD;
        }
    }

    public static long nCr_mod_p(int n, int r) {
        if (r < 0 || r > n) {
            return 0;
        }
        return (((fact[n] * invFact[r]) % MOD) * invFact[n - r]) % MOD;
    }

    public static void main(String[] args) {
        precompute();
        Scanner sc = new Scanner(System.in);
        int T = sc.nextInt();
        while (T-- > 0) {
            int n = sc.nextInt();
            int k = sc.nextInt();
            int ones = 0;
            for (int i = 0; i < n; i++) {
                if (sc.nextInt() == 1) {
                    ones++;
                }
            }
            int zeros = n - ones;

            long total_sum = 0;
            int mid_pos = (k + 1) / 2;

            for (int i = mid_pos; i <= k; i++) {
                long comb_ones = nCr_mod_p(ones, i);
                long comb_zeros = nCr_mod_p(zeros, k - i);
                long term = (comb_ones * comb_zeros) % MOD;
                total_sum = (total_sum + term) % MOD;
            }
            System.out.println(total_sum);
        }
    }
}
MOD = 1000000007
MAX_N = 200001

fact = [1] * MAX_N
invFact = [1] * MAX_N

for i in range(1, MAX_N):
    fact[i] = (fact[i - 1] * i) % MOD

invFact[MAX_N - 1] = pow(fact[MAX_N - 1], MOD - 2, MOD)
for i in range(MAX_N - 2, -1, -1):
    invFact[i] = (invFact[i + 1] * (i + 1)) % MOD

def nCr_mod_p(n, r):
    if r < 0 or r > n:
        return 0
    numerator = fact[n]
    denominator = (invFact[r] * invFact[n - r]) % MOD
    return (numerator * denominator) % MOD

def solve():
    n, k = map(int, input().split())
    a = list(map(int, input().split()))
    
    ones = a.count(1)
    zeros = n - ones
    
    total_sum = 0
    mid_pos = (k + 1) // 2
    
    for i in range(mid_pos, k + 1):
        comb_ones = nCr_mod_p(ones, i)
        comb_zeros = nCr_mod_p(zeros, k - i)
        term = (comb_ones * comb_zeros) % MOD
        total_sum = (total_sum + term) % MOD
        
    print(total_sum)

T = int(input())
for _ in range(T):
    solve()

算法及复杂度

  • 算法:组合数学、数论(模逆元)、预处理

  • 时间复杂度:预处理阶乘和逆元的时间复杂度为 ,其中 的最大值()。对于每个测试用例,我们需要统计 的个数,时间复杂度为 ,然后进行一个长度约为 的循环,每次循环内部是 的组合数查询。因此,每个测试用例的处理时间是 。由于所有测试用例的 之和不超过 ,总时间复杂度为 ,可以接受。

  • 空间复杂度:需要两个数组来存储阶乘和阶乘的逆元,大小为 。因此,总空间复杂度为