cht.py

Download
python 302 lines 9.2 KB
  1"""
  2Convex Hull Trick (CHT) - DP Optimization Technique
  3
  4The Convex Hull Trick is used to optimize certain DP recurrences from O(n^2) to O(n log n).
  5It's applicable when the recurrence has the form:
  6    dp[i] = min/max(dp[j] + cost(j, i)) for all j < i
  7where cost(j, i) can be expressed as a linear function in terms of j.
  8
  9Typical form: dp[i] = min(dp[j] + a[i] * b[j] + c[i])
 10This can be rewritten as: dp[i] = min(b[j] * a[i] + (dp[j] + c[i]))
 11Which is a line equation: y = mx + b, where m = b[j], b = dp[j] + c[i], x = a[i]
 12
 13Time Complexity: O(n log n) with Li Chao Tree, O(n) if queries are monotonic
 14Space Complexity: O(n)
 15"""
 16
 17from typing import List, Tuple, Optional
 18from dataclasses import dataclass
 19
 20
 21@dataclass
 22class Line:
 23    """Represents a line y = m*x + c"""
 24    m: float  # slope
 25    c: float  # y-intercept
 26
 27    def eval(self, x: float) -> float:
 28        """Evaluate line at position x"""
 29        return self.m * x + self.c
 30
 31    def intersect_x(self, other: 'Line') -> float:
 32        """Find x-coordinate where this line intersects with other"""
 33        if self.m == other.m:
 34            return float('inf')
 35        return (other.c - self.c) / (self.m - other.m)
 36
 37
 38class ConvexHullTrick:
 39    """
 40    Convex Hull Trick for minimum queries.
 41    Maintains a lower envelope of lines.
 42    Assumes lines are added in decreasing order of slope (for online variant).
 43    """
 44
 45    def __init__(self):
 46        self.lines: List[Line] = []
 47
 48    def _bad_line(self, l1: Line, l2: Line, l3: Line) -> bool:
 49        """
 50        Check if line l2 is redundant (will never be optimal).
 51        l2 is bad if the intersection of (l1, l3) is left of intersection of (l1, l2).
 52        """
 53        # Cross product method to avoid division
 54        # (l3.c - l1.c) * (l1.m - l2.m) < (l2.c - l1.c) * (l1.m - l3.m)
 55        return (l3.c - l1.c) * (l1.m - l2.m) <= (l2.c - l1.c) * (l1.m - l3.m)
 56
 57    def add_line(self, line: Line) -> None:
 58        """
 59        Add a line to the convex hull.
 60        Assumes lines are added in decreasing order of slope.
 61        """
 62        # Remove lines that become irrelevant
 63        while len(self.lines) >= 2:
 64            if self._bad_line(self.lines[-2], self.lines[-1], line):
 65                self.lines.pop()
 66            else:
 67                break
 68        self.lines.append(line)
 69
 70    def query(self, x: float) -> float:
 71        """
 72        Find minimum value at position x.
 73        Uses binary search if queries are not monotonic.
 74        """
 75        if not self.lines:
 76            return float('inf')
 77
 78        # Binary search for the best line
 79        left, right = 0, len(self.lines) - 1
 80        while left < right:
 81            mid = (left + right) // 2
 82            # Check if line at mid or mid+1 is better at x
 83            if self.lines[mid].eval(x) > self.lines[mid + 1].eval(x):
 84                left = mid + 1
 85            else:
 86                right = mid
 87
 88        return self.lines[left].eval(x)
 89
 90
 91class ConvexHullTrickMax:
 92    """
 93    Convex Hull Trick for maximum queries.
 94    Maintains an upper envelope of lines.
 95    """
 96
 97    def __init__(self):
 98        self.cht_min = ConvexHullTrick()
 99
