20_bitmask_dp.py

Download
python 479 lines 13.3 KB
  1"""
  2๋น„ํŠธ๋งˆ์Šคํฌ DP (Bitmask Dynamic Programming)
  3Bitmask DP
  4
  5๋น„ํŠธ ์—ฐ์‚ฐ์„ ํ™œ์šฉํ•˜์—ฌ ์ง‘ํ•ฉ ์ƒํƒœ๋ฅผ ํ‘œํ˜„ํ•˜๋Š” DP ๊ธฐ๋ฒ•์ž…๋‹ˆ๋‹ค.
  6"""
  7
  8from typing import List, Tuple
  9from functools import lru_cache
 10
 11
 12# =============================================================================
 13# 1. ๋น„ํŠธ ์—ฐ์‚ฐ ๊ธฐ์ดˆ
 14# =============================================================================
 15
 16def bit_operations_demo():
 17    """๋น„ํŠธ ์—ฐ์‚ฐ ๊ธฐ๋ณธ"""
 18    n = 5  # ์ง‘ํ•ฉ ํฌ๊ธฐ
 19
 20    # ๋นˆ ์ง‘ํ•ฉ
 21    empty = 0
 22
 23    # ์ „์ฒด ์ง‘ํ•ฉ {0, 1, 2, 3, 4}
 24    full = (1 << n) - 1  # 11111 (์ด์ง„์ˆ˜)
 25
 26    # i๋ฒˆ์งธ ์›์†Œ ์ถ”๊ฐ€
 27    def add(mask: int, i: int) -> int:
 28        return mask | (1 << i)
 29
 30    # i๋ฒˆ์งธ ์›์†Œ ์ œ๊ฑฐ
 31    def remove(mask: int, i: int) -> int:
 32        return mask & ~(1 << i)
 33
 34    # i๋ฒˆ์งธ ์›์†Œ ํ† ๊ธ€
 35    def toggle(mask: int, i: int) -> int:
 36        return mask ^ (1 << i)
 37
 38    # i๋ฒˆ์งธ ์›์†Œ ํฌํ•จ ์—ฌ๋ถ€
 39    def contains(mask: int, i: int) -> bool:
 40        return bool(mask & (1 << i))
 41
 42    # ์›์†Œ ๊ฐœ์ˆ˜
 43    def count(mask: int) -> int:
 44        return bin(mask).count('1')
 45
 46    # ์ตœํ•˜์œ„ ๋น„ํŠธ (๊ฐ€์žฅ ์ž‘์€ ์›์†Œ)
 47    def lowest_bit(mask: int) -> int:
 48        return mask & (-mask)
 49
 50    # ๋ถ€๋ถ„์ง‘ํ•ฉ ์ˆœํšŒ
 51    def subsets(mask: int):
 52        """mask์˜ ๋ชจ๋“  ๋ถ€๋ถ„์ง‘ํ•ฉ์„ ์ˆœํšŒ"""
 53        subset = mask
 54        while True:
 55            yield subset
 56            if subset == 0:
 57                break
 58            subset = (subset - 1) & mask
 59
 60    return {
 61        'empty': empty,
 62        'full': full,
 63        'add': add,
 64        'remove': remove,
 65        'toggle': toggle,
 66        'contains': contains,
 67        'count': count,
 68        'lowest_bit': lowest_bit,
 69        'subsets': subsets
 70    }
 71
 72
 73# =============================================================================
 74# 2. ์™ธํŒ์› ๋ฌธ์ œ (TSP - Traveling Salesman Problem)
 75# =============================================================================
 76
 77def tsp(dist: List[List[int]]) -> int:
 78    """
 79    ์™ธํŒ์› ๋ฌธ์ œ (TSP)
 80    ๋ชจ๋“  ๋„์‹œ๋ฅผ ๋ฐฉ๋ฌธํ•˜๊ณ  ์‹œ์ž‘์ ์œผ๋กœ ๋Œ์•„์˜ค๋Š” ์ตœ์†Œ ๋น„์šฉ
 81
 82    ์‹œ๊ฐ„๋ณต์žก๋„: O(nยฒ * 2^n)
 83    ๊ณต๊ฐ„๋ณต์žก๋„: O(n * 2^n)
 84    """
 85    n = len(dist)
 86    INF = float('inf')
 87
 88    # dp[mask][i] = mask์— ํฌํ•จ๋œ ๋„์‹œ๋ฅผ ๋ฐฉ๋ฌธํ•˜๊ณ  ํ˜„์žฌ i์— ์žˆ์„ ๋•Œ ์ตœ์†Œ ๋น„์šฉ
 89    dp = [[INF] * n for _ in range(1 << n)]
 90    dp[1][0] = 0  # ์‹œ์ž‘์  (๋„์‹œ 0)
 91
 92    for mask in range(1 << n):
 93        for last in range(n):
 94            if dp[mask][last] == INF:
 95                continue
 96            if not (mask & (1 << last)):
 97                continue
 98
 99            for next_city in range(n):
