题意
有 n∗n 的棋盘,有 n 个车,放置 n 个车使之满足下面条件:
- 每个格子都被攻击到
- 恰好有 k 个车互相攻击
求方案数,对 998244353 取模。
n,k≤200000
分析
看了题解和听了学长讲解,觉得这题没那么难,可是比赛时就是没想到啊,果然还是实力不够啊!
因为每个格子都被攻击到,所以每行都得放一个车或者每列都得放一个车。
不妨假设每行都只放一个车。
那么恰好 k 个车互相攻击,意味着 n 个车放了 n−k 列。
为什么捏?
假设放了 m 列,每一列放的车为 ti,显然 i=1∑mti=n,而且每列互相攻击数为 ti−1。
那么所有互相攻击数为 i=1∑m(ti−1)=n−m,因此 m=n−k。
那么问题转化为,有 n 个车,每行放一个车,有 n−k 列,每一列都至少放一个车的方案数。这个其实是个标准的集合划分问题。
如果你还没看出来,我们还可以把问题看成:有 n 个不同球, m 个不同箱子,每个箱子至少放一个球,求方案数。
于是我们可以快乐容斥了:
ans=i=0∑m(−1)iCmi(m−i)n
upd:第二类斯特林数要求盒子没有区别,所以要除掉 m!。但这题盒子有区别,所以不用除。所以这个严格说并不是第二类斯特林数(丢人了
复杂度 O(nlogn)
代码如下
#include <bits/stdc++.h>
#include<ext/pb_ds/hash_policy.hpp>
#include<ext/pb_ds/assoc_container.hpp>
#define N 200005
using namespace __gnu_pbds;
using namespace std;
typedef long long LL;
typedef unsigned long long uLL;
const int mod = 998244353;
struct custom_hash {
static uint64_t splitmix64(uint64_t x) {
x += 0x9e3779b97f4a7c15;
x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
return x ^ (x >> 31);
}
size_t operator()(uint64_t x) const {
static const uint64_t FIXED_RANDOM = chrono::steady_clock::now().time_since_epoch().count();
return splitmix64(x + FIXED_RANDOM);
}
};
LL z = 1;
LL read(){
LL x, f = 1;
char ch;
while(ch = getchar(), ch < '0' || ch > '9') if(ch == '-') f = -1;
x = ch - '0';
while(ch = getchar(), ch >= '0' && ch <= '9') x = x * 10 + ch - 48;
return x * f;
}
int ksm(int a, int b, int p){
int s = 1;
while(b){
if(b & 1) s = z * s * a % p;
a = z * a * a % p;
b >>= 1;
}
return s;
}
int inv[N], fac[N], maxn = N - 5;
int C(int n, int m){
return z * fac[n] * inv[m] % mod * inv[n - m] % mod;
}
int S(int n, int k){
int i, s = 0;
for(i = 0; i <= k; i++){
if(i % 2) s = (s - z * C(k, i) * ksm(k - i, n, mod) % mod) % mod;
else s = (s + z * C(k, i) * ksm(k - i, n, mod) % mod) % mod;
}
return s;
}
int main(){
int i, j, m;
LL n, k;
for(fac[0] = i = 1; i <= maxn; i++) fac[i] = z * fac[i - 1] * i % mod;
inv[maxn] = ksm(fac[maxn], mod - 2, mod);
for(i = maxn - 1; i >= 0; i--) inv[i] = z * inv[i + 1] * (i + 1) % mod;
n = read(); k = read();
if(k >= n) printf("0"), exit(0);
j = z * S(n, n - k) * C(n, n - k) % mod;
if(k > 0) j = j * 2 % mod;
printf("%d", (j + mod) % mod);
return 0;
}