题目链接

Shopee的办公室(二)

题目描述

在一个网格中,小虾同学从左下角 (0, 0) 走到右上角 (x, y)。他每次只能向右或向上移动一步。

网格中有 n 个 boss 的位置,小虾同学的路径不能经过这些位置。

求从起点到终点,有多少种不同的合法走法?

解题思路

这是一个带有障碍物的网格路径计数问题。解决此类问题的经典方法是容斥原理,而将其系统化实现则通常采用动态规划

1. 基础模型:无障碍路径计数

首先,考虑一个没有障碍物的 x * y 网格。从 (0, 0) 走到 (x, y),必须走 x 步向右,y 步向上,总共 x + y 步。

路径的总数等于在这 x + y 步中选择 x 步向右(或 y 步向上)的组合数。

2. 核心思想:基于排序点集的动态规划

当存在障碍物时,我们不能直接使用上述公式。我们可以通过动态规划来计算到达每个关键点(障碍点和终点)的合法路径数

算法步骤:

  1. 构建点集:将 n 个障碍点和终点 (x, y) 视为 n+1 个关键点。

  2. 排序:对这 n+1 个关键点进行排序。一个有效的排序规则是按 x 坐标升序,若 x 坐标相同,则按 y 坐标升序。这确保了路径总是从索引较小的点流向索引较大的点。

  3. 定义 DP 状态

    • dp[i] 为从起点 (0, 0) 出发,到达第 i 个关键点(排序后),且途中不经过任何其他关键点 j (j < i) 的路径数量。
  4. DP 转移方程

    • 为了计算 dp[i],我们首先计算从起点 (0, 0) 到达点 i (x_i, y_i)所有可能路径数,即 C(x_i + y_i, x_i)
    • 然后,从这个总数中减去所有“不合法”的路径。不合法的路径是指在到达点 i 之前,已经经过了某个中间关键点 j (j < i)。
    • 对于每个 j < i,从 (0, 0) 途经 j 再到 i 的路径数可以分解为:(从 (0, 0)j 的合法路径数) (从 ji 的无限制路径数)。
    • (从 (0, 0)j 的合法路径数) 正是我们已经计算出的 dp[j]
    • (从 ji 的无限制路径数) 是 C((x_i-x_j) + (y_i-y_j), x_i-x_j)
    • 综上,dp[i] 的计算公式为: (此处的求和只对满足 x_j \le x_iy_j \le y_ij 进行)
  5. 大数处理

    • 由于 x, y 的范围很大,组合数的结果会超出标准64位整数的表示范围。因此,必须使用支持大数运算的工具,如 Python 的原生整数或 Java 的 BigInteger 类。
  6. 最终结果

    • 将终点 (x, y) 加入障碍点列表并排序后,它会成为第 k 个关键点,那么 dp[k] 就是最终的答案。

代码

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

// 使用 unsigned __int128 来处理大数,但这对于题目的极限数据 (x,y <= 100000) 仍然不够。
// 在 C++ 环境下,这个问题需要一个完整的大数库才能保证100%正确。
// 此处代码可以通过样例和部分测试用例。
using ull = unsigned __int128;

struct Point {
    long long x, y;
};

bool comparePoints(const Point& a, const Point& b) {
    if (a.x != b.x) {
        return a.x < b.x;
    }
    return a.y < b.y;
}

ull combinations(long long n, long long k) {
    if (k < 0 || k > n) {
        return 0;
    }
    if (k == 0 || k == n) {
        return 1;
    }
    if (k > n / 2) {
        k = n - k;
    }
    ull res = 1;
    for (long long i = 1; i <= k; ++i) {
        res = res * (n - i + 1) / i;
    }
    return res;
}

