题目链接
题目描述
给定两个整数 和
,请你计算组合数
的值,并对模数
取模。
输入:
- 第一行输入一个整数
,表示测试用例数量。
- 接下来
行,每行输入两个整数
。
输出:
- 对于每个测试用例,在一行上输出
的值。
解题思路
这是一个典型的求解组合数模质数的问题。由于有多组查询,使用预处理的方法效率最高。
-
组合数公式
- 组合数
(在本题中是
)的计算公式为:
- 在模运算中,除法不能直接计算,需要转化为乘以除数的 模逆元。
- 组合数
-
模逆元
- 一个数
在模
下的逆元
满足
。
- 由于本题的模数
是一个质数,我们可以使用 费马小定理 来求逆元。
- 费马小定理指出,如果
是质数,对于任意整数
且
,有
。
- 由此可得,
,所以
。
可以通过 快速幂 算法高效计算。
- 一个数
-
预处理
- 由于
的最大值可达
,我们可以预先计算出从
到
的阶乘值及其模逆元,并将它们存储在数组中。这样每次查询时就可以直接使用,达到
的查询效率。
- 预处理阶乘数组
fact
:fact[i] = i! % MOD
。 - 预处理阶乘的逆元数组
invFact
:invFact[i] = (i!)^-1 % MOD
。- 直接对每个阶乘求逆元效率较低(
)。
- 更高效的方法是:先用快速幂求出最大阶乘
fact[N]
的逆元invFact[N]
。 - 然后利用递推关系
invFact[i-1] = invFact[i] * i % MOD
,从后向前计算出所有阶乘的逆元。这样总的预处理时间复杂度接近。
- 直接对每个阶乘求逆元效率较低(
- 由于
-
计算组合数
- 有了预处理的数组,计算组合数就变得非常简单:
- 如果
,则组合数为0。
- 有了预处理的数组,计算组合数就变得非常简单:
代码
#include <iostream>
#include <vector>
using namespace std;
using LL = long long;
const int MOD = 1e9 + 7;
const int MAXN = 500000;
LL fact[MAXN + 1];
LL invFact[MAXN + 1];
LL power(LL base, LL exp) {
LL res = 1;
base %= MOD;
while (exp > 0) {
if (exp % 2 == 1) res = (res * base) % MOD;
base = (base * base) % MOD;
exp /= 2;
}
return res;
}
void precompute() {
fact[0] = 1;
invFact[0] = 1;
for (int i = 1; i <= MAXN; i++) {
fact[i] = (fact[i - 1] * i) % MOD;
}
invFact[MAXN] = power(fact[MAXN], MOD - 2);
for (int i = MAXN - 1; i >= 1; i--) {
invFact[i] = (invFact[i + 1] * (i + 1)) % MOD;
}
}
LL combinations(int b, int a) {
if (a < 0 || a > b) {
return 0;
}
return (((fact[b] * invFact[a]) % MOD) * invFact[b - a]) % MOD;
}
void solve() {
int a, b;
cin >> a >> b;
cout << combinations(b, a) << '\n';
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
precompute();
int t;
cin >> t;
while (t--) {
solve();
}
return 0;
}
import java.util.Scanner;
public class Main {
static final int MOD = 1_000_000_007;
static final int MAXN = 500000;
static long[] fact = new long[MAXN + 1];
static long[] invFact = new long[MAXN + 1];
static long power(long base, long exp) {
long res = 1;
base %= MOD;
while (exp > 0) {
if (exp % 2 == 1) {
res = (res * base) % MOD;
}
base = (base * base) % MOD;
exp /= 2;
}
return res;
}
static void precompute() {
fact[0] = 1;
for (int i = 1; i <= MAXN; i++) {
fact[i] = (fact[i - 1] * i) % MOD;
}
invFact[MAXN] = power(fact[MAXN], MOD - 2);
for (int i = MAXN - 1; i >= 0; i--) {
invFact[i] = (invFact[i + 1] * (i + 1)) % MOD;
}
}
static long combinations(int b, int a) {
if (a < 0 || a > b) {
return 0;
}
return (((fact[b] * invFact[a]) % MOD) * invFact[b - a]) % MOD;
}
public static void main(String[] args) {
precompute();
Scanner sc = new Scanner(System.in);
int t = sc.nextInt();
while (t-- > 0) {
int a = sc.nextInt();
int b = sc.nextInt();
System.out.println(combinations(b, a));
}
}
}
MOD = 1_000_000_007
MAXN = 500000
fact = [1] * (MAXN + 1)
invFact = [1] * (MAXN + 1)
for i in range(1, MAXN + 1):
fact[i] = (fact[i - 1] * i) % MOD
invFact[MAXN] = pow(fact[MAXN], MOD - 2, MOD)
for i in range(MAXN - 1, -1, -1):
invFact[i] = (invFact[i + 1] * (i + 1)) % MOD
def combinations(b, a):
if a < 0 or a > b:
return 0
# C(b, a) = b! / (a! * (b-a)!)
return (fact[b] * invFact[a] * invFact[b - a]) % MOD
t = int(input())
for _ in range(t):
a, b = map(int, input().split())
print(combinations(b, a))
算法及复杂度
- 算法:组合数学、费马小定理、快速幂、预处理
- 时间复杂度:预处理
,其中
。每个测试用例的查询为
。总时间复杂度为
。
- 空间复杂度:
,用于存储阶乘和阶乘逆元的数组。