题目链接

【模板】组合数

题目描述

给定两个整数 ,请你计算组合数 的值,并对模数 取模。

输入:

  • 第一行输入一个整数 ,表示测试用例数量。
  • 接下来 行,每行输入两个整数

输出:

  • 对于每个测试用例,在一行上输出 的值。

解题思路

这是一个典型的求解组合数模质数的问题。由于有多组查询,使用预处理的方法效率最高。

  1. 组合数公式

    • 组合数 (在本题中是 )的计算公式为:
    • 在模运算中,除法不能直接计算,需要转化为乘以除数的 模逆元
  2. 模逆元

    • 一个数 在模 下的逆元 满足
    • 由于本题的模数 是一个质数,我们可以使用 费马小定理 来求逆元。
    • 费马小定理指出,如果 是质数,对于任意整数 ,有
    • 由此可得,,所以
    • 可以通过 快速幂 算法高效计算。
  3. 预处理

    • 由于 的最大值可达 ,我们可以预先计算出从 的阶乘值及其模逆元,并将它们存储在数组中。这样每次查询时就可以直接使用,达到 的查询效率。
    • 预处理阶乘数组 fact: fact[i] = i! % MOD
    • 预处理阶乘的逆元数组 invFact: invFact[i] = (i!)^-1 % MOD
      • 直接对每个阶乘求逆元效率较低()。
      • 更高效的方法是:先用快速幂求出最大阶乘 fact[N] 的逆元 invFact[N]
      • 然后利用递推关系 invFact[i-1] = invFact[i] * i % MOD,从后向前计算出所有阶乘的逆元。这样总的预处理时间复杂度接近
  4. 计算组合数

    • 有了预处理的数组,计算组合数就变得非常简单:
    • 如果 ,则组合数为0。

代码

#include <iostream>
#include <vector>

using namespace std;
using LL = long long;

const int MOD = 1e9 + 7;
const int MAXN = 500000;

LL fact[MAXN + 1];
LL invFact[MAXN + 1];

LL power(LL base, LL exp) {
    LL res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % MOD;
        base = (base * base) % MOD;
        exp /= 2;
    }
    return res;
}

void precompute() {
    fact[0] = 1;
    invFact[0] = 1;
    for (int i = 1; i <= MAXN; i++) {
        fact[i] = (fact[i - 1] * i) % MOD;
    }
    invFact[MAXN] = power(fact[MAXN], MOD - 2);
    for (int i = MAXN - 1; i >= 1; i--) {
        invFact[i] = (invFact[i + 1] * (i + 1)) % MOD;
    }
}

LL combinations(int b, int a) {
    if (a < 0 || a > b) {
        return 0;
    }
    return (((fact[b] * invFact[a]) % MOD) * invFact[b - a]) % MOD;
}

void solve() {
    int a, b;
    cin >> a >> b;
    cout << combinations(b, a) << '\n';
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    precompute();
    int t;
    cin >> t;
    while (t--) {
        solve();
    }
    return 0;
}
import java.util.Scanner;

public class Main {
    static final int MOD = 1_000_000_007;
    static final int MAXN = 500000;
    static long[] fact = new long[MAXN + 1];
    static long[] invFact = new long[MAXN + 1];

    static long power(long base, long exp) {
        long res = 1;
        base %= MOD;
        while (exp > 0) {
            if (exp % 2 == 1) {
                res = (res * base) % MOD;
            }
            base = (base * base) % MOD;
            exp /= 2;
        }
        return res;
    }

    static void precompute() {
        fact[0] = 1;
        for (int i = 1; i <= MAXN; i++) {
            fact[i] = (fact[i - 1] * i) % MOD;
        }
        invFact[MAXN] = power(fact[MAXN], MOD - 2);
        for (int i = MAXN - 1; i >= 0; i--) {
            invFact[i] = (invFact[i + 1] * (i + 1)) % MOD;
        }
    }

    static long combinations(int b, int a) {
        if (a < 0 || a > b) {
            return 0;
        }
        return (((fact[b] * invFact[a]) % MOD) * invFact[b - a]) % MOD;
    }

    public static void main(String[] args) {
        precompute();
        Scanner sc = new Scanner(System.in);
        int t = sc.nextInt();
        while (t-- > 0) {
            int a = sc.nextInt();
            int b = sc.nextInt();
            System.out.println(combinations(b, a));
        }
    }
}
MOD = 1_000_000_007
MAXN = 500000

fact = [1] * (MAXN + 1)
invFact = [1] * (MAXN + 1)

for i in range(1, MAXN + 1):
    fact[i] = (fact[i - 1] * i) % MOD

invFact[MAXN] = pow(fact[MAXN], MOD - 2, MOD)
for i in range(MAXN - 1, -1, -1):
    invFact[i] = (invFact[i + 1] * (i + 1)) % MOD

def combinations(b, a):
    if a < 0 or a > b:
        return 0
    # C(b, a) = b! / (a! * (b-a)!)
    return (fact[b] * invFact[a] * invFact[b - a]) % MOD

t = int(input())
for _ in range(t):
    a, b = map(int, input().split())
    print(combinations(b, a))

算法及复杂度

  • 算法:组合数学、费马小定理、快速幂、预处理
  • 时间复杂度:预处理 ,其中 。每个测试用例的查询为 。总时间复杂度为
  • 空间复杂度:,用于存储阶乘和阶乘逆元的数组。