Kruskal's algorithm: Java implementation


import java.util.Queue;
import java.util.ArrayDeque;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.HashSet;
import java.util.List;
import java.util.ArrayList;

class KruskalMinimumSpanningTree {
    private Queue mst = new ArrayDeque();
    private double weight;
    
    public KruskalMinimumSpanningTree(EdgeWeightedGraph G) {
        // build priority queue
        Queue pq = new PriorityQueue();
        for (Edge e : G.edges()) {
            pq.add(e);
        }
        UnionFind uf = new UnionFind(G.V());
        while (!pq.isEmpty() && mst.size() < G.V()-1) {
            Edge e = pq.poll();
            // greedily add edges to MST
            int v = e.either(), w = e.other(v);
            if (!uf.isConnected(v, w)) { // Check that edge v–w does not create cycle
                // merge sets
                uf.union(v, w);
                // add edge to MST
                mst.add(e);
                // add weight in total weight count 
                weight += e.weight();
            }
        }
    }
    
    public Iterable edges() { 
        return mst;
    }
    
    public double weight() { 
        return weight;
    }
    
    static class EdgeWeightedGraph {
        private final int V;
        private final List[] adj;
        
        /** create an empty graph with V vertices */
        public EdgeWeightedGraph(int V) {
            this.V = V;
            adj = new ArrayList[V]; // // create empty graph with V vertices
            for (int v = 0; v < V; v++) {
                adj[v] = new ArrayList<>();
            }
        }
        
        /** add weighted edge e to this graph */
        public void addEdge(Edge e) {
            int v = e.either(), w = e.other(v);
            adj[v].add(e);
            adj[w].add(e);
        }
        
        /** Edges adjcent to v */
        public Iterable adj(int v) { 
            return adj[v]; 
        }
        
        /** all edges in this graph */
        Iterable edges() {
            Set allEdges = new HashSet<>();
            for(List list : adj) {
                allEdges.addAll(list);
            }
            return allEdges;
        }
        
        /** number of vertices */
        public int V()  {
            return V;
        }
        
        /** number of edges */
        int E()  {
            int edges = 0;
            for(List edge : adj) {
                edges += edge.size();
            }
            return edges/2; // Since each edge counted twice
        }
    }
    
    static class Edge implements Comparable {
        private final int v;
        private final int w;
        private final double weight;
        
        /** create a weighted edge v-w */
        Edge(int v, int w, double weight) {
            this.v = v;
            this.w = w;
            this.weight = weight;
        }
        
        /** either endpoint */
        int either() {
            return v;
        }
        
        /** the endpoint that's not v */
        int other(int v) {
            if(this.v == v) {
                return w;
            }
            return v;
        }
        
        /** compare this edge to that edge */
        public int compareTo(Edge that) {
            if(this.weight < that.weight) {
                return -1;
            } else if(this.weight > that.weight) {
                return 1;
            }
            return 0;
        }
        
        /** the weight */
        double weight() {
            return weight;
        }
        
        /** string representation */
        public String toString() {
            return "Edge: " + v + "-" + w + ", Weight: " + weight;
        }
    }
    
    static class UnionFind {
        private final int[] componentTracker;
        
        public UnionFind(int size) {
            this.componentTracker = new int[size];
            // -1 in each index represent that each node is unconnected
            for(int i = 0; i < componentTracker.length; i++) {
                componentTracker[i] = -1;
            }
        }
        
        public boolean isConnected(int v, int w) {
            // it will return true if both v and w have same parent
            return find(v).getParent() == find(w).getParent();
        }
        
        public void union(int v, int w) {
            FindResult resultOfV = find(v);
            FindResult resultOfW = find(w);
            if(resultOfV.getRank() > resultOfW.getRank()) {
                componentTracker[resultOfW.getParent()] = resultOfV.getParent();
            } else {
                componentTracker[resultOfV.getParent()] = resultOfW.getParent();
            }
        }
        
        private FindResult find(int v) {
            int rank = 0;
            while(componentTracker[v] != -1) {
                v = componentTracker[v];
                rank++;
            }
            return new FindResult(v, rank);
        }
        
        class FindResult {
            int parent;
            int rank;
            FindResult(int parent, int rank) {
                this.parent = parent;
                this.rank = rank;
            }
            
            int getParent() { 
                return parent; 
            }
            
            int getRank() { 
                return rank; 
            }
        }
    }
    
    private static EdgeWeightedGraph dummyGraph() {
        String[][] input = {
            {"0", "7", "0.16"},
            {"2", "3", "0.17"},
            {"1", "7", "0.19"},
            {"0", "2", "0.26"},
            {"5", "7", "0.28"},
            {"1", "3", "0.29"},
            {"1", "5", "0.32"},
            {"2", "7", "0.34"},
            {"4", "5", "0.35"},
            {"1", "2", "0.36"},
            {"4", "7", "0.37"},
            {"0", "4", "0.38"},
            {"6", "2", "0.40"},
            {"3", "6", "0.52"},
            {"6", "0", "0.58"},
            {"6", "4", "0.93"}
        };
        EdgeWeightedGraph G = new EdgeWeightedGraph(8);
        for(int i = 0; i < input.length; i++) {
            int v = Integer.parseInt(input[i][0]);
            int w = Integer.parseInt(input[i][1]);
            double weight = Double.parseDouble(input[i][2]);
            Edge edge = new Edge(v, w, weight);
            G.addEdge(edge);
        }
        return G;
    }
    
	public static void main (String[] args) {
		KruskalMinimumSpanningTree kruskalMinimumSpanningTree = 
		    new KruskalMinimumSpanningTree(dummyGraph());
		    System.out.println("KruskalMinimumSpanningTree: Weight = " 
		    + kruskalMinimumSpanningTree.weight());
		System.out.println("MST edges: ");
		kruskalMinimumSpanningTree.edges()
		    .forEach(edge -> 
		        System.out.print(edge.either() + "-" + edge.other(edge.either()) + ", "));
	}
}

Output
KruskalMinimumSpanningTree: Weight = 1.81
MST edges: 
0-7, 2-3, 1-7, 0-2, 5-7, 4-5, 6-2, 

Comments

Popular posts from this blog

SQL basic interview question

gsutil Vs Storage Transfer Service Vs Transfer Appliance