15_mst.py

Download
python 358 lines 10.9 KB
  1"""
  2์œ ๋‹ˆ์˜จ ํŒŒ์ธ๋“œ (Union-Find / Disjoint Set Union)
  3Union-Find / Disjoint Set Union Data Structure
  4
  5์„œ๋กœ์†Œ ์ง‘ํ•ฉ์„ ๊ด€๋ฆฌํ•˜๋Š” ์ž๋ฃŒ๊ตฌ์กฐ๋กœ, ๊ทธ๋ž˜ํ”„์˜ ์—ฐ๊ฒฐ์„ฑ ๋ฌธ์ œ์— ์ฃผ๋กœ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.
  6"""
  7
  8from typing import List, Tuple
  9
 10
 11# =============================================================================
 12# 1. ๊ธฐ๋ณธ ์œ ๋‹ˆ์˜จ ํŒŒ์ธ๋“œ
 13# =============================================================================
 14class UnionFind:
 15    """
 16    ๊ธฐ๋ณธ ์œ ๋‹ˆ์˜จ ํŒŒ์ธ๋“œ ๊ตฌํ˜„
 17    - ๊ฒฝ๋กœ ์••์ถ• (Path Compression)
 18    - ๋žญํฌ ๊ธฐ๋ฐ˜ ํ•ฉ์น˜๊ธฐ (Union by Rank)
 19    """
 20
 21    def __init__(self, n: int):
 22        """
 23        n๊ฐœ์˜ ์š”์†Œ๋กœ ์ดˆ๊ธฐํ™” (0 ~ n-1)
 24        """
 25        self.parent = list(range(n))  # ์ž๊ธฐ ์ž์‹ ์„ ๋ถ€๋ชจ๋กœ
 26        self.rank = [0] * n           # ํŠธ๋ฆฌ์˜ ๋†’์ด (๊ทผ์‚ฌ๊ฐ’)
 27        self.count = n                # ์ง‘ํ•ฉ์˜ ๊ฐœ์ˆ˜
 28
 29    def find(self, x: int) -> int:
 30        """
 31        x๊ฐ€ ์†ํ•œ ์ง‘ํ•ฉ์˜ ๋Œ€ํ‘œ(๋ฃจํŠธ) ์ฐพ๊ธฐ
 32        ๊ฒฝ๋กœ ์••์ถ•์œผ๋กœ ๊ฑฐ์˜ O(1)
 33        """
 34        if self.parent[x] != x:
 35            self.parent[x] = self.find(self.parent[x])  # ๊ฒฝ๋กœ ์••์ถ•
 36        return self.parent[x]
 37
 38    def union(self, x: int, y: int) -> bool:
 39        """
 40        x์™€ y๊ฐ€ ์†ํ•œ ์ง‘ํ•ฉ์„ ํ•ฉ์น˜๊ธฐ
 41        ์ด๋ฏธ ๊ฐ™์€ ์ง‘ํ•ฉ์ด๋ฉด False ๋ฐ˜ํ™˜
 42        """
 43        root_x = self.find(x)
 44        root_y = self.find(y)
 45
 46        if root_x == root_y:
 47            return False  # ์ด๋ฏธ ๊ฐ™์€ ์ง‘ํ•ฉ
 48
 49        # ๋žญํฌ ๊ธฐ๋ฐ˜ ํ•ฉ์น˜๊ธฐ (์ž‘์€ ํŠธ๋ฆฌ๋ฅผ ํฐ ํŠธ๋ฆฌ์— ๋ถ™์ž„)
 50        if self.rank[root_x] < self.rank[root_y]:
 51            self.parent[root_x] = root_y
 52        elif self.rank[root_x] > self.rank[root_y]:
 53            self.parent[root_y] = root_x
 54        else:
 55            self.parent[root_y] = root_x
 56            self.rank[root_x] += 1
 57
 58        self.count -= 1
 59        return True
 60
 61    def connected(self, x: int, y: int) -> bool:
 62        """x์™€ y๊ฐ€ ๊ฐ™์€ ์ง‘ํ•ฉ์— ์žˆ๋Š”์ง€ ํ™•์ธ"""
 63        return self.find(x) == self.find(y)
 64
 65    def get_count(self) -> int:
 66        """ํ˜„์žฌ ์ง‘ํ•ฉ์˜ ๊ฐœ์ˆ˜"""
 67        return self.count
 68
 69
 70# =============================================================================
 71# 2. ํฌ๊ธฐ ๊ธฐ๋ฐ˜ ์œ ๋‹ˆ์˜จ ํŒŒ์ธ๋“œ
 72# =============================================================================
 73class UnionFindWithSize:
 74    """
 75    ์ง‘ํ•ฉ์˜ ํฌ๊ธฐ๋ฅผ ์ถ”์ ํ•˜๋Š” ์œ ๋‹ˆ์˜จ ํŒŒ์ธ๋“œ
 76    """
 77
 78    def __init__(self, n: int):
 79        self.parent = list(range(n))
 80        self.size = [1] * n  # ๊ฐ ์ง‘ํ•ฉ์˜ ํฌ๊ธฐ
 81
 82    def find(self, x: int) -> int:
 83        if self.parent[x] != x:
 84            self.parent[x] = self.find(self.parent[x])
 85        return self.parent[x]
 86
 87    def union(self, x: int, y: int) -> bool:
 88        root_x = self.find(x)
 89        root_y = self.find(y)
 90
 91        if root_x == root_y:
 92            return False
 93
 94        # ํฌ๊ธฐ ๊ธฐ๋ฐ˜ ํ•ฉ์น˜๊ธฐ (์ž‘์€ ์ง‘ํ•ฉ์„ ํฐ ์ง‘ํ•ฉ์—)
 95        if self.size[root_x] < self.size[root_y]:
 96            self.parent[root_x] = root_y
 97            self.size[root_y] += self.size[root_x]
 98        else:
 99            self.parent[root_y] = root_x
