题目链接

电话线

题目描述

给定一个包含 个电话杆和 段可用电缆的图。每段电缆连接两个电话杆并有相应的长度。

目标是连接电话杆 1 (公共电话网) 和电话杆 (庄园)。电话公司免费提供最多 段电缆。如果选择的连接路径需要超过 段电缆,则需要自行支付超出的部分。支付的费用等于所需购买的电缆中最长的那一段的长度

你需要计算连接 1 号和 号电话杆所需的最小费用。如果无法连接,则输出 -1。

解题思路

这是一个典型的**“最小化最大值”问题,是二分答案**的经典应用场景。

我们要求解的是“最小的费用”,这个费用实际上是路径上某条边的长度。我们可以对这个“费用”进行二分查找。

1. 二分答案

我们二分最终的答案,即二分我们愿意支付的最大费用,记为 max_len

  • 单调性:如果一个最大费用 max_len 是可行的(即存在一条路径,其付费电缆长度都不超过 max_len),那么任何大于 max_len 的费用也一定是可行的(因为允许的边集变大了)。这个单调性是二分答案的基础。
  • 二分范围:费用的下界是 0(所有电缆都免费),上界可以是所有电缆长度的最大值。
  • check(max_len) 函数:我们需要设计一个函数来验证,当最大付费长度为 max_len 时,是否能成功连接 1 号和 号电话杆。

2. check(max_len) 函数的设计

check(max_len) 的核心是判断:是否存在一条从 1 到 的路径,使得我们使用的付费电缆长度都不超过 max_len

这可以转化为一个最短路径问题。我们对图中的电缆(边)重新定义“成本”:

  • 如果一段电缆的长度 L > max_len,那么它必须使用免费名额。我们可以认为它的“新成本”是 1(因为它消耗了一个免费名额)。
  • 如果一段电缆的长度 L <= max_len,那么我们可以选择为它付费(费用不超过 max_len),所以它不必使用免费名额。我们可以认为它的“新成本”是 0。

现在问题就变成了:在这个“新成本”图上,从 1 到 最短路径是多少?这个最短路径的长度,就代表了要连接 1 和 至少需要使用多少个免费名额

  • 设这个最短路径长度(最少免费名额数)为 min_free_needed
  • 如果 min_free_needed <= K (我们拥有的免费名额),说明在最大付费 max_len 的限制下,我们可以成功连接。check(max_len) 返回 true
  • 如果 min_free_needed > K,说明免费名额不够用,连接失败。check(max_len) 返回 false

3. 求解 0-1 权图最短路

由于新图的边权只有 0 和 1,我们可以使用一种比标准 Dijkstra 更高效的算法:0-1 宽度优先搜索 (0-1 BFS)

  • 使用一个双端队列 (deque) 代替优先队列。
  • 当松弛一个点 时:
    • 如果通向 的边是 0 权边,将 加入队首
    • 如果通向 的边是 1 权边,将 加入队尾
  • 这样可以保证队列中的节点始终按距离单调排列,从而以 的线性时间复杂度完成最短路计算。

算法整体流程

  1. 对最终费用 max_len 进行二分查找。
  2. check(max_len) 函数中,运行 0-1 BFS 计算从 1 到 所需的最少免费名额。
  3. 根据 check 的结果调整二分范围,找到满足条件的最小 max_len
  4. 如果二分查找结束后,找不到任何可行的 max_len,说明无法连通,输出 -1。

代码

#include <iostream>
#include <vector>
#include <deque>
#include <algorithm>

using namespace std;

const int INF = 1e9;
int n, p, k;
vector<pair<int, int>> adj[1001];

bool check(int max_len) {
    vector<int> dist(n + 1, INF);
    deque<int> dq;

    dist[1] = 0;
    dq.push_front(1);

    while (!dq.empty()) {
        int u = dq.front();
        dq.pop_front();

        for (auto& edge : adj[u]) {
            int v = edge.first;
            int len = edge.second;
            int cost = (len > max_len);

            if (dist[u] + cost < dist[v]) {
                dist[v] = dist[u] + cost;
                if (cost == 1) {
                    dq.push_back(v);
                } else {
                    dq.push_front(v);
                }
            }
        }
    }
    return dist[n] <= k;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    cin >> n >> p >> k;

    int max_l = 0;
    for (int i = 0; i < p; ++i) {
        int u, v, l;
        cin >> u >> v >> l;
        adj[u].push_back({v, l});
        adj[v].push_back({u, l});
        max_l = max(max_l, l);
    }

    int left = 0, right = max_l, ans = -1;
    while (left <= right) {
        int mid = left + (right - left) / 2;
        if (check(mid)) {
            ans = mid;
            right = mid - 1;
        } else {
            left = mid + 1;
        }
    }

    cout << ans << endl;

    return 0;
}
import java.util.*;

