07_divide_conquer.py

Download
python 507 lines 14.0 KB
  1"""
  2๋ถ„ํ•  ์ •๋ณต (Divide and Conquer)
  3Divide and Conquer Algorithms
  4
  5๋ฌธ์ œ๋ฅผ ์ž‘์€ ๋ถ€๋ถ„ ๋ฌธ์ œ๋กœ ๋‚˜๋ˆ„์–ด ํ•ด๊ฒฐํ•˜๋Š” ์•Œ๊ณ ๋ฆฌ์ฆ˜์ž…๋‹ˆ๋‹ค.
  6"""
  7
  8from typing import List, Tuple, Optional
  9import random
 10
 11
 12# =============================================================================
 13# 1. ๋ณ‘ํ•ฉ ์ •๋ ฌ (Merge Sort)
 14# =============================================================================
 15
 16def merge_sort(arr: List[int]) -> List[int]:
 17    """
 18    ๋ณ‘ํ•ฉ ์ •๋ ฌ
 19    ์‹œ๊ฐ„๋ณต์žก๋„: O(n log n)
 20    ๊ณต๊ฐ„๋ณต์žก๋„: O(n)
 21    ์•ˆ์ • ์ •๋ ฌ
 22    """
 23    if len(arr) <= 1:
 24        return arr
 25
 26    mid = len(arr) // 2
 27    left = merge_sort(arr[:mid])
 28    right = merge_sort(arr[mid:])
 29
 30    return merge(left, right)
 31
 32
 33def merge(left: List[int], right: List[int]) -> List[int]:
 34    """๋‘ ์ •๋ ฌ๋œ ๋ฐฐ์—ด ๋ณ‘ํ•ฉ"""
 35    result = []
 36    i = j = 0
 37
 38    while i < len(left) and j < len(right):
 39        if left[i] <= right[j]:
 40            result.append(left[i])
 41            i += 1
 42        else:
 43            result.append(right[j])
 44            j += 1
 45
 46    result.extend(left[i:])
 47    result.extend(right[j:])
 48    return result
 49
 50
 51# =============================================================================
 52# 2. ํ€ต ์ •๋ ฌ (Quick Sort)
 53# =============================================================================
 54
 55def quick_sort(arr: List[int]) -> List[int]:
 56    """
 57    ํ€ต ์ •๋ ฌ (Lomuto ํŒŒํ‹ฐ์…˜)
 58    ์‹œ๊ฐ„๋ณต์žก๋„: ํ‰๊ท  O(n log n), ์ตœ์•… O(nยฒ)
 59    ๊ณต๊ฐ„๋ณต์žก๋„: O(log n) - ์žฌ๊ท€ ์Šคํƒ
 60    ๋ถˆ์•ˆ์ • ์ •๋ ฌ
 61    """
 62    if len(arr) <= 1:
 63        return arr
 64
 65    arr = arr.copy()
 66    _quick_sort(arr, 0, len(arr) - 1)
 67    return arr
 68
 69
 70def _quick_sort(arr: List[int], low: int, high: int) -> None:
 71    if low < high:
 72        pivot_idx = partition(arr, low, high)
 73        _quick_sort(arr, low, pivot_idx - 1)
 74        _quick_sort(arr, pivot_idx + 1, high)
 75
 76
 77def partition(arr: List[int], low: int, high: int) -> int:
 78    """Lomuto ํŒŒํ‹ฐ์…˜"""
 79    # ๋žœ๋ค ํ”ผ๋ฒ—์œผ๋กœ ์ตœ์•… ์ผ€์ด์Šค ๋ฐฉ์ง€
 80    pivot_idx = random.randint(low, high)
 81    arr[pivot_idx], arr[high] = arr[high], arr[pivot_idx]
 82
 83    pivot = arr[high]
 84    i = low - 1
 85
 86    for j in range(low, high):
 87        if arr[j] <= pivot:
 88            i += 1
 89            arr[i], arr[j] = arr[j], arr[i]
 90
 91    arr[i + 1], arr[high] = arr[high], arr[i + 1]
 92    return i + 1
 93
 94
 95# =============================================================================
 96# 3. ๊ฑฐ๋“ญ์ œ๊ณฑ (Power / Exponentiation)
 97# =============================================================================
 98
 99def power(base: int, exp: int, mod: int = None) -> int:
