CS/자료구조

DSU(Disjoint Set Union) / Union Find

_우지 2023. 5. 2. 20:25

 

Disjoint Set Union

Disjoint Set(서로소 집합)이란 서로 공통된 원소가 없는 집합을 말합니다.

보통 Union Find 라고 불립니다.

 

https://www.youtube.com/watch?v=Usy6eEkhesc

 

Union

2개의 집합을 1개의 집합으로 합치는 방식입니다.

 

https://www.youtube.com/watch?v=Usy6eEkhesc

 

기본적으로 높이 가 작은 쪽을 큰 쪽으로 합칩니다.

 

유니온 바이 랭크(Union by rank)

두개의 disjoint set 을 합칠 때 항상 작은 쪽을 큰 쪽에 합칩니다.

 

 

Find

어떤 element가 주어질때, 이 element 가 속해져 있는 루트 를 반환합니다.

 

아래 예시에서는 3 이라는 원소를 Find 하게 되면 루트인 6이 나오게 됩니다.

 

 

경로 압축

경로 압축은 Find 의 시간 복잡도를 개선하기 위해 원소를 Find 할 경우 바로 루트원소가 나올 수 있도록 하는 방법입니다.

개선된 시간복잡도는 O(1) 또는 O(logN) 이라고 합니다.

* 실제로는 아커만 상수라고 합니다. 하지만 생각하기 편하게 상수 시간복잡도를 가진다고 정리했습니다.

 

 

Union & Find 코드

문제마다 조금씩 코드를 수정해야하지만, 가장 기본이 되는 베이스 코드는 다음과 같습니다. 

  function find(u) {
    if (u == parent[u]) {
      return u;
    }
    return (parent[u] = find(parent[u]));
  }

  function merge(u, v) {
    const U = find(u);
    const V = find(v);
    if (U === V) return;

    parent[U] = V;
  }

 

유니온을 랭크 바이 유니온 을 사용하여 더 엄격하게 구현하면 다음과 같다.

하지만 이 방식을 사용하려면 rank 배열을 하나 더 선언해야하고, 배열에 값을 할당, 참조하는 시간이 꽤 소요 되어 사용하지않고 위 코드의 Union 처럼 한방향으로 넣는게 오히려 더 빨랐다.

  function merge(u, v) {
    const U = find(u);
    const V = find(v);
    if (U === V) return;

    if (rank[U] > rank[V]) {
      parent[V] = U;
      rank[U] += rank[V];
    } else {
      parent[U] = V;
      rank[V] += rank[U];
    }
  }

 

관련 문제 풀이

한번 관련된 문제를 풀어보면서 제대로 이해했는지 점검해봅시다.

문제링크 - https://www.acmicpc.net/problem/1717

 

다음과 같은 입력이 들어올때 0은 유니온(Union) 1은 파인드(Find) 를 호출합니다.

7 8
0 1 3
1 1 7
0 7 6
1 7 1
0 3 7
0 4 2
0 1 1
1 1 1

 

그럼 위 입력이 어떻게 되는지 알아보겠습니다.

 

1) 0 1 3

 

2) 1 1 7

여기서는 NO 를 출력하게 됩니다. 

 

 

3) 0 7 6

 

4) 1 7 1

같은 그룹에 속해 있지 않습니다. 1 의 루트는 3 이고, 7 의 루트는 6 이기 때문입니다.

NO 를 출력하게 됩니다. 

 

 

5) 0 3 7

 

6) 0 4 2

 

7) 0 1 1

변화 없습니다.

 

 

8) 1 1 1

1 과 1은 같은 그룹에 속해 있기 때문에 YES 가 출력됩니다.

또한 이 과정에서 경로 압축이 됩니다.

 

 

전체 코드

const readline = require("readline");

const rl = readline.createInterface({
  input: process.stdin,
  output: process.stdout,
});

let input = [];

rl.on("line", function (line) {
  input.push(line);
}).on("close", function () {
  input = input.map((item) => item.split(" ").map(Number)).reverse();
  main();

  process.exit();
});

function main() {
  const vertexSize = 1_000_000;

  const [n, m] = input.pop();
  const parent = Array.from({ length: vertexSize + 1 }, (_, i) => i);
  let answer = "";

  for (let i = 0; i < m; i++) {
    const [op, a, b] = input.pop();

    if (op === 1) {
      const A = find(a);
      const B = find(b);

      if (A === B) answer += "YES\n";
      if (A !== B) answer += "NO\n";
    }
    if (op === 0) merge(a, b);
  }

  console.log(answer);

  function find(u) {
    if (u == parent[u]) {
      return u;
    }
    return (parent[u] = find(parent[u]));
  }

  function merge(u, v) {
    const U = find(u);
    const V = find(v);
    if (U === V) return;

    parent[U] = V;
  }
}