题目链接

【模板】最小生成树

题目描述

给定一个包含 个顶点和 条边的无向连通图,边权均为整数。你需要求出该图的最小生成树(MST)。

输出要求:

  1. 最小生成树的所有边的权值之和。
  2. 构成最小生成树的边的原始编号(从 1 开始)。若有多种方案,输出任意一种。

解题思路

本题是求解无向连通图的最小生成树,是经典的图论问题。常用的算法有 KruskalPrim。考虑到本题需要输出构成 MST 的边的编号,使用 Kruskal 算法会更直观方便。

Kruskal 算法

Kruskal 算法是一种基于贪心思想的算法,其核心步骤如下:

  1. 边的排序:将图中所有的 条边按照权重从小到大进行排序。
  2. 并查集初始化:创建一个并查集(Disjoint Set Union, DSU),其中每个顶点初始时都各自属于一个独立的集合。
  3. 遍历与合并:按顺序遍历排序后的边。对于每一条边
    • 使用并查集的 find 操作检查顶点 是否已经属于同一个连通分量(即 find(u) == find(v))。
    • 如果它们在同一个连通分量中,说明加入这条边不会形成环路。此时,我们将这条边加入最小生成树中,并使用并查集的 unite 操作将 所在的集合合并。
    • 如果它们已经在同一个连通分量中,则跳过这条边,因为它会形成环路。
  4. 终止条件:当最小生成树中的边数达到 时,算法结束。因为一个包含 个顶点的树有且仅有 条边。

数据结构

  • 边结构体:需要存储边的两个端点 ,权重 ,以及它的原始输入编号 id
  • 并查集:用于维护图的连通分量,高效地判断任意两个顶点是否连通,以避免形成环路。它主要包含两个操作:
    • find(i): 查找元素 所在集合的根节点。为了提高效率,通常会使用路径压缩优化。
    • unite(i, j): 合并元素 所在的集合。

算法执行过程中,我们累加被选入 MST 的边的权重,并记录它们的编号,最后按要求输出即可。

代码

#include <iostream>
#include <vector>
#include <algorithm>
#include <numeric>

using namespace std;

struct Edge {
    int u, v, w, id;
};

bool compareEdges(const Edge& a, const Edge& b) {
    return a.w < b.w;
}

struct DSU {
    vector<int> parent;
    DSU(int n) {
        parent.resize(n + 1);
        iota(parent.begin(), parent.end(), 0);
    }

    int find(int i) {
        if (parent[i] == i) {
            return i;
        }
        return parent[i] = find(parent[i]);
    }

    void unite(int i, int j) {
        int root_i = find(i);
        int root_j = find(j);
        if (root_i != root_j) {
            parent[root_i] = root_j;
        }
    }
};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int n, m;
    cin >> n >> m;

    vector<Edge> edges(m);
    for (int i = 0; i < m; ++i) {
        cin >> edges[i].u >> edges[i].v >> edges[i].w;
        edges[i].id = i + 1;
    }

    sort(edges.begin(), edges.end(), compareEdges);

    DSU dsu(n);
    long long total_weight = 0;
    vector<int> mst_edge_ids;
    int edge_count = 0;

    for (const auto& edge : edges) {
        if (dsu.find(edge.u) != dsu.find(edge.v)) {
            dsu.unite(edge.u, edge.v);
            total_weight += edge.w;
            mst_edge_ids.push_back(edge.id);
            edge_count++;
            if (edge_count == n - 1) {
                break;
            }
        }
    }

    cout << total_weight << "\n";
    for (int i = 0; i < mst_edge_ids.size(); ++i) {
        cout << mst_edge_ids[i] << (i == mst_edge_ids.size() - 1 ? "" : " ");
    }
    cout << "\n";

    return 0;
}
import java.util.Scanner;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

class Main {
    static class Edge implements Comparable<Edge> {
        int u, v, w, id;

        public Edge(int u, int v, int w, int id) {
            this.u = u;
            this.v = v;
            this.w = w;
            this.id = id;
        }

        @Override
        public int compareTo(Edge other) {
            return Integer.compare(this.w, other.w);
        }
    }

    static class DSU {
        private int[] parent;

        public DSU(int n) {
            parent = new int[n + 1];
            for (int i = 0; i <= n; i++) {
                parent[i] = i;
            }
        }

        public int find(int i) {
            if (parent[i] == i) {
                return i;
            }
            return parent[i] = find(parent[i]);
        }

        public void unite(int i, int j) {
            int root_i = find(i);
            int root_j = find(j);
            if (root_i != root_j) {
                parent[root_i] = root_j;
            }
        }
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();

        List<Edge> edges = new ArrayList<>();
        for (int i = 0; i < m; i++) {
            int u = sc.nextInt();
            int v = sc.nextInt();
            int w = sc.nextInt();
            edges.add(new Edge(u, v, w, i + 1));
        }

        Collections.sort(edges);

        DSU dsu = new DSU(n);
        long totalWeight = 0;
        List<Integer> mstEdgeIds = new ArrayList<>();
        int edgeCount = 0;

        for (Edge edge : edges) {
            if (dsu.find(edge.u) != dsu.find(edge.v)) {
                dsu.unite(edge.u, edge.v);
                totalWeight += edge.w;
                mstEdgeIds.add(edge.id);
                edgeCount++;
                if (edgeCount == n - 1) {
                    break;
                }
            }
        }

        System.out.println(totalWeight);
        for (int i = 0; i < mstEdgeIds.size(); i++) {
            System.out.print(mstEdgeIds.get(i) + (i == mstEdgeIds.size() - 1 ? "" : " "));
        }
        System.out.println();
    }
}
class DSU:
    def __init__(self, n):
        self.parent = list(range(n + 1))

    def find(self, i):
        if self.parent[i] == i:
            return i
        self.parent[i] = self.find(self.parent[i])
        return self.parent[i]

    def unite(self, i, j):
        root_i = self.find(i)
        root_j = self.find(j)
        if root_i != root_j:
            self.parent[root_i] = root_j
            return True
        return False

def main():
    n, m = map(int, input().split())
    
    edges = []
    for i in range(m):
        u, v, w = map(int, input().split())
        edges.append((u, v, w, i + 1))
        
    edges.sort(key=lambda x: x[2])
    
    dsu = DSU(n)
    total_weight = 0
    mst_edge_ids = []
    edge_count = 0
    
    for u, v, w, id in edges:
        if dsu.unite(u, v):
            total_weight += w
            mst_edge_ids.append(id)
            edge_count += 1
            if edge_count == n - 1:
                break
    
    print(total_weight)
    print(*mst_edge_ids)

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:Kruskal 算法 + 并查集
  • 时间复杂度:。算法的瓶颈在于对 条边进行排序,其时间复杂度为 。之后遍历边的过程,并查集的 findunite 操作的平均时间复杂度接近于常数(即 ,其中 是反阿克曼函数,增长极其缓慢),因此总的遍历复杂度为 。综上,总时间复杂度为
  • 空间复杂度:。用于存储 条边和并查集的父节点数组。