본문 바로가기
알고리즘/알고리즘 종류

[MST] 최소 스패닝 트리 - 크루스칼(Kruskal), 프림(Prim)

by D.O.T 2024. 7. 11.
MST란?

 

최소 신장 트리(Minimum Spanning Tree)로 가중 그래프(Weight Graph)에서 가장 적은 비용(Weight)으로 모든 정점(Vertex)으로 이동할 수 있는 부분 그래프(Sub Graph)이다. MST는 각 정점에서 정점 사이에 하나의 간선(Edge)만 필요하므로 총 N-1 개의 간선을 가지게 되므로 마치 트리의 형태를 하고 있어서 트리라는 이름이 붙는다.


크루스칼 알고리즘

 

1. 모든 간선을 비용이 적은 순서를 기준으로 정렬한다. - Sort

2. 간선에 붙은 정점 u와 v를 기준으로 부모를 찾는다. - Find

3. 정점 u의 부모와 정점 v의 부모가 다르다면 서로 다른 집합에 속해있지만, 1번 과정으로 인해 해당 간선이 두 정점의 최소 비용의 간선임을 알 수 있으므로 하나의 집합으로 합친다. - Union

 

+ Union 과정에서 이어진 간선이 최소 비용이므로 해당 간선 정보로 MST의 총 비용이나 경로를 알 수 있다!

+ 1번 Sort하는 과정을 배열로 할 수도 있지만, Priority Queue로도 할 수 있다.

 

크루스칼 노드
package com.study.datastructrue.graph.mst.kruskal;

public class KruskalNode implements Comparable<KruskalNode> {

    int u;
    int v;
    int weight;

    public KruskalNode(int u, int v, int weight) {
        this.u = u;
        this.v = v;
        this.weight = weight;
    }

    @Override
    public int compareTo(KruskalNode o) {
        return this.weight - o.weight;
    }

}

정렬을 위해 Comparable 인터페이스의 메소드를 구현해준다.

 

Kruskal Algorithm - Array
package com.study.datastructrue.graph.mst.kruskal;

import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;

public class ArrayKruskal {

    private List<KruskalNode> edges;
    private int[] parent, ranks;
    private int minimumDistance;
    private Set<Integer> path;

    public ArrayKruskal(int size) {
        this.edges = new ArrayList<>();
        this.parent = new int[size + 1];
        this.ranks = new int[size + 1];
        this.minimumDistance = Integer.MAX_VALUE;
        this.path = new LinkedHashSet<>();
        initParent(size + 1);
    }

    public void update() {
        this.minimumDistance = 0;
        Collections.sort(this.edges);
        for (KruskalNode node : this.edges) {
            int parentA = find(node.u);
            int parentB = find(node.v);
            if (parentA != parentB) {
                union(parentA, parentB);
                this.minimumDistance += node.weight;
                this.path.add(node.u);
                this.path.add(node.v);
            }
        }
    }

    public int getMinimumDistance() {
        return minimumDistance;
    }

    public void printTrace() {
        System.out.println(path.toString());
    }

    public void add(int u, int v, int w) {
        this.edges.add(new KruskalNode(u, v, w));
    }

    private void initParent(int size) {
        for (int i = 0; i < size; i++) {
            this.parent[i] = i;
            this.ranks[i] = 0; // rank 초기화
        }
    }

    private int find(int x) {
        if (x != parent[x]) {
            parent[x] = find(parent[x]);
        }
        return parent[x];
    }

    private void union(int x, int y) {
        if (ranks[x] < ranks[y]) {
            int temp = ranks[x];
            ranks[x] = ranks[y];
            ranks[y] = temp;
        }

        parent[y] = x;

        if(ranks[x] == ranks[y]) {
            ranks[x] = ranks[y] + 1;
        }
    }

}

1. find 연산은 부모를 찾을 때까지 recursion을 진행한다. (성능 개선, 경로압축 기법 적용)

2. 정점 u와 정점 v가 다를 경우 Union 연산이 진행된다. 이 때, Rank로 트리의 높이를 저장하면서 각 tree의 계층을 표현할 수 있다. (Union 연산의 개선)

3. Set 자료구조를 이용해 Union 연산마다 매 노드를 추가해 중복된 노드는 무시하고 경로를 구할 수 있다.

 

Kruskal Algorithm - Priority Queue
package com.study.datastructrue.graph.mst.kruskal;

import java.util.LinkedHashSet;
import java.util.PriorityQueue;
import java.util.Set;

public class PriorityQueueKruskal {

    private PriorityQueue<KruskalNode> edges;
    private Set<Integer> path;
    private int[] parent, ranks;
    private int minimumDistance;

    public PriorityQueueKruskal(int size) {
        this.edges = new PriorityQueue<>();
        this.parent = new int[size + 1];
        this.ranks = new int[size + 1];
        this.minimumDistance = Integer.MAX_VALUE;
        this.path = new LinkedHashSet<>();
        initParent(size + 1);
    }

    public void update() {
        this.minimumDistance = 0;
        PriorityQueue<KruskalNode> temp = new PriorityQueue<>(this.edges);
        while (!temp.isEmpty()) {
            KruskalNode node = temp.poll();
            int parentA = find(node.u);
            int parentB = find(node.v);
            if (parentA != parentB) {
                union(parentA, parentB);
                this.minimumDistance += node.weight;
                this.path.add(node.u);
                this.path.add(node.v);
            }
        }
    }

    public int getMinimumDistance() {
        return minimumDistance;
    }

    public void printTrace() {
        System.out.println(path.toString());
    }

    public void add(int u, int v, int w) {
        this.edges.add(new KruskalNode(u, v, w));
    }

    private void initParent(int size) {
        for (int i = 0; i < size; i++) {
            this.parent[i] = i;
        }
    }

