题目-网格

在这里插入图片描述

问题分析

本质是 C a t a l a n Catalan Catalan序列的简单扩展, 还是用补集的思想, 合法的方案数是 C n + m n C_{n + m} ^ n Cn+mn, 减去不合法的方案数, 以下是不合法方案数的分析
在这里插入图片描述
因为不合法的方案数一定都会经过 y = x + 1 y = x + 1 y=x+1直线, 并且最终到达 ( n , m ) (n, m) (n,m)

那么不合法的方案数等价于直接从 ( 0 , 0 ) (0, 0) (0,0)开始到达 ( x , y ) (x, y) (x,y), ( x , y ) (x, y) (x,y) ( n , m ) (n, m) (n,m)对于 y = x + 1 y = x + 1 y=x+1对称后的结果

因此目标就变成了如何求 ( x , y ) (x, y) (x,y), 求出坐标后的不合法的方案数就是 C x + y x C_{x + y} ^ x Cx+yx, 以下是推导过程

因为 ( x , y ) (x, y) (x,y) ( n , m ) (n, m) (n,m)是对称关系, 那么 ( x , y ) (x, y) (x,y) ( n , m ) (n, m) (n,m)的中点在 y = x + 1 y = x + 1 y=x+1
m + y 2 = n + x 2 + 1          ⋯ ( 1 ) \frac{m + y}{2} = \frac{n + x}{2} + 1 \;\;\;\; \cdots (1) 2m+y=2n+x+1(1)
并且直线的斜率是 − 1 -1 1
m − y n − x = − 1          ⋯ ( 2 ) \frac{m - y}{n - x} = -1 \;\;\;\; \cdots (2) nxmy=1(2)

整理 ( 1 ) (1) (1) ( 2 ) (2) (2)方程得到

