1"""
2세그먼트 트리 (Segment Tree)
3Segment Tree for Range Queries
4
5구간 쿼리와 점 업데이트를 효율적으로 처리하는 자료구조입니다.
6"""
7
8from typing import List, Callable, Optional
9
10
11# =============================================================================
12# 1. 기본 세그먼트 트리 (구간 합)
13# =============================================================================
14
15class SegmentTree:
16 """
17 세그먼트 트리 (구간 합)
18 - 점 업데이트: O(log n)
19 - 구간 쿼리: O(log n)
20 - 공간: O(n)
21 """
22
23 def __init__(self, arr: List[int]):
24 self.n = len(arr)
25 self.tree = [0] * (4 * self.n)
26 if self.n > 0:
27 self._build(arr, 1, 0, self.n - 1)
28
29 def _build(self, arr: List[int], node: int, start: int, end: int):
30 """트리 구성 - O(n)"""
31 if start == end:
32 self.tree[node] = arr[start]
33 else:
34 mid = (start + end) // 2
35 self._build(arr, 2 * node, start, mid)
36 self._build(arr, 2 * node + 1, mid + 1, end)
37 self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
38
39 def update(self, idx: int, val: int):
40 """점 업데이트 - O(log n)"""
41 self._update(1, 0, self.n - 1, idx, val)
42
43 def _update(self, node: int, start: int, end: int, idx: int, val: int):
44 if start == end:
45 self.tree[node] = val
46 else:
47 mid = (start + end) // 2
48 if idx <= mid:
49 self._update(2 * node, start, mid, idx, val)
50 else:
51 self._update(2 * node + 1, mid + 1, end, idx, val)
52 self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
53
54 def query(self, left: int, right: int) -> int:
55 """구간 합 쿼리 - O(log n)"""
56 return self._query(1, 0, self.n - 1, left, right)
57
58 def _query(self, node: int, start: int, end: int, left: int, right: int) -> int:
59 if right < start or end < left:
60 return 0 # 범위 벗어남
61 if left <= start and end <= right:
62 return self.tree[node] # 완전 포함
63
64 mid = (start + end) // 2
65 left_sum = self._query(2 * node, start, mid, left, right)
66 right_sum = self._query(2 * node + 1, mid + 1, end, left, right)
67 return left_sum + right_sum
68
69
70# =============================================================================
71# 2. 일반 세그먼트 트리 (임의 연산)
72# =============================================================================
73
74class GenericSegmentTree:
75 """
76 일반 세그먼트 트리 (임의의 결합 연산)
77 - 결합 법칙을 만족하는 연산이면 사용 가능
78 """
79
80 def __init__(self, arr: List[int], func: Callable[[int, int], int], identity: int):
81 """
82 func: 결합 연산 (예: min, max, gcd, +, *)
83 identity: 항등원 (예: inf for min, 0 for +, 1 for *)
84 """
85 self.n = len(arr)
86 self.func = func
87 self.identity = identity
88 self.tree = [identity] * (4 * self.n)
89 if self.n > 0:
90 self._build(arr, 1, 0, self.n - 1)
91
92 def _build(self, arr: List[int], node: int, start: int, end: int):
93 if start == end:
94 self.tree[node] = arr[start]
95 else:
96 mid = (start + end) // 2
97 self._build(arr, 2 * node, start, mid)
98 self._build(arr, 2 * node + 1, mid + 1, end)
99 self.tree[node] = self.func(self.tree[2 * node], self.tree[2 * node + 1])
100
101 def update(self, idx: int, val: int):
102 self._update(1, 0, self.n - 1, idx, val)
103
104 def _update(self, node: int, start: int, end: int, idx: int, val: int):
105 if start == end:
106 self.tree[node] = val
107 else:
108 mid = (start + end) // 2
109 if idx <= mid:
110 self._update(2 * node, start, mid, idx, val)
111 else:
112 self._update(2 * node + 1, mid + 1, end, idx, val)
113 self.tree[node] = self.func(self.tree[2 * node], self.tree[2 * node + 1])
114
115 def query(self, left: int, right: int) -> int:
116 return self._query(1, 0, self.n - 1, left, right)
117
118 def _query(self, node: int, start: int, end: int, left: int, right: int) -> int:
119 if right < start or end < left:
120 return self.identity
121 if left <= start and end <= right:
122 return self.tree[node]
123
124 mid = (start + end) // 2
125 left_val = self._query(2 * node, start, mid, left, right)
126 right_val = self._query(2 * node + 1, mid + 1, end, left, right)
127 return self.func(left_val, right_val)
128
129
130# =============================================================================
131# 3. Lazy Propagation (구간 업데이트)
132# =============================================================================
133
134class LazySegmentTree:
135 """
136 Lazy Propagation 세그먼트 트리
137 - 구간 업데이트: O(log n)
138 - 구간 쿼리: O(log n)
139 """
140
141 def __init__(self, arr: List[int]):
142 self.n = len(arr)
143 self.tree = [0] * (4 * self.n)
144 self.lazy = [0] * (4 * self.n)
145 if self.n > 0:
146 self._build(arr, 1, 0, self.n - 1)
147
148 def _build(self, arr: List[int], node: int, start: int, end: int):
149 if start == end:
150 self.tree[node] = arr[start]
151 else:
152 mid = (start + end) // 2
153 self._build(arr, 2 * node, start, mid)
154 self._build(arr, 2 * node + 1, mid + 1, end)
155 self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
156
157 def _push_down(self, node: int, start: int, end: int):
158 """lazy 값 전파"""
159 if self.lazy[node] != 0:
160 mid = (start + end) // 2
161
162 # 왼쪽 자식
163 self.tree[2 * node] += self.lazy[node] * (mid - start + 1)
164 self.lazy[2 * node] += self.lazy[node]
165
166 # 오른쪽 자식
167 self.tree[2 * node + 1] += self.lazy[node] * (end - mid)
168 self.lazy[2 * node + 1] += self.lazy[node]
169
170 self.lazy[node] = 0
171
172 def update_range(self, left: int, right: int, val: int):
173 """구간 [left, right]에 val 더하기"""
174 self._update_range(1, 0, self.n - 1, left, right, val)
175
176 def _update_range(self, node: int, start: int, end: int, left: int, right: int, val: int):
177 if right < start or end < left:
178 return
179
180 if left <= start and end <= right:
181 self.tree[node] += val * (end - start + 1)
182 self.lazy[node] += val
183 return
184
185 self._push_down(node, start, end)
186
187 mid = (start + end) // 2
188 self._update_range(2 * node, start, mid, left, right, val)
189 self._update_range(2 * node + 1, mid + 1, end, left, right, val)
190 self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
191
192 def query(self, left: int, right: int) -> int:
193 """구간 합 쿼리"""
194 return self._query(1, 0, self.n - 1, left, right)
195
196 def _query(self, node: int, start: int, end: int, left: int, right: int) -> int:
197 if right < start or end < left:
198 return 0
199
200 if left <= start and end <= right:
201 return self.tree[node]
202
203 self._push_down(node, start, end)
204
205 mid = (start + end) // 2
206 left_sum = self._query(2 * node, start, mid, left, right)
207 right_sum = self._query(2 * node + 1, mid + 1, end, left, right)
208 return left_sum + right_sum
209
210
211# =============================================================================
212# 4. 반복 세그먼트 트리 (Iterative)
213# =============================================================================
214
215class IterativeSegmentTree:
216 """
217 반복 세그먼트 트리 (비재귀)
218 메모리 효율적, 캐시 친화적
219 """
220
221 def __init__(self, arr: List[int]):
222 self.n = len(arr)
223 self.tree = [0] * (2 * self.n)
224
225 # 리프 노드 채우기
226 for i in range(self.n):
227 self.tree[self.n + i] = arr[i]
228
229 # 내부 노드 구성
230 for i in range(self.n - 1, 0, -1):
231 self.tree[i] = self.tree[2 * i] + self.tree[2 * i + 1]
232
233 def update(self, idx: int, val: int):
234 """점 업데이트"""
235 idx += self.n
236 self.tree[idx] = val
237
238 while idx > 1:
239 idx //= 2
240 self.tree[idx] = self.tree[2 * idx] + self.tree[2 * idx + 1]
241
242 def query(self, left: int, right: int) -> int:
243 """구간 [left, right] 합"""
244 left += self.n
245 right += self.n + 1
246 result = 0
247
248 while left < right:
249 if left % 2 == 1:
250 result += self.tree[left]
251 left += 1
252 if right % 2 == 1:
253 right -= 1
254 result += self.tree[right]
255 left //= 2
256 right //= 2
257
258 return result
259
260
261# =============================================================================
262# 5. 2D 세그먼트 트리
263# =============================================================================
264
265class SegmentTree2D:
266 """
267 2D 세그먼트 트리 (구간 합)
268 - 쿼리/업데이트: O(log n * log m)
269 """
270
271 def __init__(self, matrix: List[List[int]]):
272 if not matrix or not matrix[0]:
273 self.n = self.m = 0
274 return
275
276 self.n = len(matrix)
277 self.m = len(matrix[0])
278 self.tree = [[0] * (4 * self.m) for _ in range(4 * self.n)]
279 self._build_x(matrix, 1, 0, self.n - 1)
280
281 def _build_x(self, matrix: List[List[int]], node_x: int, lx: int, rx: int):
282 if lx == rx:
283 self._build_y(matrix, node_x, lx, rx, 1, 0, self.m - 1, lx)
284 else:
285 mid = (lx + rx) // 2
286 self._build_x(matrix, 2 * node_x, lx, mid)
287 self._build_x(matrix, 2 * node_x + 1, mid + 1, rx)
288 self._merge_y(node_x, 1, 0, self.m - 1)
289
290 def _build_y(self, matrix, node_x, lx, rx, node_y, ly, ry, row):
291 if ly == ry:
292 self.tree[node_x][node_y] = matrix[row][ly]
293 else:
294 mid = (ly + ry) // 2
295 self._build_y(matrix, node_x, lx, rx, 2 * node_y, ly, mid, row)
296 self._build_y(matrix, node_x, lx, rx, 2 * node_y + 1, mid + 1, ry, row)
297 self.tree[node_x][node_y] = self.tree[node_x][2 * node_y] + self.tree[node_x][2 * node_y + 1]
298
299 def _merge_y(self, node_x, node_y, ly, ry):
300 if ly == ry:
301 self.tree[node_x][node_y] = self.tree[2 * node_x][node_y] + self.tree[2 * node_x + 1][node_y]
302 else:
303 mid = (ly + ry) // 2
304 self._merge_y(node_x, 2 * node_y, ly, mid)
305 self._merge_y(node_x, 2 * node_y + 1, mid + 1, ry)
306 self.tree[node_x][node_y] = self.tree[node_x][2 * node_y] + self.tree[node_x][2 * node_y + 1]
307
308 def query(self, x1: int, y1: int, x2: int, y2: int) -> int:
309 """(x1,y1) ~ (x2,y2) 직사각형 구간 합"""
310 return self._query_x(1, 0, self.n - 1, x1, x2, y1, y2)
311
312 def _query_x(self, node_x, lx, rx, x1, x2, y1, y2):
313 if x2 < lx or rx < x1:
314 return 0
315 if x1 <= lx and rx <= x2:
316 return self._query_y(node_x, 1, 0, self.m - 1, y1, y2)
317
318 mid = (lx + rx) // 2
319 left = self._query_x(2 * node_x, lx, mid, x1, x2, y1, y2)
320 right = self._query_x(2 * node_x + 1, mid + 1, rx, x1, x2, y1, y2)
321 return left + right
322
323 def _query_y(self, node_x, node_y, ly, ry, y1, y2):
324 if y2 < ly or ry < y1:
325 return 0
326 if y1 <= ly and ry <= y2:
327 return self.tree[node_x][node_y]
328
329 mid = (ly + ry) // 2
330 left = self._query_y(node_x, 2 * node_y, ly, mid, y1, y2)
331 right = self._query_y(node_x, 2 * node_y + 1, mid + 1, ry, y1, y2)
332 return left + right
333
334
335# =============================================================================
336# 6. 응용: 역순 쌍 개수 (Inversion Count)
337# =============================================================================
338
339def count_inversions_segtree(arr: List[int]) -> int:
340 """
341 역순 쌍 개수 (세그먼트 트리 활용)
342 시간복잡도: O(n log n)
343 """
344 if not arr:
345 return 0
346
347 # 좌표 압축
348 sorted_arr = sorted(set(arr))
349 rank = {v: i for i, v in enumerate(sorted_arr)}
350 n = len(sorted_arr)
351
352 # 세그먼트 트리 (빈도 저장)
353 tree = [0] * (4 * n)
354
355 def update(node, start, end, idx):
356 if start == end:
357 tree[node] += 1
358 else:
359 mid = (start + end) // 2
360 if idx <= mid:
361 update(2 * node, start, mid, idx)
362 else:
363 update(2 * node + 1, mid + 1, end, idx)
364 tree[node] = tree[2 * node] + tree[2 * node + 1]
365
366 def query(node, start, end, left, right):
367 if right < start or end < left:
368 return 0
369 if left <= start and end <= right:
370 return tree[node]
371 mid = (start + end) // 2
372 return query(2 * node, start, mid, left, right) + \
373 query(2 * node + 1, mid + 1, end, left, right)
374
375 inversions = 0
376 for val in arr:
377 r = rank[val]
378 # r보다 큰 값의 개수 (이미 삽입된 것 중)
379 inversions += query(1, 0, n - 1, r + 1, n - 1)
380 update(1, 0, n - 1, r)
381
382 return inversions
383
384
385# =============================================================================
386# 테스트
387# =============================================================================
388
389def main():
390 print("=" * 60)
391 print("세그먼트 트리 (Segment Tree) 예제")
392 print("=" * 60)
393
394 # 1. 기본 세그먼트 트리
395 print("\n[1] 기본 세그먼트 트리 (구간 합)")
396 arr = [1, 3, 5, 7, 9, 11]
397 st = SegmentTree(arr)
398 print(f" 배열: {arr}")
399 print(f" query(1, 3): {st.query(1, 3)}") # 3+5+7=15
400 st.update(2, 6) # 5 -> 6
401 print(f" update(2, 6) 후 query(1, 3): {st.query(1, 3)}") # 3+6+7=16
402
403 # 2. 일반 세그먼트 트리 (최소값)
404 print("\n[2] 일반 세그먼트 트리 (구간 최소)")
405 arr = [5, 2, 8, 1, 9, 3]
406 min_st = GenericSegmentTree(arr, min, float('inf'))
407 print(f" 배열: {arr}")
408 print(f" min(1, 4): {min_st.query(1, 4)}") # min(2,8,1,9)=1
409 min_st.update(3, 10) # 1 -> 10
410 print(f" update(3, 10) 후 min(1, 4): {min_st.query(1, 4)}") # min(2,8,10,9)=2
411
412 # 3. Lazy Propagation
413 print("\n[3] Lazy Propagation (구간 업데이트)")
414 arr = [1, 2, 3, 4, 5]
415 lazy_st = LazySegmentTree(arr)
416 print(f" 배열: {arr}")
417 print(f" query(0, 4): {lazy_st.query(0, 4)}") # 15
418 lazy_st.update_range(1, 3, 10) # [1,3] 구간에 10 더하기
419 print(f" update_range(1, 3, +10) 후 query(0, 4): {lazy_st.query(0, 4)}") # 45
420
421 # 4. 반복 세그먼트 트리
422 print("\n[4] 반복 세그먼트 트리")
423 arr = [1, 3, 5, 7, 9, 11]
424 iter_st = IterativeSegmentTree(arr)
425 print(f" 배열: {arr}")
426 print(f" query(1, 4): {iter_st.query(1, 4)}") # 3+5+7+9=24
427
428 # 5. 2D 세그먼트 트리
429 print("\n[5] 2D 세그먼트 트리")
430 matrix = [
431 [1, 2, 3],
432 [4, 5, 6],
433 [7, 8, 9]
434 ]
435 st2d = SegmentTree2D(matrix)
436 print(f" 행렬: {matrix}")
437 print(f" query(0,0,1,1): {st2d.query(0, 0, 1, 1)}") # 1+2+4+5=12
438 print(f" query(1,1,2,2): {st2d.query(1, 1, 2, 2)}") # 5+6+8+9=28
439
440 # 6. 역순 쌍 개수
441 print("\n[6] 역순 쌍 개수")
442 arr = [2, 4, 1, 3, 5]
443 inv = count_inversions_segtree(arr)
444 print(f" 배열: {arr}")
445 print(f" 역순 쌍 개수: {inv}") # (2,1), (4,1), (4,3) = 3
446
447 # 7. 복잡도 비교
448 print("\n[7] 복잡도 비교")
449 print(" | 연산 | 배열 | 세그먼트 트리 | Lazy |")
450 print(" |------------|---------|---------------|-----------|")
451 print(" | 점 업데이트| O(1) | O(log n) | O(log n) |")
452 print(" | 구간 업데이트| O(n) | O(n) | O(log n) |")
453 print(" | 구간 쿼리 | O(n) | O(log n) | O(log n) |")
454
455 print("\n" + "=" * 60)
456
457
458if __name__ == "__main__":
459 main()