REALHW97 地下探险

题目链接

地下探险

题目描述

一位探险家需要在一个二维网格表示的地下洞穴中,从入口移动到古代遗物所在地。

移动规则:

  • 每次只能向上、下、左、右四个方向移动一格。
  • 探险家携带的氧气瓶最多支持连续行走 步。

地图元素:

  • 0: 可通行的路径
  • 1: 无法通行的岩壁
  • 2: 氧气补给站。到达补给站时,氧气会瞬间充满,恢复到最大值

任务目标: 计算探险家从入口到达遗物位置所需的最短路径长度(总步数)。如果无法到达,则输出 -1。

解题思路

这是一个在时间、空间限制都非常严格的图论问题。简单的三维 BFS 会超内存,而之前几次优化的图建模方案在逻辑上存在缺陷或效率不足。正确的解法需要将不同类型的路径分开处理,构建一个精确的“关键点图”,然后求解最短路。

这里的“关键点”包括:起点、终点、所有氧气补给站

核心思想:构建关键点图 + Dijkstra

  1. 识别节点: 我们将起点、终点和所有补给站抽象为新图中的节点。设共有 个补给站,则新图共有 个节点(start, dest, station_0, ..., station_{P-1})。

  2. 计算边权(路径): 新图中节点间的“边”代表从一个关键点到另一个关键点的可行路径。路径分为三种:

    • 补给站之间 (Station-to-Station): 从一个补给站出发,氧气是满的。可以在氧气耗尽前到达另一个补给站。这类路径的最短距离可以通过一次多源 BFS 高效计算。这是该算法最核心的优化。
      • 正确的多源BFS实现:
        1. 区域划分: 从所有补给站同时开始 BFS,将地图划分为若干个“势力范围”。每个格子记录下离它最近的补给站是谁 (owner) 和到这个补给站的距离 (dist)。
        2. 边界扫描: 在 BFS 完全结束后,遍历整个地图。当发现一个格子和它的邻居属于不同补给站的势力范围时,就意味着我们找到了连接这两个补给站的一条路径。通过 dist1 + dist2 + 1 计算路径长度,并用它来更新这两个补给站之间的最短距离。
    • 起点到其他关键点 (Start-to-Others): 从起点出发,氧气是满的。我们需要计算在 步之内,能到达哪些补给站,以及能否直接到达终点。这可以通过一次从起点开始的限制步数的 BFS (k_limited_bfs) 来完成。
    • 其他关键点到终点 (Others-to-Dest): 任何一个关键点(起点或补给站)都可能作为到达终点的“最后一跳”。我们需要计算从所有补给站出发,在 步之内能否到达终点。这同样可以通过限制步数的 BFS 完成,但为了效率,可以反向思考,从终点进行一次 k_limited_bfs,计算它能到达哪些补给站。
  3. Dijkstra 求解: 根据上述计算出的所有可行路径(长度 )构建一个邻接表。然后,在这个关键点图上,以起点为源点,运行一次标准的 Dijkstra 算法,即可求得到达终点的最短总步数。

这个方案通过组合使用最高效的多源 BFS 和针对性的单源 BFS,精确地构建了问题模型,并在时间和空间上都达到了最优。

代码

#include <iostream>
#include <vector>
#include <queue>
#include <map>
#include <algorithm>

using namespace std;

const int INF = 1e9;

struct State {
    int r, c, dist;
};

struct DijkstraState {
    int u, dist;
    bool operator>(const DijkstraState& other) const {
        return dist > other.dist;
    }
};

int r_max, c_max;
vector<vector<int>> grid;
int dr[] = {-1, 1, 0, 0};
int dc[] = {0, 0, -1, 1};

bool is_valid(int r, int c) {
    return r >= 0 && r < r_max && c >= 0 && c < c_max && grid[r][c] != 1;
}

