题目链接

【模板】组合数

题目描述

给定两个整数 (),请你计算组合数 的值,并对模数 取模。

解题思路

本题是求解组合数模一个质数的模板题。

1. 组合数公式

组合数的基本公式为:

在进行模运算时,除法不能直接计算,需要转化为乘以除数的模逆元

公式变为:

其中 是一个质数。

2. 模逆元

因为模数 是一个质数,我们可以使用费马小定理来计算一个数 的模逆元

费马小定理指出:如果 是一个质数,且 不是 的倍数,则有

由此可得:,所以

计算 可以通过快速幂算法高效完成。

3. 预处理

题目包含多组测试用例,且 的最大值达到了 。如果每次查询都重新计算阶乘和逆元,效率会很低。

一个更高效的方法是预处理。我们可以预先计算出 范围内所有数的阶乘及其模逆元。

具体的预处理步骤如下:

  1. 计算阶乘

    创建一个数组 factfact[i] 存储

    这可以通过递推在 时间内完成:fact[i] = (fact[i-1] * i) % p

  2. 计算阶乘的逆元

    创建一个数组 invFactinvFact[i] 存储

    直接对每个阶乘求逆元效率不高。我们可以采用一种更快的线性方法:

    • 首先用快速幂计算出最大阶乘 fact[N] 的逆元,即 invFact[N] = power(fact[N], p-2)

    • 然后利用关系 反向递推: invFact[i-1] = (invFact[i] * i) % p

    这样就可以在 的时间内计算出所有阶乘的逆元。

4. 查询

完成预处理后,对于每一组查询 ,我们可以直接通过预处理好的数组在 的时间内计算结果:

不要忘记处理边界情况:如果 ,结果为

代码

#include <iostream>
#include <vector>

using namespace std;

const int MOD = 1000000007;
const int MAX_N = 500001;

vector<long long> fact(MAX_N);
vector<long long> invFact(MAX_N);

long long power(long long base, long long exp) {
    long long res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % MOD;
        base = (base * base) % MOD;
        exp /= 2;
    }
    return res;
}

void precompute() {
    fact[0] = 1;
    invFact[0] = 1;
    for (int i = 1; i < MAX_N; i++) {
        fact[i] = (fact[i - 1] * i) % MOD;
    }
    invFact[MAX_N - 1] = power(fact[MAX_N - 1], MOD - 2);
    for (int i = MAX_N - 2; i >= 1; i--) {
        invFact[i] = (invFact[i + 1] * (i + 1)) % MOD;
    }
}

long long nCr_mod_p(int n, int r) {
    if (r < 0 || r > n) {
        return 0;
    }
    return (((fact[n] * invFact[r]) % MOD) * invFact[n - r]) % MOD;
}

int main() {
    precompute();
    int T;
    cin >> T;
    while (T--) {
        int n, m;
        cin >> n >> m; // 题目输入是 n, m
        cout << nCr_mod_p(m, n) << endl; // 计算 C(m, n)
    }
    return 0;
}

import java.util.Scanner;

public class Main {
    static final int MOD = 1000000007;
    static final int MAX_N = 500001;
    static long[] fact = new long[MAX_N];
    static long[] invFact = new long[MAX_N];

    public static long power(long base, long exp) {
        long res = 1;
        base %= MOD;
        while (exp > 0) {
            if (exp % 2 == 1) res = (res * base) % MOD;
            base = (base * base) % MOD;
            exp /= 2;
        }
        return res;
    }

    public static void precompute() {
        fact[0] = 1;
        invFact[0] = 1;
        for (int i = 1; i < MAX_N; i++) {
            fact[i] = (fact[i - 1] * i) % MOD;
        }
        invFact[MAX_N - 1] = power(fact[MAX_N - 1], MOD - 2);
        for (int i = MAX_N - 2; i >= 1; i--) {
            invFact[i] = (invFact[i + 1] * (i + 1)) % MOD;
        }
    }

    public static long nCr_mod_p(int n, int r) {
        if (r < 0 || r > n) {
            return 0;
        }
        return (((fact[n] * invFact[r]) % MOD) * invFact[n - r]) % MOD;
    }

    public static void main(String[] args) {
        precompute();
        Scanner sc = new Scanner(System.in);
        int T = sc.nextInt();
        while (T-- > 0) {
            int n = sc.nextInt();
            int m = sc.nextInt();
            System.out.println(nCr_mod_p(m, n));
        }
    }
}