{ m + y = n + x + 2 m − y = x − n \begin{cases} & m + y = n + x + 2 \\ & m - y = x - n \end{cases} { m+y=n+x+2my=xn
联立求解 ( x , y ) = ( m − 1 , n + 1 ) (x, y) = (m - 1, n + 1) (x,y)=(m1,n+1)

那么最终的答案就是 C n + m n − C m + n m − 1 C_{n + m} ^ n - C_{m + n} ^ {m - 1} Cn+mnCm+nm1

算法步骤

观察数据范围 n , m ≤ 5000 n, m \le 5000 n,m5000, 因为组合数非常大并且没有取模操作, 需要写高精度实现组合数

  • 实现高精度加法, 递推法计算组合数
  • 实现高精度减法, 计算推导得到的公式

代码实现

递推法求组合数未通过, 因为超出了内存限制(即便使用了滚动数组优化)

#include <bits/stdc++.h>

using namespace std;

const int N = 5010;

int n, m;
vector<int> f[2][N];

void set_one(vector<int> &a) {
   
    a.clear();
    a.push_back(1);
}

vector<int> add(vector<int> &a, vector<int> &b) {
   
    if (a.size() < b.size()) return add(b, a);
    vector<int> ans;

    int c = 0;
    for (int i = 0; i < a.size(); ++i) {
   
        c += a[i];
        if (i < b.size()) c += b[i];
        ans.push_back(c % 10);
        c /= 10;
    }

    while (c) ans.push_back(c % 10), c /= 10;
    return ans;
}

int cmp(vector<int> &a, vector<int> &b) {
   
    if (a.size() < b.size()) return -1;
    if (a.size() > b.size()) return 1;
    for (int i = a.size() - 1; i >= 0; --i) {
   
        if (a[i] > b[i]) return 1;
        else return -1;
    }
    return 0;
}

vector<int> sub(vector<int> &a, vector<int> &b) {
   
    if (cmp(a, b) != 1) return sub(b, a);

    vector<int> ans;
    int c = 0;
    for (int i = 0; i < a.size(); ++i) {
   
        int val = a[i] - c;
        if (i < b.size()) val -= b[i];
        ans.push_back((val % 10 + 10) % 10);
        val < 0 ? c = 1 : c = 0;
    }

    while (ans.size() > 1 && ans.back() == 0) ans.pop_back();
    return ans;
}


int main() {
   
    ios::sync_with_stdio(false);
    cin.tie(0);

    cin >> n >> m;

    int maxv = max(n, m - 1);

    for (int i = 0; i <= n + m; ++i) set_one(f[i & 1][0]);
    for (int i = 1; i <= n + m; ++i) {
   
        for (int j = 0; j <= maxv; ++j) {
   
            f[i & 1][j] = f[i - 1 & 1][j];
            if (j > 0) f[i & 1][j] = add(f[i & 1][j], f[i - 1 & 1][j - 1]);
        }
    }

    vector<int> ans = sub(f[(n + m) & 1][n], f[(n + m) & 1][m - 1]);

    for (int i = ans.size() - 1; i >= 0; --i) cout << ans[i];
    cout << '\n';

    return 0;
}

尝试做如下优化

  • 基于组合数公式计算 C a b = a ! b ! ( a − b ) ! C_a ^ b = \frac{a!}{b! (a - b)!} Cab=b!(ab)!a!, 对阶乘分解质因数, 通过勒让德公式计算
  • 最终的组合数结果一定是 a n s = p 1 c 1 p 2 c 2 . . . p k c k ans = p_1 ^ {c_1} p_2 ^ {c_2} ... p_k ^ {c_k} ans=p1c1p2c2...pkck
  • 然后再实现高精度乘法

因为 x ! = 1 × 2 × 3 × . . . × x x! = 1 \times 2 \times 3 \times ... \times x x!=1×2×3×...×x的最大质因子一定不会超过 x x x, 因此需要预处理 ≤ x \le x x的所有质数, 算法时间复杂度 O ( n ) O(n) O(n)

#include <bits/stdc++.h>

using namespace std;

typedef long long LL;
const int N = 10010;

int n, m;
int primes[N], cnt;
bool st[N];
map<int, int> mp;

void init(int n) {
   
    for (int i = 2; i <= n; ++i) {
   
        if (!st[i]) primes[cnt++] = i;
        for (int j = 0; primes[j] <= n / i; ++j) {
   
            st[i * primes[j]] = true;
            if (i % primes[j] == 0) break;
        }
    }
}

vector<int> mul(vector<int> &a, int val) {
   
    vector<int> ans;

    LL c = 0;
    for (int i = 0; i < a.size(); ++i) {
   
        c += a[i] * val;
        ans.push_back(c % 10);
        c /= 10;
    }
    while (c) ans.push_back(c % 10), c /= 10;

    return ans;
}

vector<int> sub(vector<int> &a, vector<int> &b) {
   
    vector<int> ans;
    int c = 0;

    for (int i = 0; i < a.size(); ++i) {
   
        int val = a[i] - c;
        if (i < b.size()) val -= b[i];
        ans.push_back((val % 10 + 10) % 10);
        val < 0 ? c = 1 : c = 0;
    }

    while (ans.size() > 1 && ans.back() == 0) ans.pop_back();
    return ans;
}

void div(int a, bool flag) {
   
    for (int i = 0; i < cnt; ++i) {
   
        int p = primes[i], t = 0;
        for (int j = a; j; j /= p) t += j / p;
        if (t == 0) continue;
        flag ? mp[p] += t : mp[p] -= t;
    }
}

vector<int> C(int a, int b) {
   
    if (b < 0 || a < b) return {
   0};
    mp.clear();
    div(a, true), div(b, false), div(a - b, false);
    vector<int> ans = {
   1};
    for (int i = 0; i < cnt; ++i) {
   
        int cnt = mp[primes[i]];
        for (int j = 0; j < cnt; ++j) ans = mul(ans, primes[i]);
    }
    return ans;
}

int main() {
   
    ios::sync_with_stdio(false);
    cin.tie(0);

    init(N - 1);

    cin >> n >> m;
    vector<int> a = C(n + m, n);
    vector<int> b = C(n + m, m - 1);
    vector<int> ans = sub(a, b);

    for (int i = ans.size() - 1; i >= 0; --i) cout << ans[i];
    cout << '\n';

    return 0;
}