그냥 하는 노트와 메모장

DSU(Disjoint-set union) 문제들 본문

Algorithms

DSU(Disjoint-set union) 문제들

coloredrabbit 2018. 5. 17. 15:29

* DSU (Disjoint-Set Union, Union-Find 또는 상호배제 집합)

  DSU에 관련된 문제 중 푼 문제를 소개합니당



1. Codeforce

  1-1. Serial Time! (http://codeforces.com/contest/60/problem/B)

  단순히 BFS 돌려도 되지만, DSU로 풀어보고 싶어서 적용해봤다. 인접한 6 이웃에 대해 #이 아니라면 접근이 가능하다.

  

코드 : 

#include <cstdio>
int tp[10][10][10], p[1001], s[1001];
int find(int u) {
	if (u == p[u]) return u;
	return p[u] = find(p[u]);
}
void merge(int u, int v) {
	u = find(u), v = find(v);
	if (u == v) return;
	p[u] = v;
	s[v] += s[u];
}
int main() {
	char board[10][10][11], ch;
	int k, n, m, l, x, y, cnt = 1, i, j, dy[] = { 1,-1,0,0,0,0 }, dx[] = { 0,0,1,-1,0,0 }, dz[] = { 0,0,0,0,1,-1 };
	scanf("%d%d%d", &k, &n, &m);
	for (l = 0; l < k; l++) {
		scanf("%c\n", &ch);
		for (x = 0; x<n; x++) {
			scanf("%s", board[l][x]);
			for (y = 0; y<m; y++)
				if (board[l][x][y] == '.') {
					tp[l][x][y] = cnt, s[cnt] = 1;
					p[cnt] = cnt;
					cnt++;
				}
		}
	}
	scanf("%d%d", &x, &y);
	x--, y--;
	for (l = 0; l<k; l++) for (i = 0; i<n; i++) for (j = 0; j<m; j++) {
		if (!tp[l][i][j]) continue;
		for(int c = 0;c<6;c++) {
			int tz = l + dz[c], ty = i + dy[c], tx = j + dx[c];
			if (0 <= tz && tz < k && 0 <= ty && ty < n && 0 <= tx && tx < m && tp[tz][ty][tx]) {
				if (find(tp[l][i][j]) != find(tp[tz][ty][tx]))
					merge(tp[l][i][j], tp[tz][ty][tx]);
			}
		}
	}
	printf("%d", s[find(tp[0][x][y])]);
	return 0;
}
  

  1-2. pSort (http://codeforces.com/contest/28/problem/B)

  A와 B가 swap이 가능하고, B와 C가 swap이 가능하다면 A와 C가 swap이 가능하다. 집합 개념으로 본다면 같은 집합 내의 요소들끼리는 모두 swap이 가능하므로, 집합을 먼저 모두 구한 다음 각 집합에 대해 sorting 한다. 그 다음 1부터 N까지 차례로 나열되어 있는지 확인하면 된다.


코드 :

#include <cstdio>
int p[100];
int find(int u) {
	if (u == p[u]) return u;
	return p[u] = find(p[u]);
}
void merge(int u, int v) {
	u = find(u), v = find(v);
	if (u == v) return;
	p[u] = v;
}
int main() {
	int n, d[100], fn[100], i, j, ans = 1, vi[100] = {}, pn;
	scanf("%d", &n);
	for (i = 0; i<n; i++) scanf("%d", &d[i]);
	for (i = 0; i<n; i++) {
		scanf("%d", &fn[i]);
		p[i] = i;
	}
	for (i = 0; i<n; i++) {
		if (0 <= i - fn[i]) merge(i, i - fn[i]);
		if (i + fn[i] < n) merge(i, i + fn[i]);
	}
	for (i = 0; i<n; i++)
		ans &= find(i) == find(d[i]-1);
	printf("%s", ans ? "YES" : "NO");
	return 0;
}



  1-3. Roads not only in Berland (http://codeforces.com/contest/25/problem/D) (Needed to refactoring proc)

  다소 tree와 graph에 대한 개념이 필요한 문제.

  tree는 노드가 N개이고 간선이 N-1개다. 만약 노드가 N개이고 간선이 N-1인 그래프가 트리가 아니라면 이는 반드시 사이클이 존재하게 되어 있다. 이런 사이클이 존재하도록 하는 간선을 집합끼리 묶어주는 간선으로 이동시키는 문제.


  집합이 하나가 될 때까지 아래 작업을 수행한다.

    1. 현재 집합에 사이클이 없다면 다음 집합으로 넘어간다.

    2. 현재 집합에 있는 사이클을 이루는 간선 하나를 다른 집합과 현재 집합의 각 루트를 연결시켜주는 용도로 사용한다.


코드 : 

#include <cstdio>
#include <vector>
#include <cstring>
using namespace std;
int *p, n;
int find(int u) {
	if (u == p[u]) return u;
	return p[u] = find(p[u]);
}

void merge(int u, int v) {
	u = find(u), v = find(v);
	if (u == v) return;
	p[u] = v;
}

vector<vector<int>> adj;
bool adj_a[1000][1000], vi[1000];
struct _a { int i, j, u, v; };
vector<_a> ans;
struct _r { int u, v; };
_r dfs(int u, int p) {
	_r ret{-1,-1};
	vi[u] = 1;
	for (auto it = adj[u].begin(); it != adj[u].end();) {
		int v = *it;
		if (adj_a[u][v]) {
			if (!vi[v]) {
				_r ret = dfs(v, u);
				if (ret.u != -1) return ret;
			}
			else if (p != v) return _r{ u,v };
			it++;
		}
		else it = adj[u].erase(it);
	}
	return ret;
}

void rec(int root) {
	if (root == n)
		return;
	int i, f = 1;
	_r del;
	for (i = 0; i < n; i++) {
		if (f) {
			memset(vi, 0, sizeof vi);
			del = dfs(root, -1);
			f = 0;
		}
		if (del.u == -1) break;
		if (find(root) != find(i)) {
			ans.push_back(_a{ del.u, del.v, find(root), find(i) });
			adj_a[del.u][del.v] = adj_a[del.v][del.u] = 0;
			merge(root, i);
			f = 1;
		}
	}
	if (i != n) rec(root + 1);
}

int main() {
	int i,j, u, v;
	scanf("%d", &n);
	adj.resize(n);
	p = new int[n];
	for (i = 0; i < n - 1; i++) {
		scanf("%d%d", &u, &v);
		u--, v--;
		adj[u].push_back(v);
		adj[v].push_back(u);
		adj_a[u][v] = adj_a[v][u] = 1;
	}
	for (i = 0; i < n; i++) p[i] = i;

	for (i = 0; i < n; i++) {
		if (vi[n]) continue;
		for (j = 0; j < n; j++) if (adj_a[i][j] && find(i) != find(j))
			merge(i, j);
	}
	rec(0);	
	printf("%d\n", ans.size());
	for (_a& sa : ans)
		printf("%d %d %d %d\n", sa.i+1, sa.j+1, sa.u+1, sa.v+1);
	return 0;
}




  1-4. Cthulhu (http://codeforces.com/contest/103/problem/B)

  DSU가 Detecting cyclic graph에 사용할 수 있다. 현재 인접한 정점이 오로지 하나의 간선을 갖는 하나의 무방향 그래프 안에 있는 인접하지 않는 정점 A와 B에 대해 A와 B를 연결하면 사이클이 있음을 직관적으로 알 수 있다. 이는 데이터가 들어오는 방식에 따라 알 수 있다.

  Cthulhu 문제는 그래프 내에 사이클이 하나만 존재하며, 모든 정점은 이 그래프와 연결되어 있어야 한다(그래프가 오로지 하나만 존재해야 한다). 따라서 이 문제는 DFS 또는 DSU, 입맛에 맞는 풀이로 해결하시면 되겠다.


DFS sol. 코드 :

#include <cstdio>
#include <vector>
using namespace std;
vector<vector<int>> adj;
bool vi[100];
int dfs(int u, int p) {
	int ret = 0;
	vi[u] = 1;
	for (int&v : adj[u]) {
		if (vi[v]) ret += v != p;
		else ret += dfs(v, u);
	}
	return ret;
}

int main() {
	int n, m, x, y, i, trees = 0, ans = 0;
	scanf("%d%d", &n, &m);
	adj.resize(n);
	while (m--) {
		scanf("%d%d", &x, &y);
		x--, y--;
		adj[x].push_back(y);
		adj[y].push_back(x);
	}
	for (i = 0; i<n; i++)
		if (!vi[i]) {
			trees++;
			ans = dfs(i, -1) / 2;
		}
	printf("%s", trees == 1 && ans == 1 ? "FHTAGN!" : "NO");

	return 0;
}



DSU sol. 코드 :

#include <cstdio>
int p[100],cyc[100];
int find(int u) {
	return u == p[u] ? u : p[u] = find(p[u]);
}
void merge(int u, int v) {
	u = find(u), v = find(v);
	p[u] = v;
	cyc[v] += cyc[u];
}
int main() {
	int n, m,i, u, v, t_cnt = 0, c = 0, vi[100] = {},rt;
	scanf("%d%d", &n, &m);
	for (i = 0; i < n; i++) p[i] = i;
	while (m--) {
		scanf("%d%d", &u, &v);
		u--, v--;
		if (find(u) != find(v))
			merge(u, v);
		else 
			cyc[find(u)] += 1;
	}
	for(i=0;i<n;i++)
		if (!vi[(rt = find(i))]) {
			vi[rt] = 1;
			t_cnt++, c = c > cyc[rt] ? c : cyc[rt];
		}
	printf("%s", t_cnt == 1 && c == 1 ? "FHTAGN!" : "NO");
	return 0;
}





2. 백준 온라인 저지

  2-1. 방 청소 (https://www.acmicpc.net/problem/9938)

  우선 문제 상황을 보면 N제한이 매우 높기 때문에 경로 압축을 사용해야 한다.

  집합의 정의를 "연결 가능한(또는 연쇄적으로 밀 수 있는) 서랍 집합"으로 표현할 수 있다. 가령 23번째 서랍에 술을 넣어야 하는 경우, 23번째에 이미 다른 술이 들어 있다면, 그 술이 어디론가 이동할 수 있는지 없는지 판단한다. 이는 연쇄적으로 일어나며 결국 마지막 술을 옮길 수 있는지 없는지의 문제로 바뀌어 버린다.

  따라서 각 집합의 루트는 마지막 술의 Bi 인덱스를 저장하면 된다.(Ai의 위치에 넣었다면 Bi는 비어있는 상태이면 이는 곧 필요에 의해 Bi로 이동할 수 있다.)


코드 : 

#include <cstdio>
int *p;
bool *used;
void swap_v(int& a, int& b) { int t = a; a = b; b = t; }
int find(int u) {
	if (p[u] == u) return u;
	return p[u] = find(p[u]);
}

void merge(int u, int v) {
	u = find(u), v = find(v);
	if (u == v) return;
	p[u] = v;
}

int main() {
	int N, L, i, a, b, f;
	scanf("%d%d", &N, &L);
	p = new int[L], used = new bool[L] {};
	for (i = 0; i < L; i++) p[i] = i;
	while (N--) {
		f = 1;
		scanf("%d%d", &a, &b);
		a--, b--;
		a = find(a), b = find(b);
		if (!used[a]) {
			used[a] = 1;
			merge(a, b);
		}
		else if (!used[b])
			used[b] = 1;
		else f = 0;

		printf("%s\n", f ? "LADICA" : "SMECE");
	}
	return 0;
}




  2-2. 문명 (https://www.acmicpc.net/problem/14868)

  이 문제 역시 각 문명을 집합체로 보면 100000개나 되므로 경로 압축을 사용해야 한다.

  하루마다 4방향으로 문명이 전파되므로 이는 BFS으로 해야함을 알 수 있다. 여기서 BFS를 통해 만나는 문명들을 같은 집합체로 union해주면 된다. 주의해야할 점은 최초 같은 발상지는 하루가 지나지 않아도 같은 문명으로 판단한다.


코드 :

#include <cstdio>
#include <queue>
using namespace std;
void swap_v(int& a, int& b) { int t = a; a = b; b = t; }
int max_v(int a, int b) { return a > b ? a : b; }
struct _d { int y, x; };
struct _e { 
	int y, s;
	_e() : y(0), s(-1) {};
};
struct _nq { int y, x, uy, ux; };
int *p, *r, *s;

int find(int u) {
	if (u == p[u]) return u;
	return p[u] = find(p[u]);
}

void merge(int u, int v) {
	u = find(u), v = find(v);
	if (u == v) return;
	if (r[u] > r[v]) swap_v(u, v);
	p[u] = v;
	s[v] += s[u];
	if (r[u] == r[v]) ++r[v];
}

int main() {
	int N, K, i, j, dy[] = { 0,0,1,-1 }, dx[] = { 1,-1,0,0 }, ans=0;
	_e vi[2000][2000];
	queue<_d> q;
	queue<_nq> nextq;
	_d u;
	scanf("%d%d", &N, &K);
	p = new int[K], r = new int[K], s = new int[K];
	for(i=0;i<K;i++) {
		p[i] = i, r[i] = 0, s[i] = 1;
		scanf("%d%d", &u.y, &u.x);
		u.y--, u.x--;
		q.push(u);
		vi[u.y][u.x].s = p[i];
	}
	while (!q.empty()) {
		while (!q.empty() && s[find(0)] < K) {
			u = q.front(), q.pop();
			for (i = 0; i < 4; i++) {
				int ty = u.y + dy[i], tx = u.x + dx[i];
				if (0 <= ty && ty < N && 0 <= tx && tx < N) {
					if (vi[ty][tx].s == -1)// 처음으로 문명을 전파하는 지역
						nextq.push(_nq{ ty,tx,u.y,u.x });
					else if (find(vi[ty][tx].s) != find(vi[u.y][u.x].s)) { // 문명이 다른 지역
						merge(vi[ty][tx].s, vi[u.y][u.x].s);
						ans = max_v(ans, max_v(vi[ty][tx].y, vi[u.y][u.x].y));
					}
				}
			}
		}
		if (s[find(0)] >= K) break;
		while (!nextq.empty()) {
			_nq v = nextq.front(); nextq.pop();
			if (vi[v.y][v.x].s != -1) continue;
			vi[v.y][v.x].s = vi[v.uy][v.ux].s;
			vi[v.y][v.x].y = vi[v.uy][v.ux].y + 1;
			q.push(_d{ v.y, v.x });
		}
	}	
	printf("%d", ans);
	return 0;
}




  2-3. 욱제와 그의 팬들 (https://www.acmicpc.net/problem/15352) (Needed to refactoring proc)

  이 문제도..ㅋㅋ N제한이 크기 때문에 경로 압축을 사용한다. 

  쿼리의 수를 보면 매우 크기 때문 매 쿼리마다 O(1) 또는 O(lgN) 시간을 가져야한다.


  희한하게 이 문제는 merge하는 것이 아니라 delete 연산하는 것처럼 보인다. 일렬로 되어 있는 데이터 집합에서 원형이 아니기 때문에 왼쪽과 오른쪽 각각 집합체를 저장하고, 마지막으로 현재 위치에 대해 양옆으로 선물을 얼마나 줘야하는지 집합체를 선언한다. 이러한 과정은 Disjoint-set union and deletion이 아니라 구성되어 있는 트리를 그대로 가져가며 쿼리를 수행할 수 있게 한다. 이는 문제가 DSU 구조를 루트 중심으로 구성해야할 때 쓸 수 있는데, 루트가 해당 집합체를 대표할 수 있기 때문이다. 따라서 퇴출되는 데이터에 대해서는 루트에 대해 크기와 왼쪽, 오른쪽 집합체를 처리해주면 된다.

  참고) 하지만 이 방식은 시간이 매우 오래걸린다(1.7 ms). 다른 솔의 경우 빠르면 1.1 ms가 나오는데 리팩토링이 필요할 것 같다.


코드 :

#include <cstdio>
void swap_v(int& a, int& b) { int t = a; a = b; b = t; }
int *l, *r, *gl, *gr, *rank, *size;
int find(int u, int *p) {
	if (u == p[u]) return u;
	return p[u] = find(p[u], p);
}

int merge(int u, int v, int *p) {
	u = find(u, p), v = find(v, p);
	if (u == v) return u;
	if (p == gl && rank[u] > rank[v]) swap_v(u, v);
	p[u] = v;
	if (p == gl && rank[u] == rank[v]) ++rank[v];
	if (p == gl) size[v] += size[u];
	return v;
}

int main() {
	int K, N, *A, i, a, b, Q;
	scanf("%d%d", &K, &N);
	A = new int[N + 2]{}, l = new int[N + 2]{}, r = new int[N + 2]{}; // index 0, N+1 are dummies.
	rank = new int[N + 2]{}, size = new int[N + 2]{}, gl = new int[N + 2]{};
	long long ans = 0;

	l[N + 1] = r[N + 1] = N + 1;
	for (i = 1; i <= N; i++) {
		scanf("%d", &A[i]);
		gl[i] = l[i] = r[i] = i, size[i] = 1;
		if (A[i - 1] == A[i])
			merge(i - 1, i, gl);
	}

	scanf("%d", &Q);
	while (Q--) {
		scanf("%d%d", &a, &b);
		switch (a) {
		case 1:
			size[find(b, gl)]--;
			merge(b, b - 1, l);
			merge(b, b + 1, r);
			if (A[find(b, l)] == A[find(b, r)])
				merge(find(b, l), find(b, r), gl);
			break;
		case 2:
			ans += size[find(b, gl)];
			break;
		}
	}
	printf("%lld", ans);
	return 0;
}



  


  2-4. 가로수 (https://www.acmicpc.net/problem/15674) (Needed to refactoring proc)

  종만북에 나오는 예제 문제와 매우 유사하다. 하지만 쿼리 과정이 붙었다.

  쿼리당 O(1) 또는 O(lgN)의 시간을 가져야 한다.


  알고리즘 전체 구조는 아래와 같다.

    1. 같은 나무를 심어야 하는 A, B가 있을 때

-1. A와 같은 나무를 심어야 하는 또다른 나무 C와 B는 같은 집합이 되어야 한다.

-2. A와 다른 나무를 심어야 하는 또다른 나무 C와 B는 다른 집합이 되어야 한다.

    2. 다른 나무를 심어야 하는 A, B가 있을 때

-1. A와 같은 나무를 심어야 하는 또다른 나무 C와 B는 다른 집합이 되어야 한다.

-2. A와 다른 나무를 심어야 하는 또다른 나무 C와 B는 같은 집합이 되어야 한다.


  최소 비용 구하는 방식은 상수 시간으로 구한다.


- Process

  1. A, B 위치의 나무 심는 방식에 대한 쿼리

     A와 B에 연관된 데이터에 대해 각 A,B 집합에 대해 최소값을 제거해준다. 이 때, A와 같은 나무, 다른 나무의 집합체에 대해 최소값과 B와 같은 나무, 다른 나무의 집합체에 대해 최소값을 각각 구하여 결과에서 빼줘야 한다.

1-1. A와 B가 같은 나무를 심어야 한다.

1-1-1. A와 B를 union 해주고, A와 다른 나무와 심어야 하는 diff[A] 집합과 B와 다른 나무를 심어야 하는 diff[B]를 union 한다.

1-1-2. 새로 설정된 각 집합을 대표하는 루트를 얻고 diff배열을 설정한다.

   여기서 diff 배열을 초기값 -1 의미하는 것은 다른 나무를 심어야 하는 위치가 없음을 나타낸다.

1-2. A와 B가 다른 나무를 심어야 한다.

1-2-1. A와 diff[B]를 union 해주고, A와 다른 나무와 심어야 하는 diff[A] 집합과 B와 같은 나무를 심어야 하는 B를 union 한다.

1-2-2. 새로 설정된 각 집합을 대표하는 루트를 얻고 diff배열을 설정한다.

   여기서는 diff 배열에 대해 상호 배제하는 관계이기 때문에 새로 설정된 각 집합체에 대해 diff가 -1 값을 가질 일이 없다.

  2. 금액이 변경되었다.

      2-1. 은행나무의 가격이 변경되었다.

원래 있던 가격에 대해 빼준 다음, 변경된 값으로 은행나무 가격 배열 및 해당 위치에 대한 집합의 루트에 반영한다.

2-2. 플라타너스의 가격이 변경되었다.

원래 있던 가격에 대해 빼준 다음, 변경된 값으로 플라타너스 가격 배열 및 해당 위치에 대한 집합의 루트에 반영한다.

  

  3. 전체 가격을 계산한다.



소스 코드는 첨부하지만 리팩토링을 하지 않아 읽기 어려울 것 같다. 아래에 각 변수가 무엇을 의미하는지 짧게 서술하겠다 :)


 - p : parent

 - diff : store difference index from i-th index

 - r : rank (To saving proc time)

 - to1 : 은행나무 기존 가격(The price of original tree1)

 - to2 : 플라타너스 기존 가격(The price of original tree2)

 - ans : 결과 값

 - t1 : 은행 나무 집합체에 대한 가격의 합 (집합을 이루는 것끼리의 합, 집합 내에 원소가 하나라면 t1[i] = to1[i])

 - t2 : 플라타너스 집합체에 대한 가격의 합 (집합을 이루는 것끼리의 합, 집합 내에 원소가 하나라면 t2[i] = to2[i])


코드 : 

#include <cstdio>
using ll = long long;
ll min_v(ll a, ll b) { return a < b ? a : b; }
ll max_v(ll a, ll b) { return a > b ? a : b; }
void swap_v(int& a, int& b) { int t = a; a = b; b = t; }
int *p, *diff, *r, unused_set = 0, *to1,*to2;
ll ans, *t1, *t2;

int find(int u) {
	if (u == p[u]) return u;
	return p[u] = find(p[u]);
}

int merge(int u, int v) {
	if (u == -1 || v == -1) return max_v(u, v);
	u = find(u), v = find(v);
	if (u == v) return u;
	if (r[u] > r[v]) swap_v(u, v);
	if (r[u] == r[v]) ++r[v];
	p[u] = v;
	t1[v] += t1[u];
	t2[v] += t2[u];	
	return v;
}

ll getMin(int u) {
	if (diff[u] != -1)
		return min_v(t1[u] + t2[diff[u]], t2[u] + t1[diff[u]]);
	return min_v(t1[u], t2[u]);
}

void process(int c, int a, int b) {
	int root = -1,rt1,rt2;
	ll sub = 0;
	a--;
	switch (c) {
	case 0: // same
		b--;
		a = find(a), b = find(b);
		if (a != b) {
			sub = getMin(a) + getMin(b);
			rt1 = merge(a, b), rt2 = merge(diff[a], diff[b]);
			diff[rt1] = rt2;
			if (rt2 != -1) diff[rt2] = rt1;
			root = rt1;
		}
		break;
	case 1: // diff
		b--;
		a = find(a), b = find(b);
		if (a != diff[b]) {
			sub = getMin(a) + getMin(b);
			rt1 = merge(a, diff[b]), rt2 = merge(b, diff[a]);
			diff[rt1] = rt2, diff[rt2] = rt1;
			root = rt1;
		}
		break;
	case 2: // t1
		rt1 = find(a);
		sub = getMin(rt1);
		t1[rt1] = t1[rt1] - to1[a] + b;
		to1[a] = b;
		root = rt1;
		break;
	case 3: // t2
		rt1 = find(a);
		sub = getMin(rt1);
		t2[rt1] = t2[rt1] - to2[a] + b;
		to2[a] = b;
		root = rt1;
		break;
	}
	if (root != -1)
		ans = ans - sub + getMin(root);
}


int main() {
	int N, D, Q, i, A, B, C;
	scanf("%d%d", &N, &D);
	diff = new int[N], p = new int[N], to1 = new int[N], to2 = new int[N], r = new int[N] {};
	t1 = new ll[N], t2 = new ll[N];
	for (i = 0; i<N; i++) p[i] = i, diff[i] = -1,r[i] = 1;
	for (i = 0; i < N; i++) {
		scanf("%d%d", &to1[i], &to2[i]);
		t1[i] = to1[i], t2[i] = to2[i];
		ans += min_v(t1[i], t2[i]);
	}
	while (D--) {
		scanf("%d%d%d", &C, &A, &B);
		process(C, A, B);
	}
	
	printf("%lld\n", ans);
	scanf("%d", &Q);
	while (Q--) {
		scanf("%d%d%d", &C, &A, &B);
		process(C, A, B);
		printf("%lld\n",ans);
	}
	return 0;
}


Comments