1"""
2์ ๋์จ ํ์ธ๋ (Union-Find / Disjoint Set Union)
3Union-Find / Disjoint Set Union Data Structure
4
5์๋ก์ ์งํฉ์ ๊ด๋ฆฌํ๋ ์๋ฃ๊ตฌ์กฐ๋ก, ๊ทธ๋ํ์ ์ฐ๊ฒฐ์ฑ ๋ฌธ์ ์ ์ฃผ๋ก ์ฌ์ฉ๋ฉ๋๋ค.
6"""
7
8from typing import List, Tuple
9
10
11# =============================================================================
12# 1. ๊ธฐ๋ณธ ์ ๋์จ ํ์ธ๋
13# =============================================================================
14class UnionFind:
15 """
16 ๊ธฐ๋ณธ ์ ๋์จ ํ์ธ๋ ๊ตฌํ
17 - ๊ฒฝ๋ก ์์ถ (Path Compression)
18 - ๋ญํฌ ๊ธฐ๋ฐ ํฉ์น๊ธฐ (Union by Rank)
19 """
20
21 def __init__(self, n: int):
22 """
23 n๊ฐ์ ์์๋ก ์ด๊ธฐํ (0 ~ n-1)
24 """
25 self.parent = list(range(n)) # ์๊ธฐ ์์ ์ ๋ถ๋ชจ๋ก
26 self.rank = [0] * n # ํธ๋ฆฌ์ ๋์ด (๊ทผ์ฌ๊ฐ)
27 self.count = n # ์งํฉ์ ๊ฐ์
28
29 def find(self, x: int) -> int:
30 """
31 x๊ฐ ์ํ ์งํฉ์ ๋ํ(๋ฃจํธ) ์ฐพ๊ธฐ
32 ๊ฒฝ๋ก ์์ถ์ผ๋ก ๊ฑฐ์ O(1)
33 """
34 if self.parent[x] != x:
35 self.parent[x] = self.find(self.parent[x]) # ๊ฒฝ๋ก ์์ถ
36 return self.parent[x]
37
38 def union(self, x: int, y: int) -> bool:
39 """
40 x์ y๊ฐ ์ํ ์งํฉ์ ํฉ์น๊ธฐ
41 ์ด๋ฏธ ๊ฐ์ ์งํฉ์ด๋ฉด False ๋ฐํ
42 """
43 root_x = self.find(x)
44 root_y = self.find(y)
45
46 if root_x == root_y:
47 return False # ์ด๋ฏธ ๊ฐ์ ์งํฉ
48
49 # ๋ญํฌ ๊ธฐ๋ฐ ํฉ์น๊ธฐ (์์ ํธ๋ฆฌ๋ฅผ ํฐ ํธ๋ฆฌ์ ๋ถ์)
50 if self.rank[root_x] < self.rank[root_y]:
51 self.parent[root_x] = root_y
52 elif self.rank[root_x] > self.rank[root_y]:
53 self.parent[root_y] = root_x
54 else:
55 self.parent[root_y] = root_x
56 self.rank[root_x] += 1
57
58 self.count -= 1
59 return True
60
61 def connected(self, x: int, y: int) -> bool:
62 """x์ y๊ฐ ๊ฐ์ ์งํฉ์ ์๋์ง ํ์ธ"""
63 return self.find(x) == self.find(y)
64
65 def get_count(self) -> int:
66 """ํ์ฌ ์งํฉ์ ๊ฐ์"""
67 return self.count
68
69
70# =============================================================================
71# 2. ํฌ๊ธฐ ๊ธฐ๋ฐ ์ ๋์จ ํ์ธ๋
72# =============================================================================
73class UnionFindWithSize:
74 """
75 ์งํฉ์ ํฌ๊ธฐ๋ฅผ ์ถ์ ํ๋ ์ ๋์จ ํ์ธ๋
76 """
77
78 def __init__(self, n: int):
79 self.parent = list(range(n))
80 self.size = [1] * n # ๊ฐ ์งํฉ์ ํฌ๊ธฐ
81
82 def find(self, x: int) -> int:
83 if self.parent[x] != x:
84 self.parent[x] = self.find(self.parent[x])
85 return self.parent[x]
86
87 def union(self, x: int, y: int) -> bool:
88 root_x = self.find(x)
89 root_y = self.find(y)
90
91 if root_x == root_y:
92 return False
93
94 # ํฌ๊ธฐ ๊ธฐ๋ฐ ํฉ์น๊ธฐ (์์ ์งํฉ์ ํฐ ์งํฉ์)
95 if self.size[root_x] < self.size[root_y]:
96 self.parent[root_x] = root_y
97 self.size[root_y] += self.size[root_x]
98 else:
99 self.parent[root_y] = root_x
100 self.size[root_x] += self.size[root_y]
101
102 return True
103
104 def get_size(self, x: int) -> int:
105 """x๊ฐ ์ํ ์งํฉ์ ํฌ๊ธฐ"""
106 return self.size[self.find(x)]
107
108
109# =============================================================================
110# 3. ์ฐ๊ฒฐ ์์ ๊ฐ์
111# =============================================================================
112def count_components(n: int, edges: List[List[int]]) -> int:
113 """
114 n๊ฐ์ ๋
ธ๋์ ๊ฐ์ ๋ชฉ๋ก์ด ์ฃผ์ด์ง ๋ ์ฐ๊ฒฐ ์์ ๊ฐ์
115 """
116 uf = UnionFind(n)
117 for u, v in edges:
118 uf.union(u, v)
119 return uf.get_count()
120
121
122# =============================================================================
123# 4. ๊ทธ๋ํ์์ ์ฌ์ดํด ๊ฒ์ถ
124# =============================================================================
125def has_cycle(n: int, edges: List[List[int]]) -> bool:
126 """
127 ๋ฌด๋ฐฉํฅ ๊ทธ๋ํ์์ ์ฌ์ดํด ์กด์ฌ ์ฌ๋ถ
128 ๊ฐ์ ์ ์ถ๊ฐํ ๋ ์ด๋ฏธ ๊ฐ์ ์งํฉ์ด๋ฉด ์ฌ์ดํด
129 """
130 uf = UnionFind(n)
131 for u, v in edges:
132 if not uf.union(u, v):
133 return True # ์ด๋ฏธ ์ฐ๊ฒฐ๋จ = ์ฌ์ดํด
134 return False
135
136
137# =============================================================================
138# 5. ํฌ๋ฃจ์ค์นผ MST (์ต์ ์ ์ฅ ํธ๋ฆฌ)
139# =============================================================================
140def kruskal_mst(n: int, edges: List[Tuple[int, int, int]]) -> Tuple[int, List[Tuple[int, int, int]]]:
141 """
142 ํฌ๋ฃจ์ค์นผ ์๊ณ ๋ฆฌ์ฆ์ผ๋ก ์ต์ ์ ์ฅ ํธ๋ฆฌ ๊ตฌํ๊ธฐ
143 edges: [(u, v, weight), ...]
144 ๋ฐํ: (์ด ๊ฐ์ค์น, MST ๊ฐ์ ๋ฆฌ์คํธ)
145 """
146 # ๊ฐ์ค์น ๊ธฐ์ค ์ ๋ ฌ
147 edges = sorted(edges, key=lambda x: x[2])
148
149 uf = UnionFind(n)
150 mst_weight = 0
151 mst_edges = []
152
153 for u, v, w in edges:
154 if uf.union(u, v):
155 mst_weight += w
156 mst_edges.append((u, v, w))
157
158 # n-1๊ฐ์ ๊ฐ์ ์ ์ ํํ๋ฉด ์๋ฃ
159 if len(mst_edges) == n - 1:
160 break
161
162 return mst_weight, mst_edges
163
164
165# =============================================================================
166# 6. ์น๊ตฌ ๊ด๊ณ (๊ณ์ ๋ณํฉ)
167# =============================================================================
168def merge_accounts(accounts: List[List[str]]) -> List[List[str]]:
169 """
170 ๊ฐ์ ์ด๋ฉ์ผ์ ๊ฐ์ง ๊ณ์ ๋ณํฉ
171 accounts[i] = [์ด๋ฆ, ์ด๋ฉ์ผ1, ์ด๋ฉ์ผ2, ...]
172 """
173 from collections import defaultdict
174
175 # ์ด๋ฉ์ผ -> ๊ณ์ ์ธ๋ฑ์ค ๋งคํ
176 email_to_id = {}
177 email_to_name = {}
178
179 for i, account in enumerate(accounts):
180 name = account[0]
181 for email in account[1:]:
182 if email in email_to_id:
183 pass # ๋์ค์ union
184 email_to_id[email] = i
185 email_to_name[email] = name
186
187 # ์ ๋์จ ํ์ธ๋๋ก ๊ฐ์ ์ฌ๋์ ๊ณ์ ์ฐ๊ฒฐ
188 n = len(accounts)
189 uf = UnionFind(n)
190
191 email_first_account = {}
192 for i, account in enumerate(accounts):
193 for email in account[1:]:
194 if email in email_first_account:
195 uf.union(i, email_first_account[email])
196 else:
197 email_first_account[email] = i
198
199 # ๊ฒฐ๊ณผ ์ง๊ณ
200 root_to_emails = defaultdict(set)
201 for i, account in enumerate(accounts):
202 root = uf.find(i)
203 for email in account[1:]:
204 root_to_emails[root].add(email)
205
206 # ๊ฒฐ๊ณผ ํฌ๋งทํ
207 result = []
208 for root, emails in root_to_emails.items():
209 name = accounts[root][0]
210 result.append([name] + sorted(emails))
211
212 return result
213
214
215# =============================================================================
216# 7. ์ฌ ์ฐ๊ฒฐํ๊ธฐ (2D ๊ทธ๋ฆฌ๋)
217# =============================================================================
218def num_islands_union_find(grid: List[List[str]]) -> int:
219 """
220 '1'์ ๋
, '0'์ ๋ฌผ
221 ์ฐ๊ฒฐ๋ ๋
๋ฉ์ด๋ฆฌ(์ฌ)์ ๊ฐ์
222 """
223 if not grid or not grid[0]:
224 return 0
225
226 rows, cols = len(grid), len(grid[0])
227
228 # 2D -> 1D ์ขํ ๋ณํ
229 def get_index(r, c):
230 return r * cols + c
231
232 uf = UnionFind(rows * cols)
233 land_count = 0
234
235 for r in range(rows):
236 for c in range(cols):
237 if grid[r][c] == '1':
238 land_count += 1
239 # ์ค๋ฅธ์ชฝ, ์๋ ๋ฐฉํฅ๋ง ํ์ธ (์ค๋ณต ๋ฐฉ์ง)
240 for dr, dc in [(0, 1), (1, 0)]:
241 nr, nc = r + dr, c + dc
242 if 0 <= nr < rows and 0 <= nc < cols and grid[nr][nc] == '1':
243 if uf.union(get_index(r, c), get_index(nr, nc)):
244 land_count -= 1
245
246 return land_count
247
248
249# =============================================================================
250# ํ
์คํธ
251# =============================================================================
252def main():
253 print("=" * 60)
254 print("์ ๋์จ ํ์ธ๋ (Union-Find) ์์ ")
255 print("=" * 60)
256
257 # 1. ๊ธฐ๋ณธ ์ฌ์ฉ
258 print("\n[1] ๊ธฐ๋ณธ ์ ๋์จ ํ์ธ๋")
259 uf = UnionFind(10)
260 operations = [(0, 1), (2, 3), (4, 5), (1, 2), (6, 7), (8, 9), (0, 9)]
261 for u, v in operations:
262 uf.union(u, v)
263 print(f" union({u}, {v}) -> ์งํฉ ์: {uf.get_count()}")
264
265 print(f"\n 0๊ณผ 9 ์ฐ๊ฒฐ๋จ? {uf.connected(0, 9)}")
266 print(f" 0๊ณผ 6 ์ฐ๊ฒฐ๋จ? {uf.connected(0, 6)}")
267
268 # 2. ํฌ๊ธฐ ์ถ์
269 print("\n[2] ํฌ๊ธฐ ๊ธฐ๋ฐ ์ ๋์จ ํ์ธ๋")
270 uf_size = UnionFindWithSize(5)
271 uf_size.union(0, 1)
272 uf_size.union(2, 3)
273 uf_size.union(0, 2)
274 print(f" 0์ด ์ํ ์งํฉ ํฌ๊ธฐ: {uf_size.get_size(0)}")
275 print(f" 4๊ฐ ์ํ ์งํฉ ํฌ๊ธฐ: {uf_size.get_size(4)}")
276
277 # 3. ์ฐ๊ฒฐ ์์ ๊ฐ์
278 print("\n[3] ์ฐ๊ฒฐ ์์ ๊ฐ์")
279 edges = [[0, 1], [1, 2], [3, 4]]
280 count = count_components(5, edges)
281 print(f" ๋
ธ๋ 5๊ฐ, ๊ฐ์ : {edges}")
282 print(f" ์ฐ๊ฒฐ ์์ ๊ฐ์: {count}")
283
284 # 4. ์ฌ์ดํด ๊ฒ์ถ
285 print("\n[4] ์ฌ์ดํด ๊ฒ์ถ")
286 edges_no_cycle = [[0, 1], [1, 2], [2, 3]]
287 edges_with_cycle = [[0, 1], [1, 2], [2, 0]]
288 print(f" ๊ฐ์ {edges_no_cycle}: ์ฌ์ดํด = {has_cycle(4, edges_no_cycle)}")
289 print(f" ๊ฐ์ {edges_with_cycle}: ์ฌ์ดํด = {has_cycle(3, edges_with_cycle)}")
290
291 # 5. ํฌ๋ฃจ์ค์นผ MST
292 print("\n[5] ํฌ๋ฃจ์ค์นผ MST")
293 # 1
294 # 0---1
295 # |\ |
296 # 4 | \ |2
297 # | \|
298 # 3---2
299 # 3
300 edges_mst = [
301 (0, 1, 1), (0, 2, 4), (0, 3, 4),
302 (1, 2, 2), (2, 3, 3)
303 ]
304 total_weight, mst_edges = kruskal_mst(4, edges_mst)
305 print(f" ๊ฐ์ : {edges_mst}")
306 print(f" MST ์ด ๊ฐ์ค์น: {total_weight}")
307 print(f" MST ๊ฐ์ : {mst_edges}")
308
309 # 6. ๊ณ์ ๋ณํฉ
310 print("\n[6] ๊ณ์ ๋ณํฉ")
311 accounts = [
312 ["John", "john@mail.com", "john_work@mail.com"],
313 ["John", "john@mail.com", "john2@mail.com"],
314 ["Mary", "mary@mail.com"],
315 ["John", "john3@mail.com"]
316 ]
317 result = merge_accounts(accounts)
318 print(f" ์
๋ ฅ:")
319 for acc in accounts:
320 print(f" {acc}")
321 print(f" ๋ณํฉ ๊ฒฐ๊ณผ:")
322 for acc in result:
323 print(f" {acc}")
324
325 # 7. ์ฌ ๊ฐ์ (์ ๋์จ ํ์ธ๋)
326 print("\n[7] ์ฌ์ ๊ฐ์ (์ ๋์จ ํ์ธ๋)")
327 grid = [
328 ['1', '1', '0', '0', '0'],
329 ['1', '1', '0', '0', '0'],
330 ['0', '0', '1', '0', '0'],
331 ['0', '0', '0', '1', '1']
332 ]
333 count = num_islands_union_find(grid)
334 print(f" ๊ฒฉ์:")
335 for row in grid:
336 print(f" {row}")
337 print(f" ์ฌ์ ๊ฐ์: {count}")
338
339 print("\n" + "=" * 60)
340 print("์ ๋์จ ํ์ธ๋ ์๊ฐ ๋ณต์ก๋")
341 print("=" * 60)
342 print("""
343 ๊ฒฝ๋ก ์์ถ + ๋ญํฌ/ํฌ๊ธฐ ๊ธฐ๋ฐ ํฉ์น๊ธฐ ์ฌ์ฉ ์:
344 - find(): ๊ฑฐ์ O(1) (์ ํํ๋ O(ฮฑ(n)), ฮฑ๋ ์์ปค๋ง ์ญํจ์)
345 - union(): ๊ฑฐ์ O(1)
346 - ๊ณต๊ฐ ๋ณต์ก๋: O(n)
347
348 ์ฃผ์ ํ์ฉ:
349 - ์ฐ๊ฒฐ ์์ ๊ด๋ฆฌ
350 - ์ฌ์ดํด ๊ฒ์ถ
351 - ์ต์ ์ ์ฅ ํธ๋ฆฌ (ํฌ๋ฃจ์ค์นผ)
352 - ๋์ ์ฐ๊ฒฐ์ฑ ๋ฌธ์
353 """)
354
355
356if __name__ == "__main__":
357 main()