[Educational Codeforces Round 86 (Rated for Div. 2)] E. Placing Rooks (组合数学,容斥定律)
E. Placing Rooks
time limit per test
2 seconds
memory limit per test
512 megabytes
input
standard input
output
standard output
Calculate the number of ways to place nn rooks on n×nn×n chessboard so that both following conditions are met:
- each empty cell is under attack;
- exactly kk pairs of rooks attack each other.
An empty cell is under attack if there is at least one rook in the same row or at least one rook in the same column. Two rooks attack each other if they share the same row or column, and there are no other rooks between them. For example, there are only two pairs of rooks that attack each other in the following picture:
One of the ways to place the rooks for n=3n=3 and k=2k=2
Two ways to place the rooks are considered different if there exists at least one cell which is empty in one of the ways but contains a rook in another way.
The answer might be large, so print it modulo 998244353998244353.
Input
The only line of the input contains two integers nn and kk (1≤n≤2000001≤n≤200000; 0≤k≤n(n−1)20≤k≤n(n−1)2).
Output
Print one integer — the number of ways to place the rooks, taken modulo 998244353998244353.
Examples
input
Copy
3 2
output
Copy
6
input
Copy
3 3
output
Copy
0
input
Copy
4 0
output
Copy
24
input
Copy
1337 42
output
Copy
807905441
题意:
在一个\(n*n\)的国际象棋棋盘中,有\(\mathit n\)个车,问这\(\mathit n\)个车在棋盘中有多少种放置的方法可以满足以下条件:
- 每一个空格子都受到攻击
- 有刚好\(\mathit k\)对车相互攻击
思路:
我们想让每一个空格子都受到攻击必须满足每一行必须有一个车或者每一列必须有一个车。
我们可以只考虑每一行必须有一个车的情况,最后将其答案\(*2\)。
因为每一行必须有一个车是无法更改的前提,
要有刚好\(\mathit k\)对车相互攻击必须有严格的\(n-k\)列有车子。
我们令\(c=n-k\),我们从\(\mathit n\)列中任意选择\(\mathit c\)列让其有车子,一共有\(C(n,c)\)种方案。
接下来就是求:
有多少种方案使其\(\mathit n\)个车子放在\(\mathit c\)列中,每一列有\(\mathit n\)行,每一列都有至少一个车子,每一行必须有一个车子。
根据容斥定律可以得到方案数为:\(\sum \limits_{i = 0}^{c} (-1)^i {{c}\choose{i}} (c-i)^n\)
根据乘法原理有:\(C(n,c)*\sum \limits_{i = 0}^{c} (-1)^i {{c}\choose{i}} (c-i)^n\)
如果\(k>0\),再将其答案乘以\(\text 2\)。
当\(k=0\)时,因为每一行和每一列有严格的有一个车子,所以不用乘以\(\text 2\)。
代码:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <vector>
#include <iomanip>
#include <sstream>
#include <bitset>
#include <unordered_map>
// #include <bits/stdc++.h>
#define ALL(x) (x).begin(), (x).end()
#define sz(a) int(a.size())
#define rep(i,x,n) for(int i=x;i<n;i++)
#define repd(i,x,n) for(int i=x;i<=n;i++)
#define pii pair<int,int>
#define pll pair<long long ,long long>
#define gbtb ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
#define MS0(X) memset((X), 0, sizeof((X)))
#define MSC0(X) memset((X), '\0', sizeof((X)))
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define eps 1e-6
#define chu(x) if(DEBUG_Switch) cout<<"["<<#x<<" "<<(x)<<"]"<<endl
#define du3(a,b,c) scanf("%d %d %d",&(a),&(b),&(c))
#define du2(a,b) scanf("%d %d",&(a),&(b))
#define du1(a) scanf("%d",&(a));
using namespace std;
typedef long long ll;
ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
ll lcm(ll a, ll b) {return a / gcd(a, b) * b;}
ll powmod(ll a, ll b, ll MOD) { if (a == 0ll) {return 0ll;} a %= MOD; ll ans = 1; while (b) {if (b & 1) {ans = ans * a % MOD;} a = a * a % MOD; b >>= 1;} return ans;}
ll poww(ll a, ll b) { if (a == 0ll) {return 0ll;} ll ans = 1; while (b) {if (b & 1) {ans = ans * a ;} a = a * a ; b >>= 1;} return ans;}
void Pv(const vector<int> &V) {int Len = sz(V); for (int i = 0; i < Len; ++i) {printf("%d", V[i] ); if (i != Len - 1) {printf(" ");} else {printf("\n");}}}
void Pvl(const vector<ll> &V) {int Len = sz(V); for (int i = 0; i < Len; ++i) {printf("%lld", V[i] ); if (i != Len - 1) {printf(" ");} else {printf("\n");}}}
inline long long readll() {long long tmp = 0, fh = 1; char c = getchar(); while (c < '0' || c > '9') {if (c == '-') fh = -1; c = getchar();} while (c >= '0' && c <= '9') tmp = tmp * 10 + c - 48, c = getchar(); return tmp * fh;}
inline int readint() {int tmp = 0, fh = 1; char c = getchar(); while (c < '0' || c > '9') {if (c == '-') fh = -1; c = getchar();} while (c >= '0' && c <= '9') tmp = tmp * 10 + c - 48, c = getchar(); return tmp * fh;}
void pvarr_int(int *arr, int n, int strat = 1) {if (strat == 0) {n--;} repd(i, strat, n) {printf("%d%c", arr[i], i == n ? '\n' : ' ');}}
void pvarr_LL(ll *arr, int n, int strat = 1) {if (strat == 0) {n--;} repd(i, strat, n) {printf("%lld%c", arr[i], i == n ? '\n' : ' ');}}
const int maxn = 1000010;
const int inf = 0x3f3f3f3f;
/*** TEMPLATE CODE * * STARTS HERE ***/
#define DEBUG_Switch 0
const ll mod = 998244353ll;
ll fac[maxn], inv[maxn];
void pre()
{
fac[0] = 1;
for (int i = 1; i < maxn; i++) fac[i] = fac[i - 1] * i % mod;
inv[maxn - 1] = powmod(fac[maxn - 1], mod - 2, mod);
for (int i = maxn - 2; i >= 0; i--) inv[i] = inv[i + 1] * (i + 1) % mod;
}
ll C(ll a, ll b)
{
if (b > a || b < 0) return 0;
return fac[a] * inv[b] % mod * inv[a - b] % mod;
}
int main()
{
#if DEBUG_Switch
freopen("C:\\code\\input.txt", "r", stdin);
#endif
//freopen("C:\\code\\output.txt","r",stdin);
ll n = readll();
ll k = readll();
if (k >= n)
{
printf("0\n");
return 0;
}
pre();
ll ans = 0ll;
ll c = n - k;
repd(i, 1, c)
{
if (i % 2 == c % 2)
{
ans += powmod(i, n, mod) * C(c, i) % mod;
ans %= mod;
} else
{
ans += mod - powmod(i, n, mod) * C(c, i) % mod;
ans %= mod;
}
}
ans = ans * C(n, c) % mod;
if (k > 0)
{
ans = ans * 2ll % mod;
}
printf("%lld\n", ans );
return 0;
}