// 从指定点开始的k步限定BFS,返回到targets中各点的距离
vector<int> k_limited_bfs(int start_r, int start_c, int k, const vector<pair<int, int>>& targets) {
    vector<int> results(targets.size(), INF);
    if (!is_valid(start_r, start_c)) return results;

    map<pair<int, int>, int> target_coord_to_idx;
    for(size_t i = 0; i < targets.size(); ++i) {
        target_coord_to_idx[targets[i]] = i;
    }

    queue<State> q;
    vector<vector<int>> d(r_max, vector<int>(c_max, -1));
    q.push({start_r, start_c, 0});
    d[start_r][start_c] = 0;
    
    // 提前检查起点是否就是某个目标点
    if(target_coord_to_idx.count({start_r, start_c})) {
        results[target_coord_to_idx[{start_r, start_c}]] = 0;
    }

    while (!q.empty()) {
        State current = q.front();
        q.pop();

        if (current.dist >= k) continue;

        for (int i = 0; i < 4; ++i) {
            int nr = current.r + dr[i];
            int nc = current.c + dc[i];
            if (is_valid(nr, nc) && d[nr][nc] == -1) {
                d[nr][nc] = current.dist + 1;
                q.push({nr, nc, d[nr][nc]});
                if(target_coord_to_idx.count({nr, nc})) {
                    results[target_coord_to_idx[{nr, nc}]] = d[nr][nc];
                }
            }
        }
    }
    return results;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    cin >> r_max >> c_max;
    grid.assign(r_max, vector<int>(c_max));
    vector<pair<int, int>> stations;

    for (int i = 0; i < r_max; ++i) {
        for (int j = 0; j < c_max; ++j) {
            cin >> grid[i][j];
            if (grid[i][j] == 2) {
                stations.push_back({i, j});
            }
        }
    }

    int start_r, start_c, dest_r, dest_c, k;
    cin >> start_r >> start_c >> dest_r >> dest_c >> k;
    
    if (start_r == dest_r && start_c == dest_c) {
        cout << 0 << endl;
        return 0;
    }

    int num_stations = stations.size();
    int start_id = num_stations;
    int dest_id = num_stations + 1;
    int num_meta_nodes = num_stations + 2;
    vector<vector<pair<int, int>>> meta_adj(num_meta_nodes);

    // 1. 计算站与站之间的距离
    if (num_stations > 0) {
        vector<vector<int>> dist_to_station(r_max, vector<int>(c_max, INF));
        vector<vector<int>> owner(r_max, vector<int>(c_max, -1));
        queue<pair<int, int>> q;
        for (int i = 0; i < num_stations; ++i) {
            dist_to_station[stations[i].first][stations[i].second] = 0;
            owner[stations[i].first][stations[i].second] = i;
            q.push(stations[i]);
        }
        while(!q.empty()){
            pair<int, int> curr = q.front(); q.pop();
            for(int i=0; i<4; ++i){
                int nr = curr.first + dr[i], nc = curr.second + dc[i];
                if(is_valid(nr, nc) && owner[nr][nc] == -1){
                    owner[nr][nc] = owner[curr.first][curr.second];
                    dist_to_station[nr][nc] = dist_to_station[curr.first][curr.second] + 1;
                    q.push({nr, nc});
                }
            }
        }
        map<pair<int, int>, int> station_dist;
        for(int r=0; r<r_max; ++r) for(int c=0; c<c_max; ++c){
            for(int i=0; i<4; ++i){
                int nr = r + dr[i], nc = c + dc[i];
                if(is_valid(nr, nc) && owner[r][c]!=-1 && owner[nr][nc]!=-1 && owner[r][c]!=owner[nr][nc]){
                    int u = owner[r][c], v = owner[nr][nc];
                    int w = dist_to_station[r][c] + dist_to_station[nr][nc] + 1;
                    pair<int, int> key = {min(u, v), max(u, v)};
                    if (station_dist.find(key) == station_dist.end() || w < station_dist[key]) {
                        station_dist[key] = w;
                    }
                }
            }
        }
        for(auto const& [key_pair, w] : station_dist) {
            if (w <= k) {
                meta_adj[key_pair.first].push_back({key_pair.second, w});
                meta_adj[key_pair.second].push_back({key_pair.first, w});
            }
        }
    }

    // 2. 计算起点到其他关键点的距离
    vector<pair<int,int>> start_targets = stations;
    start_targets.push_back({dest_r, dest_c});
    vector<int> start_dists = k_limited_bfs(start_r, start_c, k, start_targets);
    if(start_dists.back() <= k) meta_adj[start_id].push_back({dest_id, start_dists.back()});
    for(int i=0; i<num_stations; ++i) if(start_dists[i] <= k) meta_adj[start_id].push_back({i, start_dists[i]});

    // 3. 计算其他关键点到终点的距离
    if (!stations.empty()) {
        vector<int> dest_dists = k_limited_bfs(dest_r, dest_c, k, stations);
        for(int i=0; i<num_stations; ++i) if(dest_dists[i] <= k) meta_adj[i].push_back({dest_id, dest_dists[i]});
    }

    // 4. Dijkstra
    priority_queue<DijkstraState, vector<DijkstraState>, greater<DijkstraState>> pq;
    vector<int> final_dist(num_meta_nodes, INF);
    final_dist[start_id] = 0;
    pq.push({start_id, 0});
    while (!pq.empty()) {
        DijkstraState current = pq.top(); pq.pop();
        int u = current.u;
        if(current.dist > final_dist[u]) continue;
        if(u == dest_id) break;
        for(auto& edge : meta_adj[u]){
            int v = edge.first, weight = edge.second;
            if(final_dist[u] + weight < final_dist[v]){
                final_dist[v] = final_dist[u] + weight;
                pq.push({v, final_dist[v]});
            }
        }
    }
    
    if (final_dist[dest_id] == INF) cout << -1 << endl;
    else cout << final_dist[dest_id] << endl;

    return 0;
}
import java.util.*;

