1"""
2๊ณ ๊ธ DP ์ต์ ํ (Advanced DP Optimization)
3Advanced Dynamic Programming Optimization Techniques
4
5DP์ ์๊ฐ๋ณต์ก๋๋ฅผ ๊ฐ์ ํ๋ ์ต์ ํ ๊ธฐ๋ฒ๋ค์
๋๋ค.
6"""
7
8from typing import List, Tuple, Callable
9from collections import deque
10from math import inf
11
12
13# =============================================================================
14# 1. Convex Hull Trick (CHT)
15# =============================================================================
16
17class ConvexHullTrick:
18 """
19 ๋ณผ๋ก ๊ป์ง ํธ๋ฆญ
20 ์ต์๊ฐ ์ฟผ๋ฆฌ: min(a[i] * x + b[i]) for all i
21 ์กฐ๊ฑด: a[i]๊ฐ ๋จ์กฐ ๊ฐ์ (๋๋ ์ฆ๊ฐ)
22
23 ์๊ฐ๋ณต์ก๋: ์ฝ์
O(1) ํ๊ท , ์ฟผ๋ฆฌ O(log n) ๋๋ O(1)
24 """
25
26 def __init__(self):
27 self.lines = deque() # (๊ธฐ์ธ๊ธฐ, y์ ํธ)
28
29 def is_bad(self, l1: Tuple[int, int], l2: Tuple[int, int], l3: Tuple[int, int]) -> bool:
30 """l2๊ฐ ๋ถํ์ํ์ง ํ์ธ (l1๊ณผ l3 ์ฌ์ด์์)"""
31 # ๊ต์ ๋น๊ต: (l1, l2) ๊ต์ >= (l2, l3) ๊ต์ ์ด๋ฉด l2 ๋ถํ์
32 # (b2 - b1) / (a1 - a2) >= (b3 - b2) / (a2 - a3)
33 # (b2 - b1) * (a2 - a3) >= (b3 - b2) * (a1 - a2)
34 a1, b1 = l1
35 a2, b2 = l2
36 a3, b3 = l3
37 return (b2 - b1) * (a2 - a3) >= (b3 - b2) * (a1 - a2)
38
39 def add_line(self, a: int, b: int):
40 """
41 ์ง์ y = ax + b ์ถ๊ฐ
42 a๋ ๋จ์กฐ ๊ฐ์ํด์ผ ํจ
43 """
44 line = (a, b)
45
46 while len(self.lines) >= 2 and self.is_bad(self.lines[-2], self.lines[-1], line):
47 self.lines.pop()
48
49 self.lines.append(line)
50
51 def query_min(self, x: int) -> int:
52 """
53 x์์์ ์ต์๊ฐ ์ฟผ๋ฆฌ
54 x๊ฐ ๋จ์กฐ ์ฆ๊ฐํ ๋ O(1)
55 """
56 while len(self.lines) >= 2:
57 a1, b1 = self.lines[0]
58 a2, b2 = self.lines[1]
59 if a1 * x + b1 >= a2 * x + b2:
60 self.lines.popleft()
61 else:
62 break
63
64 a, b = self.lines[0]
65 return a * x + b
66
67
68class LiChaoTree:
69 """
70 Li Chao Tree (์ธ๊ทธ๋จผํธ ํธ๋ฆฌ ๊ธฐ๋ฐ CHT)
71 ์์์ ์ง์ ์ถ๊ฐ์ ์์์ x์์ ์ฟผ๋ฆฌ ์ง์
72
73 ์๊ฐ๋ณต์ก๋: ์ฝ์
O(log C), ์ฟผ๋ฆฌ O(log C)
74 C = ์ขํ ๋ฒ์
75 """
76
77 def __init__(self, lo: int, hi: int):
78 self.lo = lo
79 self.hi = hi
80 self.tree = {} # ๋
ธ๋๋ณ ์ฐ์ธ ์ง์ ์ ์ฅ
81
82 def _eval(self, line: Tuple[int, int], x: int) -> int:
83 """์ง์ ๊ฐ ๊ณ์ฐ"""
84 if line is None:
85 return inf
86 a, b = line
87 return a * x + b
88
89 def add_line(self, a: int, b: int):
90 """์ง์ ์ถ๊ฐ"""
91 self._add_line_impl((a, b), self.lo, self.hi, 1)
92
93 def _add_line_impl(self, new_line: Tuple[int, int], lo: int, hi: int, node: int):
94 if lo > hi:
95 return
96
97 mid = (lo + hi) // 2
98 cur_line = self.tree.get(node)
99
100 # ์ค๊ฐ์ ์์ ๋น๊ต
101 new_better_at_mid = self._eval(new_line, mid) < self._eval(cur_line, mid)
102
103 if cur_line is None or new_better_at_mid:
104 self.tree[node], new_line = new_line, cur_line
105
106 if lo == hi or new_line is None:
107 return
108
109 # ์ผ์ชฝ/์ค๋ฅธ์ชฝ ์์์ผ๋ก ์ ํ
110 new_better_at_lo = self._eval(new_line, lo) < self._eval(self.tree.get(node), lo)
111
112 if new_better_at_lo:
113 self._add_line_impl(new_line, lo, mid - 1, 2 * node)
114 else:
115 self._add_line_impl(new_line, mid + 1, hi, 2 * node + 1)
116
117 def query(self, x: int) -> int:
118 """x์์์ ์ต์๊ฐ"""
119 return self._query_impl(x, self.lo, self.hi, 1)
120
121 def _query_impl(self, x: int, lo: int, hi: int, node: int) -> int:
122 if lo > hi:
123 return inf
124
125 result = self._eval(self.tree.get(node), x)
126
127 if lo == hi:
128 return result
129
130 mid = (lo + hi) // 2
131 if x <= mid:
132 return min(result, self._query_impl(x, lo, mid - 1, 2 * node))
133 else:
134 return min(result, self._query_impl(x, mid + 1, hi, 2 * node + 1))
135
136
137# =============================================================================
138# 2. Divide and Conquer Optimization
139# =============================================================================
140
141def dc_optimization(n: int, m: int, cost: Callable[[int, int], int]) -> List[List[int]]:
142 """
143 ๋ถํ ์ ๋ณต ์ต์ ํ
144 ์กฐ๊ฑด: opt[k][i] <= opt[k][i+1] (๋จ์กฐ์ฑ)
145 ์ ํ์: dp[k][j] = min(dp[k-1][i] + cost(i, j)) for i < j
146
147 ์๊ฐ๋ณต์ก๋: O(k * n log n) (์ผ๋ฐ O(k * n^2)์์ ๊ฐ์ )
148
149 n: ์์ ๊ฐ์
150 m: ๋ถํ ๊ทธ๋ฃน ์
151 cost(i, j): i+1 ~ j ๊ตฌ๊ฐ์ ๋น์ฉ
152 """
153 INF = float('inf')
154 dp = [[INF] * (n + 1) for _ in range(m + 1)]
155 dp[0][0] = 0
156
157 def compute(k: int, lo: int, hi: int, opt_lo: int, opt_hi: int):
158 """dp[k][lo:hi+1] ๊ณ์ฐ, ์ต์ ์ ๋ถํ ์ ์ opt_lo ~ opt_hi ๋ฒ์"""
159 if lo > hi:
160 return
161
162 mid = (lo + hi) // 2
163 best_cost = INF
164 best_opt = opt_lo
165
166 for i in range(opt_lo, min(opt_hi, mid) + 1):
167 curr_cost = dp[k - 1][i] + cost(i, mid)
168 if curr_cost < best_cost:
169 best_cost = curr_cost
170 best_opt = i
171
172 dp[k][mid] = best_cost
173
174 # ๋ถํ ์ ๋ณต
175 compute(k, lo, mid - 1, opt_lo, best_opt)
176 compute(k, mid + 1, hi, best_opt, opt_hi)
177
178 for k in range(1, m + 1):
179 compute(k, k, n, k - 1, n - 1)
180
181 return dp
182
183
184def dc_optimization_example():
185 """
186 ์์ : ๋ฐฐ์ด์ k๊ฐ ๊ทธ๋ฃน์ผ๋ก ๋๋๊ธฐ
187 ๊ฐ ๊ทธ๋ฃน์ ๋น์ฉ = ๊ตฌ๊ฐ ๋ด ์์ ์ฐจ์ด์ ์ ๊ณฑ ํฉ
188 """
189 arr = [1, 5, 2, 8, 3, 7, 4, 6]
190 n = len(arr)
191 k = 3
192
193 # ์ ์ฒ๋ฆฌ: prefix sum for cost calculation
194 prefix = [0] * (n + 1)
195 prefix_sq = [0] * (n + 1)
196 for i in range(n):
197 prefix[i + 1] = prefix[i] + arr[i]
198 prefix_sq[i + 1] = prefix_sq[i] + arr[i] * arr[i]
199
200 def cost(l: int, r: int) -> int:
201 """๊ตฌ๊ฐ [l+1, r]์ ๋ถ์ฐ (์ ๊ณฑํฉ - ํ๊ท *ํฉ)"""
202 if l >= r:
203 return 0
204 length = r - l
205 s = prefix[r] - prefix[l]
206 sq = prefix_sq[r] - prefix_sq[l]
207 # ๋ถ์ฐ = E[X^2] - E[X]^2, ์ฌ๊ธฐ์๋ ์ ๊ณฑํฉ - ํฉ^2/n
208 return sq * length - s * s
209
210 dp = dc_optimization(n, k, cost)
211 return dp[k][n]
212
213
214# =============================================================================
215# 3. Knuth Optimization
216# =============================================================================
217
218def knuth_optimization(n: int, cost: List[List[int]]) -> Tuple[List[List[int]], List[List[int]]]:
219 """
220 Knuth ์ต์ ํ
221 ์กฐ๊ฑด: cost๊ฐ ์ฌ๊ฐ ๋ถ๋ฑ์ ๋ง์กฑ (Quadrangle Inequality)
222 cost[a][c] + cost[b][d] <= cost[a][d] + cost[b][c] (a <= b <= c <= d)
223 ์ ํ์: dp[i][j] = min(dp[i][k] + dp[k][j]) + cost[i][j] for i < k < j
224
225 ์๊ฐ๋ณต์ก๋: O(n^2) (์ผ๋ฐ O(n^3)์์ ๊ฐ์ )
226
227 ์: ์ต์ ์ด์ง ํ์ ํธ๋ฆฌ, ํ๋ ฌ ์ฒด์ธ ๊ณฑ์
228 """
229 INF = float('inf')
230 dp = [[0] * n for _ in range(n)]
231 opt = [[0] * n for _ in range(n)]
232
233 # ๊ธฐ์ : ๊ธธ์ด 1
234 for i in range(n):
235 opt[i][i] = i
236
237 # ๊ธธ์ด 2 ์ด์
238 for length in range(2, n + 1):
239 for i in range(n - length + 1):
240 j = i + length - 1
241 dp[i][j] = INF
242
243 # opt[i][j-1] <= opt[i][j] <= opt[i+1][j]
244 lo = opt[i][j - 1] if j > 0 else i
245 hi = opt[i + 1][j] if i + 1 < n else j
246
247 for k in range(lo, min(hi, j) + 1):
248 curr = dp[i][k] + dp[k + 1][j] + cost[i][j]
249 if curr < dp[i][j]:
250 dp[i][j] = curr
251 opt[i][j] = k
252
253 return dp, opt
254
255
256def optimal_bst(keys: List[int], freq: List[int]) -> int:
257 """
258 ์ต์ ์ด์ง ํ์ ํธ๋ฆฌ
259 freq[i] = keys[i]์ ์ ๊ทผ ๋น๋
260 """
261 n = len(keys)
262
263 # cost[i][j] = freq[i] + ... + freq[j]
264 cost = [[0] * n for _ in range(n)]
265 for i in range(n):
266 total = 0
267 for j in range(i, n):
268 total += freq[j]
269 cost[i][j] = total
270
271 dp, opt = knuth_optimization(n, cost)
272 return dp[0][n - 1]
273
274
275# =============================================================================
276# 4. 1D/1D DP ์ต์ ํ (Monotone Queue)
277# =============================================================================
278
279def sliding_window_max(arr: List[int], k: int) -> List[int]:
280 """
281 ์ฌ๋ผ์ด๋ฉ ์๋์ฐ ์ต๋๊ฐ
282 ์๊ฐ๋ณต์ก๋: O(n)
283 """
284 result = []
285 dq = deque() # (์ธ๋ฑ์ค, ๊ฐ)
286
287 for i, val in enumerate(arr):
288 # ์๋์ฐ ๋ฒ์ ๋ฐ ์ ๊ฑฐ
289 while dq and dq[0][0] <= i - k:
290 dq.popleft()
291
292 # ํ์ฌ ๊ฐ๋ณด๋ค ์์ ์์ ์ ๊ฑฐ
293 while dq and dq[-1][1] <= val:
294 dq.pop()
295
296 dq.append((i, val))
297
298 if i >= k - 1:
299 result.append(dq[0][1])
300
301 return result
302
303
304def dp_with_monotone_queue(arr: List[int], k: int) -> List[int]:
305 """
306 DP with Monotone Queue
307 dp[i] = max(dp[j] + arr[i]) for i-k <= j < i
308 ์๊ฐ๋ณต์ก๋: O(n)
309 """
310 n = len(arr)
311 dp = [0] * n
312 dq = deque() # (์ธ๋ฑ์ค, dp๊ฐ)
313
314 for i in range(n):
315 # ๋ฒ์ ๋ฐ ์ ๊ฑฐ
316 while dq and dq[0][0] < i - k:
317 dq.popleft()
318
319 # ์ต๋๊ฐ์ผ๋ก dp ๊ณ์ฐ
320 if dq:
321 dp[i] = dq[0][1] + arr[i]
322 else:
323 dp[i] = arr[i]
324
325 # ํ์ฌ dp๊ฐ ์ฝ์
326 while dq and dq[-1][1] <= dp[i]:
327 dq.pop()
328 dq.append((i, dp[i]))
329
330 return dp
331
332
333# =============================================================================
334# 5. Slope Trick
335# =============================================================================
336
337class SlopeTrick:
338 """
339 Slope Trick
340 ๋ณผ๋ก ํจ์์ ํจ์จ์ ์ธ ๊ด๋ฆฌ
341 ์ ๋๊ฐ ํจ์์ ํฉ ์ต์ ํ์ ์ ์ฉ
342 """
343
344 def __init__(self):
345 import heapq
346 self.left = [] # ์ต๋ ํ (์์๋ก ์ ์ฅ)
347 self.right = [] # ์ต์ ํ
348 self.min_f = 0
349 self.add_l = 0 # left์ ๋ํ ๊ฐ
350 self.add_r = 0 # right์ ๋ํ ๊ฐ
351
352 def add_abs(self, a: int):
353 """
354 f(x) += |x - a|
355 """
356 import heapq
357
358 # a ์์น์์ ๊ธฐ์ธ๊ธฐ ๋ณํ: ์ผ์ชฝ +1, ์ค๋ฅธ์ชฝ -1
359 l = -self.left[0] + self.add_l if self.left else -inf
360 r = self.right[0] + self.add_r if self.right else inf
361
362 if a <= l:
363 # a๊ฐ ์ผ์ชฝ์ ์์น
364 self.min_f += l - a
365 heapq.heappush(self.left, -(a - self.add_l))
366 # ์ผ์ชฝ ์ต๋๊ฐ์ ์ค๋ฅธ์ชฝ์ผ๋ก
367 val = -heapq.heappop(self.left) + self.add_l
368 heapq.heappush(self.right, val - self.add_r)
369 elif a >= r:
370 # a๊ฐ ์ค๋ฅธ์ชฝ์ ์์น
371 self.min_f += a - r
372 heapq.heappush(self.right, a - self.add_r)
373 # ์ค๋ฅธ์ชฝ ์ต์๊ฐ์ ์ผ์ชฝ์ผ๋ก
374 val = heapq.heappop(self.right) + self.add_r
375 heapq.heappush(self.left, -(val - self.add_l))
376 else:
377 # a๊ฐ ํํํ ๊ตฌ๊ฐ์ ์์น
378 heapq.heappush(self.left, -(a - self.add_l))
379 heapq.heappush(self.right, a - self.add_r)
380
381 def shift(self, a: int, b: int):
382 """
383 f(x) โ f(x-a) (์ผ์ชฝ ์ด๋), f(x) โ f(x-b) (์ค๋ฅธ์ชฝ ์ด๋)
384 ํํํ ๊ตฌ๊ฐ ํ์ฅ
385 """
386 self.add_l += a
387 self.add_r += b
388
389 def get_min(self) -> int:
390 """์ต์๊ฐ ๋ฐํ"""
391 return self.min_f
392
393
394# =============================================================================
395# 6. Alien Trick (WQS Binary Search)
396# =============================================================================
397
398def alien_trick_example(arr: List[int], k: int) -> int:
399 """
400 Alien Trick (WQS Binary Search / Lagrange Relaxation)
401 ์ ํํ k๊ฐ์ ์์๋ฅผ ์ ํํ๋ ๋ฌธ์ ๋ฅผ ์ด์
402
403 ์: ๋ฐฐ์ด์์ ์ ํํ k๊ฐ ์ ํ, ์ธ์ ํ ๊ฒ ๋ถ๊ฐ, ํฉ ์ต๋ํ
404 """
405
406 def check(penalty: float) -> Tuple[float, int]:
407 """
408 penalty๋ฅผ ์ฌ์ฉํ ์ด์ ๋ฌธ์ ํด๊ฒฐ
409 ๋ฐํ: (์ต์ ๊ฐ, ์ ํํ ๊ฐ์)
410 """
411 n = len(arr)
412 # dp[i][0]: i๊น์ง, arr[i] ์ ํ ์ํจ
413 # dp[i][1]: i๊น์ง, arr[i] ์ ํํจ
414
415 dp = [[-inf, -inf] for _ in range(n)]
416 cnt = [[0, 0] for _ in range(n)]
417
418 dp[0][0] = 0
419 dp[0][1] = arr[0] - penalty
420 cnt[0][1] = 1
421
422 for i in range(1, n):
423 # ์ ํ ์ํจ
424 if dp[i - 1][0] > dp[i - 1][1]:
425 dp[i][0] = dp[i - 1][0]
426 cnt[i][0] = cnt[i - 1][0]
427 else:
428 dp[i][0] = dp[i - 1][1]
429 cnt[i][0] = cnt[i - 1][1]
430
431 # ์ ํํจ (์ด์ ์ ์ ํ ์ํ ์ํ์์๋ง)
432 dp[i][1] = dp[i - 1][0] + arr[i] - penalty
433 cnt[i][1] = cnt[i - 1][0] + 1
434
435 if dp[n - 1][0] > dp[n - 1][1]:
436 return dp[n - 1][0], cnt[n - 1][0]
437 return dp[n - 1][1], cnt[n - 1][1]
438
439 # ์ด๋ถ ํ์
440 lo, hi = -10**9, 10**9
441
442 while hi - lo > 1e-6:
443 mid = (lo + hi) / 2
444 _, count = check(mid)
445 if count >= k:
446 lo = mid
447 else:
448 hi = mid
449
450 result, _ = check(lo)
451 return int(result + lo * k)
452
453
454# =============================================================================
455# ํ
์คํธ
456# =============================================================================
457
458def main():
459 print("=" * 60)
460 print("๊ณ ๊ธ DP ์ต์ ํ (Advanced DP Optimization) ์์ ")
461 print("=" * 60)
462
463 # 1. Convex Hull Trick
464 print("\n[1] Convex Hull Trick (CHT)")
465 cht = ConvexHullTrick()
466 # ์ง์ : y = -3x + 10, y = -2x + 5, y = -1x + 3
467 cht.add_line(-3, 10)
468 cht.add_line(-2, 5)
469 cht.add_line(-1, 3)
470 print(" ์ง์ ๋ค: y=-3x+10, y=-2x+5, y=-x+3")
471 for x in [0, 1, 2, 3, 4, 5]:
472 print(f" min at x={x}: {cht.query_min(x)}")
473
474 # 2. Li Chao Tree
475 print("\n[2] Li Chao Tree")
476 lct = LiChaoTree(-100, 100)
477 lct.add_line(2, 5) # y = 2x + 5
478 lct.add_line(-1, 10) # y = -x + 10
479 lct.add_line(1, 0) # y = x
480 print(" ์ง์ ๋ค: y=2x+5, y=-x+10, y=x")
481 for x in [-5, 0, 3, 7]:
482 print(f" min at x={x}: {lct.query(x)}")
483
484 # 3. Divide and Conquer Optimization
485 print("\n[3] ๋ถํ ์ ๋ณต ์ต์ ํ")
486 result = dc_optimization_example()
487 print(f" ๋ฐฐ์ด [1,5,2,8,3,7,4,6]์ 3๊ทธ๋ฃน์ผ๋ก ๋ถํ ")
488 print(f" ์ต์ ๋น์ฉ: {result}")
489
490 # 4. Knuth Optimization
491 print("\n[4] Knuth ์ต์ ํ (์ต์ BST)")
492 keys = [10, 20, 30, 40]
493 freq = [4, 2, 6, 3]
494 cost = optimal_bst(keys, freq)
495 print(f" ํค: {keys}")
496 print(f" ๋น๋: {freq}")
497 print(f" ์ต์ ํ์ ๋น์ฉ: {cost}")
498
499 # 5. Monotone Queue
500 print("\n[5] ๋ชจ๋
ธํค ํ ์ต์ ํ")
501 arr = [1, 3, -1, -3, 5, 3, 6, 7]
502 k = 3
503 result = sliding_window_max(arr, k)
504 print(f" ๋ฐฐ์ด: {arr}")
505 print(f" ์๋์ฐ ํฌ๊ธฐ: {k}")
506 print(f" ์ฌ๋ผ์ด๋ฉ ์ต๋๊ฐ: {result}")
507
508 # 6. DP with Monotone Queue
509 print("\n[6] ๋ชจ๋
ธํค ํ DP")
510 arr = [2, 1, 5, 1, 3, 2]
511 k = 2
512 dp = dp_with_monotone_queue(arr, k)
513 print(f" ๋ฐฐ์ด: {arr}, k={k}")
514 print(f" dp[i] = max(dp[j] + arr[i]) for i-k <= j < i")
515 print(f" DP: {dp}")
516
517 # 7. Slope Trick
518 print("\n[7] Slope Trick")
519 st = SlopeTrick()
520 points = [1, 5, 2, 8]
521 for p in points:
522 st.add_abs(p)
523 print(f" ์ ๋ค: {points}")
524 print(f" f(x) = sum(|x - p|) ์ต์๊ฐ: {st.get_min()}")
525
526 # 8. ๋ณต์ก๋ ๋น๊ต
527 print("\n[8] ์ต์ ํ ๊ธฐ๋ฒ ๋น๊ต")
528 print(" | ๊ธฐ๋ฒ | ์๋ ๋ณต์ก๋ | ์ต์ ํ ํ | ์กฐ๊ฑด |")
529 print(" |-------------------|-------------|-------------|-------------------------|")
530 print(" | CHT | O(nยฒ) | O(n) | ๊ธฐ์ธ๊ธฐ ๋จ์กฐ |")
531 print(" | Li Chao Tree | O(nยฒ) | O(n log C) | ์์ |")
532 print(" | D&C Optimization | O(knยฒ) | O(kn log n) | opt ๋จ์กฐ์ฑ |")
533 print(" | Knuth Optimization| O(nยณ) | O(nยฒ) | ์ฌ๊ฐ ๋ถ๋ฑ์ |")
534 print(" | Monotone Queue | O(nk) | O(n) | ์๋์ฐ ์ต์ ํ |")
535 print(" | Alien Trick | ์ ์ฝ ๋ฌธ์ | ์ด์ | ๋ณผ๋ก์ฑ |")
536
537 print("\n" + "=" * 60)
538
539
540if __name__ == "__main__":
541 main()