题目-网格

问题分析
本质是 C a t a l a n Catalan Catalan序列的简单扩展, 还是用补集的思想, 合法的方案数是 C n + m n C_{n + m} ^ n Cn+mn, 减去不合法的方案数, 以下是不合法方案数的分析

因为不合法的方案数一定都会经过 y = x + 1 y = x + 1 y=x+1直线, 并且最终到达 ( n , m ) (n, m) (n,m)
那么不合法的方案数等价于直接从 ( 0 , 0 ) (0, 0) (0,0)开始到达 ( x , y ) (x, y) (x,y), ( x , y ) (x, y) (x,y)是 ( n , m ) (n, m) (n,m)对于 y = x + 1 y = x + 1 y=x+1对称后的结果
因此目标就变成了如何求 ( x , y ) (x, y) (x,y), 求出坐标后的不合法的方案数就是 C x + y x C_{x + y} ^ x Cx+yx, 以下是推导过程
因为 ( x , y ) (x, y) (x,y)和 ( n , m ) (n, m) (n,m)是对称关系, 那么 ( x , y ) (x, y) (x,y)和 ( n , m ) (n, m) (n,m)的中点在 y = x + 1 y = x + 1 y=x+1上
m + y 2 = n + x 2 + 1 ⋯ ( 1 ) \frac{m + y}{2} = \frac{n + x}{2} + 1 \;\;\;\; \cdots (1) 2m+y=2n+x+1⋯(1)
并且直线的斜率是 − 1 -1 −1
m − y n − x = − 1 ⋯ ( 2 ) \frac{m - y}{n - x} = -1 \;\;\;\; \cdots (2) n−xm−y=−1⋯(2)
整理 ( 1 ) (1) (1)和 ( 2 ) (2) (2)方程得到
{ m + y = n + x + 2 m − y = x − n \begin{cases} & m + y = n + x + 2 \\ & m - y = x - n \end{cases} { m+y=n+x+2m−y=x−n
联立求解 ( x , y ) = ( m − 1 , n + 1 ) (x, y) = (m - 1, n + 1) (x,y)=(m−1,n+1)
那么最终的答案就是 C n + m n − C m + n m − 1 C_{n + m} ^ n - C_{m + n} ^ {m - 1} Cn+mn−Cm+nm−1
算法步骤
观察数据范围 n , m ≤ 5000 n, m \le 5000 n,m≤5000, 因为组合数非常大并且没有取模操作, 需要写高精度实现组合数
- 实现高精度加法, 递推法计算组合数
- 实现高精度减法, 计算推导得到的公式
代码实现
递推法求组合数未通过, 因为超出了内存限制(即便使用了滚动数组优化)
#include <bits/stdc++.h>
using namespace std;
const int N = 5010;
int n, m;
vector<int> f[2][N];
void set_one(vector<int> &a) {
a.clear();
a.push_back(1);
}
vector<int> add(vector<int> &a, vector<int> &b) {
if (a.size() < b.size()) return add(b, a);
vector<int> ans;
int c = 0;
for (int i = 0; i < a.size(); ++i) {
c += a[i];
if (i < b.size()) c += b[i];
ans.push_back(c % 10);
c /= 10;
}
while (c) ans.push_back(c % 10), c /= 10;
return ans;
}
int cmp(vector<int> &a, vector<int> &b) {
if (a.size() < b.size()) return -1;
if (a.size() > b.size()) return 1;
for (int i = a.size() - 1; i >= 0; --i) {
if (a[i] > b[i]) return 1;
else return -1;
}
return 0;
}
vector<int> sub(vector<int> &a, vector<int> &b) {
if (cmp(a, b) != 1) return sub(b, a);
vector<int> ans;
int c = 0;
for (int i = 0; i < a.size(); ++i) {
int val = a[i] - c;
if (i < b.size()) val -= b[i];
ans.push_back((val % 10 + 10) % 10);
val < 0 ? c = 1 : c = 0;
}
while (ans.size() > 1 && ans.back() == 0) ans.pop_back();
return ans;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cin >> n >> m;
int maxv = max(n, m - 1);
for (int i = 0; i <= n + m; ++i) set_one(f[i & 1][0]);
for (int i = 1; i <= n + m; ++i) {
for (int j = 0; j <= maxv; ++j) {
f[i & 1][j] = f[i - 1 & 1][j];
if (j > 0) f[i & 1][j] = add(f[i & 1][j], f[i - 1 & 1][j - 1]);
}
}
vector<int> ans = sub(f[(n + m) & 1][n], f[(n + m) & 1][m - 1]);
for (int i = ans.size() - 1; i >= 0; --i) cout << ans[i];
cout << '\n';
return 0;
}
尝试做如下优化
- 基于组合数公式计算 C a b = a ! b ! ( a − b ) ! C_a ^ b = \frac{a!}{b! (a - b)!} Cab=b!(a−b)!a!, 对阶乘分解质因数, 通过勒让德公式计算
- 最终的组合数结果一定是 a n s = p 1 c 1 p 2 c 2 . . . p k c k ans = p_1 ^ {c_1} p_2 ^ {c_2} ... p_k ^ {c_k} ans=p1c1p2c2...pkck
- 然后再实现高精度乘法
因为 x ! = 1 × 2 × 3 × . . . × x x! = 1 \times 2 \times 3 \times ... \times x x!=1×2×3×...×x的最大质因子一定不会超过 x x x, 因此需要预处理 ≤ x \le x ≤x的所有质数, 算法时间复杂度 O ( n ) O(n) O(n)
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 10010;
int n, m;
int primes[N], cnt;
bool st[N];
map<int, int> mp;
void init(int n) {
for (int i = 2; i <= n; ++i) {
if (!st[i]) primes[cnt++] = i;
for (int j = 0; primes[j] <= n / i; ++j) {
st[i * primes[j]] = true;
if (i % primes[j] == 0) break;
}
}
}
vector<int> mul(vector<int> &a, int val) {
vector<int> ans;
LL c = 0;
for (int i = 0; i < a.size(); ++i) {
c += a[i] * val;
ans.push_back(c % 10);
c /= 10;
}
while (c) ans.push_back(c % 10), c /= 10;
return ans;
}
vector<int> sub(vector<int> &a, vector<int> &b) {
vector<int> ans;
int c = 0;
for (int i = 0; i < a.size(); ++i) {
int val = a[i] - c;
if (i < b.size()) val -= b[i];
ans.push_back((val % 10 + 10) % 10);
val < 0 ? c = 1 : c = 0;
}
while (ans.size() > 1 && ans.back() == 0) ans.pop_back();
return ans;
}
void div(int a, bool flag) {
for (int i = 0; i < cnt; ++i) {
int p = primes[i], t = 0;
for (int j = a; j; j /= p) t += j / p;
if (t == 0) continue;
flag ? mp[p] += t : mp[p] -= t;
}
}
vector<int> C(int a, int b) {
if (b < 0 || a < b) return {
0};
mp.clear();
div(a, true), div(b, false), div(a - b, false);
vector<int> ans = {
1};
for (int i = 0; i < cnt; ++i) {
int cnt = mp[primes[i]];
for (int j = 0; j < cnt; ++j) ans = mul(ans, primes[i]);
}
return ans;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
init(N - 1);
cin >> n >> m;
vector<int> a = C(n + m, n);
vector<int> b = C(n + m, m - 1);
vector<int> ans = sub(a, b);
for (int i = ans.size() - 1; i >= 0; --i) cout << ans[i];
cout << '\n';
return 0;
}

京公网安备 11010502036488号