题目链接
题目描述
给定两个整数 与
(
),请你计算组合数
的值,并对模数
取模。
解题思路
本题是求解组合数模一个质数的模板题。
1. 组合数公式
组合数的基本公式为:
在进行模运算时,除法不能直接计算,需要转化为乘以除数的模逆元。
公式变为:
其中 是一个质数。
2. 模逆元
因为模数 是一个质数,我们可以使用费马小定理来计算一个数
的模逆元
。
费马小定理指出:如果 是一个质数,且
不是
的倍数,则有
。
由此可得:,所以
。
计算 可以通过快速幂算法高效完成。
3. 预处理
题目包含多组测试用例,且 的最大值达到了
。如果每次查询都重新计算阶乘和逆元,效率会很低。
一个更高效的方法是预处理。我们可以预先计算出 到
范围内所有数的阶乘及其模逆元。
具体的预处理步骤如下:
-
计算阶乘:
创建一个数组
fact
,fact[i]
存储。
这可以通过递推在
时间内完成:
fact[i] = (fact[i-1] * i) % p
。 -
计算阶乘的逆元:
创建一个数组
invFact
,invFact[i]
存储。
直接对每个阶乘求逆元效率不高。我们可以采用一种更快的线性方法:
-
首先用快速幂计算出最大阶乘
fact[N]
的逆元,即invFact[N] = power(fact[N], p-2)
。 -
然后利用关系
反向递推:
invFact[i-1] = (invFact[i] * i) % p
。
这样就可以在
的时间内计算出所有阶乘的逆元。
-
4. 查询
完成预处理后,对于每一组查询 ,我们可以直接通过预处理好的数组在
的时间内计算结果:
不要忘记处理边界情况:如果 或
,结果为
。
代码
#include <iostream>
#include <vector>
using namespace std;
const int MOD = 1000000007;
const int MAX_N = 500001;
vector<long long> fact(MAX_N);
vector<long long> invFact(MAX_N);
long long power(long long base, long long exp) {
long long res = 1;
base %= MOD;
while (exp > 0) {
if (exp % 2 == 1) res = (res * base) % MOD;
base = (base * base) % MOD;
exp /= 2;
}
return res;
}
void precompute() {
fact[0] = 1;
invFact[0] = 1;
for (int i = 1; i < MAX_N; i++) {
fact[i] = (fact[i - 1] * i) % MOD;
}
invFact[MAX_N - 1] = power(fact[MAX_N - 1], MOD - 2);
for (int i = MAX_N - 2; i >= 1; i--) {
invFact[i] = (invFact[i + 1] * (i + 1)) % MOD;
}
}
long long nCr_mod_p(int n, int r) {
if (r < 0 || r > n) {
return 0;
}
return (((fact[n] * invFact[r]) % MOD) * invFact[n - r]) % MOD;
}
int main() {
precompute();
int T;
cin >> T;
while (T--) {
int n, m;
cin >> n >> m; // 题目输入是 n, m
cout << nCr_mod_p(m, n) << endl; // 计算 C(m, n)
}
return 0;
}
import java.util.Scanner;
public class Main {
static final int MOD = 1000000007;
static final int MAX_N = 500001;
static long[] fact = new long[MAX_N];
static long[] invFact = new long[MAX_N];
public static long power(long base, long exp) {
long res = 1;
base %= MOD;
while (exp > 0) {
if (exp % 2 == 1) res = (res * base) % MOD;
base = (base * base) % MOD;
exp /= 2;
}
return res;
}
public static void precompute() {
fact[0] = 1;
invFact[0] = 1;
for (int i = 1; i < MAX_N; i++) {
fact[i] = (fact[i - 1] * i) % MOD;
}
invFact[MAX_N - 1] = power(fact[MAX_N - 1], MOD - 2);
for (int i = MAX_N - 2; i >= 1; i--) {
invFact[i] = (invFact[i + 1] * (i + 1)) % MOD;
}
}
public static long nCr_mod_p(int n, int r) {
if (r < 0 || r > n) {
return 0;
}
return (((fact[n] * invFact[r]) % MOD) * invFact[n - r]) % MOD;
}
public static void main(String[] args) {
precompute();
Scanner sc = new Scanner(System.in);
int T = sc.nextInt();
while (T-- > 0) {
int n = sc.nextInt();
int m = sc.nextInt();
System.out.println(nCr_mod_p(m, n));
}
}
}