题目链接

牛之国

题目描述

给定 个城市的二维坐标。所有城市之间都会同时开始修路,速度为 1 单位/年。任意两个城市之间的施工队都沿着直线路径相向而行。当两个施工队相遇时,这两个城市就连通了。

求使得所有城市构成一个连通图所需的最短时间,结果向上取整。

解题思路

这是一个关于连通性的最优化问题,可以巧妙地转化为求解图的最小生成树 (MST)。

时间、距离与连通

首先,我们分析两个城市 何时会连通。设它们的坐标分别为

  • 它们之间的欧几里得距离是
  • 由于两个城市的施工队相向而行,它们的相对速度是 单位/年。
  • 因此,它们相遇所需的时间是

问题转化:最小生成树

我们可以将 个城市看作图的 个顶点。任意两个城市之间都可能建立一条边,边的“权重”就是连接这两个城市所需的时间。

我们的目标是找到一个最小的时间 ,使得在 年后,整个图是连通的。

  • 在时间 时,任意两个城市 之间会存在一条边,当且仅当连接它们所需的时间
  • 这等价于 ,即

我们可以想象时间从小到大流逝的过程。一开始,所有城市都是孤立的。随着时间的推移,距离最近的城市对会首先连通,然后是次近的,以此类推。这恰好就是 Kruskal 算法 构造最小生成树的过程:不断加入权重最小且不会形成环的边,直到所有顶点连通。

当整个图恰好变为连通时,我们所加入的最后一条边,一定是这个连通图(也就是一棵生成树)中权重最大的那条边。这个最大的权重,就是我们所求的最小时间

因此,问题被转化为: 在由 个城市构成的完全图中,找到其最小生成树 (MST),并求出该 MST 中最长边的权重(即连接时间)。

算法步骤

  1. 构建边集:将所有 个城市看作顶点。计算每对城市之间的距离,并生成一个包含所有 条边的列表。为了避免浮点数精度问题和 sqrt 运算的开销,我们暂时使用欧几里得距离的平方作为边的权重进行排序。
  2. 排序边集:将所有边按照权重(距离平方)从小到大排序。
  3. Kruskal 算法
    • 初始化一个并查集 (Disjoint Set Union, DSU),每个城市自成一个集合。
    • 遍历排序后的边集。对于每条边 ,如果 不在同一个连通分量中(通过 find 操作判断),则:
      • 将这条边加入 MST。
      • 合并 所在的集合(通过 union 操作)。
      • 记录下当前这条边的权重(距离平方),因为这可能是最后加入 MST 的边。
    • 当所有城市都属于同一个连通分量时(即只剩下一个集合),我们找到了完整的 MST。此时记录的最后一条边的权重(距离平方),就是 MST 的最大边权,我们称之为
  4. 计算最终结果
    • MST 中最长边的实际距离为
    • 所需的最短时间为
    • 根据题意,将结果向上取整,最终答案为

由于坐标值的范围很大,距离的平方可能会超过 int 甚至 long long 的范围。在 C++ 中需要使用 __int128long double,或者在计算时注意溢出。不过,本题数据范围的坐标差的平方可以用 long long 存储。

代码

#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>
#include <iomanip>

using namespace std;

struct Point {
    long long x, y;
    int id;
};

struct Edge {
    int u, v;
    long long weight; // Squared distance

    bool operator<(const Edge& other) const {
        return weight < other.weight;
    }
};

struct DSU {
    vector<int> parent;
    DSU(int n) {
        parent.resize(n);
        for (int i = 0; i < n; ++i) {
            parent[i] = i;
        }
    }

    int find(int i) {
        if (parent[i] == i) {
            return i;
        }
        return parent[i] = find(parent[i]);
    }

