23_segment_tree.py

Download
python 460 lines 15.7 KB
  1"""
  2세그먼트 트리 (Segment Tree)
  3Segment Tree for Range Queries
  4
  5구간 쿼리와 점 업데이트를 효율적으로 처리하는 자료구조입니다.
  6"""
  7
  8from typing import List, Callable, Optional
  9
 10
 11# =============================================================================
 12# 1. 기본 세그먼트 트리 (구간 합)
 13# =============================================================================
 14
 15class SegmentTree:
 16    """
 17    세그먼트 트리 (구간 합)
 18    - 점 업데이트: O(log n)
 19    - 구간 쿼리: O(log n)
 20    - 공간: O(n)
 21    """
 22
 23    def __init__(self, arr: List[int]):
 24        self.n = len(arr)
 25        self.tree = [0] * (4 * self.n)
 26        if self.n > 0:
 27            self._build(arr, 1, 0, self.n - 1)
 28
 29    def _build(self, arr: List[int], node: int, start: int, end: int):
 30        """트리 구성 - O(n)"""
 31        if start == end:
 32            self.tree[node] = arr[start]
 33        else:
 34            mid = (start + end) // 2
 35            self._build(arr, 2 * node, start, mid)
 36            self._build(arr, 2 * node + 1, mid + 1, end)
 37            self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
 38
 39    def update(self, idx: int, val: int):
 40        """점 업데이트 - O(log n)"""
 41        self._update(1, 0, self.n - 1, idx, val)
 42
 43    def _update(self, node: int, start: int, end: int, idx: int, val: int):
 44        if start == end:
 45            self.tree[node] = val
 46        else:
 47            mid = (start + end) // 2
 48            if idx <= mid:
 49                self._update(2 * node, start, mid, idx, val)
 50            else:
 51                self._update(2 * node + 1, mid + 1, end, idx, val)
 52            self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
 53
 54    def query(self, left: int, right: int) -> int:
 55        """구간 합 쿼리 - O(log n)"""
 56        return self._query(1, 0, self.n - 1, left, right)
 57
 58    def _query(self, node: int, start: int, end: int, left: int, right: int) -> int:
 59        if right < start or end < left:
 60            return 0  # 범위 벗어남
 61        if left <= start and end <= right:
 62            return self.tree[node]  # 완전 포함
 63
 64        mid = (start + end) // 2
 65        left_sum = self._query(2 * node, start, mid, left, right)
 66        right_sum = self._query(2 * node + 1, mid + 1, end, left, right)
 67        return left_sum + right_sum
 68
 69
 70# =============================================================================
 71# 2. 일반 세그먼트 트리 (임의 연산)
 72# =============================================================================
 73
 74class GenericSegmentTree:
 75    """
 76    일반 세그먼트 트리 (임의의 결합 연산)
 77    - 결합 법칙을 만족하는 연산이면 사용 가능
 78    """
 79
 80    def __init__(self, arr: List[int], func: Callable[[int, int], int], identity: int):
 81        """
 82        func: 결합 연산 (예: min, max, gcd, +, *)
 83        identity: 항등원 (예: inf for min, 0 for +, 1 for *)
 84        """
 85        self.n = len(arr)
 86        self.func = func
 87        self.identity = identity
 88        self.tree = [identity] * (4 * self.n)
 89        if self.n > 0:
 90            self._build(arr, 1, 0, self.n - 1)
 91
 92    def _build(self, arr: List[int], node: int, start: int, end: int):
 93        if start == end:
 94            self.tree[node] = arr[start]
 95        else:
 96            mid = (start + end) // 2
 97            self._build(arr, 2 * node, start, mid)
 98            self._build(arr, 2 * node + 1, mid + 1, end)
 99            self.tree[node] = self.func(self.tree[2 * node], self.tree[2 * node + 1])
