题目链接 排座位

本题知识点:不定方程解的数量 + 容斥(二项式反演)

题目需要求最大间隔的期望,我们设 f(x)f(x) 表示最大间隔恰好xx 的方案数量。那么根据期望的定义,答案为

ans=x=0mnxf(x)(mn)ans=\sum_{x=0}^{m-n}\dfrac{x\cdot f(x)}{m\choose n}

f(x)f(x) 不好计算,我们设 g(x)g(x) 表示最大间隔不超过 xx 的方案数量。显然有

f(x)={g(0),x=0g(x)g(x1),x>0f(x)=\begin{cases} g(0), &x = 0\\ g(x)-g(x-1), &x>0 \end{cases}

现在需要计算 g(x)g(x)

此时再分析一下题意:总共 mm 个位置,nn 个人——那么就有 n+1n+1 段连续的空位,空位的长度总和为 mnm-n,并且每个空位长度都在 [0,x][0,x] 范围。

很容易抽象出不定方程的模型:未知数个数 k=n+1k=n+1,总和 s=mns=m-n

a1+a2++ak=si[1,k], 0aixa_1+a_2+\dots +a_k =s \\ \forall i \in [1,k],\ 0 \le a_i \le x
  • 如果没有 aika_i\le k 的约束,这就是个经典的不定方程非负整数解的问题(推导见 oiwiki 里面的插板法
(s+k1s){s+k-1\choose {s}}
  • 如果本题的限制是 aikia_i \le k_i,也就是每个变量的上界限制互不相同,那么就只能用二进制枚举的办法去容斥(见 oiwiki 里面的容斥原理),复杂度将会是指数级别。本题套用这个做法时间复杂度超标。

对于本题的限制 aika_i\le k ,我们的强力武器——二项式反演就登场了。

二项式反演其实是容斥的一个例子,由于应用过于广泛所以被抽象成为一个模型。它的作用在于:将求 “恰好” 的计数问题转换为求 “至多” 或 “至少” 的计数问题。假如我们钦定 tt 个元素不满足约束的方案数量为 h(t)h(t),那么答案就是

h(0)(k1)h(1)+(k2)h(2)...h(0)-{k\choose 1}h(1) + {k\choose 2}h(2)-...

关键是结合容斥的思想理解 “钦定” 这个词的含义。关于二项式定理详细的叙述和证明以及更多例题参考网上其他博客。

就这个题而言,我们求的是恰好 kk 个变量满足 aixa_i\le x 。所以定义 h(t)h(t):含义是我钦定 tt 个变量不满足 aixa_i\le x 的条件,也就是先给它们赋值 x+1x+1,总和 ss 也变为 st(x+1)s-t\cdot (x+1)。这个时候我去求 “不定方程非负整数解” 的问题,就会保证至少有 tt 个变量不符合要求,数量是

h(t)=(stx+k1stx)h(t) = {{s-t\cdot x + k-1}\choose {s-t\cdot x}}

那么g(x)g(x) 表达式为

g(x)=(k0)h(0)(k1)h(1)+(k2)h(2)(k3)h(3)+=t=1mnt(1)t(n+1i)(mtxn)\begin{aligned} g(x)&={k\choose 0}h(0)-{k\choose 1}h(1)+{k\choose 2}h(2)-{k\choose 3}h(3) +\dots \\ &=\sum_{t=1}^{\lfloor \frac{m-n}{t} \rfloor} (-1)^{t} {{n+1}\choose {i}}{m-t\cdot x \choose n} \end{aligned}

整个题目就解决了。

#include <bits/stdc++.h>
using namespace std;
#define int long long
#define PII pair<int, int>
#define endl "\n"
/**********************  Core code begins  **********************/
const int MOD = 998244353, N = 1e6 + 7;
int fac[N], inv[N], invfac[N];

int qpow(int x, int k) {
    int res = 1;
    while (k) {
        if (k % 2) {
            res = res * x % MOD;
        }
        x = x * x % MOD;
        k /= 2;
    }
    return res;
}
int ny(int x) {
    return qpow(x, MOD - 2);
}
void init() {
    fac[0] = inv[0] = invfac[0] = 1;
    fac[1] = inv[1] = invfac[1] = 1;
    for (int i = 2; i < N; i++) {
        fac[i] = fac[i - 1] * i % MOD;
        inv[i] = ((MOD - MOD / i * inv[MOD % i]) % MOD + MOD) % MOD;
        invfac[i] = invfac[i - 1] * inv[i] % MOD;
    }
}
int C(int n, int m) {
    if (n < m || n < 0 || m < 0) {
        return 0;
    }
    return fac[n] * invfac[m] % MOD * invfac[n - m] % MOD;
}

void SolveTest() {
    int n, m;
    cin >> n >> m;
    vector<int> g(m + 1);
    for (int x = 0; x <= m - n + 1; x++) {
        int tmp = 0, sig = 1;
        for (int t = 0; t <= (m - n) / (x + 1); t++) {
            tmp = (tmp + sig * C(n + 1, t) % MOD * C(m - t * (x + 1), n) % MOD) % MOD;
            sig = MOD - sig;
        }
        g[x] = tmp;
    }
    int ans = 0;
    for (int i = 1; i <= m - n; i++) {
        ans = (ans + i * (g[i] - g[i - 1] + MOD) % MOD) % MOD;
    }
    ans = ans * ny(C(m, n)) % MOD;
    cout << ans << endl;
}

/**********************  Core code ends  ***********************/
signed main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int T = 1;
    // cin >> T;
    init();
    for (int i = 1; i <= T; i++) {
        SolveTest();
    }
    return 0;
}