public class Main {
    static final int INF = Integer.MAX_VALUE;
    static int R, C, K;
    static int[][] grid;
    static int[] dr = {-1, 1, 0, 0};
    static int[] dc = {0, 0, -1, 1};

    static class State {
        int r, c, dist;
        State(int r, int c, int dist) {
            this.r = r; this.c = c; this.dist = dist;
        }
    }
    
    static class Point {
        int r, c;
        Point(int r, int c) { this.r = r; this.c = c; }
        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;
            Point point = (Point) o;
            return r == point.r && c == point.c;
        }
        @Override
        public int hashCode() {
            return Objects.hash(r, c);
        }
    }

    static boolean isValid(int r, int c) {
        return r >= 0 && r < R && c >= 0 && c < C && grid[r][c] != 1;
    }

    static int[] kLimitedBfs(int startR, int startC, List<Point> targets) {
        int[] results = new int[targets.size()];
        Arrays.fill(results, INF);
        if (!isValid(startR, startC)) return results;

        Map<Point, Integer> targetCoordToIdx = new HashMap<>();
        for (int i = 0; i < targets.size(); i++) {
            targetCoordToIdx.put(targets.get(i), i);
        }

        Queue<State> q = new LinkedList<>();
        int[][] d = new int[R][C];
        for(int[] row : d) Arrays.fill(row, -1);
        
        q.offer(new State(startR, startC, 0));
        d[startR][startC] = 0;

        if (targetCoordToIdx.containsKey(new Point(startR, startC))) {
            results[targetCoordToIdx.get(new Point(startR, startC))] = 0;
        }

        while (!q.isEmpty()) {
            State current = q.poll();
            if (current.dist >= K) continue;

            for (int i = 0; i < 4; i++) {
                int nr = current.r + dr[i];
                int nc = current.c + dc[i];
                if (isValid(nr, nc) && d[nr][nc] == -1) {
                    d[nr][nc] = current.dist + 1;
                    q.offer(new State(nr, nc, d[nr][nc]));
                    Point nextPoint = new Point(nr, nc);
                    if (targetCoordToIdx.containsKey(nextPoint)) {
                        results[targetCoordToIdx.get(nextPoint)] = d[nr][nc];
                    }
                }
            }
        }
        return results;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        R = sc.nextInt(); C = sc.nextInt();
        grid = new int[R][C];
        List<Point> stations = new ArrayList<>();
        for (int i = 0; i < R; i++) {
            for (int j = 0; j < C; j++) {
                grid[i][j] = sc.nextInt();
                if (grid[i][j] == 2) stations.add(new Point(i, j));
            }
        }
        int startR = sc.nextInt(), startC = sc.nextInt();
        int endR = sc.nextInt(), endC = sc.nextInt();
        K = sc.nextInt();
        
        if (startR == endR && startC == endC) { System.out.println(0); return; }

        int numStations = stations.size();
        int startId = numStations, destId = numStations + 1, numMetaNodes = numStations + 2;
        List<List<int[]>> metaAdj = new ArrayList<>();
        for (int i = 0; i < numMetaNodes; i++) metaAdj.add(new ArrayList<>());

        if (numStations > 0) {
            int[][] distToStation = new int[R][C];
            int[][] owner = new int[R][C];
            for(int[] row : distToStation) Arrays.fill(row, INF);
            for(int[] row : owner) Arrays.fill(row, -1);
            Queue<Point> q = new LinkedList<>();
            for (int i = 0; i < numStations; i++) {
                Point p = stations.get(i);
                distToStation[p.r][p.c] = 0;
                owner[p.r][p.c] = i;
                q.offer(p);
            }
            while (!q.isEmpty()) {
                Point curr = q.poll();
                for (int i = 0; i < 4; i++) {
                    int nr = curr.r + dr[i], nc = curr.c + dc[i];
                    if (isValid(nr, nc) && owner[nr][nc] == -1) {
                        owner[nr][nc] = owner[curr.r][curr.c];
                        distToStation[nr][nc] = distToStation[curr.r][curr.c] + 1;
                        q.offer(new Point(nr, nc));
                    }
                }
            }
            Map<Long, Integer> stationDist = new HashMap<>();
            for (int r = 0; r < R; r++) for (int c = 0; c < C; c++) {
                for (int i = 0; i < 4; i++) {
                    int nr = r + dr[i], nc = c + dc[i];
                    if (isValid(nr, nc) && owner[r][c] != -1 && owner[nr][nc] != -1 && owner[r][c] != owner[nr][nc]) {
                        int u = owner[r][c], v = owner[nr][nc];
                        int w = distToStation[r][c] + distToStation[nr][nc] + 1;
                        long key = ((long) Math.min(u, v) << 32) | (long) Math.max(u, v);
                        stationDist.put(key, Math.min(stationDist.getOrDefault(key, INF), w));
                    }
                }
            }
            for (Map.Entry<Long, Integer> entry : stationDist.entrySet()) {
                if (entry.getValue() <= K) {
                    int u = (int) (entry.getKey() >> 32);
                    int v = (int) (entry.getKey() & 0xFFFFFFFFL);
                    metaAdj.get(u).add(new int[]{v, entry.getValue()});
                    metaAdj.get(v).add(new int[]{u, entry.getValue()});
                }
            }
        }

        List<Point> startTargets = new ArrayList<>(stations);
        startTargets.add(new Point(endR, endC));
        int[] startDists = kLimitedBfs(startR, startC, startTargets);
        if (startDists[startDists.length - 1] <= K) metaAdj.get(startId).add(new int[]{destId, startDists[startDists.length - 1]});
        for (int i = 0; i < numStations; i++) if (startDists[i] <= K) metaAdj.get(startId).add(new int[]{i, startDists[i]});

        if (!stations.isEmpty()) {
            int[] destDists = kLimitedBfs(endR, endC, stations);
            for (int i = 0; i < numStations; i++) if (destDists[i] <= K) metaAdj.get(i).add(new int[]{destId, destDists[i]});
        }

        PriorityQueue<int[]> pq = new PriorityQueue<>(Comparator.comparingInt(a -> a[0]));
        int[] finalDist = new int[numMetaNodes];
        Arrays.fill(finalDist, INF);
        finalDist[startId] = 0;
        pq.offer(new int[]{0, startId});
        while (!pq.isEmpty()) {
            int[] current = pq.poll();
            int d = current[0], u = current[1];
            if (d > finalDist[u]) continue;
            if (u == destId) break;
            for (int[] edge : metaAdj.get(u)) {
                int v = edge[0], weight = edge[1];
                if (finalDist[u] + weight < finalDist[v]) {
                    finalDist[v] = finalDist[u] + weight;
                    pq.offer(new int[]{finalDist[v], v});
                }
            }
        }
        
        System.out.println(finalDist[destId] == INF ? -1 : finalDist[destId]);
    }
}
import collections
import heapq