100    """
101    ๋น ๋ฅธ ๊ฑฐ๋“ญ์ œ๊ณฑ
102    ์‹œ๊ฐ„๋ณต์žก๋„: O(log n)
103    """
104    if exp == 0:
105        return 1
106
107    if exp % 2 == 0:
108        half = power(base, exp // 2, mod)
109        result = half * half
110    else:
111        result = base * power(base, exp - 1, mod)
112
113    return result % mod if mod else result
114
115
116def power_iterative(base: int, exp: int, mod: int = None) -> int:
117    """๋น ๋ฅธ ๊ฑฐ๋“ญ์ œ๊ณฑ (๋ฐ˜๋ณต)"""
118    result = 1
119
120    while exp > 0:
121        if exp % 2 == 1:
122            result = result * base
123            if mod:
124                result %= mod
125        base = base * base
126        if mod:
127            base %= mod
128        exp //= 2
129
130    return result
131
132
133# =============================================================================
134# 4. ํ–‰๋ ฌ ๊ฑฐ๋“ญ์ œ๊ณฑ (Matrix Exponentiation)
135# =============================================================================
136
137def matrix_multiply(A: List[List[int]], B: List[List[int]], mod: int = None) -> List[List[int]]:
138    """2x2 ํ–‰๋ ฌ ๊ณฑ์…ˆ"""
139    n = len(A)
140    C = [[0] * n for _ in range(n)]
141
142    for i in range(n):
143        for j in range(n):
144            for k in range(n):
145                C[i][j] += A[i][k] * B[k][j]
146                if mod:
147                    C[i][j] %= mod
148
149    return C
150
151
152def matrix_power(M: List[List[int]], exp: int, mod: int = None) -> List[List[int]]:
153    """
154    ํ–‰๋ ฌ ๊ฑฐ๋“ญ์ œ๊ณฑ
155    ์‹œ๊ฐ„๋ณต์žก๋„: O(kยณ log n), k = ํ–‰๋ ฌ ํฌ๊ธฐ
156    """
157    n = len(M)
158    # ๋‹จ์œ„ ํ–‰๋ ฌ
159    result = [[1 if i == j else 0 for j in range(n)] for i in range(n)]
160
161    while exp > 0:
162        if exp % 2 == 1:
163            result = matrix_multiply(result, M, mod)
164        M = matrix_multiply(M, M, mod)
165        exp //= 2
166
167    return result
168
169
170def fibonacci_matrix(n: int, mod: int = None) -> int:
171    """
172    ํ”ผ๋ณด๋‚˜์น˜ ์ˆ˜์—ด (ํ–‰๋ ฌ ๊ฑฐ๋“ญ์ œ๊ณฑ)
173    ์‹œ๊ฐ„๋ณต์žก๋„: O(log n)
174    """
175    if n <= 1:
176        return n
177
178    # [[F(n+1), F(n)], [F(n), F(n-1)]] = [[1,1],[1,0]]^n
179    M = [[1, 1], [1, 0]]
180    result = matrix_power(M, n, mod)
181    return result[0][1]
182
183
184# =============================================================================
185# 5. ์—ญ์ˆœ ์Œ ๊ฐœ์ˆ˜ (Inversion Count)
186# =============================================================================
187
188def count_inversions(arr: List[int]) -> int:
189    """
190    ์—ญ์ˆœ ์Œ ๊ฐœ์ˆ˜ (i < j์ด๋ฉด์„œ arr[i] > arr[j])
191    ๋ณ‘ํ•ฉ ์ •๋ ฌ ๋ณ€ํ˜•
192    ์‹œ๊ฐ„๋ณต์žก๋„: O(n log n)
193    """
194
195    def merge_count(arr: List[int]) -> Tuple[List[int], int]:
196        if len(arr) <= 1:
197            return arr, 0
198
199        mid = len(arr) // 2
200        left, left_inv = merge_count(arr[:mid])
201        right, right_inv = merge_count(arr[mid:])
202
203        merged = []
204        inversions = left_inv + right_inv
205        i = j = 0
206
207        while i < len(left) and j < len(right):
208            if left[i] <= right[j]:
209                merged.append(left[i])
210                i += 1
211            else:
212                merged.append(right[j])
213                inversions += len(left) - i  # ๋‚จ์€ ์™ผ์ชฝ ์š”์†Œ ์ˆ˜๋งŒํผ ์—ญ์ˆœ
214                j += 1
215
216        merged.extend(left[i:])
217        merged.extend(right[j:])
218
219        return merged, inversions
220
221    _, count = merge_count(arr)
222    return count
223
224
225# =============================================================================
226# 6. ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ์  ์Œ (Closest Pair of Points)
227# =============================================================================
228
229def distance(p1: Tuple[float, float], p2: Tuple[float, float]) -> float:
230    """๋‘ ์  ์‚ฌ์ด ๊ฑฐ๋ฆฌ"""
231    return ((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) ** 0.5
232
233
234def closest_pair(points: List[Tuple[float, float]]) -> float:
235    """
236    ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ๋‘ ์  ์‚ฌ์ด์˜ ๊ฑฐ๋ฆฌ
237    ์‹œ๊ฐ„๋ณต์žก๋„: O(n log n)
238    """
239
240    def closest_recursive(px: List, py: List) -> float:
241        n = len(px)
242
243        # ๊ธฐ์ € ์ผ€์ด์Šค: ๋ธŒ๋ฃจํŠธ ํฌ์Šค
244        if n <= 3:
245            min_dist = float('inf')
246            for i in range(n):
247                for j in range(i + 1, n):
248                    min_dist = min(min_dist, distance(px[i], px[j]))
249            return min_dist
250
251        mid = n // 2
252        mid_point = px[mid]
253
254        # x์ขŒํ‘œ ๊ธฐ์ค€ ๋ถ„ํ• 
255        pyl = [p for p in py if p[0] <= mid_point[0]]
256        pyr = [p for p in py if p[0] > mid_point[0]]
257
258        dl = closest_recursive(px[:mid], pyl)
259        dr = closest_recursive(px[mid:], pyr)
260
261        d = min(dl, dr)
262
263        # ์ค‘๊ฐ„ ๋ ์—์„œ ํ™•์ธ
264        strip = [p for p in py if abs(p[0] - mid_point[0]) < d]
265
266        # ๋  ๋‚ด ์ ๋“ค ๋น„๊ต (์ตœ๋Œ€ 7๊ฐœ๋งŒ ํ™•์ธ)
267        for i in range(len(strip)):
268            for j in range(i + 1, min(i + 7, len(strip))):
269                if strip[j][1] - strip[i][1] >= d:
270                    break
271                d = min(d, distance(strip[i], strip[j]))
272
273        return d
274
275    px = sorted(points, key=lambda p: p[0])
276    py = sorted(points, key=lambda p: p[1])
277
278    return closest_recursive(px, py)
279
280
281# =============================================================================
282# 7. ์ตœ๋Œ€ ๋ถ€๋ถ„ ๋ฐฐ์—ด ํ•ฉ (Maximum Subarray - D&C)
283# =============================================================================
284
285def max_subarray_dc(arr: List[int]) -> int:
286    """
287    ์ตœ๋Œ€ ๋ถ€๋ถ„ ๋ฐฐ์—ด ํ•ฉ (๋ถ„ํ•  ์ •๋ณต)
288    ์‹œ๊ฐ„๋ณต์žก๋„: O(n log n)
289    """
290
291    def max_crossing_sum(arr: List[int], low: int, mid: int, high: int) -> int:
292        # ์™ผ์ชฝ ์ตœ๋Œ€
293        left_sum = float('-inf')
294        total = 0
295        for i in range(mid, low - 1, -1):
296            total += arr[i]
297            left_sum = max(left_sum, total)
298
299        # ์˜ค๋ฅธ์ชฝ ์ตœ๋Œ€
300        right_sum = float('-inf')
301        total = 0
302        for i in range(mid + 1, high + 1):
303            total += arr[i]
304            right_sum = max(right_sum, total)
305
306        return left_sum + right_sum
307
308    def max_subarray(arr: List[int], low: int, high: int) -> int:
309        if low == high:
310            return arr[low]
311
312        mid = (low + high) // 2
313
314        left_max = max_subarray(arr, low, mid)
315        right_max = max_subarray(arr, mid + 1, high)
316        cross_max = max_crossing_sum(arr, low, mid, high)
317
318        return max(left_max, right_max, cross_max)
319
320    if not arr:
321        return 0
322    return max_subarray(arr, 0, len(arr) - 1)
323
324
325# =============================================================================
326# 8. ์นด๋ผ์ธ ๋ฐ” ๊ณฑ์…ˆ (Karatsuba Multiplication)
327# =============================================================================
328
329def karatsuba(x: int, y: int) -> int:
330    """
331    ์นด๋ผ์ธ ๋ฐ” ํฐ ์ˆ˜ ๊ณฑ์…ˆ
332    ์‹œ๊ฐ„๋ณต์žก๋„: O(n^1.585)
333    """
334    # ๊ธฐ์ € ์ผ€์ด์Šค
335    if x < 10 or y < 10:
336        return x * y
337
338    # ์ž๋ฆฟ์ˆ˜ ๊ณ„์‚ฐ
339    n = max(len(str(x)), len(str(y)))
340    m = n // 2
341
342    # x = a * 10^m + b, y = c * 10^m + d
343    divisor = 10 ** m
344
345    a, b = divmod(x, divisor)
346    c, d = divmod(y, divisor)
347
348    # ์„ธ ๋ฒˆ์˜ ๊ณฑ์…ˆ
349    ac = karatsuba(a, c)
350    bd = karatsuba(b, d)
351    ad_bc = karatsuba(a + b, c + d) - ac - bd
352
353    return ac * (10 ** (2 * m)) + ad_bc * (10 ** m) + bd
354
355
356# =============================================================================
357# 9. ์ŠคํŠธ๋ผ์„ผ ํ–‰๋ ฌ ๊ณฑ์…ˆ (Strassen's Matrix Multiplication)
358# =============================================================================
359
360def strassen(A: List[List[int]], B: List[List[int]]) -> List[List[int]]:
361    """
362    ์ŠคํŠธ๋ผ์„ผ ํ–‰๋ ฌ ๊ณฑ์…ˆ
363    ์‹œ๊ฐ„๋ณต์žก๋„: O(n^2.807)
364    (์‹ค์ œ๋กœ๋Š” ์ž‘์€ ํ–‰๋ ฌ์—์„œ ์˜ค๋ฒ„ํ—ค๋“œ๊ฐ€ ํฌ๋ฏ€๋กœ ๊ธฐ์ค€ ํฌ๊ธฐ ์ดํ•˜๋Š” ์ผ๋ฐ˜ ๊ณฑ์…ˆ)
365    """
366    n = len(A)
367
368    # ๊ธฐ์ € ์ผ€์ด์Šค
369    if n <= 64:  # ์ž„๊ณ„๊ฐ’
370        return naive_matrix_multiply(A, B)
371
372    # ํ–‰๋ ฌ ๋ถ„ํ• 
373    mid = n // 2
374
375    A11 = [row[:mid] for row in A[:mid]]
376    A12 = [row[mid:] for row in A[:mid]]
377    A21 = [row[:mid] for row in A[mid:]]
378    A22 = [row[mid:] for row in A[mid:]]
379
380    B11 = [row[:mid] for row in B[:mid]]
381    B12 = [row[mid:] for row in B[:mid]]
382    B21 = [row[:mid] for row in B[mid:]]
383    B22 = [row[mid:] for row in B[mid:]]
384
385    # 7๊ฐœ์˜ ๊ณฑ์…ˆ (์ŠคํŠธ๋ผ์„ผ ๊ณต์‹)
386    M1 = strassen(matrix_add(A11, A22), matrix_add(B11, B22))
387    M2 = strassen(matrix_add(A21, A22), B11)
388    M3 = strassen(A11, matrix_sub(B12, B22))
389    M4 = strassen(A22, matrix_sub(B21, B11))
390    M5 = strassen(matrix_add(A11, A12), B22)
391    M6 = strassen(matrix_sub(A21, A11), matrix_add(B11, B12))
392    M7 = strassen(matrix_sub(A12, A22), matrix_add(B21, B22))
393
394    # ๊ฒฐ๊ณผ ์กฐํ•ฉ
395    C11 = matrix_add(matrix_sub(matrix_add(M1, M4), M5), M7)
396    C12 = matrix_add(M3, M5)
397    C21 = matrix_add(M2, M4)
398    C22 = matrix_add(matrix_sub(matrix_add(M1, M3), M2), M6)
399
400    return combine_matrices(C11, C12, C21, C22)
401
402
403def naive_matrix_multiply(A: List[List[int]], B: List[List[int]]) -> List[List[int]]:
404    """์ผ๋ฐ˜ ํ–‰๋ ฌ ๊ณฑ์…ˆ"""
405    n = len(A)
406    C = [[0] * n for _ in range(n)]
407    for i in range(n):
408        for j in range(n):
409            for k in range(n):
410                C[i][j] += A[i][k] * B[k][j]
411    return C
412
413
414def matrix_add(A: List[List[int]], B: List[List[int]]) -> List[List[int]]:
415    """ํ–‰๋ ฌ ๋ง์…ˆ"""
416    n = len(A)
417    return [[A[i][j] + B[i][j] for j in range(n)] for i in range(n)]
418
419
420def matrix_sub(A: List[List[int]], B: List[List[int]]) -> List[List[int]]:
421    """ํ–‰๋ ฌ ๋บ„์…ˆ"""
422    n = len(A)
423    return [[A[i][j] - B[i][j] for j in range(n)] for i in range(n)]
424
425
426def combine_matrices(C11, C12, C21, C22) -> List[List[int]]:
427    """4๊ฐœ์˜ ๋ถ€๋ถ„ ํ–‰๋ ฌ ๊ฒฐํ•ฉ"""
428    n = len(C11)
429    result = [[0] * (2 * n) for _ in range(2 * n)]
430    for i in range(n):
431        for j in range(n):
432            result[i][j] = C11[i][j]
433            result[i][j + n] = C12[i][j]
434            result[i + n][j] = C21[i][j]
435            result[i + n][j + n] = C22[i][j]
436    return result
437
438
439# =============================================================================
440# ํ…Œ์ŠคํŠธ
441# =============================================================================
442
443def main():
444    print("=" * 60)
445    print("๋ถ„ํ•  ์ •๋ณต (Divide and Conquer) ์˜ˆ์ œ")
446    print("=" * 60)
447
448    # 1. ๋ณ‘ํ•ฉ ์ •๋ ฌ
449    print("\n[1] ๋ณ‘ํ•ฉ ์ •๋ ฌ")
450    arr = [64, 34, 25, 12, 22, 11, 90]
451    sorted_arr = merge_sort(arr)
452    print(f"    ์›๋ณธ: {arr}")
453    print(f"    ์ •๋ ฌ: {sorted_arr}")
454
455    # 2. ํ€ต ์ •๋ ฌ
456    print("\n[2] ํ€ต ์ •๋ ฌ")
457    arr = [64, 34, 25, 12, 22, 11, 90]
458    sorted_arr = quick_sort(arr)
459    print(f"    ์›๋ณธ: {arr}")
460    print(f"    ์ •๋ ฌ: {sorted_arr}")
461
462    # 3. ๋น ๋ฅธ ๊ฑฐ๋“ญ์ œ๊ณฑ
463    print("\n[3] ๋น ๋ฅธ ๊ฑฐ๋“ญ์ œ๊ณฑ")
464    print(f"    2^10 = {power(2, 10)}")
465    print(f"    2^10 (๋ฐ˜๋ณต) = {power_iterative(2, 10)}")
466    print(f"    3^7 mod 1000 = {power(3, 7, 1000)}")
467
468    # 4. ํ”ผ๋ณด๋‚˜์น˜ (ํ–‰๋ ฌ ๊ฑฐ๋“ญ์ œ๊ณฑ)
469    print("\n[4] ํ”ผ๋ณด๋‚˜์น˜ (ํ–‰๋ ฌ ๊ฑฐ๋“ญ์ œ๊ณฑ)")
470    for n in [10, 20, 50]:
471        fib = fibonacci_matrix(n)
472        print(f"    F({n}) = {fib}")
473
474    # 5. ์—ญ์ˆœ ์Œ ๊ฐœ์ˆ˜
475    print("\n[5] ์—ญ์ˆœ ์Œ ๊ฐœ์ˆ˜")
476    arr = [2, 4, 1, 3, 5]
477    inv = count_inversions(arr)
478    print(f"    ๋ฐฐ์—ด: {arr}")
479    print(f"    ์—ญ์ˆœ ์Œ: {inv}๊ฐœ")
480
481    # 6. ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ์  ์Œ
482    print("\n[6] ๊ฐ€์žฅ ๊ฐ€๊นŒ์šด ์  ์Œ")
483    points = [(2, 3), (12, 30), (40, 50), (5, 1), (12, 10), (3, 4)]
484    dist = closest_pair(points)
485    print(f"    ์ ๋“ค: {points}")
486    print(f"    ์ตœ์†Œ ๊ฑฐ๋ฆฌ: {dist:.4f}")
487
488    # 7. ์ตœ๋Œ€ ๋ถ€๋ถ„ ๋ฐฐ์—ด ํ•ฉ (D&C)
489    print("\n[7] ์ตœ๋Œ€ ๋ถ€๋ถ„ ๋ฐฐ์—ด ํ•ฉ (๋ถ„ํ•  ์ •๋ณต)")
490    arr = [-2, 1, -3, 4, -1, 2, 1, -5, 4]
491    max_sum = max_subarray_dc(arr)
492    print(f"    ๋ฐฐ์—ด: {arr}")
493    print(f"    ์ตœ๋Œ€ ํ•ฉ: {max_sum}")
494
495    # 8. ์นด๋ผ์ธ ๋ฐ” ๊ณฑ์…ˆ
496    print("\n[8] ์นด๋ผ์ธ ๋ฐ” ๊ณฑ์…ˆ")
497    x, y = 1234, 5678
498    result = karatsuba(x, y)
499    print(f"    {x} ร— {y} = {result}")
500    print(f"    ๊ฒ€์ฆ: {x * y}")
501
502    print("\n" + "=" * 60)
503
504
505if __name__ == "__main__":
506    main()