题目链接 排座位
本题知识点:不定方程解的数量 + 容斥(二项式反演)
题目需要求最大间隔的期望,我们设 表示最大间隔恰好为 的方案数量。那么根据期望的定义,答案为
不好计算,我们设 表示最大间隔不超过 的方案数量。显然有
现在需要计算 。
此时再分析一下题意:总共 个位置, 个人——那么就有 段连续的空位,空位的长度总和为 ,并且每个空位长度都在 范围。
很容易抽象出不定方程的模型:未知数个数 ,总和 。
- 如果没有 的约束,这就是个经典的不定方程非负整数解的问题(推导见 oiwiki 里面的插板法)
- 如果本题的限制是 ,也就是每个变量的上界限制互不相同,那么就只能用二进制枚举的办法去容斥(见 oiwiki 里面的容斥原理),复杂度将会是指数级别。本题套用这个做法时间复杂度超标。
对于本题的限制 ,我们的强力武器——二项式反演就登场了。
二项式反演其实是容斥的一个例子,由于应用过于广泛所以被抽象成为一个模型。它的作用在于:将求 “恰好” 的计数问题转换为求 “至多” 或 “至少” 的计数问题。假如我们钦定 个元素不满足约束的方案数量为 ,那么答案就是
关键是结合容斥的思想理解 “钦定” 这个词的含义。关于二项式定理详细的叙述和证明以及更多例题参考网上其他博客。
就这个题而言,我们求的是恰好 个变量满足 。所以定义 :含义是我钦定 个变量不满足 的条件,也就是先给它们赋值 ,总和 也变为 。这个时候我去求 “不定方程非负整数解” 的问题,就会保证至少有 个变量不符合要求,数量是
那么 表达式为
整个题目就解决了。
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define PII pair<int, int>
#define endl "\n"
/********************** Core code begins **********************/
const int MOD = 998244353, N = 1e6 + 7;
int fac[N], inv[N], invfac[N];
int qpow(int x, int k) {
int res = 1;
while (k) {
if (k % 2) {
res = res * x % MOD;
}
x = x * x % MOD;
k /= 2;
}
return res;
}
int ny(int x) {
return qpow(x, MOD - 2);
}
void init() {
fac[0] = inv[0] = invfac[0] = 1;
fac[1] = inv[1] = invfac[1] = 1;
for (int i = 2; i < N; i++) {
fac[i] = fac[i - 1] * i % MOD;
inv[i] = ((MOD - MOD / i * inv[MOD % i]) % MOD + MOD) % MOD;
invfac[i] = invfac[i - 1] * inv[i] % MOD;
}
}
int C(int n, int m) {
if (n < m || n < 0 || m < 0) {
return 0;
}
return fac[n] * invfac[m] % MOD * invfac[n - m] % MOD;
}
void SolveTest() {
int n, m;
cin >> n >> m;
vector<int> g(m + 1);
for (int x = 0; x <= m - n + 1; x++) {
int tmp = 0, sig = 1;
for (int t = 0; t <= (m - n) / (x + 1); t++) {
tmp = (tmp + sig * C(n + 1, t) % MOD * C(m - t * (x + 1), n) % MOD) % MOD;
sig = MOD - sig;
}
g[x] = tmp;
}
int ans = 0;
for (int i = 1; i <= m - n; i++) {
ans = (ans + i * (g[i] - g[i - 1] + MOD) % MOD) % MOD;
}
ans = ans * ny(C(m, n)) % MOD;
cout << ans << endl;
}
/********************** Core code ends ***********************/
signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int T = 1;
// cin >> T;
init();
for (int i = 1; i <= T; i++) {
SolveTest();
}
return 0;
}