100            self.size[root_x] += self.size[root_y]
101
102        return True
103
104    def get_size(self, x: int) -> int:
105        """x๊ฐ€ ์†ํ•œ ์ง‘ํ•ฉ์˜ ํฌ๊ธฐ"""
106        return self.size[self.find(x)]
107
108
109# =============================================================================
110# 3. ์—ฐ๊ฒฐ ์š”์†Œ ๊ฐœ์ˆ˜
111# =============================================================================
112def count_components(n: int, edges: List[List[int]]) -> int:
113    """
114    n๊ฐœ์˜ ๋…ธ๋“œ์™€ ๊ฐ„์„  ๋ชฉ๋ก์ด ์ฃผ์–ด์งˆ ๋•Œ ์—ฐ๊ฒฐ ์š”์†Œ ๊ฐœ์ˆ˜
115    """
116    uf = UnionFind(n)
117    for u, v in edges:
118        uf.union(u, v)
119    return uf.get_count()
120
121
122# =============================================================================
123# 4. ๊ทธ๋ž˜ํ”„์—์„œ ์‚ฌ์ดํด ๊ฒ€์ถœ
124# =============================================================================
125def has_cycle(n: int, edges: List[List[int]]) -> bool:
126    """
127    ๋ฌด๋ฐฉํ–ฅ ๊ทธ๋ž˜ํ”„์—์„œ ์‚ฌ์ดํด ์กด์žฌ ์—ฌ๋ถ€
128    ๊ฐ„์„ ์„ ์ถ”๊ฐ€ํ•  ๋•Œ ์ด๋ฏธ ๊ฐ™์€ ์ง‘ํ•ฉ์ด๋ฉด ์‚ฌ์ดํด
129    """
130    uf = UnionFind(n)
131    for u, v in edges:
132        if not uf.union(u, v):
133            return True  # ์ด๋ฏธ ์—ฐ๊ฒฐ๋จ = ์‚ฌ์ดํด
134    return False
135
136
137# =============================================================================
138# 5. ํฌ๋ฃจ์Šค์นผ MST (์ตœ์†Œ ์‹ ์žฅ ํŠธ๋ฆฌ)
139# =============================================================================
140def kruskal_mst(n: int, edges: List[Tuple[int, int, int]]) -> Tuple[int, List[Tuple[int, int, int]]]:
141    """
142    ํฌ๋ฃจ์Šค์นผ ์•Œ๊ณ ๋ฆฌ์ฆ˜์œผ๋กœ ์ตœ์†Œ ์‹ ์žฅ ํŠธ๋ฆฌ ๊ตฌํ•˜๊ธฐ
143    edges: [(u, v, weight), ...]
144    ๋ฐ˜ํ™˜: (์ด ๊ฐ€์ค‘์น˜, MST ๊ฐ„์„  ๋ฆฌ์ŠคํŠธ)
145    """
146    # ๊ฐ€์ค‘์น˜ ๊ธฐ์ค€ ์ •๋ ฌ
147    edges = sorted(edges, key=lambda x: x[2])
148
149    uf = UnionFind(n)
150    mst_weight = 0
151    mst_edges = []
152
153    for u, v, w in edges:
154        if uf.union(u, v):
155            mst_weight += w
156            mst_edges.append((u, v, w))
157
158            # n-1๊ฐœ์˜ ๊ฐ„์„ ์„ ์„ ํƒํ•˜๋ฉด ์™„๋ฃŒ
159            if len(mst_edges) == n - 1:
160                break
161
162    return mst_weight, mst_edges
163
164
165# =============================================================================
166# 6. ์นœ๊ตฌ ๊ด€๊ณ„ (๊ณ„์ • ๋ณ‘ํ•ฉ)
167# =============================================================================
168def merge_accounts(accounts: List[List[str]]) -> List[List[str]]:
169    """
170    ๊ฐ™์€ ์ด๋ฉ”์ผ์„ ๊ฐ€์ง„ ๊ณ„์ • ๋ณ‘ํ•ฉ
171    accounts[i] = [์ด๋ฆ„, ์ด๋ฉ”์ผ1, ์ด๋ฉ”์ผ2, ...]
172    """
173    from collections import defaultdict
174
175    # ์ด๋ฉ”์ผ -> ๊ณ„์ • ์ธ๋ฑ์Šค ๋งคํ•‘
176    email_to_id = {}
177    email_to_name = {}
178
179    for i, account in enumerate(accounts):
180        name = account[0]
181        for email in account[1:]:
182            if email in email_to_id:
183                pass  # ๋‚˜์ค‘์— union
184            email_to_id[email] = i
185            email_to_name[email] = name
186
187    # ์œ ๋‹ˆ์˜จ ํŒŒ์ธ๋“œ๋กœ ๊ฐ™์€ ์‚ฌ๋žŒ์˜ ๊ณ„์ • ์—ฐ๊ฒฐ
188    n = len(accounts)
189    uf = UnionFind(n)
190
191    email_first_account = {}
192    for i, account in enumerate(accounts):
193        for email in account[1:]:
194            if email in email_first_account:
195                uf.union(i, email_first_account[email])
196            else:
197                email_first_account[email] = i
198
199    # ๊ฒฐ๊ณผ ์ง‘๊ณ„
200    root_to_emails = defaultdict(set)
201    for i, account in enumerate(accounts):
202        root = uf.find(i)
203        for email in account[1:]:
204            root_to_emails[root].add(email)
205
206    # ๊ฒฐ๊ณผ ํฌ๋งทํŒ…
207    result = []
208    for root, emails in root_to_emails.items():
209        name = accounts[root][0]
210        result.append([name] + sorted(emails))
211
212    return result
213
214
215# =============================================================================
216# 7. ์„ฌ ์—ฐ๊ฒฐํ•˜๊ธฐ (2D ๊ทธ๋ฆฌ๋“œ)
217# =============================================================================
218def num_islands_union_find(grid: List[List[str]]) -> int:
219    """
220    '1'์€ ๋•…, '0'์€ ๋ฌผ
221    ์—ฐ๊ฒฐ๋œ ๋•… ๋ฉ์–ด๋ฆฌ(์„ฌ)์˜ ๊ฐœ์ˆ˜
222    """
223    if not grid or not grid[0]:
224        return 0
225
226    rows, cols = len(grid), len(grid[0])
227
228    # 2D -> 1D ์ขŒํ‘œ ๋ณ€ํ™˜
229    def get_index(r, c):
230        return r * cols + c
231
232    uf = UnionFind(rows * cols)
233    land_count = 0
234
235    for r in range(rows):
236        for c in range(cols):
237            if grid[r][c] == '1':
238                land_count += 1
239                # ์˜ค๋ฅธ์ชฝ, ์•„๋ž˜ ๋ฐฉํ–ฅ๋งŒ ํ™•์ธ (์ค‘๋ณต ๋ฐฉ์ง€)
240                for dr, dc in [(0, 1), (1, 0)]:
241                    nr, nc = r + dr, c + dc
242                    if 0 <= nr < rows and 0 <= nc < cols and grid[nr][nc] == '1':
243                        if uf.union(get_index(r, c), get_index(nr, nc)):
244                            land_count -= 1
245
246    return land_count
247
248
249# =============================================================================
250# ํ…Œ์ŠคํŠธ
251# =============================================================================
252def main():
253    print("=" * 60)
254    print("์œ ๋‹ˆ์˜จ ํŒŒ์ธ๋“œ (Union-Find) ์˜ˆ์ œ")
255    print("=" * 60)
256
257    # 1. ๊ธฐ๋ณธ ์‚ฌ์šฉ
258    print("\n[1] ๊ธฐ๋ณธ ์œ ๋‹ˆ์˜จ ํŒŒ์ธ๋“œ")
259    uf = UnionFind(10)
260    operations = [(0, 1), (2, 3), (4, 5), (1, 2), (6, 7), (8, 9), (0, 9)]
261    for u, v in operations:
262        uf.union(u, v)
263        print(f"    union({u}, {v}) -> ์ง‘ํ•ฉ ์ˆ˜: {uf.get_count()}")
264
265    print(f"\n    0๊ณผ 9 ์—ฐ๊ฒฐ๋จ? {uf.connected(0, 9)}")
266    print(f"    0๊ณผ 6 ์—ฐ๊ฒฐ๋จ? {uf.connected(0, 6)}")
267
268    # 2. ํฌ๊ธฐ ์ถ”์ 
269    print("\n[2] ํฌ๊ธฐ ๊ธฐ๋ฐ˜ ์œ ๋‹ˆ์˜จ ํŒŒ์ธ๋“œ")
270    uf_size = UnionFindWithSize(5)
271    uf_size.union(0, 1)
272    uf_size.union(2, 3)
273    uf_size.union(0, 2)
274    print(f"    0์ด ์†ํ•œ ์ง‘ํ•ฉ ํฌ๊ธฐ: {uf_size.get_size(0)}")
275    print(f"    4๊ฐ€ ์†ํ•œ ์ง‘ํ•ฉ ํฌ๊ธฐ: {uf_size.get_size(4)}")
276
277    # 3. ์—ฐ๊ฒฐ ์š”์†Œ ๊ฐœ์ˆ˜
278    print("\n[3] ์—ฐ๊ฒฐ ์š”์†Œ ๊ฐœ์ˆ˜")
279    edges = [[0, 1], [1, 2], [3, 4]]
280    count = count_components(5, edges)
281    print(f"    ๋…ธ๋“œ 5๊ฐœ, ๊ฐ„์„ : {edges}")
282    print(f"    ์—ฐ๊ฒฐ ์š”์†Œ ๊ฐœ์ˆ˜: {count}")
283
284    # 4. ์‚ฌ์ดํด ๊ฒ€์ถœ
285    print("\n[4] ์‚ฌ์ดํด ๊ฒ€์ถœ")
286    edges_no_cycle = [[0, 1], [1, 2], [2, 3]]
287    edges_with_cycle = [[0, 1], [1, 2], [2, 0]]
288    print(f"    ๊ฐ„์„  {edges_no_cycle}: ์‚ฌ์ดํด = {has_cycle(4, edges_no_cycle)}")
289    print(f"    ๊ฐ„์„  {edges_with_cycle}: ์‚ฌ์ดํด = {has_cycle(3, edges_with_cycle)}")
290
291    # 5. ํฌ๋ฃจ์Šค์นผ MST
292    print("\n[5] ํฌ๋ฃจ์Šค์นผ MST")
293    #     1
294    #   0---1
295    #   |\  |
296    # 4 | \ |2
297    #   |  \|
298    #   3---2
299    #     3
300    edges_mst = [
301        (0, 1, 1), (0, 2, 4), (0, 3, 4),
302        (1, 2, 2), (2, 3, 3)
303    ]
304    total_weight, mst_edges = kruskal_mst(4, edges_mst)
305    print(f"    ๊ฐ„์„ : {edges_mst}")
306    print(f"    MST ์ด ๊ฐ€์ค‘์น˜: {total_weight}")
307    print(f"    MST ๊ฐ„์„ : {mst_edges}")
308
309    # 6. ๊ณ„์ • ๋ณ‘ํ•ฉ
310    print("\n[6] ๊ณ„์ • ๋ณ‘ํ•ฉ")
311    accounts = [
312        ["John", "john@mail.com", "john_work@mail.com"],
313        ["John", "john@mail.com", "john2@mail.com"],
314        ["Mary", "mary@mail.com"],
315        ["John", "john3@mail.com"]
316    ]
317    result = merge_accounts(accounts)
318    print(f"    ์ž…๋ ฅ:")
319    for acc in accounts:
320        print(f"      {acc}")
321    print(f"    ๋ณ‘ํ•ฉ ๊ฒฐ๊ณผ:")
322    for acc in result:
323        print(f"      {acc}")
324
325    # 7. ์„ฌ ๊ฐœ์ˆ˜ (์œ ๋‹ˆ์˜จ ํŒŒ์ธ๋“œ)
326    print("\n[7] ์„ฌ์˜ ๊ฐœ์ˆ˜ (์œ ๋‹ˆ์˜จ ํŒŒ์ธ๋“œ)")
327    grid = [
328        ['1', '1', '0', '0', '0'],
329        ['1', '1', '0', '0', '0'],
330        ['0', '0', '1', '0', '0'],
331        ['0', '0', '0', '1', '1']
332    ]
333    count = num_islands_union_find(grid)
334    print(f"    ๊ฒฉ์ž:")
335    for row in grid:
336        print(f"    {row}")
337    print(f"    ์„ฌ์˜ ๊ฐœ์ˆ˜: {count}")
338
339    print("\n" + "=" * 60)
340    print("์œ ๋‹ˆ์˜จ ํŒŒ์ธ๋“œ ์‹œ๊ฐ„ ๋ณต์žก๋„")
341    print("=" * 60)
342    print("""
343    ๊ฒฝ๋กœ ์••์ถ• + ๋žญํฌ/ํฌ๊ธฐ ๊ธฐ๋ฐ˜ ํ•ฉ์น˜๊ธฐ ์‚ฌ์šฉ ์‹œ:
344    - find(): ๊ฑฐ์˜ O(1) (์ •ํ™•ํžˆ๋Š” O(ฮฑ(n)), ฮฑ๋Š” ์•„์ปค๋งŒ ์—ญํ•จ์ˆ˜)
345    - union(): ๊ฑฐ์˜ O(1)
346    - ๊ณต๊ฐ„ ๋ณต์žก๋„: O(n)
347
348    ์ฃผ์š” ํ™œ์šฉ:
349    - ์—ฐ๊ฒฐ ์š”์†Œ ๊ด€๋ฆฌ
350    - ์‚ฌ์ดํด ๊ฒ€์ถœ
351    - ์ตœ์†Œ ์‹ ์žฅ ํŠธ๋ฆฌ (ํฌ๋ฃจ์Šค์นผ)
352    - ๋™์  ์—ฐ๊ฒฐ์„ฑ ๋ฌธ์ œ
353    """)
354
355
356if __name__ == "__main__":
357    main()