    private int find(int x) {
        if (x == parent[x]) {
            return x;
        }
        return parent[x] = find(parent[x]);
    }

    private void union(int x, int y) {
        if (x == y) {
            return;
        }

        if (ranks[x] < ranks[y]) {
            int temp = ranks[x];
            ranks[x] = ranks[y];
            ranks[y] = temp;
        }

        parent[y] = x;

        if(ranks[x] == ranks[y]) {
            ranks[x] = ranks[y] + 1;
        }
    }

}

- Priority Queue는 poll 연산을 통해 최소 신장 트리를 갱신할 수 있지만, 일회용이므로 복사하게 되는 시간이 발생한다. 

- 따로 정렬 객체를 부르지 않아도 된다는 장점이 있지만 음.. Priority Queue가 편하면 이 방법을, 아니면 ArrayList 방법을 사용할 수 있도록 하려고 두가지 방법을 다 작성했다.

 

실사용 Main
package com.study.datastructrue.graph.mst.kruskal;

public class Main {

    public static void main(String[] args) {
        PriorityQueueKruskal pKruskal = new PriorityQueueKruskal(5);
        pKruskal.add(1, 2, 6);      pKruskal.add(1, 3, 3);
        pKruskal.add(1, 4, 1);      pKruskal.add(2, 5, 4);
        pKruskal.add(3, 4, 2);      pKruskal.add(3, 5, 5);
        pKruskal.add(4, 5, 7);      pKruskal.update();

        System.out.println(pKruskal.getMinimumDistance());
        pKruskal.printTrace();

        ArrayKruskal aKruskal = new ArrayKruskal(5);
        aKruskal.add(1, 2, 6);      aKruskal.add(1, 3, 3);
        aKruskal.add(1, 4, 1);      aKruskal.add(2, 5, 4);
        aKruskal.add(3, 4, 2);      aKruskal.add(3, 5, 5);
        aKruskal.add(4, 5, 7);      aKruskal.update();

        System.out.println(aKruskal.getMinimumDistance());
        aKruskal.printTrace();
    }

}

프림 알고리즘

 

1. 모든 간선이 비용이 적은 순서로 정렬한다. - Sort

2. 시작 정점을 정한다.

3. 시작 정점으로부터 방문하지 않은 정점과 연결된 간선만 최소 비용 간선으로 채택한다.

4. 3번에서 구한 정점과의 최소 비용 간선을 구하는 과정을 반복한다.

 

+  3번, 4번 과정을 반복하는 과정을 하려면 PriorityQueue로 구현하는 것이 좋다.

+ 최소 비용을 기준으로 방문하지 않은 정점만 방문하므로 최소 비용 경로를 구할 수 있다.

 

프림 노드
package com.study.datastructrue.graph.mst.prim;

public class PrimNode implements Comparable<PrimNode>{

    int v;
    int weight;

    public PrimNode(int v, int weight) {
        this.v = v;
        this.weight = weight;
    }

    @Override
    public int compareTo(PrimNode o) {
        return this.weight - o.weight;
    }

}

- 크루스칼은 간선의 비용을 기준으로 정렬하기 때문에 u, v 정보가 모두 필요하다.

- 프림은 다음 정점까지의 비용을 기준으로 정렬하기 때문에 u 정보가 필요 없다.

 

Prim Algorithm
package com.study.datastructrue.graph.mst.prim;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.PriorityQueue;

public class Prim {

    private List<PrimNode>[] graph;
    private List<Integer> path;
    private boolean[] visited;
    private int minimumDistance;

    public Prim(int size) {
        this.graph = new List[size + 1];
        this.path = new ArrayList<>();
        this.visited = new boolean[size + 1];
        this.minimumDistance = Integer.MAX_VALUE;
        initGraph(size);
    }

    private void initGraph(int size) {
        for (int i = 0; i <= size; i++) {
            this.graph[i] = new ArrayList<>();
        }
    }

    public void add(int u, int v, int w) {
        this.graph[u].add(new PrimNode(v, w));
        this.graph[v].add(new PrimNode(u, w));
    }

    public void update(int start) {
        Arrays.fill(visited, false);
        path.clear();
        this.minimumDistance = 0;

        PriorityQueue<PrimNode> pq = new PriorityQueue<>();
        pq.offer(new PrimNode(start, 0));

        while(!pq.isEmpty()) {
            PrimNode node = pq.poll();

            if (visited[node.v]) continue;
            visited[node.v] = true;
            this.minimumDistance += node.weight;
            path.add(node.v);

            for (PrimNode next: graph[node.v]) {
                if (!visited[next.v]) {
                    pq.offer(next);
                }
            }
        }
    }

    public int getMinimumDistance() {
        return this.minimumDistance;
    }

    public void printTrace() {
        for (int v: path) {
            System.out.print(v + " ");
        }
        System.out.println();
    }

}

- Path는 방문하지 않은 정점 순서대로 최소 비용으로 채택될테니 List 를 통해 추가하면 된다.

 

Main
package com.study.datastructrue.graph.mst.prim;

import com.study.datastructrue.graph.mst.kruskal.PriorityQueueKruskal;

public class Main {

    public static void main(String[] args) {
        Prim prim = new Prim(5);
        prim.add(1, 2, 6);      prim.add(1, 3, 3);
        prim.add(1, 4, 1);      prim.add(2, 5, 4);
        prim.add(3, 4, 2);      prim.add(3, 5, 5);
        prim.add(4, 5, 7);      prim.update(1);

        System.out.println(prim.getMinimumDistance());
        prim.printTrace();
    }

}

 

보통 간선이 많은 경우, 간선을 기준으로 정렬하는 크루스칼

정점이 많은 경우, 정점을 기준으로 정렬하는 프림을 사용하면 효율적이겠다.