def solve():
    r_max, c_max = map(int, input().split())
    grid = [list(map(int, input().split())) for _ in range(r_max)]
    stations = []
    for r in range(r_max):
        for c in range(c_max):
            if grid[r][c] == 2:
                stations.append((r, c))
    
    start_r, start_c = map(int, input().split())
    dest_r, dest_c = map(int, input().split())
    k = int(input())

    if (start_r, start_c) == (dest_r, dest_c):
        print(0)
        return

    dr = [-1, 1, 0, 0]
    dc = [0, 0, -1, 1]

    def is_valid(r, c):
        return 0 <= r < r_max and 0 <= c < c_max and grid[r][c] != 1

    def k_limited_bfs(start_r, start_c, targets):
        results = {target: float('inf') for target in targets}
        if not is_valid(start_r, start_c): return [results[t] for t in targets]

        q = collections.deque([(start_r, start_c, 0)])
        d = {(start_r, start_c): 0}
        
        if (start_r, start_c) in results:
            results[(start_r, start_c)] = 0
        
        while q:
            r, c, dist = q.popleft()
            if dist >= k: continue
            
            for i in range(4):
                nr, nc = r + dr[i], c + dc[i]
                if is_valid(nr, nc) and (nr, nc) not in d:
                    d[(nr, nc)] = dist + 1
                    q.append((nr, nc, dist + 1))
                    if (nr, nc) in results:
                        results[(nr, nc)] = dist + 1
        return [results[t] for t in targets]

    num_stations = len(stations)
    start_id = num_stations
    dest_id = num_stations + 1
    num_meta_nodes = num_stations + 2
    meta_adj = [[] for _ in range(num_meta_nodes)]

    if num_stations > 0:
        dist_to_station = [[float('inf')] * c_max for _ in range(r_max)]
        owner = [[-1] * c_max for _ in range(r_max)]
        q = collections.deque()
        for i, (r, c) in enumerate(stations):
            dist_to_station[r][c] = 0
            owner[r][c] = i
            q.append((r, c))
        
        while q:
            r, c = q.popleft()
            for i in range(4):
                nr, nc = r + dr[i], c + dc[i]
                if is_valid(nr, nc) and owner[nr][nc] == -1:
                    owner[nr][nc] = owner[r][c]
                    dist_to_station[nr][nc] = dist_to_station[r][c] + 1
                    q.append((nr, nc))
        
        station_dist = collections.defaultdict(lambda: float('inf'))
        for r in range(r_max):
            for c in range(c_max):
                for i in range(4):
                    nr, nc = r + dr[i], c + dc[i]
                    if is_valid(nr, nc) and owner[r][c] != -1 and owner[nr][nc] != -1 and owner[r][c] != owner[nr][nc]:
                        u, v = owner[r][c], owner[nr][nc]
                        key = tuple(sorted((u, v)))
                        w = dist_to_station[r][c] + dist_to_station[nr][nc] + 1
                        station_dist[key] = min(station_dist[key], w)

        for (u, v), w in station_dist.items():
            if w <= k:
                meta_adj[u].append((v, w))
                meta_adj[v].append((u, w))

    start_targets = stations + [(dest_r, dest_c)]
    start_dists = k_limited_bfs(start_r, start_c, start_targets)
    if start_dists[-1] <= k: meta_adj[start_id].append((dest_id, start_dists[-1]))
    for i in range(num_stations):
        if start_dists[i] <= k: meta_adj[start_id].append((i, start_dists[i]))

    if stations:
        dest_dists = k_limited_bfs(dest_r, dest_c, stations)
        for i in range(num_stations):
            if dest_dists[i] <= k: meta_adj[i].append((dest_id, dest_dists[i]))

    pq = [(0, start_id)]
    final_dist = [float('inf')] * num_meta_nodes
    final_dist[start_id] = 0
    
    while pq:
        d, u = heapq.heappop(pq)
        if d > final_dist[u]: continue
        if u == dest_id: break
        for v, weight in meta_adj[u]:
            if final_dist[u] + weight < final_dist[v]:
                final_dist[v] = final_dist[u] + weight
                heapq.heappush(pq, (final_dist[v], v))
    
    result = final_dist[dest_id]
    print(result if result != float('inf') else -1)

solve()

算法及复杂度

  • 算法: 多源 BFS + 单源 BFS + Dijkstra
  • 时间复杂度: ,其中 是补给站数量。最坏情况下,k_limited_bfs 会接近遍历全图,其调用次数与关键点数量相关,成为瓶颈。
  • 空间复杂度: ,主要由 BFS 使用的距离和所有者网格决定。