    void unite(int i, int j) {
        int root_i = find(i);
        int root_j = find(j);
        if (root_i != root_j) {
            parent[root_i] = root_j;
        }
    }
};

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    cin >> n;

    vector<Point> points(n);
    for (int i = 0; i < n; ++i) {
        cin >> points[i].x >> points[i].y;
        points[i].id = i;
    }

    if (n <= 1) {
        cout << 0 << endl;
        return 0;
    }

    vector<Edge> edges;
    for (int i = 0; i < n; ++i) {
        for (int j = i + 1; j < n; ++j) {
            long long dx = points[i].x - points[j].x;
            long long dy = points[i].y - points[j].y;
            edges.push_back({i, j, dx * dx + dy * dy});
        }
    }

    sort(edges.begin(), edges.end());

    DSU dsu(n);
    long long max_edge_weight = 0;
    int edges_count = 0;
    for (const auto& edge : edges) {
        if (dsu.find(edge.u) != dsu.find(edge.v)) {
            dsu.unite(edge.u, edge.v);
            max_edge_weight = edge.weight;
            edges_count++;
            if (edges_count == n - 1) {
                break;
            }
        }
    }

    double time = sqrt((double)max_edge_weight) / 2.0;
    cout << (long long)ceil(time) << endl;

    return 0;
}
import java.util.Scanner;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public class Main {
    static class Point {
        long x, y;
        Point(long x, long y) {
            this.x = x;
            this.y = y;
        }
    }

    static class Edge implements Comparable<Edge> {
        int u, v;
        long weight; // Squared distance

        Edge(int u, int v, long weight) {
            this.u = u;
            this.v = v;
            this.weight = weight;
        }

        @Override
        public int compareTo(Edge other) {
            return Long.compare(this.weight, other.weight);
        }
    }

    static class DSU {
        int[] parent;
        DSU(int n) {
            parent = new int[n];
            for (int i = 0; i < n; i++) {
                parent[i] = i;
            }
        }

        int find(int i) {
            if (parent[i] == i) {
                return i;
            }
            return parent[i] = find(parent[i]);
        }

        void unite(int i, int j) {
            int root_i = find(i);
            int root_j = find(j);
            if (root_i != root_j) {
                parent[root_i] = root_j;
            }
        }
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        Point[] points = new Point[n];
        for (int i = 0; i < n; i++) {
            points[i] = new Point(sc.nextLong(), sc.nextLong());
        }

        if (n <= 1) {
            System.out.println(0);
            return;
        }

        List<Edge> edges = new ArrayList<>();
        for (int i = 0; i < n; i++) {
            for (int j = i + 1; j < n; j++) {
                long dx = points[i].x - points[j].x;
                long dy = points[i].y - points[j].y;
                edges.add(new Edge(i, j, dx * dx + dy * dy));
            }
        }

        Collections.sort(edges);

        DSU dsu = new DSU(n);
        long maxEdgeWeight = 0;
        int edgesCount = 0;
        for (Edge edge : edges) {
            if (dsu.find(edge.u) != dsu.find(edge.v)) {
                dsu.unite(edge.u, edge.v);
                maxEdgeWeight = edge.weight;
                edgesCount++;
                if (edgesCount == n - 1) {
                    break;
                }
            }
        }

        double time = Math.sqrt(maxEdgeWeight) / 2.0;
        System.out.println((long) Math.ceil(time));
    }
}
import sys
import math

class DSU:
    def __init__(self, n):
        self.parent = list(range(n))

    def find(self, i):
        if self.parent[i] == i:
            return i
        self.parent[i] = self.find(self.parent[i])
        return self.parent[i]

    def unite(self, i, j):
        root_i = self.find(i)
        root_j = self.find(j)
        if root_i != root_j:
            self.parent[root_i] = root_j
            return True
        return False

def solve():
    n = int(sys.stdin.readline())
    if n <= 1:
        print(0)
        return
        
    points = []
    for _ in range(n):
        points.append(list(map(int, sys.stdin.readline().split())))

    edges = []
    for i in range(n):
        for j in range(i + 1, n):
            dx = points[i][0] - points[j][0]
            dy = points[i][1] - points[j][1]
            dist_sq = dx * dx + dy * dy
            edges.append((dist_sq, i, j))

    edges.sort()

    dsu = DSU(n)
    max_edge_weight = 0
    edges_count = 0
    for weight, u, v in edges:
        if dsu.unite(u, v):
            max_edge_weight = weight
            edges_count += 1
            if edges_count == n - 1:
                break
    
    time = math.sqrt(max_edge_weight) / 2.0
    print(math.ceil(time))

solve()

算法及复杂度

  • 算法:最小生成树 (Kruskal 算法) + 并查集
  • 时间复杂度。其中 是城市数量。主要开销在于:
    1. 生成所有 条边需要 的时间。
    2. 对这 条边进行排序,需要 的时间。
    3. Kruskal 算法本身,遍历所有边并执行并查集操作,需要 的时间,其中 是反阿克曼函数,增长极其缓慢,可视为常数。 因此,总时间复杂度的瓶颈是排序,为
  • 空间复杂度。主要用于存储所有 条边的列表。并查集需要 的空间。