100                if mask & (1 << next_city):
101                    continue
102
103                new_mask = mask | (1 << next_city)
104                dp[new_mask][next_city] = min(
105                    dp[new_mask][next_city],
106                    dp[mask][last] + dist[last][next_city]
107                )
108
109    # ๋ชจ๋“  ๋„์‹œ ๋ฐฉ๋ฌธ ํ›„ ์‹œ์ž‘์ ์œผ๋กœ
110    full_mask = (1 << n) - 1
111    result = min(dp[full_mask][i] + dist[i][0] for i in range(n))
112
113    return result if result != INF else -1
114
115
116def tsp_path(dist: List[List[int]]) -> Tuple[int, List[int]]:
117    """TSP ์ตœ์†Œ ๋น„์šฉ๊ณผ ๊ฒฝ๋กœ ๋ฐ˜ํ™˜"""
118    n = len(dist)
119    INF = float('inf')
120
121    dp = [[INF] * n for _ in range(1 << n)]
122    parent = [[-1] * n for _ in range(1 << n)]
123
124    dp[1][0] = 0
125
126    for mask in range(1 << n):
127        for last in range(n):
128            if dp[mask][last] == INF:
129                continue
130
131            for next_city in range(n):
132                if mask & (1 << next_city):
133                    continue
134
135                new_mask = mask | (1 << next_city)
136                new_cost = dp[mask][last] + dist[last][next_city]
137
138                if new_cost < dp[new_mask][next_city]:
139                    dp[new_mask][next_city] = new_cost
140                    parent[new_mask][next_city] = last
141
142    full_mask = (1 << n) - 1
143    min_cost = INF
144    last_city = -1
145
146    for i in range(n):
147        cost = dp[full_mask][i] + dist[i][0]
148        if cost < min_cost:
149            min_cost = cost
150            last_city = i
151
152    # ๊ฒฝ๋กœ ๋ณต์›
153    path = []
154    mask = full_mask
155    city = last_city
156
157    while city != -1:
158        path.append(city)
159        prev_city = parent[mask][city]
160        mask ^= (1 << city)
161        city = prev_city
162
163    path.reverse()
164    path.append(0)  # ์‹œ์ž‘์ ์œผ๋กœ ๋ณต๊ท€
165
166    return min_cost, path
167
168
169# =============================================================================
170# 3. ์ง‘ํ•ฉ ๋ถ„ํ•  ๋ฌธ์ œ (Set Partition)
171# =============================================================================
172
173def can_partition_k_subsets(nums: List[int], k: int) -> bool:
174    """
175    ๋ฐฐ์—ด์„ ํ•ฉ์ด ๊ฐ™์€ k๊ฐœ์˜ ๋ถ€๋ถ„์ง‘ํ•ฉ์œผ๋กœ ๋ถ„ํ•  ๊ฐ€๋Šฅํ•œ์ง€
176    ์‹œ๊ฐ„๋ณต์žก๋„: O(n * 2^n)
177    """
178    total = sum(nums)
179    if total % k != 0:
180        return False
181
182    target = total // k
183    n = len(nums)
184
185    # dp[mask] = mask ์ง‘ํ•ฉ์„ ์‚ฌ์šฉํ–ˆ์„ ๋•Œ ํ˜„์žฌ ๋ฒ„ํ‚ท์˜ ํ•ฉ (target์œผ๋กœ ๋‚˜๋ˆˆ ๋‚˜๋จธ์ง€)
186    dp = [-1] * (1 << n)
187    dp[0] = 0
188
189    for mask in range(1 << n):
190        if dp[mask] == -1:
191            continue
192
193        for i in range(n):
194            if mask & (1 << i):
195                continue
196
197            if dp[mask] + nums[i] <= target:
198                new_mask = mask | (1 << i)
199                dp[new_mask] = (dp[mask] + nums[i]) % target
200
201    return dp[(1 << n) - 1] == 0
202
203
204# =============================================================================
205# 4. ์ตœ์†Œ ๋น„์šฉ ์ž‘์—… ํ• ๋‹น (Assignment Problem)
206# =============================================================================
207
208def min_cost_assignment(cost: List[List[int]]) -> int:
209    """
210    n๋ช…์˜ ์‚ฌ๋žŒ์—๊ฒŒ n๊ฐœ์˜ ์ž‘์—…์„ 1:1 ํ• ๋‹นํ•˜๋Š” ์ตœ์†Œ ๋น„์šฉ
211    cost[i][j] = ์‚ฌ๋žŒ i๊ฐ€ ์ž‘์—… j๋ฅผ ์ˆ˜ํ–‰ํ•˜๋Š” ๋น„์šฉ
212
213    ์‹œ๊ฐ„๋ณต์žก๋„: O(n * 2^n)
214    """
215    n = len(cost)
216
217    @lru_cache(maxsize=None)
218    def dp(mask: int) -> int:
219        person = bin(mask).count('1')
220
221        if person == n:
222            return 0
223
224        min_cost = float('inf')
225        for job in range(n):
226            if mask & (1 << job):
227                continue
228
229            min_cost = min(min_cost, cost[person][job] + dp(mask | (1 << job)))
230
231        return min_cost
232
233    return dp(0)
234
235
236# =============================================================================
237# 5. ํ•ด๋ฐ€ํ„ด ๊ฒฝ๋กœ (Hamiltonian Path)
238# =============================================================================
239
240def hamiltonian_path_count(adj: List[List[int]]) -> int:
241    """
242    ํ•ด๋ฐ€ํ„ด ๊ฒฝ๋กœ์˜ ๊ฐœ์ˆ˜ (๋ชจ๋“  ์ •์ ์„ ํ•œ ๋ฒˆ์”ฉ ๋ฐฉ๋ฌธํ•˜๋Š” ๊ฒฝ๋กœ)
243    adj: ์ธ์ ‘ ํ–‰๋ ฌ (adj[i][j] = 1์ด๋ฉด iโ†’j ๊ฐ„์„  ์กด์žฌ)
244
245    ์‹œ๊ฐ„๋ณต์žก๋„: O(nยฒ * 2^n)
246    """
247    n = len(adj)
248
249    # dp[mask][i] = mask ์ •์ ๋“ค์„ ๋ฐฉ๋ฌธํ•˜๊ณ  i์—์„œ ๋๋‚˜๋Š” ๊ฒฝ๋กœ ์ˆ˜
250    dp = [[0] * n for _ in range(1 << n)]
251
252    # ์ดˆ๊ธฐํ™”: ๊ฐ ์ •์ ์—์„œ ์‹œ์ž‘
253    for i in range(n):
254        dp[1 << i][i] = 1
255
256    for mask in range(1 << n):
257        for last in range(n):
258            if dp[mask][last] == 0:
259                continue
260            if not (mask & (1 << last)):
261                continue
262
263            for next_v in range(n):
264                if mask & (1 << next_v):
265                    continue
266                if not adj[last][next_v]:
267                    continue
268
269                new_mask = mask | (1 << next_v)
270                dp[new_mask][next_v] += dp[mask][last]
271
272    # ๋ชจ๋“  ์ •์  ๋ฐฉ๋ฌธํ•œ ๊ฒฝ๋กœ ํ•ฉ
273    full_mask = (1 << n) - 1
274    return sum(dp[full_mask])
275
276
277# =============================================================================
278# 6. ์Šคํ‹ฐ์ปค ์ตœ์  ๋ฐฐ์น˜ (SOS DP ์ „์ฒ˜๋ฆฌ)
279# =============================================================================
280
281def sos_dp(arr: List[int]) -> List[int]:
282    """
283    Sum over Subsets DP
284    ๊ฐ ๋งˆ์Šคํฌ์— ๋Œ€ํ•ด ๋ถ€๋ถ„์ง‘ํ•ฉ๋“ค์˜ ๊ฐ’ ํ•ฉ ๊ณ„์‚ฐ
285
286    result[mask] = sum(arr[subset]) for all subset of mask
287
288    ์‹œ๊ฐ„๋ณต์žก๋„: O(n * 2^n)
289    """
290    n = len(arr).bit_length()
291    dp = arr.copy()
292
293    # 0~(len(arr)-1)๊นŒ์ง€ ํ™•์žฅ
294    while len(dp) < (1 << n):
295        dp.append(0)
296
297    for i in range(n):
298        for mask in range(1 << n):
299            if mask & (1 << i):
300                dp[mask] += dp[mask ^ (1 << i)]
301
302    return dp
303
304
305# =============================================================================
306# 7. ์ตœ๋Œ€ ๋…๋ฆฝ ์ง‘ํ•ฉ (Maximum Independent Set on Trees - Bitmask)
307# =============================================================================
308
309def max_independent_set(adj: List[List[int]]) -> int:
310    """
311    ๊ทธ๋ž˜ํ”„์—์„œ ์ตœ๋Œ€ ๋…๋ฆฝ ์ง‘ํ•ฉ ํฌ๊ธฐ (์„œ๋กœ ์ธ์ ‘ํ•˜์ง€ ์•Š์€ ์ •์  ์ง‘ํ•ฉ)
312    ์ž‘์€ ๊ทธ๋ž˜ํ”„์—์„œ ๋น„ํŠธ๋งˆ์Šคํฌ๋กœ brute force
313
314    ์‹œ๊ฐ„๋ณต์žก๋„: O(2^n * nยฒ)
315    """
316    n = len(adj)
317    max_size = 0
318
319    for mask in range(1 << n):
320        # mask๊ฐ€ ๋…๋ฆฝ ์ง‘ํ•ฉ์ธ์ง€ ํ™•์ธ
321        valid = True
322        for i in range(n):
323            if not (mask & (1 << i)):
324                continue
325            for j in range(i + 1, n):
326                if not (mask & (1 << j)):
327                    continue
328                if adj[i][j]:
329                    valid = False
330                    break
331            if not valid:
332                break
333
334        if valid:
335            max_size = max(max_size, bin(mask).count('1'))
336
337    return max_size
338
339
340# =============================================================================
341# 8. ๊ฒฉ์ž ์ฑ„์šฐ๊ธฐ (Broken Profile DP)
342# =============================================================================
343
344def domino_tiling(m: int, n: int) -> int:
345    """
346    mร—n ๊ฒฉ์ž๋ฅผ 1ร—2 ๋„๋ฏธ๋…ธ๋กœ ์ฑ„์šฐ๋Š” ๊ฒฝ์šฐ์˜ ์ˆ˜
347    ๋น„ํŠธ๋งˆ์Šคํฌ DP (profile ๋ฐฉ์‹)
348
349    ์‹œ๊ฐ„๋ณต์žก๋„: O(n * 2^m * 2^m)
350    """
351    if m > n:
352        m, n = n, m
353
354    # dp[col][profile] = ํ˜„์žฌ ์—ด๊นŒ์ง€ ์ฑ„์šฐ๊ณ  ํ”„๋กœํŒŒ์ผ์ด profile์ธ ๊ฒฝ์šฐ์˜ ์ˆ˜
355    dp = {0: 1}
356
357    for col in range(n):
358        for row in range(m):
359            new_dp = {}
360
361            for profile, count in dp.items():
362                # ํ˜„์žฌ ์…€์ด ์ด๋ฏธ ์ฑ„์›Œ์ง„ ๊ฒฝ์šฐ
363                if profile & (1 << row):
364                    new_profile = profile ^ (1 << row)
365                    new_dp[new_profile] = new_dp.get(new_profile, 0) + count
366                else:
367                    # ์ˆ˜ํ‰ ๋„๋ฏธ๋…ธ (๋‹ค์Œ ์—ด๋กœ ํ™•์žฅ)
368                    new_profile = profile | (1 << row)
369                    new_dp[new_profile] = new_dp.get(new_profile, 0) + count
370
371                    # ์ˆ˜์ง ๋„๋ฏธ๋…ธ (์•„๋ž˜ ์…€๊ณผ ํ•จ๊ป˜)
372                    if row + 1 < m and not (profile & (1 << (row + 1))):
373                        new_dp[profile] = new_dp.get(profile, 0) + count
374
375            dp = new_dp
376
377    return dp.get(0, 0)
378
379
380# =============================================================================
381# ํ…Œ์ŠคํŠธ
382# =============================================================================
383
384def main():
385    print("=" * 60)
386    print("๋น„ํŠธ๋งˆ์Šคํฌ DP (Bitmask DP) ์˜ˆ์ œ")
387    print("=" * 60)
388
389    # 1. ๋น„ํŠธ ์—ฐ์‚ฐ ๊ธฐ์ดˆ
390    print("\n[1] ๋น„ํŠธ ์—ฐ์‚ฐ ๊ธฐ์ดˆ")
391    ops = bit_operations_demo()
392    mask = 0b10110  # {1, 2, 4}
393    print(f"    mask = {bin(mask)} ({mask})")
394    print(f"    ์›์†Œ ๊ฐœ์ˆ˜: {ops['count'](mask)}")
395    print(f"    3 ํฌํ•จ: {ops['contains'](mask, 3)}")
396    print(f"    2 ํฌํ•จ: {ops['contains'](mask, 2)}")
397    print(f"    3 ์ถ”๊ฐ€: {bin(ops['add'](mask, 3))}")
398    print(f"    ๋ถ€๋ถ„์ง‘ํ•ฉ: ", end="")
399    for s in ops['subsets'](mask):
400        print(f"{bin(s)} ", end="")
401    print()
402
403    # 2. TSP
404    print("\n[2] ์™ธํŒ์› ๋ฌธ์ œ (TSP)")
405    dist = [
406        [0, 10, 15, 20],
407        [10, 0, 35, 25],
408        [15, 35, 0, 30],
409        [20, 25, 30, 0]
410    ]
411    min_cost, path = tsp_path(dist)
412    print(f"    ๊ฑฐ๋ฆฌ ํ–‰๋ ฌ: 4x4")
413    print(f"    ์ตœ์†Œ ๋น„์šฉ: {min_cost}")
414    print(f"    ๊ฒฝ๋กœ: {path}")
415
416    # 3. ์ง‘ํ•ฉ ๋ถ„ํ• 
417    print("\n[3] K๊ฐœ ๋ถ€๋ถ„์ง‘ํ•ฉ ๋ถ„ํ• ")
418    nums = [4, 3, 2, 3, 5, 2, 1]
419    k = 4
420    result = can_partition_k_subsets(nums, k)
421    print(f"    ๋ฐฐ์—ด: {nums}, k={k}")
422    print(f"    ๋ถ„ํ•  ๊ฐ€๋Šฅ: {result}")
423
424    # 4. ์ž‘์—… ํ• ๋‹น
425    print("\n[4] ์ตœ์†Œ ๋น„์šฉ ์ž‘์—… ํ• ๋‹น")
426    cost = [
427        [9, 2, 7, 8],
428        [6, 4, 3, 7],
429        [5, 8, 1, 8],
430        [7, 6, 9, 4]
431    ]
432    min_assign = min_cost_assignment(cost)
433    print(f"    ๋น„์šฉ ํ–‰๋ ฌ: 4x4")
434    print(f"    ์ตœ์†Œ ๋น„์šฉ: {min_assign}")
435
436    # 5. ํ•ด๋ฐ€ํ„ด ๊ฒฝ๋กœ
437    print("\n[5] ํ•ด๋ฐ€ํ„ด ๊ฒฝ๋กœ ๊ฐœ์ˆ˜")
438    adj = [
439        [0, 1, 1, 1],
440        [1, 0, 1, 0],
441        [1, 1, 0, 1],
442        [1, 0, 1, 0]
443    ]
444    count = hamiltonian_path_count(adj)
445    print(f"    ์ธ์ ‘ ํ–‰๋ ฌ: 4x4")
446    print(f"    ํ•ด๋ฐ€ํ„ด ๊ฒฝ๋กœ ์ˆ˜: {count}")
447
448    # 6. ์ตœ๋Œ€ ๋…๋ฆฝ ์ง‘ํ•ฉ
449    print("\n[6] ์ตœ๋Œ€ ๋…๋ฆฝ ์ง‘ํ•ฉ")
450    adj2 = [
451        [0, 1, 0, 1],
452        [1, 0, 1, 0],
453        [0, 1, 0, 1],
454        [1, 0, 1, 0]
455    ]
456    mis = max_independent_set(adj2)
457    print(f"    ๊ทธ๋ž˜ํ”„: 4-cycle")
458    print(f"    ์ตœ๋Œ€ ๋…๋ฆฝ ์ง‘ํ•ฉ ํฌ๊ธฐ: {mis}")
459
460    # 7. ๋„๋ฏธ๋…ธ ํƒ€์ผ๋ง
461    print("\n[7] ๋„๋ฏธ๋…ธ ํƒ€์ผ๋ง")
462    for m, n in [(2, 3), (2, 4), (3, 4)]:
463        count = domino_tiling(m, n)
464        print(f"    {m}ร—{n} ๊ฒฉ์ž: {count}๊ฐ€์ง€")
465
466    # 8. SOS DP
467    print("\n[8] SOS DP (๋ถ€๋ถ„์ง‘ํ•ฉ ํ•ฉ)")
468    arr = [1, 2, 4, 8]  # ๊ฐ ์›์†Œ๋Š” ํ•ด๋‹น ๋น„ํŠธ์˜ ๊ฐ’
469    result = sos_dp(arr)
470    print(f"    ๋ฐฐ์—ด: {arr}")
471    print(f"    result[0b0111] = result[7] = {result[7]}")
472    print(f"    (๋ถ€๋ถ„์ง‘ํ•ฉ: {{0},{1},{0,1},{2},{0,2},{1,2},{0,1,2}} ํ•ฉ)")
473
474    print("\n" + "=" * 60)
475
476
477if __name__ == "__main__":
478    main()