Minimum Spanning Tree
Minimum Spanning Tree¶
Overview¶
A Minimum Spanning Tree (MST) is a tree that connects all vertices of a graph while minimizing the sum of edge weights. We'll learn about Kruskal and Prim algorithms, as well as the Union-Find data structure.
Table of Contents¶
1. MST Concept¶
Spanning Tree¶
Spanning Tree: A subgraph that includes all vertices
with no cycles
Conditions:
- Vertices: V
- Edges: V-1
- All vertices connected
- No cycles
Minimum Spanning Tree (MST)¶
MST: A spanning tree with minimum sum of edge weights
(1)āā4āā(2)
āā² āā²
2 1 5 3
ā ā² ā ā²
(3)āā6āā(4)āā7āā(5)
MST (total weight: 11):
(1)āā4āā(2)
ā² ā²
1 3
ā² ā²
(4)āāāāāā(5)
ā
2(connected to 3, not shown properly in diagram)
Actual MST:
(1)-1-(4), (1)-2-(3), (2)-4-(1), (2)-3-(5)
ā 1+2+4+3 = 10? Or other combination
MST Properties¶
1. Cut Property: When dividing a graph into two sets,
the minimum weight edge crossing the cut is in the MST
2. Cycle Property: The maximum weight edge in a cycle
is not in the MST
3. Uniqueness: If all edge weights are distinct, the MST is unique
2. Union-Find (Disjoint Set Union)¶
Concept¶
Disjoint Sets: Sets with no common elements
Operations:
- Find(x): Returns the representative element of the set containing x
- Union(x, y): Merges the sets containing x and y
Use Cases:
- Cycle detection
- Connected component management
- Kruskal's algorithm
Basic Implementation¶
// C
#define MAX_N 100001
int parent[MAX_N];
void init(int n) {
for (int i = 0; i < n; i++) {
parent[i] = i; // Each element is its own parent
}
}
int find(int x) {
if (parent[x] == x) {
return x;
}
return find(parent[x]);
}
void unite(int x, int y) {
int px = find(x);
int py = find(y);
if (px != py) {
parent[px] = py;
}
}
Optimization 1: Path Compression¶
Connect all nodes on the path directly to the root during Find
(5) (5)
ā /|\
(3) ā (1)(2)(3)
/ā ā
(1)(2) (4)
ā
(4)
Time Complexity: Nearly O(1) (Amortized)
// C - Path Compression
int find(int x) {
if (parent[x] != x) {
parent[x] = find(parent[x]); // Recursively connect to root
}
return parent[x];
}
Optimization 2: Union by Rank¶
Attach smaller tree to larger tree
Tree1 (Rank 2) Tree2 (Rank 1)
(a) (b)
/ ā \ ā
(c)(d)(e) (f)
After union:
(a)
/āā² \
(c)(d)(e)(b)
ā
(f)
// C - Path Compression + Union by Rank
int parent[MAX_N];
int rank_arr[MAX_N];
void init(int n) {
for (int i = 0; i < n; i++) {
parent[i] = i;
rank_arr[i] = 0;
}
}
int find(int x) {
if (parent[x] != x) {
parent[x] = find(parent[x]);
}
return parent[x];
}
void unite(int x, int y) {
int px = find(x);
int py = find(y);
if (px == py) return;
// Attach smaller rank tree to larger rank tree
if (rank_arr[px] < rank_arr[py]) {
parent[px] = py;
} else if (rank_arr[px] > rank_arr[py]) {
parent[py] = px;
} else {
parent[py] = px;
rank_arr[px]++;
}
}
C++/Python Implementation¶
// C++
class UnionFind {
private:
vector<int> parent, rank_;
public:
UnionFind(int n) : parent(n), rank_(n, 0) {
iota(parent.begin(), parent.end(), 0);
}
int find(int x) {
if (parent[x] != x) {
parent[x] = find(parent[x]);
}
return parent[x];
}
bool unite(int x, int y) {
int px = find(x), py = find(y);
if (px == py) return false;
if (rank_[px] < rank_[py]) swap(px, py);
parent[py] = px;
if (rank_[px] == rank_[py]) rank_[px]++;
return true;
}
bool connected(int x, int y) {
return find(x) == find(y);
}
};
# Python
class UnionFind:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x, y):
px, py = self.find(x), self.find(y)
if px == py:
return False
if self.rank[px] < self.rank[py]:
px, py = py, px
self.parent[py] = px
if self.rank[px] == self.rank[py]:
self.rank[px] += 1
return True
def connected(self, x, y):
return self.find(x) == self.find(y)
3. Kruskal Algorithm¶
Concept¶
Sort edges by weight and select edges that don't create cycles
Principle:
1. Sort all edges in ascending order by weight
2. Select from the smallest edge
3. Skip if it creates a cycle (check with Union-Find)
4. Stop when V-1 edges are selected
Time Complexity: O(E log E)
Example¶
Graph:
(0)āā7āā(1)
āā² ā±ā
5 8 9 7
ā ā²ā± ā
(2)āā5āā(3)
Sorted edges: (2,3,5), (0,2,5), (0,1,7), (1,3,7), (0,3,8), (1,2,9)
Selection process:
1. (2,3,5) selected ā No cycle ā
2. (0,2,5) selected ā No cycle ā
3. (0,1,7) selected ā No cycle ā
4. V-1=3 edges selected, complete
MST: (2,3), (0,2), (0,1)
Total weight: 5+5+7 = 17
Implementation¶
// C
#define MAX_E 100001
typedef struct {
int u, v, weight;
} Edge;
Edge edges[MAX_E];
int parent[MAX_E];
int cmp(const void* a, const void* b) {
return ((Edge*)a)->weight - ((Edge*)b)->weight;
}
int find(int x) {
if (parent[x] != x)
parent[x] = find(parent[x]);
return parent[x];
}
int kruskal(int V, int E) {
// Initialize
for (int i = 0; i < V; i++)
parent[i] = i;
// Sort
qsort(edges, E, sizeof(Edge), cmp);
int mstWeight = 0;
int edgeCount = 0;
for (int i = 0; i < E && edgeCount < V - 1; i++) {
int pu = find(edges[i].u);
int pv = find(edges[i].v);
if (pu != pv) {
parent[pu] = pv;
mstWeight += edges[i].weight;
edgeCount++;
}
}
return mstWeight;
}
// C++
struct Edge {
int u, v, weight;
bool operator<(const Edge& other) const {
return weight < other.weight;
}
};
int kruskal(int V, vector<Edge>& edges) {
sort(edges.begin(), edges.end());
UnionFind uf(V);
int mstWeight = 0;
int edgeCount = 0;
for (const auto& e : edges) {
if (edgeCount >= V - 1) break;
if (uf.unite(e.u, e.v)) {
mstWeight += e.weight;
edgeCount++;
}
}
return mstWeight;
}
# Python
def kruskal(V, edges):
edges.sort(key=lambda x: x[2]) # Sort by weight
uf = UnionFind(V)
mst_weight = 0
edge_count = 0
for u, v, w in edges:
if edge_count >= V - 1:
break
if uf.union(u, v):
mst_weight += w
edge_count += 1
return mst_weight
4. Prim Algorithm¶
Concept¶
Gradually expand the MST starting from a starting vertex
Principle:
1. Start from an arbitrary vertex
2. Among edges going out from vertices in the MST,
select the edge with the smallest weight
3. Add the new vertex to the MST
4. Stop when all vertices are included
Time Complexity:
- Priority Queue: O(E log V)
- Adjacency Matrix: O(V²)
Example¶
Graph (starting from 0):
(0)āā7āā(1)
āā² ā±ā
5 8 9 7
ā ā²ā± ā
(2)āā5āā(3)
Steps:
1. Start: MST = {0}
Adjacent edges: (0,1,7), (0,2,5), (0,3,8)
Select: (0,2,5) ā MST = {0,2}
2. Adjacent edges: (0,1,7), (0,3,8), (2,3,5)
Select: (2,3,5) ā MST = {0,2,3}
3. Adjacent edges: (0,1,7), (3,1,7)
Select: (0,1,7) or (3,1,7) ā MST = {0,1,2,3}
Result: Total weight = 5+5+7 = 17
Implementation (Priority Queue)¶
// C++
int prim(int V, const vector<vector<pair<int,int>>>& adj) {
vector<bool> inMST(V, false);
// {weight, vertex}
priority_queue<pair<int,int>, vector<pair<int,int>>, greater<>> pq;
int mstWeight = 0;
pq.push({0, 0}); // Starting vertex
while (!pq.empty()) {
auto [w, u] = pq.top();
pq.pop();
if (inMST[u]) continue;
inMST[u] = true;
mstWeight += w;
for (auto [v, weight] : adj[u]) {
if (!inMST[v]) {
pq.push({weight, v});
}
}
}
return mstWeight;
}
# Python
import heapq
def prim(V, adj):
in_mst = [False] * V
pq = [(0, 0)] # (weight, vertex)
mst_weight = 0
while pq:
w, u = heapq.heappop(pq)
if in_mst[u]:
continue
in_mst[u] = True
mst_weight += w
for v, weight in adj[u]:
if not in_mst[v]:
heapq.heappush(pq, (weight, v))
return mst_weight
Implementation (Adjacency Matrix, V²)¶
// C++ - Better for dense graphs
int primMatrix(int V, const vector<vector<int>>& adj) {
vector<int> key(V, INT_MAX);
vector<bool> inMST(V, false);
key[0] = 0;
int mstWeight = 0;
for (int count = 0; count < V; count++) {
// Select vertex with minimum key value
int u = -1;
for (int v = 0; v < V; v++) {
if (!inMST[v] && (u == -1 || key[v] < key[u])) {
u = v;
}
}
inMST[u] = true;
mstWeight += key[u];
// Update key values of adjacent vertices
for (int v = 0; v < V; v++) {
if (adj[u][v] && !inMST[v] && adj[u][v] < key[v]) {
key[v] = adj[u][v];
}
}
}
return mstWeight;
}
5. Algorithm Comparison¶
Kruskal vs Prim¶
āāāāāāāāāāāāāāā¬āāāāāāāāāāāāāāāāāāā¬āāāāāāāāāāāāāāāāāāā
ā ā Kruskal ā Prim ā
āāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāā¤
ā Approach ā Edge-centric ā Vertex-centric ā
ā Data Struct ā Union-Find ā Priority Queue ā
ā Time ā O(E log E) ā O(E log V) ā
ā Best for ā Sparse graphs ā Dense graphs ā
ā Complexity ā Relatively simpleā Relatively complexā
āāāāāāāāāāāāāāā“āāāāāāāāāāāāāāāāāāā“āāāāāāāāāāāāāāāāāāā
Selection Criteria¶
Sparse graphs (E ā V): Kruskal is better
Dense graphs (E ā V²): Prim is better
Edge list input: Kruskal is better
Adjacency list input: Prim is better
6. Practice Problems¶
Problem 1: Minimum Spanning Tree¶
Find the total weight of the MST for the given graph.
Solution Code
def solution(V, edges):
# Kruskal
edges.sort(key=lambda x: x[2])
uf = UnionFind(V)
total = 0
count = 0
for u, v, w in edges:
if count >= V - 1:
break
if uf.union(u, v):
total += w
count += 1
return total
Problem 2: City Division Plan¶
Divide N villages into 2 groups and connect each group with minimum cost.
Hint
After constructing the MST, remove the largest edge to create 2 groupsSolution Code
def divide_villages(V, edges):
edges.sort(key=lambda x: x[2])
uf = UnionFind(V)
mst_edges = []
for u, v, w in edges:
if uf.union(u, v):
mst_edges.append(w)
if len(mst_edges) == V - 1:
break
# Remove the largest edge
return sum(mst_edges) - max(mst_edges)
Recommended Problems¶
| Difficulty | Problem | Platform | Algorithm |
|---|---|---|---|
| āā | Minimum Spanning Tree | BOJ | Kruskal/Prim |
| āā | Sanggeun's Travel | BOJ | MST Concept |
| āāā | City Division Plan | BOJ | MST Application |
| āāā | Network Connection | BOJ | MST |
| āāā | Min Cost to Connect | LeetCode | Prim |
Template Summary¶
Union-Find¶
class UnionFind:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x, y):
px, py = self.find(x), self.find(y)
if px == py:
return False
if self.rank[px] < self.rank[py]:
px, py = py, px
self.parent[py] = px
if self.rank[px] == self.rank[py]:
self.rank[px] += 1
return True
Kruskal¶
def kruskal(V, edges):
edges.sort(key=lambda x: x[2])
uf = UnionFind(V)
total = 0
for u, v, w in edges:
if uf.union(u, v):
total += w
return total
Prim¶
def prim(V, adj):
in_mst = [False] * V
pq = [(0, 0)]
total = 0
while pq:
w, u = heapq.heappop(pq)
if in_mst[u]:
continue
in_mst[u] = True
total += w
for v, weight in adj[u]:
if not in_mst[v]:
heapq.heappush(pq, (weight, v))
return total
Next Steps¶
- 16_LCA_and_Tree_Queries.md - LCA, Tree Queries
References¶
- MST Visualization
- Union-Find Tutorial
- Introduction to Algorithms (CLRS) - Chapter 23