업데이트:

문제 링크

백준 2887번 - 행성 터널 (Platinum 5)

문제 설명

3차원 좌표상의 점 N($\leq$ 100,000)개가 주어져 있다.
두 점 $A = (x_A, y_A, z_A)$와 $B = (x_B, y_B, z_B)$을 잇는데 드는 비용은 $\mathrm{min}(|x_A - x_B|, |y_A - y_B|, |z_A - z_B|)$이다.
N개의 점이 모두 연결되도록 하는데 필요한 최소 비용을 구하시오.

정답 코드 및 설명

기본적으로 최소 신장 트리를 구하는 문제임은 쉽게 알 수 있다. 하지만 모든 점들 간의 간선을 전부 만든다면, $O(N^2)$개의 간선을 만들어야하는데 여기에서 벌써 시간 초과가 발생한다. N이 최대 100,000이기 때문이다. 따라서, 필요한 간선만을 추가하여 최소 신장 트리를 구해야 한다.

먼저, 같은 점을 잇는 간선이 여러 개인 경우에 최소 신장 트리를 구한다고 생각해보자.
당연하게도, 같은 점을 잇는 간선 중 최소의 길이인 것이 쓰이게 될 것이다.
따라서, 두 점 A, B를 잇는 가중치가 $|x_A - x_B|$, $|y_A - y_B|$, $|z_A - z_B|$인 3개의 간선을 추가하고 생각해도 무방하다.

이제 크루스칼 알고리즘을 생각해보자.
크루스칼 알고리즘은 아직 사용하지 않은 간선 중 최소 길이의 간선을 본다.
이 간선의 양 끝 점이 이미 연결되어 있다면 다음 길이의 간선으로 넘어가고, 연결되어 있지 않다면 두 점을 잇는 방식으로 진행된다.
점 $A = (x_A, y_A, z_A)$, $B = (x_B, y_B, z_B)$, $C = (x_C, y_C, z_C)$가 있다고 하자.
$x_A \leq x_B \leq x_C$라면, A-C를 잇는 가중치 $|x_A - x_C|$인 간선 1은 무조건 A-B를 잇는 가중치 $|x_A - x_B|$인 간선 2와 B-C를 잇는 가중치 $|x_B - x_C|$인 간선 3보다 뒤에 나온다.
따라서, 간선 1을 보는 시점에서는 간선 2와 3은 이미 본 뒤고, 따라서 A와 B, B와 C는 서로 연결되어 있다.
이를 다시 말하면, 간선 1을 보는 시점에서 A와 C는 항상 연결되어 있으므로 간선 1을 아예 고려할 필요가 없다는 의미가 된다.

따라서, 다음의 세 종류에 해당하는 간선들만 추가하여 최소 신장 트리를 구하면 된다.

  1. N개의 점을 x좌표 순으로 정렬하였을 때 서로 이웃하는 점들의 x좌표 차이에 해당하는 간선
  2. N개의 점을 y좌표 순으로 정렬하였을 때 서로 이웃하는 점들의 y좌표 차이에 해당하는 간선
  3. N개의 점을 z좌표 순으로 정렬하였을 때 서로 이웃하는 점들의 z좌표 차이에 해당하는 간선

이제 전체 시간 복잡도를 구해보자.
간선을 추가하려면 정렬을 세 번 수행해야 하므로, $O(N\log N)$의 시간이 필요하다.
추가된 간선들로 최소 신장 트리를 구하려면 $O(E \log E) = O(N \log N)$의 시간이 필요하다.
따라서 전체 시간 복잡도는 $O(N \log N)$이다. N = 100,000이므로 1초 안에 수행이 가능하다!
위에서는 설명의 편의를 위해 크루스칼 알고리즘을 사용했지만, 프림 알고리즘을 사용해도 무방하다. 최소 신장 트리의 유일성은 이미 증명된 사실이기 때문이다.

코드로 구현하면 아래와 같다.
간선을 추가하는 부분은 코드의 28행 ~ 39행에 해당한다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;
import java.util.stream.Collectors;

public class BOJ2887 {
    int n;
    Planet[] planets;
    PriorityQueue<int[]> edges;
    int[] sets;

    void input() throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        n = Integer.parseInt(br.readLine());
        planets = new Planet[n];
        StringTokenizer st;
        int x, y, z;
        for (int i = 0; i < n; i++) {
            st = new StringTokenizer(br.readLine());
            x = Integer.parseInt(st.nextToken());
            y = Integer.parseInt(st.nextToken());
            z = Integer.parseInt(st.nextToken());
            planets[i] = new Planet(i, new Point(x, y, z));
        }
    }

    void setEdges() {
        edges = new PriorityQueue<>(Comparator.comparingInt(o -> o[2]));
        for (int i = 0; i < 3; i++) {
            int index = i;
            List<Planet> sortedPlanets = Arrays.stream(planets)
                    .sorted(Comparator.comparingInt(p -> p.point.point[index]))
                    .collect(Collectors.toList());
            for (int j = 1; j < n; j++) {
                Planet first = sortedPlanets.get(j);
                Planet second = sortedPlanets.get(j - 1);
                edges.offer(new int[]{second.planetNo, first.planetNo,
                        first.point.point[index] - second.point.point[index]});
            }
        }
    }

    long Kruskal() {
        sets = new int[n];
        Arrays.fill(sets, -1);
        long cost = 0;
        int count = 0;
        while (count < n && !edges.isEmpty()) {
            int[] currEdge = edges.poll();
            if (union(currEdge[0], currEdge[1])) {
                count++;
                cost += currEdge[2];
            }
        }
        return cost;
    }

    int findSet(int a) {
        if (sets[a] < 0) return a;
        return sets[a] = findSet(sets[a]);
    }

    boolean union(int a, int b) {
        int setA = findSet(a);
        int setB = findSet(b);
        if (setA == setB) return false;
        if (-sets[setA] < -sets[setB]) {
            int temp = setA;
            setA = setB;
            setB = temp;
        }
        sets[setA] += sets[setB];
        sets[setB] = setA;
        return true;
    }

    void solution() throws IOException {
        input();
        setEdges();
        System.out.println(Kruskal());
    }

    public static void main(String[] args) throws IOException {
        new BOJ2887().solution();
    }

    static class Point {
        int[] point = new int[3];

        public Point(int x, int y, int z) {
            point[0] = x;
            point[1] = y;
            point[2] = z;
        }
    }

    static class Planet {
        int planetNo;
        Point point;

        public Planet(int planetNo, Point point) {
            this.planetNo = planetNo;
            this.point = point;
        }
    }
}

댓글남기기