魔改森林

题解

分组讨论,
对于一个 n,m <= 1000,使用暴力枚举即可。

if (n <= 1000 && m <= 1000) {
        dp[0][0] = 1;
        memset(vis, false, sizeof(vis));
        for (int i = 0; i < k; ++ i) {
            cin >> x >> y;
            x --, y --;
            x = n - x;
            vis[x][y] = true;
        }
        for (int i = 0; i <= n; ++ i) {
            for (int j = 0; j <= m; ++ j) {
                if (vis[i][j]) continue;
                if (i == 0 && j == 0) continue;
                for (int f = 0; f < 2; ++ f) {
                    int dx = step[f][0] + i;
                    int dy = step[f][1] + j;
                    if (dx >= 0 && dy >= 0) {
                        dp[i][j] += dp[dx][dy];
                    }
                    dp[i][j] %= mod;
                }
            }
        }
        cout << dp[n][m] << endl;
        return 0;
    }

之后因为后面的 k 非常小,暴力枚举即可,
利用容斥暴力枚举,

void dfs(int x, int y, int pos, int k, int n, int m, ll val) {
    ans += val * C(n-x+m-y, n-x);
    ans %= mod;
    for (int i = pos; i < k; ++ i) {
        if (x <= no[i].x && y <= no[i].y) {
            dfs(no[i].x, no[i].y, i+1, k, n, m, ((val*-1ll*C(no[i].x-x+no[i].y-y, no[i].x-x))%mod+mod)%mod);
        }
    }
}

代码

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long ll;
const int maxn = 1e3+55;
const ll mod = 998244353;
bool vis[maxn][maxn];
ll dp[maxn][maxn];
int step[2][2] = {{0, -1}, {-1, 0}};

const int maxxn = 25e4+55;
ll f[maxxn];

struct node {
    int x, y;
}no[5];
ll qpow(ll a, ll b, ll p) {
    ll ret = 1;
    while (b) {
        if (b & 1) ret = ret * a % p;
        b >>= 1;
        a = a * a % p;
    }
    return ret;
}
void init(int n) {
    f[0] = 1;
    for (int i = 1; i <= n; ++ i) {
        f[i] = (f[i-1] * i) % mod; 
    }
}
ll C(int n, int m) { // n!/m!/(n-m)!
    ll x = f[n];
    ll y = f[m];
    y = (y * f[n-m]) % mod;
    y = qpow(y, mod-2, mod);
    ll z = (x * y) % mod;
    return z;
}
ll ans = 0;
void dfs(int x, int y, int pos, int k, int n, int m, ll val) {
    ans += val * C(n-x+m-y, n-x);
    ans %= mod;
    for (int i = pos; i < k; ++ i) {
        if (x <= no[i].x && y <= no[i].y) {
            dfs(no[i].x, no[i].y, i+1, k, n, m, ((val*-1ll*C(no[i].x-x+no[i].y-y, no[i].x-x))%mod+mod)%mod);
        }
    }
}
int main() {
    int n, m, k, x, y;
    cin >> n >> m >> k;
    if (n <= 1000 && m <= 1000) {
        dp[0][0] = 1;
        memset(vis, false, sizeof(vis));
        for (int i = 0; i < k; ++ i) {
            cin >> x >> y;
            x --, y --;
            x = n - x;
            vis[x][y] = true;
        }
        for (int i = 0; i <= n; ++ i) {
            for (int j = 0; j <= m; ++ j) {
                if (vis[i][j]) continue;
                if (i == 0 && j == 0) continue;
                for (int f = 0; f < 2; ++ f) {
                    int dx = step[f][0] + i;
                    int dy = step[f][1] + j;
                    if (dx >= 0 && dy >= 0) {
                        dp[i][j] += dp[dx][dy];
                    }
                    dp[i][j] %= mod;
                }
            }
        }
        cout << dp[n][m] << endl;
        return 0;
    }
    for (int i = 0; i < k; ++ i) {
        cin >> no[i].x >> no[i].y;
        no[i].x --, no[i].y --;
        no[i].x = n - no[i].x;
    }
    sort(no, no+k, [](node x, node y){
        return x.x + x.y < y.x + y.y;
    });
    init(250000);
    dfs(0, 0, 0, k, n, m, 1);
    cout << ans << endl;
    return 0;
}