题意

n n n*n nn 的棋盘,有 n n n 个车,放置 n n n 个车使之满足下面条件:

  1. 每个格子都被攻击到
  2. 恰好有 k k k 个车互相攻击

求方案数,对 998244353 998244353 998244353 取模。
n , k 200000 n,k\leq 200000 n,k200000

分析

看了题解和听了学长讲解,觉得这题没那么难,可是比赛时就是没想到啊,果然还是实力不够啊!
因为每个格子都被攻击到,所以每行都得放一个车或者每列都得放一个车。
不妨假设每行都只放一个车。
那么恰好 k k k 个车互相攻击,意味着 n n n 个车放了 n k n-k nk 列。
为什么捏?
假设放了 m m m 列,每一列放的车为 t i t_i ti,显然 i = 1 m t i = n \sum\limits_{i=1}^{m} t_i=n i=1mti=n,而且每列互相攻击数为 t i 1 t_i-1 ti1
那么所有互相攻击数为 i = 1 m ( t i 1 ) = n m \sum\limits_{i=1}^{m}(t_i-1)=n-m i=1m(ti1)=nm,因此 m = n k m=n-k m=nk
那么问题转化为,有 n n n 个车,每行放一个车,有 n k n-k nk 列,每一列都至少放一个车的方案数。这个其实是个标准的集合划分问题。
如果你还没看出来,我们还可以把问题看成:有 n n n 个不同球, m m m 个不同箱子,每个箱子至少放一个球,求方案数。
于是我们可以快乐容斥了:
a n s = <munderover> i = 0 m </munderover> ( 1 ) i C m i ( m i ) n ans=\sum\limits_{i=0}^{m}(-1)^iC_{m}^{i}(m-i)^n ans=i=0m(1)iCmi(mi)n
upd:第二类斯特林数要求盒子没有区别,所以要除掉 m ! m! m!。但这题盒子有区别,所以不用除。所以这个严格说并不是第二类斯特林数(丢人了
复杂度 O ( n l o g n ) O(nlogn) O(nlogn)

代码如下

#include <bits/stdc++.h>
#include<ext/pb_ds/hash_policy.hpp>
#include<ext/pb_ds/assoc_container.hpp>
#define N 200005
using namespace __gnu_pbds;
using namespace std;
typedef long long LL;
typedef unsigned long long uLL;
const int mod = 998244353;
struct custom_hash {
	   static uint64_t splitmix64(uint64_t x) {
	   	x += 0x9e3779b97f4a7c15;
		x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
		x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
		return x ^ (x >> 31);
	}
	size_t operator()(uint64_t x) const {
		static const uint64_t FIXED_RANDOM = chrono::steady_clock::now().time_since_epoch().count();
		return splitmix64(x + FIXED_RANDOM);
	}
};
LL z = 1;
LL read(){
	LL x, f = 1;
	char ch;
	while(ch = getchar(), ch < '0' || ch > '9') if(ch == '-') f = -1;
	x = ch - '0';
	while(ch = getchar(), ch >= '0' && ch <= '9') x = x * 10 + ch - 48;
	return x * f;
}
int ksm(int a, int b, int p){
	int s = 1;
	while(b){
		if(b & 1) s = z * s * a % p;
		a = z * a * a % p;
		b >>= 1;
	}
	return s;
}
int inv[N], fac[N], maxn = N - 5;
int C(int n, int m){
	return z * fac[n] * inv[m] % mod * inv[n - m] % mod;
}
int S(int n, int k){
	int i, s = 0;
	for(i = 0; i <= k; i++){
		if(i % 2) s = (s - z * C(k, i) * ksm(k - i, n, mod) % mod) % mod;
		else s = (s + z * C(k, i) * ksm(k - i, n, mod) % mod) % mod;
	}
	return s;
}
int main(){
	int i, j, m;
	LL n, k;
	for(fac[0] = i = 1; i <= maxn; i++) fac[i] = z * fac[i - 1] * i % mod;
	inv[maxn] = ksm(fac[maxn], mod - 2, mod);
	for(i = maxn - 1; i >= 0; i--) inv[i] = z * inv[i + 1] * (i + 1) % mod;
	n = read(); k = read();
	if(k >= n) printf("0"), exit(0);
	j = z * S(n, n - k) * C(n, n - k) % mod;
	if(k > 0) j = j * 2 % mod;
	printf("%d", (j + mod) % mod);
	return 0;
}