首先扔一个结论:记最终答案是 ans,另外 S(n,m) 表示前 n 个 自然数 的 m 次幂和:
S(n,m)ans=i=0∑n−1im=n−S(n,m)S(n,m+1)
考虑如何计算它:观察到 m≤5000,可以考虑拉格朗日插值(一个 m 次幂和的前缀和一定是一个不超过 m+1 次幂的多项式,计算出来前 m+2 即可):
记:
xiyik=i=S(i,m)=n
转化为基本的拉格朗日插值:
已知:f 是一个次数不超过 n 的多项式,f(xi)=yi(i∈[0,n]⋂N),求:f(k)=?。
解
f(k)=∑i=0nyi∏i=jxi−xjk−xj
其中,展开得:
f(k)=y0(x0−x1k−x1 ⋅ x0−x2k−x2 ⋯x0−xnk−xn )+y1(x1−x0k−x0 ⋅ x1−x2k−x2 ⋯x1−xnk−xn )+y2(x2−x0k−x0 ⋅ x2−x1k−x1 ⋯x2−xnk−xn )⋯+yn(xn−x0k−x0 ⋅ xn−x1k−x1 ⋯xn−xn−1k−xn−1 )
该式对所有 i∈[0,n]⋂N 均满足 f(xi)=yi。
理由是 yj(j=i) 的系数必含 k−xi 项,此时求 f(k):
f(k)=y0(x0−xik−xi ⋯ )+y1(x1−xik−xi ⋯ )+y2(x2−xik−xi ⋯ )⋯+yn(xn−xik−xi ⋯)
这些项用 k=xi 代入即全部消掉,其中只有一项 yi 的系数是特殊的:
f(k)=yi(xi−x0xi−x0 ⋅ xi−x1xi−x1 ⋯ xi−xnxi−xn)
后面一项的值显然为 1,因此对任意 i∈[0,n]⋂N 均有 f(xi)=yi。
又根据 n+1 个点能唯一确定一个不超过 n 次的多项式,该 f(k) 的表达式可以表示所有的 f(k)。
直接将题中 k 的值代入解析式 f(k)=∑i=0nyi∏i=jxi−xjk−xj 求解即可。
考虑证明,严谨证明直接套贝叶斯公式就好,其实没那么麻烦:
在第 i 名的分布列的比例关系是:0m,1m,⋯,(n−1)m,然后就直接对着算就好了。
#include<cstdio>
#include<cstring>
#define int long long
int init(){
char c = getchar();
int x = 0, f = 1;
for (; c < '0' || c > '9'; c = getchar())
if (c == '-') f = -1;
for (; c >= '0' && c <= '9'; c = getchar())
x = (x << 1) + (x << 3) + (c ^ 48);
return x * f;
}
void print(int x){
if (x < 0) x = -x, putchar('-');
if (x > 9) print(x / 10);
putchar(x % 10 + '0');
}
const int N = (int) 1e4 + 5, Mod = 998244353;
int x[N], y[N]; int ok[N << 1];
int quick_mod(int a, int b){ // 这道题 5000 * 5000 的数据范围卡朴素快速幂,需要优化普通的快速幂到 O (M * log + m^2),第一个 M 是可能的逆元个数,本题中约为 2m 个
if (b == Mod - 2) {
if (a < N && ok[a + N] != -1)
return ok[a + N];
}
int s = 1, olda = a, oldb = b;
a = (a + Mod) % Mod;
while (b) {
if (b & 1) s = s * a % Mod;
a = a * a % Mod; b >>= 1;
}
if (oldb == Mod - 2 && olda < N)
ok[olda + N] = s;
return s;
}
int S(int n, int m){
for (int i = 1; i <= m + 2; ++i)
x[i] = i, y[i] = (y[i - 1] + quick_mod(i - 1, m)) % Mod;
int k = n, ans = 0;
for (int i = 1; i <= m + 2; ++i) {
int mul = y[i];
for (int j = 1; j <= m + 2; ++j)
if (i != j)
mul = (mul * ((k - x[j]) % Mod + Mod) % Mod) * quick_mod(x[i] - x[j], Mod - 2) % Mod;
ans = (ans + mul) % Mod;
}
return ans;
}
signed main(){
memset(ok, -1, sizeof(ok));
int n = init(), m = init();
int s1 = S(n, m + 1), s2 = S(n, m);
int ans = s1 * quick_mod(s2, Mod - 2) % Mod;
print(((n - ans) % Mod + Mod) % Mod), putchar('\n');
}