Prim和Kruskal两种方法
// Prim
// Time: O(MlogM), each edge is inserted to minHeap once.
// Spcae: O(n)
import java.util.*;
public class Solution {
public int miniSpanningTree (int n, int m, int[][] cost) {
// u -> [[v1, c1], [v2, c2], ...]
List<List<int[]>> graph = new ArrayList<>();
for (int i = 0; i <= n; i++) graph.add(new ArrayList<>()); // 0 is dummy
for (int[] c : cost) {
graph.get(c[0]).add(new int[]{c[1], c[2]});
graph.get(c[1]).add(new int[]{c[0], c[2]});
}
// [u, v, cost]
PriorityQueue<int[]> minEdgeHeap = new PriorityQueue<>((a,b)->a[2]-b[2]);
for (int[] edges : graph.get(1)) { // add all edges connecting 1
minEdgeHeap.offer(new int[]{1, edges[0], edges[1]});
}
boolean[] visited = new boolean[n+1];
visited[1] = true; // mark 1 as visited
int totalCost = 0;
while (!minEdgeHeap.isEmpty()) {
int[] minEdge = minEdgeHeap.poll();
int from = minEdge[0], to = minEdge[1], c = minEdge[2];
if (visited[to]) continue; // already in MST
visited[to] = true;
totalCost += c;
for (int[] nei : graph.get(to)) {
if (visited[nei[0]]) continue; // neighbor alrady in MST
minEdgeHeap.offer(new int[]{to, nei[0], nei[1]});
}
}
return totalCost;
}
}
// Krsukal
// Time: O(MlogM + M) = O(MLogM)
// sort M edges takes MlogM, union N nodes takes ~O(M)
// Space: O(N)
import java.util.*;
public class Solution {
class DSU {
int[] rank;
int[] root;
DSU(int size) {
this.rank = new int[size];
this.root = new int[size];
for (int i = 0; i < size; i++) {
this.root[i] = i; // init with self as parent
}
}
// find with path compression
int findRoot(int id) {
if (root[id] != id) {
root[id] = findRoot(root[id]);
}
return root[id];
}
// union by rank
void union(int id_a, int id_b) {
int root_a = findRoot(id_a);
int root_b = findRoot(id_b);
if (rank[root_a] == rank[root_b]) {
root[root_b] = root_a;
rank[root_a]++;
} else if (rank[root_a] > rank[root_b]){
root[root_b] = root_a;
} else {
root[root_a] = root_b;
}
}
}
public int miniSpanningTree (int n, int m, int[][] cost) {
DSU dsu = new DSU(n);
Arrays.sort(cost, (a, b) -> a[2] - b[2]);
int minCost = 0;
for (int[] c : cost) {
// nodes are indexed from 1 to n, so need to minus 1
int root_a = dsu.findRoot(c[0]-1);
int root_b = dsu.findRoot(c[1]-1);
if (root_a == root_b) continue; // already connected
dsu.union(root_a, root_b);
minCost += c[2];
}
return minCost;
}
}