魔改森林
题解
分组讨论,
对于一个 n,m <= 1000,使用暴力枚举即可。
if (n <= 1000 && m <= 1000) { dp[0][0] = 1; memset(vis, false, sizeof(vis)); for (int i = 0; i < k; ++ i) { cin >> x >> y; x --, y --; x = n - x; vis[x][y] = true; } for (int i = 0; i <= n; ++ i) { for (int j = 0; j <= m; ++ j) { if (vis[i][j]) continue; if (i == 0 && j == 0) continue; for (int f = 0; f < 2; ++ f) { int dx = step[f][0] + i; int dy = step[f][1] + j; if (dx >= 0 && dy >= 0) { dp[i][j] += dp[dx][dy]; } dp[i][j] %= mod; } } } cout << dp[n][m] << endl; return 0; }
之后因为后面的 k 非常小,暴力枚举即可,
利用容斥暴力枚举,
void dfs(int x, int y, int pos, int k, int n, int m, ll val) { ans += val * C(n-x+m-y, n-x); ans %= mod; for (int i = pos; i < k; ++ i) { if (x <= no[i].x && y <= no[i].y) { dfs(no[i].x, no[i].y, i+1, k, n, m, ((val*-1ll*C(no[i].x-x+no[i].y-y, no[i].x-x))%mod+mod)%mod); } } }
代码
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; typedef long long ll; const int maxn = 1e3+55; const ll mod = 998244353; bool vis[maxn][maxn]; ll dp[maxn][maxn]; int step[2][2] = {{0, -1}, {-1, 0}}; const int maxxn = 25e4+55; ll f[maxxn]; struct node { int x, y; }no[5]; ll qpow(ll a, ll b, ll p) { ll ret = 1; while (b) { if (b & 1) ret = ret * a % p; b >>= 1; a = a * a % p; } return ret; } void init(int n) { f[0] = 1; for (int i = 1; i <= n; ++ i) { f[i] = (f[i-1] * i) % mod; } } ll C(int n, int m) { // n!/m!/(n-m)! ll x = f[n]; ll y = f[m]; y = (y * f[n-m]) % mod; y = qpow(y, mod-2, mod); ll z = (x * y) % mod; return z; } ll ans = 0; void dfs(int x, int y, int pos, int k, int n, int m, ll val) { ans += val * C(n-x+m-y, n-x); ans %= mod; for (int i = pos; i < k; ++ i) { if (x <= no[i].x && y <= no[i].y) { dfs(no[i].x, no[i].y, i+1, k, n, m, ((val*-1ll*C(no[i].x-x+no[i].y-y, no[i].x-x))%mod+mod)%mod); } } } int main() { int n, m, k, x, y; cin >> n >> m >> k; if (n <= 1000 && m <= 1000) { dp[0][0] = 1; memset(vis, false, sizeof(vis)); for (int i = 0; i < k; ++ i) { cin >> x >> y; x --, y --; x = n - x; vis[x][y] = true; } for (int i = 0; i <= n; ++ i) { for (int j = 0; j <= m; ++ j) { if (vis[i][j]) continue; if (i == 0 && j == 0) continue; for (int f = 0; f < 2; ++ f) { int dx = step[f][0] + i; int dy = step[f][1] + j; if (dx >= 0 && dy >= 0) { dp[i][j] += dp[dx][dy]; } dp[i][j] %= mod; } } } cout << dp[n][m] << endl; return 0; } for (int i = 0; i < k; ++ i) { cin >> no[i].x >> no[i].y; no[i].x --, no[i].y --; no[i].x = n - no[i].x; } sort(no, no+k, [](node x, node y){ return x.x + x.y < y.x + y.y; }); init(250000); dfs(0, 0, 0, k, n, m, 1); cout << ans << endl; return 0; }