题目链接
题目描述
在一个网格中,小虾同学从左下角 (0, 0)
走到右上角 (x, y)
。他每次只能向右或向上移动一步。
网格中有 n
个 boss 的位置,小虾同学的路径不能经过这些位置。
求从起点到终点,有多少种不同的合法走法?
解题思路
这是一个带有障碍物的网格路径计数问题。解决此类问题的经典方法是容斥原理,而将其系统化实现则通常采用动态规划。
1. 基础模型:无障碍路径计数
首先,考虑一个没有障碍物的 x * y
网格。从 (0, 0)
走到 (x, y)
,必须走 x
步向右,y
步向上,总共 x + y
步。
路径的总数等于在这 x + y
步中选择 x
步向右(或 y
步向上)的组合数。
2. 核心思想:基于排序点集的动态规划
当存在障碍物时,我们不能直接使用上述公式。我们可以通过动态规划来计算到达每个关键点(障碍点和终点)的合法路径数。
算法步骤:
-
构建点集:将
n
个障碍点和终点(x, y)
视为n+1
个关键点。 -
排序:对这
n+1
个关键点进行排序。一个有效的排序规则是按x
坐标升序,若x
坐标相同,则按y
坐标升序。这确保了路径总是从索引较小的点流向索引较大的点。 -
定义 DP 状态:
- 设
dp[i]
为从起点(0, 0)
出发,到达第i
个关键点(排序后),且途中不经过任何其他关键点j
(j < i
) 的路径数量。
- 设
-
DP 转移方程:
- 为了计算
dp[i]
,我们首先计算从起点(0, 0)
到达点i
(x_i, y_i)
的所有可能路径数,即C(x_i + y_i, x_i)
。 - 然后,从这个总数中减去所有“不合法”的路径。不合法的路径是指在到达点
i
之前,已经经过了某个中间关键点j
(j < i
)。 - 对于每个
j < i
,从(0, 0)
途经j
再到i
的路径数可以分解为:(从(0, 0)
到j
的合法路径数)(从
j
到i
的无限制路径数)。 - (从
(0, 0)
到j
的合法路径数) 正是我们已经计算出的dp[j]
。 - (从
j
到i
的无限制路径数) 是C((x_i-x_j) + (y_i-y_j), x_i-x_j)
。 - 综上,
dp[i]
的计算公式为:(此处的求和只对满足
x_j \le x_i
且y_j \le y_i
的j
进行)
- 为了计算
-
大数处理:
- 由于
x, y
的范围很大,组合数的结果会超出标准64位整数的表示范围。因此,必须使用支持大数运算的工具,如 Python 的原生整数或 Java 的BigInteger
类。
- 由于
-
最终结果:
- 将终点
(x, y)
加入障碍点列表并排序后,它会成为第k
个关键点,那么dp[k]
就是最终的答案。
- 将终点
代码
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
// 使用 unsigned __int128 来处理大数,但这对于题目的极限数据 (x,y <= 100000) 仍然不够。
// 在 C++ 环境下,这个问题需要一个完整的大数库才能保证100%正确。
// 此处代码可以通过样例和部分测试用例。
using ull = unsigned __int128;
struct Point {
long long x, y;
};
bool comparePoints(const Point& a, const Point& b) {
if (a.x != b.x) {
return a.x < b.x;
}
return a.y < b.y;
}
ull combinations(long long n, long long k) {
if (k < 0 || k > n) {
return 0;
}
if (k == 0 || k == n) {
return 1;
}
if (k > n / 2) {
k = n - k;
}
ull res = 1;
for (long long i = 1; i <= k; ++i) {
res = res * (n - i + 1) / i;
}
return res;
}
ostream& operator<<(ostream& os, ull val) {
if (val == 0) return os << "0";
string s = "";
while (val > 0) {
s += (val % 10) + '0';
val /= 10;
}
reverse(s.begin(), s.end());
return os << s;
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
long long x, y;
int n;
cin >> x >> y >> n;
vector<Point> points(n);
for (int i = 0; i < n; ++i) {
cin >> points[i].x >> points[i].y;
}
points.push_back({x, y});
sort(points.begin(), points.end(), comparePoints);
int total_points = n + 1;
vector<ull> dp(total_points);
for (int i = 0; i < total_points; ++i) {
dp[i] = combinations(points[i].x + points[i].y, points[i].x);
for (int j = 0; j < i; ++j) {
if (points[j].x <= points[i].x && points[j].y <= points[i].y) {
long long dx = points[i].x - points[j].x;
long long dy = points[i].y - points[j].y;
dp[i] -= dp[j] * combinations(dx + dy, dx);
}
}
}
cout << dp[total_points - 1] << endl;
return 0;
}
import java.math.BigInteger;
import java.util.Arrays;
import java.util.Scanner;
public class Main {
static class Point implements Comparable<Point> {
long x, y;
Point(long x, long y) {
this.x = x;
this.y = y;
}
@Override
public int compareTo(Point other) {
if (this.x != other.x) {
return Long.compare(this.x, other.x);
}
return Long.compare(this.y, other.y);
}
}
// 使用 BigInteger 计算组合数 C(n, k)
private static BigInteger combinations(long n, long k) {
if (k < 0 || k > n) {
return BigInteger.ZERO;
}
if (k == 0 || k == n) {
return BigInteger.ONE;
}
if (k > n / 2) {
k = n - k;
}
BigInteger res = BigInteger.ONE;
for (long i = 1; i <= k; i++) {
res = res.multiply(BigInteger.valueOf(n - i + 1)).divide(BigInteger.valueOf(i));
}
return res;
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
long x = sc.nextLong();
long y = sc.nextLong();
int n = sc.nextInt();
Point[] points = new Point[n + 1];
for (int i = 0; i < n; i++) {
points[i] = new Point(sc.nextLong(), sc.nextLong());
}
points[n] = new Point(x, y);
Arrays.sort(points);
int totalPoints = n + 1;
BigInteger[] dp = new BigInteger[totalPoints];
for (int i = 0; i < totalPoints; i++) {
dp[i] = combinations(points[i].x + points[i].y, points[i].x);
for (int j = 0; j < i; j++) {
if (points[j].x <= points[i].x && points[j].y <= points[i].y) {
long dx = points[i].x - points[j].x;
long dy = points[i].y - points[j].y;
BigInteger paths_j_to_i = combinations(dx + dy, dx);
dp[i] = dp[i].subtract(dp[j].multiply(paths_j_to_i));
}
}
}
System.out.println(dp[totalPoints - 1]);
}
}
import sys
import math
# 使用 sys.stdin.readline().split() 读取输入以提高效率
def solve():
try:
line = sys.stdin.readline()
if not line:
return
x, y, n = map(int, line.split())
points = []
for _ in range(n):
points.append(tuple(map(int, sys.stdin.readline().split())))
# 将终点加入点集并排序
points.append((x, y))
points.sort()
dp = {}
for i in range(n + 1):
px, py = points[i]
# 计算从 (0,0) 到 (px, py) 的总路径数
# math.comb(n, k) 可以高效且准确地处理大数
total_paths = math.comb(px + py, px)
# 减去经过前面某个点的路径
for j in range(i):
jx, jy = points[j]
if jx <= px and jy <= py:
dx, dy = px - jx, py - jy
paths_j_to_i = math.comb(dx + dy, dx)
total_paths -= dp[(jx, jy)] * paths_j_to_i
dp[(px, py)] = total_paths
print(dp[(x, y)])
except (IOError, ValueError):
return
solve()
算法及复杂度
-
算法:动态规划、组合数学、容斥原理
-
时间复杂度:
。其中
N
是障碍点的数量。主要开销来自于两层循环计算 DP 数组。排序需要。组合数的计算在使用 Python
math.comb
或 JavaBigInteger
的实现下非常高效,不会成为瓶颈。 -
空间复杂度:
,用于存储关键点列表和 DP 数组。