题目链接

小红送外卖

题目描述

在一个带权无向图中,给定一个起始点(1号美食街)和 个送餐目的地。对于每个目的地,都需要从起始点出发,送达后再返回起始点。求完成所有 次送餐任务所需的最短总骑行距离。

解题思路

这个问题的核心是计算从一个固定的起点(1号节点)到多个不同目的地的最短往返距离之和。

  1. 独立任务:每次送餐都是一次独立的任务,从美食街出发,到达目的地,再返回。送餐的顺序不影响总距离。
  2. 往返距离:对于一次到学校 的送餐任务,其最短骑行距离为“1到的最短路”+“到1的最短路”。因为是无向图,这两段距离是相等的。所以,一次任务的距离就是
  3. 总距离:总的最短距离就是所有 次任务的最短距离之和。
  4. 单源最短路径:因此,问题转化为:首先计算出1号节点到图中所有其他节点的最短路径,然后根据给定的 个目的地,将对应的往返距离累加起来。
  5. 算法选择:由于图的边权都是非负的,这是一个经典的单源最短路径问题,使用 Dijkstra 算法 是最合适的选择。

算法流程

  1. 构建图的邻接表表示。
  2. 从1号节点作为源点,运行 Dijkstra 算法,计算出到所有其他节点的最短距离,并存储在 dist 数组中。
  3. 初始化总距离 total_distance = 0
  4. 读取 个送餐目的地。对于每一个目的地 dest,查询 dist[dest],然后将 dist[dest] * 2 累加到 total_distance 中。
  5. 输出最终的 total_distance

代码

#include <iostream>
#include <vector>
#include <queue>

using namespace std;
using ll = long long;
using P = pair<ll, int>;

const ll INF = 1e18;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n, m, k;
    cin >> n >> m >> k;

    vector<vector<P>> adj(n + 1);
    for (int i = 0; i < m; ++i) {
        int u, v, w;
        cin >> u >> v >> w;
        adj[u].push_back({v, w});
        adj[v].push_back({u, w});
    }

    vector<ll> dist(n + 1, INF);
    priority_queue<P, vector<P>, greater<P>> pq;

    dist[1] = 0;
    pq.push({0, 1});

    while (!pq.empty()) {
        auto [d, u] = pq.top();
        pq.pop();

        if (d > dist[u]) {
            continue;
        }

        for (auto& edge : adj[u]) {
            int v = edge.first;
            int w = edge.second;
            if (dist[u] + w < dist[v]) {
                dist[v] = dist[u] + w;
                pq.push({dist[v], v});
            }
        }
    }

    ll total_distance = 0;
    for (int i = 0; i < k; ++i) {
        int dest;
        cin >> dest;
        if (dist[dest] != INF) {
            total_distance += dist[dest] * 2;
        }
    }

    cout << total_distance << endl;

    return 0;
}
import java.util.*;

public class Main {
    static class Edge {
        int to;
        int weight;

        Edge(int to, int weight) {
            this.to = to;
            this.weight = weight;
        }
    }

    static class State implements Comparable<State> {
        long dist;
        int u;

        State(long dist, int u) {
            this.dist = dist;
            this.u = u;
        }

        @Override
        public int compareTo(State other) {
            return Long.compare(this.dist, other.dist);
        }
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();
        int k = sc.nextInt();

        List<List<Edge>> adj = new ArrayList<>();
        for (int i = 0; i <= n; i++) {
            adj.add(new ArrayList<>());
        }

        for (int i = 0; i < m; i++) {
            int u = sc.nextInt();
            int v = sc.nextInt();
            int w = sc.nextInt();
            adj.get(u).add(new Edge(v, w));
            adj.get(v).add(new Edge(u, w));
        }

        long[] dist = new long[n + 1];
        Arrays.fill(dist, Long.MAX_VALUE);
        PriorityQueue<State> pq = new PriorityQueue<>();

        dist[1] = 0;
        pq.offer(new State(0, 1));

        while (!pq.isEmpty()) {
            State current = pq.poll();
            long d = current.dist;
            int u = current.u;

            if (d > dist[u]) {
                continue;
            }

            for (Edge edge : adj.get(u)) {
                int v = edge.to;
                int w = edge.weight;
                if (dist[u] + w < dist[v]) {
                    dist[v] = dist[u] + w;
                    pq.offer(new State(dist[v], v));
                }
            }
        }

        long totalDistance = 0;
        for (int i = 0; i < k; i++) {
            int dest = sc.nextInt();
            if (dist[dest] != Long.MAX_VALUE) {
                totalDistance += dist[dest] * 2;
            }
        }

        System.out.println(totalDistance);
    }
}
import sys
import heapq

def main():
    input = sys.stdin.readline
    n, m, k = map(int, input().split())

    adj = [[] for _ in range(n + 1)]
    for _ in range(m):
        u, v, w = map(int, input().split())
        adj[u].append((v, w))
        adj[v].append((u, w))

    dist = [float('inf')] * (n + 1)
    pq = [(0, 1)]  # (distance, node)
    dist[1] = 0

    while pq:
        d, u = heapq.heappop(pq)

        if d > dist[u]:
            continue

        for v, w in adj[u]:
            if dist[u] + w < dist[v]:
                dist[v] = dist[u] + w
                heapq.heappush(pq, (dist[v], v))

    destinations = list(map(int, input().split()))
    
    total_distance = 0
    for dest in destinations:
        if dist[dest] != float('inf'):
            total_distance += dist[dest] * 2
            
    print(total_distance)

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:Dijkstra 算法
  • 时间复杂度:,其中 是顶点数, 是边数, 是送餐次数。Dijkstra 算法的时间复杂度为 ,后续累加距离的复杂度为
  • 空间复杂度:,用于存储邻接表和 Dijkstra 算法所需的辅助数组。