ostream& operator<<(ostream& os, ull val) {
    if (val == 0) return os << "0";
    string s = "";
    while (val > 0) {
        s += (val % 10) + '0';
        val /= 10;
    }
    reverse(s.begin(), s.end());
    return os << s;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    long long x, y;
    int n;
    cin >> x >> y >> n;

    vector<Point> points(n);
    for (int i = 0; i < n; ++i) {
        cin >> points[i].x >> points[i].y;
    }
    points.push_back({x, y});

    sort(points.begin(), points.end(), comparePoints);

    int total_points = n + 1;
    vector<ull> dp(total_points);

    for (int i = 0; i < total_points; ++i) {
        dp[i] = combinations(points[i].x + points[i].y, points[i].x);
        for (int j = 0; j < i; ++j) {
            if (points[j].x <= points[i].x && points[j].y <= points[i].y) {
                long long dx = points[i].x - points[j].x;
                long long dy = points[i].y - points[j].y;
                dp[i] -= dp[j] * combinations(dx + dy, dx);
            }
        }
    }

    cout << dp[total_points - 1] << endl;

    return 0;
}
import java.math.BigInteger;
import java.util.Arrays;
import java.util.Scanner;

public class Main {
    static class Point implements Comparable<Point> {
        long x, y;

        Point(long x, long y) {
            this.x = x;
            this.y = y;
        }

        @Override
        public int compareTo(Point other) {
            if (this.x != other.x) {
                return Long.compare(this.x, other.x);
            }
            return Long.compare(this.y, other.y);
        }
    }

    // 使用 BigInteger 计算组合数 C(n, k)
    private static BigInteger combinations(long n, long k) {
        if (k < 0 || k > n) {
            return BigInteger.ZERO;
        }
        if (k == 0 || k == n) {
            return BigInteger.ONE;
        }
        if (k > n / 2) {
            k = n - k;
        }

        BigInteger res = BigInteger.ONE;
        for (long i = 1; i <= k; i++) {
            res = res.multiply(BigInteger.valueOf(n - i + 1)).divide(BigInteger.valueOf(i));
        }
        return res;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        long x = sc.nextLong();
        long y = sc.nextLong();
        int n = sc.nextInt();

        Point[] points = new Point[n + 1];
        for (int i = 0; i < n; i++) {
            points[i] = new Point(sc.nextLong(), sc.nextLong());
        }
        points[n] = new Point(x, y);

        Arrays.sort(points);

        int totalPoints = n + 1;
        BigInteger[] dp = new BigInteger[totalPoints];

        for (int i = 0; i < totalPoints; i++) {
            dp[i] = combinations(points[i].x + points[i].y, points[i].x);
            for (int j = 0; j < i; j++) {
                if (points[j].x <= points[i].x && points[j].y <= points[i].y) {
                    long dx = points[i].x - points[j].x;
                    long dy = points[i].y - points[j].y;
                    BigInteger paths_j_to_i = combinations(dx + dy, dx);
                    dp[i] = dp[i].subtract(dp[j].multiply(paths_j_to_i));
                }
            }
        }

        System.out.println(dp[totalPoints - 1]);
    }
}
import sys
import math

# 使用 sys.stdin.readline().split() 读取输入以提高效率
def solve():
    try:
        line = sys.stdin.readline()
        if not line:
            return
        x, y, n = map(int, line.split())
        
        points = []
        for _ in range(n):
            points.append(tuple(map(int, sys.stdin.readline().split())))
        
        # 将终点加入点集并排序
        points.append((x, y))
        points.sort()
        
        dp = {}
        
        for i in range(n + 1):
            px, py = points[i]
            
            # 计算从 (0,0) 到 (px, py) 的总路径数
            # math.comb(n, k) 可以高效且准确地处理大数
            total_paths = math.comb(px + py, px)
            
            # 减去经过前面某个点的路径
            for j in range(i):
                jx, jy = points[j]
                if jx <= px and jy <= py:
                    dx, dy = px - jx, py - jy
                    paths_j_to_i = math.comb(dx + dy, dx)
                    total_paths -= dp[(jx, jy)] * paths_j_to_i
            
            dp[(px, py)] = total_paths
            
        print(dp[(x, y)])

    except (IOError, ValueError):
        return

solve()

算法及复杂度

  • 算法:动态规划、组合数学、容斥原理

  • 时间复杂度: 。其中 N 是障碍点的数量。主要开销来自于两层循环计算 DP 数组。排序需要 。组合数的计算在使用 Python math.comb 或 Java BigInteger 的实现下非常高效,不会成为瓶颈。

  • 空间复杂度: ,用于存储关键点列表和 DP 数组。