100    def add_line(self, line: Line) -> None:
101        """Add line by negating for maximum query"""
102        self.cht_min.add_line(Line(-line.m, -line.c))
103
104    def query(self, x: float) -> float:
105        """Find maximum value at position x"""
106        return -self.cht_min.query(x)
107
108
109class LiChaoTree:
110    """
111    Li Chao Tree - Supports dynamic line insertion without slope restrictions.
112    Works for any order of line insertion.
113    Time Complexity: O(log n) per insertion and query
114    """
115
116    def __init__(self, x_min: int, x_max: int):
117        self.x_min = x_min
118        self.x_max = x_max
119        self.tree: dict = {}  # node_id -> Line
120
121    def _update(self, line: Line, node_id: int, left: int, right: int) -> None:
122        """Recursively insert line into the tree"""
123        mid = (left + right) // 2
124
125        if node_id not in self.tree:
126            self.tree[node_id] = line
127            return
128
129        cur_line = self.tree[node_id]
130
131        # Determine which line is better at left, mid, right
132        left_better = line.eval(left) < cur_line.eval(left)
133        mid_better = line.eval(mid) < cur_line.eval(mid)
134
135        if mid_better:
136            line, self.tree[node_id] = self.tree[node_id], line
137
138        if left == right:
139            return
140
141        if left_better != mid_better:
142            self._update(line, 2 * node_id, left, mid)
143        else:
144            self._update(line, 2 * node_id + 1, mid + 1, right)
145
146    def add_line(self, line: Line) -> None:
147        """Add a line to the Li Chao Tree"""
148        self._update(line, 1, self.x_min, self.x_max)
149
150    def _query(self, x: int, node_id: int, left: int, right: int) -> float:
151        """Recursively query minimum value at position x"""
152        if node_id not in self.tree:
153            return float('inf')
154
155        if left == right:
156            return self.tree[node_id].eval(x)
157
158        mid = (left + right) // 2
159        result = self.tree[node_id].eval(x)
160
161        if x <= mid:
162            result = min(result, self._query(x, 2 * node_id, left, mid))
163        else:
164            result = min(result, self._query(x, 2 * node_id + 1, mid + 1, right))
165
166        return result
167
168    def query(self, x: int) -> float:
169        """Find minimum value at position x"""
170        return self._query(x, 1, self.x_min, self.x_max)
171
172
173def solve_machine_cost_problem(n: int, costs: List[int], machines: List[int]) -> int:
174    """
175    Example Problem: Machine Cost Minimization
176
177    Problem: You have n tasks. For task i, you can:
178    - Use a new machine with cost costs[i]
179    - Use a previous machine j (j < i) with cost (machines[i] - machines[j])^2
180
181    Find minimum total cost.
182
183    DP recurrence: dp[i] = min(costs[i], min(dp[j] + (machines[i] - machines[j])^2))
184    Expand: dp[i] = min(costs[i], min(dp[j] + machines[i]^2 - 2*machines[i]*machines[j] + machines[j]^2))
185    Rewrite as line: y = mx + b where m = -2*machines[j], b = dp[j] + machines[j]^2
186    Query at x = machines[i], then add machines[i]^2
187    """
188    dp = [float('inf')] * (n + 1)
189    dp[0] = 0
190
191    cht = ConvexHullTrick()
192    # Initial line for dp[0]
193    cht.add_line(Line(-2 * machines[0], dp[0] + machines[0] ** 2))
194
195    for i in range(1, n + 1):
196        # Option 1: Use new machine
197        dp[i] = costs[i - 1]
198
199        # Option 2: Use previous machine (query CHT)
200        if machines[i - 1] >= machines[0]:  # Ensure valid query
201            cost_from_prev = cht.query(machines[i - 1]) + machines[i - 1] ** 2
202            dp[i] = min(dp[i], cost_from_prev)
203
204        # Add current state as a line for future queries
205        cht.add_line(Line(-2 * machines[i], dp[i] + machines[i] ** 2))
206
207    return int(dp[n])
208
209
210def solve_slope_optimization(n: int, a: List[int], b: List[int]) -> int:
211    """
212    Example Problem: Slope Optimization
213
214    DP recurrence: dp[i] = min(dp[j] + a[i] * b[j]) for j < i
215    This is a line equation: y = b[j] * x + dp[j], where x = a[i]
216    """
217    dp = [float('inf')] * n
218    dp[0] = 0
219
220    cht = ConvexHullTrick()
221    cht.add_line(Line(b[0], dp[0]))
222
223    for i in range(1, n):
224        dp[i] = cht.query(a[i])
225        cht.add_line(Line(b[i], dp[i]))
226
227    return int(dp[n - 1])
228
229
230if __name__ == "__main__":
231    print("=== Convex Hull Trick Examples ===\n")
232
233    # Test 1: Machine Cost Problem
234    print("Test 1: Machine Cost Problem")
235    n = 5
236    costs = [10, 15, 20, 12, 18]
237    machines = [1, 2, 3, 4, 5]
238    result = solve_machine_cost_problem(n, costs, machines)
239    print(f"Tasks: {n}, Costs: {costs}, Machines: {machines}")
240    print(f"Minimum cost: {result}")
241    print()
242
243    # Test 2: Slope Optimization
244    print("Test 2: Slope Optimization")
245    n = 6
246    a = [1, 2, 3, 4, 5, 6]
247    b = [6, 5, 4, 3, 2, 1]
248    result = solve_slope_optimization(n, a, b)
249    print(f"n: {n}, a: {a}, b: {b}")
250    print(f"Minimum dp value: {result}")
251    print()
252
253    # Test 3: Li Chao Tree
254    print("Test 3: Li Chao Tree (dynamic line insertion)")
255    li_chao = LiChaoTree(0, 100)
256
257    # Add lines in arbitrary order
258    lines = [
259        Line(2, 5),    # y = 2x + 5
260        Line(-1, 20),  # y = -x + 20
261        Line(0.5, 10), # y = 0.5x + 10
262    ]
263
264    for line in lines:
265        li_chao.add_line(line)
266        print(f"Added line: y = {line.m}x + {line.c}")
267
268    # Query at different points
269    query_points = [0, 5, 10, 15, 20]
270    print("\nQuery results:")
271    for x in query_points:
272        min_val = li_chao.query(x)
273        print(f"  x = {x}: min value = {min_val:.2f}")
274    print()
275
276    # Test 4: ConvexHullTrickMax
277    print("Test 4: Maximum Query with CHT")
278    cht_max = ConvexHullTrickMax()
279
280    # Add lines (must be in decreasing slope order for simple CHT)
281    max_lines = [
282        Line(5, 10),   # y = 5x + 10
283        Line(3, 15),   # y = 3x + 15
284        Line(1, 20),   # y = x + 20
285    ]
286
287    for line in max_lines:
288        cht_max.add_line(line)
289        print(f"Added line: y = {line.m}x + {line.c}")
290
291    print("\nMaximum query results:")
292    for x in [0, 2, 5, 10]:
293        max_val = cht_max.query(x)
294        print(f"  x = {x}: max value = {max_val:.2f}")
295    print()
296
297    print("=== Complexity Analysis ===")
298    print("CHT with monotonic queries: O(n) total")
299    print("CHT with binary search: O(n log n)")
300    print("Li Chao Tree: O(n log n)")
301    print("\nOptimization: Reduces O(n^2) DP to O(n log n)")