题目链接

中位数之和

题目描述

给定一个长度为 的二进制数组(元素为0或1)。对于该数组所有长度恰好为 的子序列(为奇数),求出它们的中位数之和,结果对 取模。

  • 子序列: 从原数组中删除任意数量(可以为0)的元素,剩下的元素保持原相对顺序组成的序列。
  • 中位数: 长度为奇数 的数组,排序后的第 个元素。

输入:

  • 第一行一个整数 表示测试用例数量。
  • 每个测试用例第一行是 。第二行是 个0或1。

输出:

  • 对每个测试用例,输出中位数之和的模。

解题思路

直接枚举所有长度为 的子序列是不可行的,数量级会非常巨大。我们需要转换思路。

问题的核心突破口在于:

  1. 数组是二进制的。
  2. 我们要求的是中位数之和

因为数组中只有0和1,所以任何子序列的中位数也只可能是0或1。因此,中位数之和就等于中位数为1的子序列的数量

问题转化为:在所有长度为 的子序列中,有多少个的中位数是1?

设一个长度为 的子序列,当它被排序后,会呈现出 [0, 0, ..., 1, 1, ...] 的形式。 中位数是第 个元素。 要使中位数为1,当且仅当排序后第 个元素是1。这也就意味着,这个子序列中 1的数量必须大于等于

所以,我们的最终目标是:统计有多少长度为 的子序列,其中包含至少 个1。

这变成了一个组合计数问题。

  1. 首先,遍历原始数组,统计出其中 1 的总数,记为 count_one,和 0 的总数 count_zero
  2. 我们枚举子序列中 1 的数量,设为 i。根据中位数为1的条件, 的取值范围是
  3. 对于一个固定的 i,我们要构造一个包含 i 个1和 m-i 个0的子序列。
    • count_one 个1中选出 i 个,方案数为
    • count_zero 个0中选出 m-i 个,方案数为
    • 根据乘法原理,构成这样一个子序列的方案数是
  4. 根据加法原理,将所有可能的 i(从 )的方案数累加起来,就是最终答案。 总方案数 =

为了高效计算组合数 ,我们可以预先计算出阶乘及其模逆元。由于所有测试用例的 之和不超过 ,我们可以一次性预处理到

代码

#include <iostream>
#include <vector>
#include <numeric>

using namespace std;
using LL = long long;

const int MOD = 1e9 + 7;
const int MAXN = 200000;

LL fact[MAXN + 1];
LL invFact[MAXN + 1];

LL power(LL base, LL exp) {
    LL res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % MOD;
        base = (base * base) % MOD;
        exp /= 2;
    }
    return res;
}

void precompute() {
    fact[0] = 1;
    for (int i = 1; i <= MAXN; i++) {
        fact[i] = (fact[i - 1] * i) % MOD;
    }
    invFact[MAXN] = power(fact[MAXN], MOD - 2);
    for (int i = MAXN - 1; i >= 0; i--) {
        invFact[i] = (invFact[i + 1] * (i + 1)) % MOD;
    }
}

LL combinations(int n, int k) {
    if (k < 0 || k > n) {
        return 0;
    }
    return (((fact[n] * invFact[k]) % MOD) * invFact[n - k]) % MOD;
}

void solve() {
    int n, m;
    cin >> n >> m;
    int count_one = 0;
    for (int i = 0; i < n; ++i) {
        int x;
        cin >> x;
        if (x == 1) count_one++;
    }
    int count_zero = n - count_one;
    int k = (m + 1) / 2;
    
    LL ans = 0;
    for (int i = k; i <= m; ++i) {
        LL term = (combinations(count_one, i) * combinations(count_zero, m - i)) % MOD;
        ans = (ans + term) % MOD;
    }
    cout << ans << '\n';
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    precompute();
    int t;
    cin >> t;
    while (t--) {
        solve();
    }
    return 0;
}
import java.util.Scanner;

public class Main {
    static final int MOD = 1_000_000_007;
    static final int MAXN = 200000;
    static long[] fact = new long[MAXN + 1];
    static long[] invFact = new long[MAXN + 1];

