Kruskal's Algorithm — Minimum Spanning Tree with Union-Find

Sanjeev SharmaSanjeev Sharma
2 min read

Advertisement

Problem

Given a weighted undirected graph, find the MST — the tree connecting all vertices with minimum total edge weight.


Algorithm

  1. Sort all edges by weight
  2. For each edge (u, v, w) in sorted order:
    • If find(u) != find(v): add edge, union(u, v)
  3. Stop when V-1 edges added

Solutions

Python

def kruskal(n, edges):
    edges.sort(key=lambda x: x[2])
    par = list(range(n))
    rank = [0]*n

    def find(x):
        if par[x] != x: par[x] = find(par[x])
        return par[x]

    def union(x, y):
        px, py = find(x), find(y)
        if px == py: return False
        if rank[px] < rank[py]: px, py = py, px
        par[py] = px
        if rank[px] == rank[py]: rank[px] += 1
        return True

    mst_cost, mst_edges = 0, []
    for u, v, w in edges:
        if union(u, v):
            mst_cost += w
            mst_edges.append((u, v, w))
            if len(mst_edges) == n-1: break
    return mst_cost, mst_edges

JavaScript

function kruskal(n, edges) {
    edges.sort((a,b)=>a[2]-b[2]);
    const par=Array.from({length:n},(_,i)=>i), rank=new Array(n).fill(0);
    const find=x=>par[x]===x?x:par[x]=find(par[x]);
    const union=(x,y)=>{
        let [px,py]=[find(x),find(y)]; if(px===py)return false;
        if(rank[px]<rank[py])[px,py]=[py,px];
        par[py]=px; if(rank[px]===rank[py])rank[px]++; return true;
    };
    let cost=0; const mst=[];
    for(const[u,v,w]of edges) if(union(u,v)){cost+=w;mst.push([u,v,w]);if(mst.length===n-1)break;}
    return[cost,mst];
}

Java

import java.util.*;
public int kruskal(int n, int[][] edges) {
    Arrays.sort(edges, (a,b)->a[2]-b[2]);
    int[] par=new int[n], rank=new int[n];
    for(int i=0;i<n;i++) par[i]=i;
    int cost=0, cnt=0;
    for(int[] e:edges){
        int pu=find(par,e[0]),pv=find(par,e[1]);
        if(pu!=pv){
            if(rank[pu]<rank[pv]){int t=pu;pu=pv;pv=t;}
            par[pv]=pu; if(rank[pu]==rank[pv])rank[pu]++;
            cost+=e[2]; if(++cnt==n-1)break;
        }
    }
    return cost;
}
int find(int[] p,int x){return p[x]==x?x:p[x]=find(p,p[x]);}

C++

#include <vector>
#include <algorithm>
using namespace std;
int par[100001], rnk[100001];
int find(int x){return par[x]==x?x:par[x]=find(par[x]);}
bool unite(int x,int y){int px=find(x),py=find(y);if(px==py)return false;if(rnk[px]<rnk[py])swap(px,py);par[py]=px;if(rnk[px]==rnk[py])rnk[px]++;return true;}
int kruskal(int n,vector<array<int,3>>& edges){
    sort(edges.begin(),edges.end(),[](auto&a,auto&b){return a[2]<b[2];});
    for(int i=0;i<n;i++){par[i]=i;rnk[i]=0;}
    int cost=0,cnt=0;
    for(auto&e:edges) if(unite(e[0],e[1])){cost+=e[2];if(++cnt==n-1)break;}
    return cost;
}

C

int par[100001],rnk[100001];
int find(int x){return par[x]==x?x:(par[x]=find(par[x]));}
int unite(int x,int y){int px=find(x),py=find(y);if(px==py)return 0;if(rnk[px]<rnk[py]){int t=px;px=py;py=t;}par[py]=px;if(rnk[px]==rnk[py])rnk[px]++;return 1;}

Complexity

StepTime
Sort edgesO(E log E)
Union-Find opsO(E α(V))
TotalO(E log E)

Advertisement

Sanjeev Sharma

Written by

Sanjeev Sharma

Full Stack Engineer · E-mopro