class Main {
    static int n, p, k;
    static List<List<Edge>> adj;
    static final int INF = Integer.MAX_VALUE;

    static class Edge {
        int to;
        int length;

        Edge(int to, int length) {
            this.to = to;
            this.length = length;
        }
    }

    static boolean check(int maxLen) {
        int[] dist = new int[n + 1];
        Arrays.fill(dist, INF);
        Deque<Integer> dq = new ArrayDeque<>();

        dist[1] = 0;
        dq.addFirst(1);

        while (!dq.isEmpty()) {
            int u = dq.pollFirst();

            for (Edge edge : adj.get(u)) {
                int v = edge.to;
                int len = edge.length;
                int cost = (len > maxLen) ? 1 : 0;

                if (dist[u] != INF && dist[u] + cost < dist[v]) {
                    dist[v] = dist[u] + cost;
                    if (cost == 1) {
                        dq.addLast(v);
                    } else {
                        dq.addFirst(v);
                    }
                }
            }
        }
        return dist[n] <= k;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        n = sc.nextInt();
        p = sc.nextInt();
        k = sc.nextInt();

        adj = new ArrayList<>();
        for (int i = 0; i <= n; i++) {
            adj.add(new ArrayList<>());
        }

        int maxL = 0;
        for (int i = 0; i < p; i++) {
            int u = sc.nextInt();
            int v = sc.nextInt();
            int l = sc.nextInt();
            adj.get(u).add(new Edge(v, l));
            adj.get(v).add(new Edge(u, l));
            maxL = Math.max(maxL, l);
        }

        int left = 0, right = maxL, ans = -1;
        
        // Special case: if 1 cannot reach N at all.
        if (!check(maxL)) {
            System.out.println(-1);
            return;
        }

        while (left <= right) {
            int mid = left + (right - left) / 2;
            if (check(mid)) {
                ans = mid;
                right = mid - 1;
            } else {
                left = mid + 1;
            }
        }
        
        System.out.println(ans);
    }
}
import collections

def check(max_len, n, adj, k):
    dist = {i: float('inf') for i in range(1, n + 1)}
    dist[1] = 0
    dq = collections.deque([1])

    while dq:
        u = dq.popleft()

        for v, length in adj.get(u, []):
            cost = 1 if length > max_len else 0
            if dist[u] + cost < dist[v]:
                dist[v] = dist[u] + cost
                if cost == 1:
                    dq.append(v)
                else:
                    dq.appendleft(v)
    
    return dist[n] <= k

def main():
    n, p, k = map(int, input().split())
    
    adj = {}
    max_l = 0
    for _ in range(p):
        u, v, l = map(int, input().split())
        if u not in adj: adj[u] = []
        if v not in adj: adj[v] = []
        adj[u].append((v, l))
        adj[v].append((u, l))
        max_l = max(max_l, l)

    left, right = 0, max_l
    ans = -1

    # Check for basic connectivity first. If even with all cables free it's impossible, output -1.
    # The largest possible max_len is max_l, meaning all cables cost 0 free slots.
    if not check(max_l, n, adj, k):
        print(-1)
        return

    while left <= right:
        mid = (left + right) // 2
        if check(mid, n, adj, k):
            ans = mid
            right = mid - 1
        else:
            left = mid + 1
            
    print(ans)

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:二分答案 + 0-1 宽度优先搜索 (BFS)
  • 时间复杂度:。其中 是电话杆数, 是电缆数, 是最大电缆长度。二分答案需要 次迭代,每次迭代内部运行一次 0-1 BFS,其复杂度为
  • 空间复杂度:。用于存储图的邻接表以及 0-1 BFS 所需的距离数组和双端队列。