23_segment_tree.cpp

Download
cpp 474 lines 14.2 KB
  1/*
  2 * 세그먼트 트리 (Segment Tree)
  3 * Range Sum/Min/Max Query, Lazy Propagation
  4 *
  5 * 구간 쿼리와 업데이트를 효율적으로 처리합니다.
  6 */
  7
  8#include <iostream>
  9#include <vector>
 10#include <algorithm>
 11#include <climits>
 12#include <functional>
 13
 14using namespace std;
 15
 16// =============================================================================
 17// 1. 기본 세그먼트 트리 (구간 합)
 18// =============================================================================
 19
 20class SegmentTree {
 21private:
 22    vector<long long> tree;
 23    int n;
 24
 25    void build(const vector<int>& arr, int node, int start, int end) {
 26        if (start == end) {
 27            tree[node] = arr[start];
 28            return;
 29        }
 30
 31        int mid = (start + end) / 2;
 32        build(arr, 2 * node, start, mid);
 33        build(arr, 2 * node + 1, mid + 1, end);
 34        tree[node] = tree[2 * node] + tree[2 * node + 1];
 35    }
 36
 37    void update(int node, int start, int end, int idx, long long val) {
 38        if (start == end) {
 39            tree[node] = val;
 40            return;
 41        }
 42
 43        int mid = (start + end) / 2;
 44        if (idx <= mid) {
 45            update(2 * node, start, mid, idx, val);
 46        } else {
 47            update(2 * node + 1, mid + 1, end, idx, val);
 48        }
 49        tree[node] = tree[2 * node] + tree[2 * node + 1];
 50    }
 51
 52    long long query(int node, int start, int end, int l, int r) {
 53        if (r < start || end < l) {
 54            return 0;
 55        }
 56        if (l <= start && end <= r) {
 57            return tree[node];
 58        }
 59
 60        int mid = (start + end) / 2;
 61        return query(2 * node, start, mid, l, r) +
 62               query(2 * node + 1, mid + 1, end, l, r);
 63    }
 64
 65public:
 66    SegmentTree(const vector<int>& arr) : n(arr.size()) {
 67        tree.resize(4 * n);
 68        build(arr, 1, 0, n - 1);
 69    }
 70
 71    void update(int idx, long long val) {
 72        update(1, 0, n - 1, idx, val);
 73    }
 74
 75    long long query(int l, int r) {
 76        return query(1, 0, n - 1, l, r);
 77    }
 78};
 79
 80// =============================================================================
 81// 2. 구간 최솟값 세그먼트 트리
 82// =============================================================================
 83
 84class MinSegmentTree {
 85private:
 86    vector<int> tree;
 87    int n;
 88    const int INF = INT_MAX;
 89
 90    void build(const vector<int>& arr, int node, int start, int end) {
 91        if (start == end) {
 92            tree[node] = arr[start];
 93            return;
 94        }
 95
 96        int mid = (start + end) / 2;
 97        build(arr, 2 * node, start, mid);
 98        build(arr, 2 * node + 1, mid + 1, end);
 99        tree[node] = min(tree[2 * node], tree[2 * node + 1]);
100    }
101
102    void update(int node, int start, int end, int idx, int val) {
103        if (start == end) {
104            tree[node] = val;
105            return;
106        }
107
108        int mid = (start + end) / 2;
109        if (idx <= mid) {
110            update(2 * node, start, mid, idx, val);
111        } else {
112            update(2 * node + 1, mid + 1, end, idx, val);
113        }
114        tree[node] = min(tree[2 * node], tree[2 * node + 1]);
115    }
116
117    int query(int node, int start, int end, int l, int r) {
118        if (r < start || end < l) {
119            return INF;
120        }
121        if (l <= start && end <= r) {
122            return tree[node];
123        }
124
125        int mid = (start + end) / 2;
126        return min(query(2 * node, start, mid, l, r),
127                   query(2 * node + 1, mid + 1, end, l, r));
128    }
129
130public:
131    MinSegmentTree(const vector<int>& arr) : n(arr.size()) {
132        tree.resize(4 * n, INF);
133        build(arr, 1, 0, n - 1);
134    }
135
136    void update(int idx, int val) {
137        update(1, 0, n - 1, idx, val);
138    }
139
140    int query(int l, int r) {
141        return query(1, 0, n - 1, l, r);
142    }
143};
144
145// =============================================================================
146// 3. Lazy Propagation (구간 업데이트)
147// =============================================================================
148
149class LazySegmentTree {
150private:
151    vector<long long> tree, lazy;
152    int n;
153
154    void propagate(int node, int start, int end) {
155        if (lazy[node] != 0) {
156            tree[node] += (end - start + 1) * lazy[node];
157            if (start != end) {
158                lazy[2 * node] += lazy[node];
159                lazy[2 * node + 1] += lazy[node];
160            }
161            lazy[node] = 0;
162        }
163    }
164
165    void build(const vector<int>& arr, int node, int start, int end) {
166        if (start == end) {
167            tree[node] = arr[start];
168            return;
169        }
170
171        int mid = (start + end) / 2;
172        build(arr, 2 * node, start, mid);
173        build(arr, 2 * node + 1, mid + 1, end);
174        tree[node] = tree[2 * node] + tree[2 * node + 1];
175    }
176
177    void updateRange(int node, int start, int end, int l, int r, long long val) {
178        propagate(node, start, end);
179
180        if (r < start || end < l) return;
181
182        if (l <= start && end <= r) {
183            lazy[node] = val;
184            propagate(node, start, end);
185            return;
186        }
187
188        int mid = (start + end) / 2;
189        updateRange(2 * node, start, mid, l, r, val);
190        updateRange(2 * node + 1, mid + 1, end, l, r, val);
191        tree[node] = tree[2 * node] + tree[2 * node + 1];
192    }
193
194    long long query(int node, int start, int end, int l, int r) {
195        propagate(node, start, end);
196
197        if (r < start || end < l) return 0;
198
199        if (l <= start && end <= r) {
200            return tree[node];
201        }
202
203        int mid = (start + end) / 2;
204        return query(2 * node, start, mid, l, r) +
205               query(2 * node + 1, mid + 1, end, l, r);
206    }
207
208public:
209    LazySegmentTree(const vector<int>& arr) : n(arr.size()) {
210        tree.resize(4 * n, 0);
211        lazy.resize(4 * n, 0);
212        build(arr, 1, 0, n - 1);
213    }
214
215    void updateRange(int l, int r, long long val) {
216        updateRange(1, 0, n - 1, l, r, val);
217    }
218
219    long long query(int l, int r) {
220        return query(1, 0, n - 1, l, r);
221    }
222};
223
224// =============================================================================
225// 4. 동적 세그먼트 트리
226// =============================================================================
227
228class DynamicSegmentTree {
229private:
230    struct Node {
231        long long sum = 0;
232        Node *left = nullptr, *right = nullptr;
233    };
234
235    Node* root;
236    long long lo, hi;
237
238    void update(Node*& node, long long start, long long end, long long idx, long long val) {
239        if (!node) node = new Node();
240
241        if (start == end) {
242            node->sum += val;
243            return;
244        }
245
246        long long mid = (start + end) / 2;
247        if (idx <= mid) {
248            update(node->left, start, mid, idx, val);
249        } else {
250            update(node->right, mid + 1, end, idx, val);
251        }
252
253        node->sum = (node->left ? node->left->sum : 0) +
254                    (node->right ? node->right->sum : 0);
255    }
256
257    long long query(Node* node, long long start, long long end, long long l, long long r) {
258        if (!node || r < start || end < l) return 0;
259        if (l <= start && end <= r) return node->sum;
260
261        long long mid = (start + end) / 2;
262        return query(node->left, start, mid, l, r) +
263               query(node->right, mid + 1, end, l, r);
264    }
265
266public:
267    DynamicSegmentTree(long long lo, long long hi) : root(nullptr), lo(lo), hi(hi) {}
268
269    void update(long long idx, long long val) {
270        update(root, lo, hi, idx, val);
271    }
272
273    long long query(long long l, long long r) {
274        return query(root, lo, hi, l, r);
275    }
276};
277
278// =============================================================================
279// 5. 머지 소트 트리 (구간 K번째 원소)
280// =============================================================================
281
282class MergeSortTree {
283private:
284    vector<vector<int>> tree;
285    int n;
286
287    void build(const vector<int>& arr, int node, int start, int end) {
288        if (start == end) {
289            tree[node] = {arr[start]};
290            return;
291        }
292
293        int mid = (start + end) / 2;
294        build(arr, 2 * node, start, mid);
295        build(arr, 2 * node + 1, mid + 1, end);
296
297        merge(tree[2 * node].begin(), tree[2 * node].end(),
298              tree[2 * node + 1].begin(), tree[2 * node + 1].end(),
299              back_inserter(tree[node]));
300    }
301
302    // [l, r] 구간에서 x 이하인 원소 개수
303    int countLessEqual(int node, int start, int end, int l, int r, int x) {
304        if (r < start || end < l) return 0;
305        if (l <= start && end <= r) {
306            return upper_bound(tree[node].begin(), tree[node].end(), x) -
307                   tree[node].begin();
308        }
309
310        int mid = (start + end) / 2;
311        return countLessEqual(2 * node, start, mid, l, r, x) +
312               countLessEqual(2 * node + 1, mid + 1, end, l, r, x);
313    }
314
315public:
316    MergeSortTree(const vector<int>& arr) : n(arr.size()) {
317        tree.resize(4 * n);
318        build(arr, 1, 0, n - 1);
319    }
320
321    // [l, r] 구간에서 k번째 작은 원소
322    int kthSmallest(int l, int r, int k) {
323        int lo = INT_MIN, hi = INT_MAX;
324
325        while (lo < hi) {
326            int mid = lo / 2 + hi / 2;
327            if (countLessEqual(1, 0, n - 1, l, r, mid) < k) {
328                lo = mid + 1;
329            } else {
330                hi = mid;
331            }
332        }
333
334        return lo;
335    }
336};
337
338// =============================================================================
339// 6. 2D 세그먼트 트리
340// =============================================================================
341
342class SegmentTree2D {
343private:
344    vector<vector<long long>> tree;
345    int n, m;
346
347    void buildY(const vector<vector<int>>& arr, int nx, int lx, int rx, int ny, int ly, int ry) {
348        if (ly == ry) {
349            if (lx == rx) {
350                tree[nx][ny] = arr[lx][ly];
351            } else {
352                tree[nx][ny] = tree[2 * nx][ny] + tree[2 * nx + 1][ny];
353            }
354            return;
355        }
356
357        int my = (ly + ry) / 2;
358        buildY(arr, nx, lx, rx, 2 * ny, ly, my);
359        buildY(arr, nx, lx, rx, 2 * ny + 1, my + 1, ry);
360        tree[nx][ny] = tree[nx][2 * ny] + tree[nx][2 * ny + 1];
361    }
362
363    void buildX(const vector<vector<int>>& arr, int nx, int lx, int rx) {
364        if (lx != rx) {
365            int mx = (lx + rx) / 2;
366            buildX(arr, 2 * nx, lx, mx);
367            buildX(arr, 2 * nx + 1, mx + 1, rx);
368        }
369        buildY(arr, nx, lx, rx, 1, 0, m - 1);
370    }
371
372    long long queryY(int nx, int ny, int ly, int ry, int y1, int y2) {
373        if (y2 < ly || ry < y1) return 0;
374        if (y1 <= ly && ry <= y2) return tree[nx][ny];
375
376        int my = (ly + ry) / 2;
377        return queryY(nx, 2 * ny, ly, my, y1, y2) +
378               queryY(nx, 2 * ny + 1, my + 1, ry, y1, y2);
379    }
380
381    long long queryX(int nx, int lx, int rx, int x1, int x2, int y1, int y2) {
382        if (x2 < lx || rx < x1) return 0;
383        if (x1 <= lx && rx <= x2) return queryY(nx, 1, 0, m - 1, y1, y2);
384
385        int mx = (lx + rx) / 2;
386        return queryX(2 * nx, lx, mx, x1, x2, y1, y2) +
387               queryX(2 * nx + 1, mx + 1, rx, x1, x2, y1, y2);
388    }
389
390public:
391    SegmentTree2D(const vector<vector<int>>& arr) {
392        n = arr.size();
393        m = arr[0].size();
394        tree.assign(4 * n, vector<long long>(4 * m, 0));
395        buildX(arr, 1, 0, n - 1);
396    }
397
398    long long query(int x1, int y1, int x2, int y2) {
399        return queryX(1, 0, n - 1, x1, x2, y1, y2);
400    }
401};
402
403// =============================================================================
404// 테스트
405// =============================================================================
406
407int main() {
408    cout << "============================================================" << endl;
409    cout << "세그먼트 트리 예제" << endl;
410    cout << "============================================================" << endl;
411
412    vector<int> arr = {1, 3, 5, 7, 9, 11};
413
414    // 1. 기본 세그먼트 트리
415    cout << "\n[1] 구간 합 세그먼트 트리" << endl;
416    cout << "    배열: [1, 3, 5, 7, 9, 11]" << endl;
417    SegmentTree st(arr);
418    cout << "    sum[1, 3] = " << st.query(1, 3) << endl;
419    st.update(2, 10);
420    cout << "    arr[2] = 10 업데이트 후" << endl;
421    cout << "    sum[1, 3] = " << st.query(1, 3) << endl;
422
423    // 2. 구간 최솟값
424    cout << "\n[2] 구간 최솟값 세그먼트 트리" << endl;
425    MinSegmentTree minSt(arr);
426    cout << "    min[0, 5] = " << minSt.query(0, 5) << endl;
427    cout << "    min[2, 4] = " << minSt.query(2, 4) << endl;
428
429    // 3. Lazy Propagation
430    cout << "\n[3] Lazy Propagation" << endl;
431    vector<int> arr2 = {1, 2, 3, 4, 5};
432    LazySegmentTree lazySt(arr2);
433    cout << "    배열: [1, 2, 3, 4, 5]" << endl;
434    cout << "    sum[0, 4] = " << lazySt.query(0, 4) << endl;
435    lazySt.updateRange(1, 3, 10);  // [1, 3] 구간에 10 더하기
436    cout << "    [1, 3] += 10 후" << endl;
437    cout << "    sum[0, 4] = " << lazySt.query(0, 4) << endl;
438
439    // 4. 동적 세그먼트 트리
440    cout << "\n[4] 동적 세그먼트 트리" << endl;
441    DynamicSegmentTree dynSt(0, 1000000000);
442    dynSt.update(100, 5);
443    dynSt.update(500000000, 10);
444    cout << "    범위: [0, 10^9]" << endl;
445    cout << "    update(100, 5), update(5×10^8, 10)" << endl;
446    cout << "    sum[0, 10^9] = " << dynSt.query(0, 1000000000) << endl;
447
448    // 5. 2D 세그먼트 트리
449    cout << "\n[5] 2D 세그먼트 트리" << endl;
450    vector<vector<int>> arr2d = {
451        {1, 2, 3},
452        {4, 5, 6},
453        {7, 8, 9}
454    };
455    SegmentTree2D st2d(arr2d);
456    cout << "    3x3 행렬" << endl;
457    cout << "    sum[(0,0) to (2,2)] = " << st2d.query(0, 0, 2, 2) << endl;
458    cout << "    sum[(0,0) to (1,1)] = " << st2d.query(0, 0, 1, 1) << endl;
459
460    // 6. 복잡도 요약
461    cout << "\n[6] 복잡도 요약" << endl;
462    cout << "    | 연산            | 시간복잡도 | 공간복잡도 |" << endl;
463    cout << "    |-----------------|------------|------------|" << endl;
464    cout << "    | 빌드            | O(n)       | O(n)       |" << endl;
465    cout << "    | 점 업데이트     | O(log n)   | O(1)       |" << endl;
466    cout << "    | 구간 쿼리       | O(log n)   | O(1)       |" << endl;
467    cout << "    | 구간 업데이트   | O(log n)   | O(n) lazy  |" << endl;
468    cout << "    | 2D 쿼리         | O(log² n)  | O(n²)      |" << endl;
469
470    cout << "\n============================================================" << endl;
471
472    return 0;
473}