1"""
2๋ถํ ์ ๋ณต (Divide and Conquer)
3Divide and Conquer Algorithms
4
5๋ฌธ์ ๋ฅผ ์์ ๋ถ๋ถ ๋ฌธ์ ๋ก ๋๋์ด ํด๊ฒฐํ๋ ์๊ณ ๋ฆฌ์ฆ์
๋๋ค.
6"""
7
8from typing import List, Tuple, Optional
9import random
10
11
12# =============================================================================
13# 1. ๋ณํฉ ์ ๋ ฌ (Merge Sort)
14# =============================================================================
15
16def merge_sort(arr: List[int]) -> List[int]:
17 """
18 ๋ณํฉ ์ ๋ ฌ
19 ์๊ฐ๋ณต์ก๋: O(n log n)
20 ๊ณต๊ฐ๋ณต์ก๋: O(n)
21 ์์ ์ ๋ ฌ
22 """
23 if len(arr) <= 1:
24 return arr
25
26 mid = len(arr) // 2
27 left = merge_sort(arr[:mid])
28 right = merge_sort(arr[mid:])
29
30 return merge(left, right)
31
32
33def merge(left: List[int], right: List[int]) -> List[int]:
34 """๋ ์ ๋ ฌ๋ ๋ฐฐ์ด ๋ณํฉ"""
35 result = []
36 i = j = 0
37
38 while i < len(left) and j < len(right):
39 if left[i] <= right[j]:
40 result.append(left[i])
41 i += 1
42 else:
43 result.append(right[j])
44 j += 1
45
46 result.extend(left[i:])
47 result.extend(right[j:])
48 return result
49
50
51# =============================================================================
52# 2. ํต ์ ๋ ฌ (Quick Sort)
53# =============================================================================
54
55def quick_sort(arr: List[int]) -> List[int]:
56 """
57 ํต ์ ๋ ฌ (Lomuto ํํฐ์
)
58 ์๊ฐ๋ณต์ก๋: ํ๊ท O(n log n), ์ต์
O(nยฒ)
59 ๊ณต๊ฐ๋ณต์ก๋: O(log n) - ์ฌ๊ท ์คํ
60 ๋ถ์์ ์ ๋ ฌ
61 """
62 if len(arr) <= 1:
63 return arr
64
65 arr = arr.copy()
66 _quick_sort(arr, 0, len(arr) - 1)
67 return arr
68
69
70def _quick_sort(arr: List[int], low: int, high: int) -> None:
71 if low < high:
72 pivot_idx = partition(arr, low, high)
73 _quick_sort(arr, low, pivot_idx - 1)
74 _quick_sort(arr, pivot_idx + 1, high)
75
76
77def partition(arr: List[int], low: int, high: int) -> int:
78 """Lomuto ํํฐ์
"""
79 # ๋๋ค ํผ๋ฒ์ผ๋ก ์ต์
์ผ์ด์ค ๋ฐฉ์ง
80 pivot_idx = random.randint(low, high)
81 arr[pivot_idx], arr[high] = arr[high], arr[pivot_idx]
82
83 pivot = arr[high]
84 i = low - 1
85
86 for j in range(low, high):
87 if arr[j] <= pivot:
88 i += 1
89 arr[i], arr[j] = arr[j], arr[i]
90
91 arr[i + 1], arr[high] = arr[high], arr[i + 1]
92 return i + 1
93
94
95# =============================================================================
96# 3. ๊ฑฐ๋ญ์ ๊ณฑ (Power / Exponentiation)
97# =============================================================================
98
99def power(base: int, exp: int, mod: int = None) -> int:
100 """
101 ๋น ๋ฅธ ๊ฑฐ๋ญ์ ๊ณฑ
102 ์๊ฐ๋ณต์ก๋: O(log n)
103 """
104 if exp == 0:
105 return 1
106
107 if exp % 2 == 0:
108 half = power(base, exp // 2, mod)
109 result = half * half
110 else:
111 result = base * power(base, exp - 1, mod)
112
113 return result % mod if mod else result
114
115
116def power_iterative(base: int, exp: int, mod: int = None) -> int:
117 """๋น ๋ฅธ ๊ฑฐ๋ญ์ ๊ณฑ (๋ฐ๋ณต)"""
118 result = 1
119
120 while exp > 0:
121 if exp % 2 == 1:
122 result = result * base
123 if mod:
124 result %= mod
125 base = base * base
126 if mod:
127 base %= mod
128 exp //= 2
129
130 return result
131
132
133# =============================================================================
134# 4. ํ๋ ฌ ๊ฑฐ๋ญ์ ๊ณฑ (Matrix Exponentiation)
135# =============================================================================
136
137def matrix_multiply(A: List[List[int]], B: List[List[int]], mod: int = None) -> List[List[int]]:
138 """2x2 ํ๋ ฌ ๊ณฑ์
"""
139 n = len(A)
140 C = [[0] * n for _ in range(n)]
141
142 for i in range(n):
143 for j in range(n):
144 for k in range(n):
145 C[i][j] += A[i][k] * B[k][j]
146 if mod:
147 C[i][j] %= mod
148
149 return C
150
151
152def matrix_power(M: List[List[int]], exp: int, mod: int = None) -> List[List[int]]:
153 """
154 ํ๋ ฌ ๊ฑฐ๋ญ์ ๊ณฑ
155 ์๊ฐ๋ณต์ก๋: O(kยณ log n), k = ํ๋ ฌ ํฌ๊ธฐ
156 """
157 n = len(M)
158 # ๋จ์ ํ๋ ฌ
159 result = [[1 if i == j else 0 for j in range(n)] for i in range(n)]
160
161 while exp > 0:
162 if exp % 2 == 1:
163 result = matrix_multiply(result, M, mod)
164 M = matrix_multiply(M, M, mod)
165 exp //= 2
166
167 return result
168
169
170def fibonacci_matrix(n: int, mod: int = None) -> int:
171 """
172 ํผ๋ณด๋์น ์์ด (ํ๋ ฌ ๊ฑฐ๋ญ์ ๊ณฑ)
173 ์๊ฐ๋ณต์ก๋: O(log n)
174 """
175 if n <= 1:
176 return n
177
178 # [[F(n+1), F(n)], [F(n), F(n-1)]] = [[1,1],[1,0]]^n
179 M = [[1, 1], [1, 0]]
180 result = matrix_power(M, n, mod)
181 return result[0][1]
182
183
184# =============================================================================
185# 5. ์ญ์ ์ ๊ฐ์ (Inversion Count)
186# =============================================================================
187
188def count_inversions(arr: List[int]) -> int:
189 """
190 ์ญ์ ์ ๊ฐ์ (i < j์ด๋ฉด์ arr[i] > arr[j])
191 ๋ณํฉ ์ ๋ ฌ ๋ณํ
192 ์๊ฐ๋ณต์ก๋: O(n log n)
193 """
194
195 def merge_count(arr: List[int]) -> Tuple[List[int], int]:
196 if len(arr) <= 1:
197 return arr, 0
198
199 mid = len(arr) // 2
200 left, left_inv = merge_count(arr[:mid])
201 right, right_inv = merge_count(arr[mid:])
202
203 merged = []
204 inversions = left_inv + right_inv
205 i = j = 0
206
207 while i < len(left) and j < len(right):
208 if left[i] <= right[j]:
209 merged.append(left[i])
210 i += 1
211 else:
212 merged.append(right[j])
213 inversions += len(left) - i # ๋จ์ ์ผ์ชฝ ์์ ์๋งํผ ์ญ์
214 j += 1
215
216 merged.extend(left[i:])
217 merged.extend(right[j:])
218
219 return merged, inversions
220
221 _, count = merge_count(arr)
222 return count
223
224
225# =============================================================================
226# 6. ๊ฐ์ฅ ๊ฐ๊น์ด ์ ์ (Closest Pair of Points)
227# =============================================================================
228
229def distance(p1: Tuple[float, float], p2: Tuple[float, float]) -> float:
230 """๋ ์ ์ฌ์ด ๊ฑฐ๋ฆฌ"""
231 return ((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) ** 0.5
232
233
234def closest_pair(points: List[Tuple[float, float]]) -> float:
235 """
236 ๊ฐ์ฅ ๊ฐ๊น์ด ๋ ์ ์ฌ์ด์ ๊ฑฐ๋ฆฌ
237 ์๊ฐ๋ณต์ก๋: O(n log n)
238 """
239
240 def closest_recursive(px: List, py: List) -> float:
241 n = len(px)
242
243 # ๊ธฐ์ ์ผ์ด์ค: ๋ธ๋ฃจํธ ํฌ์ค
244 if n <= 3:
245 min_dist = float('inf')
246 for i in range(n):
247 for j in range(i + 1, n):
248 min_dist = min(min_dist, distance(px[i], px[j]))
249 return min_dist
250
251 mid = n // 2
252 mid_point = px[mid]
253
254 # x์ขํ ๊ธฐ์ค ๋ถํ
255 pyl = [p for p in py if p[0] <= mid_point[0]]
256 pyr = [p for p in py if p[0] > mid_point[0]]
257
258 dl = closest_recursive(px[:mid], pyl)
259 dr = closest_recursive(px[mid:], pyr)
260
261 d = min(dl, dr)
262
263 # ์ค๊ฐ ๋ ์์ ํ์ธ
264 strip = [p for p in py if abs(p[0] - mid_point[0]) < d]
265
266 # ๋ ๋ด ์ ๋ค ๋น๊ต (์ต๋ 7๊ฐ๋ง ํ์ธ)
267 for i in range(len(strip)):
268 for j in range(i + 1, min(i + 7, len(strip))):
269 if strip[j][1] - strip[i][1] >= d:
270 break
271 d = min(d, distance(strip[i], strip[j]))
272
273 return d
274
275 px = sorted(points, key=lambda p: p[0])
276 py = sorted(points, key=lambda p: p[1])
277
278 return closest_recursive(px, py)
279
280
281# =============================================================================
282# 7. ์ต๋ ๋ถ๋ถ ๋ฐฐ์ด ํฉ (Maximum Subarray - D&C)
283# =============================================================================
284
285def max_subarray_dc(arr: List[int]) -> int:
286 """
287 ์ต๋ ๋ถ๋ถ ๋ฐฐ์ด ํฉ (๋ถํ ์ ๋ณต)
288 ์๊ฐ๋ณต์ก๋: O(n log n)
289 """
290
291 def max_crossing_sum(arr: List[int], low: int, mid: int, high: int) -> int:
292 # ์ผ์ชฝ ์ต๋
293 left_sum = float('-inf')
294 total = 0
295 for i in range(mid, low - 1, -1):
296 total += arr[i]
297 left_sum = max(left_sum, total)
298
299 # ์ค๋ฅธ์ชฝ ์ต๋
300 right_sum = float('-inf')
301 total = 0
302 for i in range(mid + 1, high + 1):
303 total += arr[i]
304 right_sum = max(right_sum, total)
305
306 return left_sum + right_sum
307
308 def max_subarray(arr: List[int], low: int, high: int) -> int:
309 if low == high:
310 return arr[low]
311
312 mid = (low + high) // 2
313
314 left_max = max_subarray(arr, low, mid)
315 right_max = max_subarray(arr, mid + 1, high)
316 cross_max = max_crossing_sum(arr, low, mid, high)
317
318 return max(left_max, right_max, cross_max)
319
320 if not arr:
321 return 0
322 return max_subarray(arr, 0, len(arr) - 1)
323
324
325# =============================================================================
326# 8. ์นด๋ผ์ธ ๋ฐ ๊ณฑ์
(Karatsuba Multiplication)
327# =============================================================================
328
329def karatsuba(x: int, y: int) -> int:
330 """
331 ์นด๋ผ์ธ ๋ฐ ํฐ ์ ๊ณฑ์
332 ์๊ฐ๋ณต์ก๋: O(n^1.585)
333 """
334 # ๊ธฐ์ ์ผ์ด์ค
335 if x < 10 or y < 10:
336 return x * y
337
338 # ์๋ฆฟ์ ๊ณ์ฐ
339 n = max(len(str(x)), len(str(y)))
340 m = n // 2
341
342 # x = a * 10^m + b, y = c * 10^m + d
343 divisor = 10 ** m
344
345 a, b = divmod(x, divisor)
346 c, d = divmod(y, divisor)
347
348 # ์ธ ๋ฒ์ ๊ณฑ์
349 ac = karatsuba(a, c)
350 bd = karatsuba(b, d)
351 ad_bc = karatsuba(a + b, c + d) - ac - bd
352
353 return ac * (10 ** (2 * m)) + ad_bc * (10 ** m) + bd
354
355
356# =============================================================================
357# 9. ์คํธ๋ผ์ผ ํ๋ ฌ ๊ณฑ์
(Strassen's Matrix Multiplication)
358# =============================================================================
359
360def strassen(A: List[List[int]], B: List[List[int]]) -> List[List[int]]:
361 """
362 ์คํธ๋ผ์ผ ํ๋ ฌ ๊ณฑ์
363 ์๊ฐ๋ณต์ก๋: O(n^2.807)
364 (์ค์ ๋ก๋ ์์ ํ๋ ฌ์์ ์ค๋ฒํค๋๊ฐ ํฌ๋ฏ๋ก ๊ธฐ์ค ํฌ๊ธฐ ์ดํ๋ ์ผ๋ฐ ๊ณฑ์
)
365 """
366 n = len(A)
367
368 # ๊ธฐ์ ์ผ์ด์ค
369 if n <= 64: # ์๊ณ๊ฐ
370 return naive_matrix_multiply(A, B)
371
372 # ํ๋ ฌ ๋ถํ
373 mid = n // 2
374
375 A11 = [row[:mid] for row in A[:mid]]
376 A12 = [row[mid:] for row in A[:mid]]
377 A21 = [row[:mid] for row in A[mid:]]
378 A22 = [row[mid:] for row in A[mid:]]
379
380 B11 = [row[:mid] for row in B[:mid]]
381 B12 = [row[mid:] for row in B[:mid]]
382 B21 = [row[:mid] for row in B[mid:]]
383 B22 = [row[mid:] for row in B[mid:]]
384
385 # 7๊ฐ์ ๊ณฑ์
(์คํธ๋ผ์ผ ๊ณต์)
386 M1 = strassen(matrix_add(A11, A22), matrix_add(B11, B22))
387 M2 = strassen(matrix_add(A21, A22), B11)
388 M3 = strassen(A11, matrix_sub(B12, B22))
389 M4 = strassen(A22, matrix_sub(B21, B11))
390 M5 = strassen(matrix_add(A11, A12), B22)
391 M6 = strassen(matrix_sub(A21, A11), matrix_add(B11, B12))
392 M7 = strassen(matrix_sub(A12, A22), matrix_add(B21, B22))
393
394 # ๊ฒฐ๊ณผ ์กฐํฉ
395 C11 = matrix_add(matrix_sub(matrix_add(M1, M4), M5), M7)
396 C12 = matrix_add(M3, M5)
397 C21 = matrix_add(M2, M4)
398 C22 = matrix_add(matrix_sub(matrix_add(M1, M3), M2), M6)
399
400 return combine_matrices(C11, C12, C21, C22)
401
402
403def naive_matrix_multiply(A: List[List[int]], B: List[List[int]]) -> List[List[int]]:
404 """์ผ๋ฐ ํ๋ ฌ ๊ณฑ์
"""
405 n = len(A)
406 C = [[0] * n for _ in range(n)]
407 for i in range(n):
408 for j in range(n):
409 for k in range(n):
410 C[i][j] += A[i][k] * B[k][j]
411 return C
412
413
414def matrix_add(A: List[List[int]], B: List[List[int]]) -> List[List[int]]:
415 """ํ๋ ฌ ๋ง์
"""
416 n = len(A)
417 return [[A[i][j] + B[i][j] for j in range(n)] for i in range(n)]
418
419
420def matrix_sub(A: List[List[int]], B: List[List[int]]) -> List[List[int]]:
421 """ํ๋ ฌ ๋บ์
"""
422 n = len(A)
423 return [[A[i][j] - B[i][j] for j in range(n)] for i in range(n)]
424
425
426def combine_matrices(C11, C12, C21, C22) -> List[List[int]]:
427 """4๊ฐ์ ๋ถ๋ถ ํ๋ ฌ ๊ฒฐํฉ"""
428 n = len(C11)
429 result = [[0] * (2 * n) for _ in range(2 * n)]
430 for i in range(n):
431 for j in range(n):
432 result[i][j] = C11[i][j]
433 result[i][j + n] = C12[i][j]
434 result[i + n][j] = C21[i][j]
435 result[i + n][j + n] = C22[i][j]
436 return result
437
438
439# =============================================================================
440# ํ
์คํธ
441# =============================================================================
442
443def main():
444 print("=" * 60)
445 print("๋ถํ ์ ๋ณต (Divide and Conquer) ์์ ")
446 print("=" * 60)
447
448 # 1. ๋ณํฉ ์ ๋ ฌ
449 print("\n[1] ๋ณํฉ ์ ๋ ฌ")
450 arr = [64, 34, 25, 12, 22, 11, 90]
451 sorted_arr = merge_sort(arr)
452 print(f" ์๋ณธ: {arr}")
453 print(f" ์ ๋ ฌ: {sorted_arr}")
454
455 # 2. ํต ์ ๋ ฌ
456 print("\n[2] ํต ์ ๋ ฌ")
457 arr = [64, 34, 25, 12, 22, 11, 90]
458 sorted_arr = quick_sort(arr)
459 print(f" ์๋ณธ: {arr}")
460 print(f" ์ ๋ ฌ: {sorted_arr}")
461
462 # 3. ๋น ๋ฅธ ๊ฑฐ๋ญ์ ๊ณฑ
463 print("\n[3] ๋น ๋ฅธ ๊ฑฐ๋ญ์ ๊ณฑ")
464 print(f" 2^10 = {power(2, 10)}")
465 print(f" 2^10 (๋ฐ๋ณต) = {power_iterative(2, 10)}")
466 print(f" 3^7 mod 1000 = {power(3, 7, 1000)}")
467
468 # 4. ํผ๋ณด๋์น (ํ๋ ฌ ๊ฑฐ๋ญ์ ๊ณฑ)
469 print("\n[4] ํผ๋ณด๋์น (ํ๋ ฌ ๊ฑฐ๋ญ์ ๊ณฑ)")
470 for n in [10, 20, 50]:
471 fib = fibonacci_matrix(n)
472 print(f" F({n}) = {fib}")
473
474 # 5. ์ญ์ ์ ๊ฐ์
475 print("\n[5] ์ญ์ ์ ๊ฐ์")
476 arr = [2, 4, 1, 3, 5]
477 inv = count_inversions(arr)
478 print(f" ๋ฐฐ์ด: {arr}")
479 print(f" ์ญ์ ์: {inv}๊ฐ")
480
481 # 6. ๊ฐ์ฅ ๊ฐ๊น์ด ์ ์
482 print("\n[6] ๊ฐ์ฅ ๊ฐ๊น์ด ์ ์")
483 points = [(2, 3), (12, 30), (40, 50), (5, 1), (12, 10), (3, 4)]
484 dist = closest_pair(points)
485 print(f" ์ ๋ค: {points}")
486 print(f" ์ต์ ๊ฑฐ๋ฆฌ: {dist:.4f}")
487
488 # 7. ์ต๋ ๋ถ๋ถ ๋ฐฐ์ด ํฉ (D&C)
489 print("\n[7] ์ต๋ ๋ถ๋ถ ๋ฐฐ์ด ํฉ (๋ถํ ์ ๋ณต)")
490 arr = [-2, 1, -3, 4, -1, 2, 1, -5, 4]
491 max_sum = max_subarray_dc(arr)
492 print(f" ๋ฐฐ์ด: {arr}")
493 print(f" ์ต๋ ํฉ: {max_sum}")
494
495 # 8. ์นด๋ผ์ธ ๋ฐ ๊ณฑ์
496 print("\n[8] ์นด๋ผ์ธ ๋ฐ ๊ณฑ์
")
497 x, y = 1234, 5678
498 result = karatsuba(x, y)
499 print(f" {x} ร {y} = {result}")
500 print(f" ๊ฒ์ฆ: {x * y}")
501
502 print("\n" + "=" * 60)
503
504
505if __name__ == "__main__":
506 main()