100
101    def update(self, idx: int, val: int):
102        self._update(1, 0, self.n - 1, idx, val)
103
104    def _update(self, node: int, start: int, end: int, idx: int, val: int):
105        if start == end:
106            self.tree[node] = val
107        else:
108            mid = (start + end) // 2
109            if idx <= mid:
110                self._update(2 * node, start, mid, idx, val)
111            else:
112                self._update(2 * node + 1, mid + 1, end, idx, val)
113            self.tree[node] = self.func(self.tree[2 * node], self.tree[2 * node + 1])
114
115    def query(self, left: int, right: int) -> int:
116        return self._query(1, 0, self.n - 1, left, right)
117
118    def _query(self, node: int, start: int, end: int, left: int, right: int) -> int:
119        if right < start or end < left:
120            return self.identity
121        if left <= start and end <= right:
122            return self.tree[node]
123
124        mid = (start + end) // 2
125        left_val = self._query(2 * node, start, mid, left, right)
126        right_val = self._query(2 * node + 1, mid + 1, end, left, right)
127        return self.func(left_val, right_val)
128
129
130# =============================================================================
131# 3. Lazy Propagation (구간 업데이트)
132# =============================================================================
133
134class LazySegmentTree:
135    """
136    Lazy Propagation 세그먼트 트리
137    - 구간 업데이트: O(log n)
138    - 구간 쿼리: O(log n)
139    """
140
141    def __init__(self, arr: List[int]):
142        self.n = len(arr)
143        self.tree = [0] * (4 * self.n)
144        self.lazy = [0] * (4 * self.n)
145        if self.n > 0:
146            self._build(arr, 1, 0, self.n - 1)
147
148    def _build(self, arr: List[int], node: int, start: int, end: int):
149        if start == end:
150            self.tree[node] = arr[start]
151        else:
152            mid = (start + end) // 2
153            self._build(arr, 2 * node, start, mid)
154            self._build(arr, 2 * node + 1, mid + 1, end)
155            self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
156
157    def _push_down(self, node: int, start: int, end: int):
158        """lazy 값 전파"""
159        if self.lazy[node] != 0:
160            mid = (start + end) // 2
161
162            # 왼쪽 자식
163            self.tree[2 * node] += self.lazy[node] * (mid - start + 1)
164            self.lazy[2 * node] += self.lazy[node]
165
166            # 오른쪽 자식
167            self.tree[2 * node + 1] += self.lazy[node] * (end - mid)
168            self.lazy[2 * node + 1] += self.lazy[node]
169
170            self.lazy[node] = 0
171
172    def update_range(self, left: int, right: int, val: int):
173        """구간 [left, right]에 val 더하기"""
174        self._update_range(1, 0, self.n - 1, left, right, val)
175
176    def _update_range(self, node: int, start: int, end: int, left: int, right: int, val: int):
177        if right < start or end < left:
178            return
179
180        if left <= start and end <= right:
181            self.tree[node] += val * (end - start + 1)
182            self.lazy[node] += val
183            return
184
185        self._push_down(node, start, end)
186
187        mid = (start + end) // 2
188        self._update_range(2 * node, start, mid, left, right, val)
189        self._update_range(2 * node + 1, mid + 1, end, left, right, val)
190        self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
191
192    def query(self, left: int, right: int) -> int:
193        """구간 합 쿼리"""
194        return self._query(1, 0, self.n - 1, left, right)
195
196    def _query(self, node: int, start: int, end: int, left: int, right: int) -> int:
197        if right < start or end < left:
198            return 0
199
200        if left <= start and end <= right:
201            return self.tree[node]
202
203        self._push_down(node, start, end)
204
205        mid = (start + end) // 2
206        left_sum = self._query(2 * node, start, mid, left, right)
207        right_sum = self._query(2 * node + 1, mid + 1, end, left, right)
208        return left_sum + right_sum
209
210
211# =============================================================================
212# 4. 반복 세그먼트 트리 (Iterative)
213# =============================================================================
214
215class IterativeSegmentTree:
216    """
217    반복 세그먼트 트리 (비재귀)
218    메모리 효율적, 캐시 친화적
219    """
220
221    def __init__(self, arr: List[int]):
222        self.n = len(arr)
223        self.tree = [0] * (2 * self.n)
224
225        # 리프 노드 채우기
226        for i in range(self.n):
227            self.tree[self.n + i] = arr[i]
228
229        # 내부 노드 구성
230        for i in range(self.n - 1, 0, -1):
231            self.tree[i] = self.tree[2 * i] + self.tree[2 * i + 1]
232
233    def update(self, idx: int, val: int):
234        """점 업데이트"""
235        idx += self.n
236        self.tree[idx] = val
237
238        while idx > 1:
239            idx //= 2
240            self.tree[idx] = self.tree[2 * idx] + self.tree[2 * idx + 1]
241
242    def query(self, left: int, right: int) -> int:
243        """구간 [left, right] 합"""
244        left += self.n
245        right += self.n + 1
246        result = 0
247
248        while left < right:
249            if left % 2 == 1:
250                result += self.tree[left]
251                left += 1
252            if right % 2 == 1:
253                right -= 1
254                result += self.tree[right]
255            left //= 2
256            right //= 2
257
258        return result
259
260
261# =============================================================================
262# 5. 2D 세그먼트 트리
263# =============================================================================
264
265class SegmentTree2D:
266    """
267    2D 세그먼트 트리 (구간 합)
268    - 쿼리/업데이트: O(log n * log m)
269    """
270
271    def __init__(self, matrix: List[List[int]]):
272        if not matrix or not matrix[0]:
273            self.n = self.m = 0
274            return
275
276        self.n = len(matrix)
277        self.m = len(matrix[0])
278        self.tree = [[0] * (4 * self.m) for _ in range(4 * self.n)]
279        self._build_x(matrix, 1, 0, self.n - 1)
280
281    def _build_x(self, matrix: List[List[int]], node_x: int, lx: int, rx: int):
282        if lx == rx:
283            self._build_y(matrix, node_x, lx, rx, 1, 0, self.m - 1, lx)
284        else:
285            mid = (lx + rx) // 2
286            self._build_x(matrix, 2 * node_x, lx, mid)
287            self._build_x(matrix, 2 * node_x + 1, mid + 1, rx)
288            self._merge_y(node_x, 1, 0, self.m - 1)
289
290    def _build_y(self, matrix, node_x, lx, rx, node_y, ly, ry, row):
291        if ly == ry:
292            self.tree[node_x][node_y] = matrix[row][ly]
293        else:
294            mid = (ly + ry) // 2
295            self._build_y(matrix, node_x, lx, rx, 2 * node_y, ly, mid, row)
296            self._build_y(matrix, node_x, lx, rx, 2 * node_y + 1, mid + 1, ry, row)
297            self.tree[node_x][node_y] = self.tree[node_x][2 * node_y] + self.tree[node_x][2 * node_y + 1]
298
299    def _merge_y(self, node_x, node_y, ly, ry):
300        if ly == ry:
301            self.tree[node_x][node_y] = self.tree[2 * node_x][node_y] + self.tree[2 * node_x + 1][node_y]
302        else:
303            mid = (ly + ry) // 2
304            self._merge_y(node_x, 2 * node_y, ly, mid)
305            self._merge_y(node_x, 2 * node_y + 1, mid + 1, ry)
306            self.tree[node_x][node_y] = self.tree[node_x][2 * node_y] + self.tree[node_x][2 * node_y + 1]
307
308    def query(self, x1: int, y1: int, x2: int, y2: int) -> int:
309        """(x1,y1) ~ (x2,y2) 직사각형 구간 합"""
310        return self._query_x(1, 0, self.n - 1, x1, x2, y1, y2)
311
312    def _query_x(self, node_x, lx, rx, x1, x2, y1, y2):
313        if x2 < lx or rx < x1:
314            return 0
315        if x1 <= lx and rx <= x2:
316            return self._query_y(node_x, 1, 0, self.m - 1, y1, y2)
317
318        mid = (lx + rx) // 2
319        left = self._query_x(2 * node_x, lx, mid, x1, x2, y1, y2)
320        right = self._query_x(2 * node_x + 1, mid + 1, rx, x1, x2, y1, y2)
321        return left + right
322
323    def _query_y(self, node_x, node_y, ly, ry, y1, y2):
324        if y2 < ly or ry < y1:
325            return 0
326        if y1 <= ly and ry <= y2:
327            return self.tree[node_x][node_y]
328
329        mid = (ly + ry) // 2
330        left = self._query_y(node_x, 2 * node_y, ly, mid, y1, y2)
331        right = self._query_y(node_x, 2 * node_y + 1, mid + 1, ry, y1, y2)
332        return left + right
333
334
335# =============================================================================
336# 6. 응용: 역순 쌍 개수 (Inversion Count)
337# =============================================================================
338
339def count_inversions_segtree(arr: List[int]) -> int:
340    """
341    역순 쌍 개수 (세그먼트 트리 활용)
342    시간복잡도: O(n log n)
343    """
344    if not arr:
345        return 0
346
347    # 좌표 압축
348    sorted_arr = sorted(set(arr))
349    rank = {v: i for i, v in enumerate(sorted_arr)}
350    n = len(sorted_arr)
351
352    # 세그먼트 트리 (빈도 저장)
353    tree = [0] * (4 * n)
354
355    def update(node, start, end, idx):
356        if start == end:
357            tree[node] += 1
358        else:
359            mid = (start + end) // 2
360            if idx <= mid:
361                update(2 * node, start, mid, idx)
362            else:
363                update(2 * node + 1, mid + 1, end, idx)
364            tree[node] = tree[2 * node] + tree[2 * node + 1]
365
366    def query(node, start, end, left, right):
367        if right < start or end < left:
368            return 0
369        if left <= start and end <= right:
370            return tree[node]
371        mid = (start + end) // 2
372        return query(2 * node, start, mid, left, right) + \
373               query(2 * node + 1, mid + 1, end, left, right)
374
375    inversions = 0
376    for val in arr:
377        r = rank[val]
378        # r보다 큰 값의 개수 (이미 삽입된 것 중)
379        inversions += query(1, 0, n - 1, r + 1, n - 1)
380        update(1, 0, n - 1, r)
381
382    return inversions
383
384
385# =============================================================================
386# 테스트
387# =============================================================================
388
389def main():
390    print("=" * 60)
391    print("세그먼트 트리 (Segment Tree) 예제")
392    print("=" * 60)
393
394    # 1. 기본 세그먼트 트리
395    print("\n[1] 기본 세그먼트 트리 (구간 합)")
396    arr = [1, 3, 5, 7, 9, 11]
397    st = SegmentTree(arr)
398    print(f"    배열: {arr}")
399    print(f"    query(1, 3): {st.query(1, 3)}")  # 3+5+7=15
400    st.update(2, 6)  # 5 -> 6
401    print(f"    update(2, 6) 후 query(1, 3): {st.query(1, 3)}")  # 3+6+7=16
402
403    # 2. 일반 세그먼트 트리 (최소값)
404    print("\n[2] 일반 세그먼트 트리 (구간 최소)")
405    arr = [5, 2, 8, 1, 9, 3]
406    min_st = GenericSegmentTree(arr, min, float('inf'))
407    print(f"    배열: {arr}")
408    print(f"    min(1, 4): {min_st.query(1, 4)}")  # min(2,8,1,9)=1
409    min_st.update(3, 10)  # 1 -> 10
410    print(f"    update(3, 10) 후 min(1, 4): {min_st.query(1, 4)}")  # min(2,8,10,9)=2
411
412    # 3. Lazy Propagation
413    print("\n[3] Lazy Propagation (구간 업데이트)")
414    arr = [1, 2, 3, 4, 5]
415    lazy_st = LazySegmentTree(arr)
416    print(f"    배열: {arr}")
417    print(f"    query(0, 4): {lazy_st.query(0, 4)}")  # 15
418    lazy_st.update_range(1, 3, 10)  # [1,3] 구간에 10 더하기
419    print(f"    update_range(1, 3, +10) 후 query(0, 4): {lazy_st.query(0, 4)}")  # 45
420
421    # 4. 반복 세그먼트 트리
422    print("\n[4] 반복 세그먼트 트리")
423    arr = [1, 3, 5, 7, 9, 11]
424    iter_st = IterativeSegmentTree(arr)
425    print(f"    배열: {arr}")
426    print(f"    query(1, 4): {iter_st.query(1, 4)}")  # 3+5+7+9=24
427
428    # 5. 2D 세그먼트 트리
429    print("\n[5] 2D 세그먼트 트리")
430    matrix = [
431        [1, 2, 3],
432        [4, 5, 6],
433        [7, 8, 9]
434    ]
435    st2d = SegmentTree2D(matrix)
436    print(f"    행렬: {matrix}")
437    print(f"    query(0,0,1,1): {st2d.query(0, 0, 1, 1)}")  # 1+2+4+5=12
438    print(f"    query(1,1,2,2): {st2d.query(1, 1, 2, 2)}")  # 5+6+8+9=28
439
440    # 6. 역순 쌍 개수
441    print("\n[6] 역순 쌍 개수")
442    arr = [2, 4, 1, 3, 5]
443    inv = count_inversions_segtree(arr)
444    print(f"    배열: {arr}")
445    print(f"    역순 쌍 개수: {inv}")  # (2,1), (4,1), (4,3) = 3
446
447    # 7. 복잡도 비교
448    print("\n[7] 복잡도 비교")
449    print("    | 연산       | 배열    | 세그먼트 트리 | Lazy      |")
450    print("    |------------|---------|---------------|-----------|")
451    print("    | 점 업데이트| O(1)    | O(log n)      | O(log n)  |")
452    print("    | 구간 업데이트| O(n)  | O(n)          | O(log n)  |")
453    print("    | 구간 쿼리  | O(n)    | O(log n)      | O(log n)  |")
454
455    print("\n" + "=" * 60)
456
457
458if __name__ == "__main__":
459    main()