    static long power(long base, long exp) {
        long res = 1;
        base %= MOD;
        while (exp > 0) {
            if (exp % 2 == 1) res = (res * base) % MOD;
            base = (base * base) % MOD;
            exp /= 2;
        }
        return res;
    }

    static void precompute() {
        fact[0] = 1;
        for (int i = 1; i <= MAXN; i++) {
            fact[i] = (fact[i - 1] * i) % MOD;
        }
        invFact[MAXN] = power(fact[MAXN], MOD - 2);
        for (int i = MAXN - 1; i >= 0; i--) {
            invFact[i] = (invFact[i + 1] * (i + 1)) % MOD;
        }
    }

    static long combinations(int n, int k) {
        if (k < 0 || k > n) return 0;
        return (((fact[n] * invFact[k]) % MOD) * invFact[n - k]) % MOD;
    }

    static void solve(Scanner sc) {
        int n = sc.nextInt();
        int m = sc.nextInt();
        int count_one = 0;
        for (int i = 0; i < n; i++) {
            if (sc.nextInt() == 1) count_one++;
        }
        int count_zero = n - count_one;
        int k = (m + 1) / 2;

        long ans = 0;
        for (int i = k; i <= m; i++) {
            long term = (combinations(count_one, i) * combinations(count_zero, m - i)) % MOD;
            ans = (ans + term) % MOD;
        }
        System.out.println(ans);
    }

    public static void main(String[] args) {
        precompute();
        Scanner sc = new Scanner(System.in);
        int t = sc.nextInt();
        while (t-- > 0) {
            solve(sc);
        }
    }
}
MOD = 1_000_000_007
MAXN = 200000

# 全局预处理阶乘和逆元
fact = [1] * (MAXN + 1)
invFact = [1] * (MAXN + 1)

def power(base, exp):
    res = 1
    base %= MOD
    while exp > 0:
        if exp % 2 == 1:
            res = (res * base) % MOD
        base = (base * base) % MOD
        exp //= 2
    return res

def precompute():
    fact[0] = 1
    for i in range(1, MAXN + 1):
        fact[i] = (fact[i - 1] * i) % MOD
    
    # 计算 fact[MAXN] 的逆元
    invFact[MAXN] = power(fact[MAXN], MOD - 2)
    # 从后向前推导所有阶乘的逆元
    for i in range(MAXN - 1, -1, -1):
        invFact[i] = (invFact[i + 1] * (i + 1)) % MOD

def combinations(n, k):
    if k < 0 or k > n:
        return 0
    # C(n, k) = n! / (k! * (n-k)!)
    # C(n, k) = fact[n] * invFact[k] * invFact[n-k]
    numerator = fact[n]
    denominator = (invFact[k] * invFact[n - k]) % MOD
    return (numerator * denominator) % MOD

def solve():
    n, m = map(int, input().split())
    # 读取数组并计算1的数量
    arr = list(map(int, input().split()))
    count_one = sum(arr)
    count_zero = n - count_one
    
    # 中位数为1,则子序列中1的数量至少为 k
    k = (m + 1) // 2
    
    ans = 0
    # 枚举子序列中1的数量 i (从 k到m)
    for i in range(k, m + 1):
        # 从 count_one 个1中选i个, 从 count_zero 个0中选 m-i 个
        # 组合数 C(count_one, i) * C(count_zero, m - i)
        if i > count_one or m - i > count_zero:
            continue # 如果选的数量超过总数,则跳过
            
        term = (combinations(count_one, i) * combinations(count_zero, m - i)) % MOD
        ans = (ans + term) % MOD
        
    print(ans)

def main():
    precompute()
    t = int(input())
    for _ in range(t):
        solve()

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:组合数学、计数原理、费马小定理、快速幂
  • 时间复杂度:预处理为 。对于每个测试用例,复杂度为 ,其中 用于读入和计数, 用于循环累加组合数。由于所有 的总和以及 ,总时间复杂度为 ,符合题目要求。
  • 空间复杂度:,用于存储预计算的阶乘和逆元数组。