23_segment_tree.c

Download
c 437 lines 14.2 KB
  1/*
  2 * 세그먼트 트리 (Segment Tree)
  3 * 구간 합, 구간 최소/최대, Lazy Propagation
  4 *
  5 * 구간 쿼리와 업데이트를 O(log n)에 처리합니다.
  6 */
  7
  8#include <stdio.h>
  9#include <stdlib.h>
 10#include <limits.h>
 11#include <string.h>
 12
 13#define MAX_N 100001
 14#define INF INT_MAX
 15
 16/* =============================================================================
 17 * 1. 기본 세그먼트 트리 (구간 합)
 18 * ============================================================================= */
 19
 20typedef struct {
 21    long long* tree;
 22    int n;
 23} SegmentTree;
 24
 25SegmentTree* st_create(int n) {
 26    SegmentTree* st = malloc(sizeof(SegmentTree));
 27    st->n = n;
 28    st->tree = calloc(4 * n, sizeof(long long));
 29    return st;
 30}
 31
 32void st_free(SegmentTree* st) {
 33    free(st->tree);
 34    free(st);
 35}
 36
 37void st_build(SegmentTree* st, int arr[], int node, int start, int end) {
 38    if (start == end) {
 39        st->tree[node] = arr[start];
 40        return;
 41    }
 42    int mid = (start + end) / 2;
 43    st_build(st, arr, 2 * node, start, mid);
 44    st_build(st, arr, 2 * node + 1, mid + 1, end);
 45    st->tree[node] = st->tree[2 * node] + st->tree[2 * node + 1];
 46}
 47
 48/* 점 업데이트 */
 49void st_update(SegmentTree* st, int node, int start, int end, int idx, long long val) {
 50    if (start == end) {
 51        st->tree[node] = val;
 52        return;
 53    }
 54    int mid = (start + end) / 2;
 55    if (idx <= mid) {
 56        st_update(st, 2 * node, start, mid, idx, val);
 57    } else {
 58        st_update(st, 2 * node + 1, mid + 1, end, idx, val);
 59    }
 60    st->tree[node] = st->tree[2 * node] + st->tree[2 * node + 1];
 61}
 62
 63/* 구간 합 쿼리 */
 64long long st_query(SegmentTree* st, int node, int start, int end, int l, int r) {
 65    if (r < start || end < l) return 0;
 66    if (l <= start && end <= r) return st->tree[node];
 67
 68    int mid = (start + end) / 2;
 69    return st_query(st, 2 * node, start, mid, l, r) +
 70           st_query(st, 2 * node + 1, mid + 1, end, l, r);
 71}
 72
 73/* =============================================================================
 74 * 2. 구간 최소/최대 세그먼트 트리
 75 * ============================================================================= */
 76
 77typedef struct {
 78    int* tree_min;
 79    int* tree_max;
 80    int n;
 81} MinMaxTree;
 82
 83MinMaxTree* mmt_create(int n) {
 84    MinMaxTree* mmt = malloc(sizeof(MinMaxTree));
 85    mmt->n = n;
 86    mmt->tree_min = malloc(4 * n * sizeof(int));
 87    mmt->tree_max = malloc(4 * n * sizeof(int));
 88    for (int i = 0; i < 4 * n; i++) {
 89        mmt->tree_min[i] = INF;
 90        mmt->tree_max[i] = -INF;
 91    }
 92    return mmt;
 93}
 94
 95void mmt_build(MinMaxTree* mmt, int arr[], int node, int start, int end) {
 96    if (start == end) {
 97        mmt->tree_min[node] = arr[start];
 98        mmt->tree_max[node] = arr[start];
 99        return;
100    }
101    int mid = (start + end) / 2;
102    mmt_build(mmt, arr, 2 * node, start, mid);
103    mmt_build(mmt, arr, 2 * node + 1, mid + 1, end);
104    mmt->tree_min[node] = (mmt->tree_min[2 * node] < mmt->tree_min[2 * node + 1])
105                          ? mmt->tree_min[2 * node] : mmt->tree_min[2 * node + 1];
106    mmt->tree_max[node] = (mmt->tree_max[2 * node] > mmt->tree_max[2 * node + 1])
107                          ? mmt->tree_max[2 * node] : mmt->tree_max[2 * node + 1];
108}
109
110int mmt_query_min(MinMaxTree* mmt, int node, int start, int end, int l, int r) {
111    if (r < start || end < l) return INF;
112    if (l <= start && end <= r) return mmt->tree_min[node];
113
114    int mid = (start + end) / 2;
115    int left = mmt_query_min(mmt, 2 * node, start, mid, l, r);
116    int right = mmt_query_min(mmt, 2 * node + 1, mid + 1, end, l, r);
117    return (left < right) ? left : right;
118}
119
120int mmt_query_max(MinMaxTree* mmt, int node, int start, int end, int l, int r) {
121    if (r < start || end < l) return -INF;
122    if (l <= start && end <= r) return mmt->tree_max[node];
123
124    int mid = (start + end) / 2;
125    int left = mmt_query_max(mmt, 2 * node, start, mid, l, r);
126    int right = mmt_query_max(mmt, 2 * node + 1, mid + 1, end, l, r);
127    return (left > right) ? left : right;
128}
129
130void mmt_free(MinMaxTree* mmt) {
131    free(mmt->tree_min);
132    free(mmt->tree_max);
133    free(mmt);
134}
135
136/* =============================================================================
137 * 3. Lazy Propagation (구간 업데이트)
138 * ============================================================================= */
139
140typedef struct {
141    long long* tree;
142    long long* lazy;
143    int n;
144} LazySegTree;
145
146LazySegTree* lst_create(int n) {
147    LazySegTree* lst = malloc(sizeof(LazySegTree));
148    lst->n = n;
149    lst->tree = calloc(4 * n, sizeof(long long));
150    lst->lazy = calloc(4 * n, sizeof(long long));
151    return lst;
152}
153
154void lst_free(LazySegTree* lst) {
155    free(lst->tree);
156    free(lst->lazy);
157    free(lst);
158}
159
160void lst_build(LazySegTree* lst, int arr[], int node, int start, int end) {
161    if (start == end) {
162        lst->tree[node] = arr[start];
163        return;
164    }
165    int mid = (start + end) / 2;
166    lst_build(lst, arr, 2 * node, start, mid);
167    lst_build(lst, arr, 2 * node + 1, mid + 1, end);
168    lst->tree[node] = lst->tree[2 * node] + lst->tree[2 * node + 1];
169}
170
171void lst_propagate(LazySegTree* lst, int node, int start, int end) {
172    if (lst->lazy[node] != 0) {
173        lst->tree[node] += lst->lazy[node] * (end - start + 1);
174        if (start != end) {
175            lst->lazy[2 * node] += lst->lazy[node];
176            lst->lazy[2 * node + 1] += lst->lazy[node];
177        }
178        lst->lazy[node] = 0;
179    }
180}
181
182/* 구간 [l, r]에 val 더하기 */
183void lst_update_range(LazySegTree* lst, int node, int start, int end, int l, int r, long long val) {
184    lst_propagate(lst, node, start, end);
185
186    if (r < start || end < l) return;
187
188    if (l <= start && end <= r) {
189        lst->lazy[node] = val;
190        lst_propagate(lst, node, start, end);
191        return;
192    }
193
194    int mid = (start + end) / 2;
195    lst_update_range(lst, 2 * node, start, mid, l, r, val);
196    lst_update_range(lst, 2 * node + 1, mid + 1, end, l, r, val);
197    lst->tree[node] = lst->tree[2 * node] + lst->tree[2 * node + 1];
198}
199
200long long lst_query(LazySegTree* lst, int node, int start, int end, int l, int r) {
201    lst_propagate(lst, node, start, end);
202
203    if (r < start || end < l) return 0;
204    if (l <= start && end <= r) return lst->tree[node];
205
206    int mid = (start + end) / 2;
207    return lst_query(lst, 2 * node, start, mid, l, r) +
208           lst_query(lst, 2 * node + 1, mid + 1, end, l, r);
209}
210
211/* =============================================================================
212 * 4. 동적 세그먼트 트리 (좌표 압축 없이)
213 * ============================================================================= */
214
215typedef struct DynamicNode {
216    long long sum;
217    struct DynamicNode* left;
218    struct DynamicNode* right;
219} DynamicNode;
220
221DynamicNode* dn_create(void) {
222    DynamicNode* node = malloc(sizeof(DynamicNode));
223    node->sum = 0;
224    node->left = NULL;
225    node->right = NULL;
226    return node;
227}
228
229void dn_update(DynamicNode* node, long long start, long long end, long long idx, long long val) {
230    if (start == end) {
231        node->sum += val;
232        return;
233    }
234    long long mid = (start + end) / 2;
235    if (idx <= mid) {
236        if (!node->left) node->left = dn_create();
237        dn_update(node->left, start, mid, idx, val);
238    } else {
239        if (!node->right) node->right = dn_create();
240        dn_update(node->right, mid + 1, end, idx, val);
241    }
242    node->sum = (node->left ? node->left->sum : 0) +
243                (node->right ? node->right->sum : 0);
244}
245
246long long dn_query(DynamicNode* node, long long start, long long end, long long l, long long r) {
247    if (!node || r < start || end < l) return 0;
248    if (l <= start && end <= r) return node->sum;
249
250    long long mid = (start + end) / 2;
251    return dn_query(node->left, start, mid, l, r) +
252           dn_query(node->right, mid + 1, end, l, r);
253}
254
255void dn_free(DynamicNode* node) {
256    if (!node) return;
257    dn_free(node->left);
258    dn_free(node->right);
259    free(node);
260}
261
262/* =============================================================================
263 * 5. 머지 소트 트리 (Merge Sort Tree)
264 * ============================================================================= */
265
266typedef struct {
267    int** tree;
268    int* sizes;
269    int n;
270} MergeSortTree;
271
272void mst_merge(int* arr, int* temp, int left, int mid, int right) {
273    int i = left, j = mid + 1, k = left;
274    while (i <= mid && j <= right) {
275        if (arr[i] <= arr[j]) temp[k++] = arr[i++];
276        else temp[k++] = arr[j++];
277    }
278    while (i <= mid) temp[k++] = arr[i++];
279    while (j <= right) temp[k++] = arr[j++];
280    for (i = left; i <= right; i++) arr[i] = temp[i];
281}
282
283MergeSortTree* mst_create(int arr[], int n) {
284    MergeSortTree* mst = malloc(sizeof(MergeSortTree));
285    mst->n = n;
286    mst->tree = malloc(4 * n * sizeof(int*));
287    mst->sizes = calloc(4 * n, sizeof(int));
288
289    /* 배열 복사 및 정렬 */
290    int* temp = malloc(n * sizeof(int));
291    int* copy = malloc(n * sizeof(int));
292    memcpy(copy, arr, n * sizeof(int));
293
294    /* 빌드 함수 호출 */
295    void build(int node, int start, int end) {
296        if (start == end) {
297            mst->tree[node] = malloc(sizeof(int));
298            mst->tree[node][0] = arr[start];
299            mst->sizes[node] = 1;
300            return;
301        }
302        int mid = (start + end) / 2;
303        build(2 * node, start, mid);
304        build(2 * node + 1, mid + 1, end);
305
306        /* 병합 */
307        int left_size = mst->sizes[2 * node];
308        int right_size = mst->sizes[2 * node + 1];
309        mst->sizes[node] = left_size + right_size;
310        mst->tree[node] = malloc(mst->sizes[node] * sizeof(int));
311
312        int i = 0, j = 0, k = 0;
313        while (i < left_size && j < right_size) {
314            if (mst->tree[2 * node][i] <= mst->tree[2 * node + 1][j])
315                mst->tree[node][k++] = mst->tree[2 * node][i++];
316            else
317                mst->tree[node][k++] = mst->tree[2 * node + 1][j++];
318        }
319        while (i < left_size) mst->tree[node][k++] = mst->tree[2 * node][i++];
320        while (j < right_size) mst->tree[node][k++] = mst->tree[2 * node + 1][j++];
321    }
322
323    build(1, 0, n - 1);
324    free(temp);
325    free(copy);
326    return mst;
327}
328
329/* 구간 [l, r]에서 k 이하인 원소 개수 */
330int count_le(int* arr, int size, int k) {
331    int lo = 0, hi = size;
332    while (lo < hi) {
333        int mid = (lo + hi) / 2;
334        if (arr[mid] <= k) lo = mid + 1;
335        else hi = mid;
336    }
337    return lo;
338}
339
340int mst_query(MergeSortTree* mst, int node, int start, int end, int l, int r, int k) {
341    if (r < start || end < l) return 0;
342    if (l <= start && end <= r) {
343        return count_le(mst->tree[node], mst->sizes[node], k);
344    }
345    int mid = (start + end) / 2;
346    return mst_query(mst, 2 * node, start, mid, l, r, k) +
347           mst_query(mst, 2 * node + 1, mid + 1, end, l, r, k);
348}
349
350/* =============================================================================
351 * 테스트
352 * ============================================================================= */
353
354int main(void) {
355    printf("============================================================\n");
356    printf("세그먼트 트리 예제\n");
357    printf("============================================================\n");
358
359    /* 1. 기본 세그먼트 트리 */
360    printf("\n[1] 기본 세그먼트 트리 (구간 합)\n");
361    int arr1[] = {1, 3, 5, 7, 9, 11};
362    int n1 = 6;
363    SegmentTree* st = st_create(n1);
364    st_build(st, arr1, 1, 0, n1 - 1);
365
366    printf("    배열: [1, 3, 5, 7, 9, 11]\n");
367    printf("    구간 합 [1, 3]: %lld\n", st_query(st, 1, 0, n1 - 1, 1, 3));
368    printf("    구간 합 [0, 5]: %lld\n", st_query(st, 1, 0, n1 - 1, 0, 5));
369
370    st_update(st, 1, 0, n1 - 1, 2, 10);  /* arr[2] = 10 */
371    printf("    arr[2] = 10 후 구간 합 [1, 3]: %lld\n", st_query(st, 1, 0, n1 - 1, 1, 3));
372    st_free(st);
373
374    /* 2. 구간 최소/최대 */
375    printf("\n[2] 구간 최소/최대 세그먼트 트리\n");
376    int arr2[] = {5, 2, 8, 1, 9, 3, 7, 4};
377    int n2 = 8;
378    MinMaxTree* mmt = mmt_create(n2);
379    mmt_build(mmt, arr2, 1, 0, n2 - 1);
380
381    printf("    배열: [5, 2, 8, 1, 9, 3, 7, 4]\n");
382    printf("    구간 [0, 3] 최소: %d\n", mmt_query_min(mmt, 1, 0, n2 - 1, 0, 3));
383    printf("    구간 [0, 3] 최대: %d\n", mmt_query_max(mmt, 1, 0, n2 - 1, 0, 3));
384    printf("    구간 [4, 7] 최소: %d\n", mmt_query_min(mmt, 1, 0, n2 - 1, 4, 7));
385    mmt_free(mmt);
386
387    /* 3. Lazy Propagation */
388    printf("\n[3] Lazy Propagation\n");
389    int arr3[] = {1, 2, 3, 4, 5};
390    int n3 = 5;
391    LazySegTree* lst = lst_create(n3);
392    lst_build(lst, arr3, 1, 0, n3 - 1);
393
394    printf("    배열: [1, 2, 3, 4, 5]\n");
395    printf("    초기 구간 합 [0, 4]: %lld\n", lst_query(lst, 1, 0, n3 - 1, 0, 4));
396
397    lst_update_range(lst, 1, 0, n3 - 1, 1, 3, 10);  /* [1, 3]에 10 더하기 */
398    printf("    [1, 3]에 10 더한 후 구간 합 [0, 4]: %lld\n", lst_query(lst, 1, 0, n3 - 1, 0, 4));
399    printf("    구간 합 [1, 3]: %lld\n", lst_query(lst, 1, 0, n3 - 1, 1, 3));
400    lst_free(lst);
401
402    /* 4. 동적 세그먼트 트리 */
403    printf("\n[4] 동적 세그먼트 트리\n");
404    DynamicNode* root = dn_create();
405    long long max_range = 1000000000LL;
406
407    dn_update(root, 0, max_range, 100, 5);
408    dn_update(root, 0, max_range, 1000000, 3);
409    dn_update(root, 0, max_range, 500, 7);
410
411    printf("    인덱스 100에 5, 1000000에 3, 500에 7 추가\n");
412    printf("    구간 [0, 1000] 합: %lld\n", dn_query(root, 0, max_range, 0, 1000));
413    printf("    구간 [0, 1000000] 합: %lld\n", dn_query(root, 0, max_range, 0, 1000000));
414    dn_free(root);
415
416    /* 5. 복잡도 */
417    printf("\n[5] 복잡도 분석\n");
418    printf("    | 연산           | 시간복잡도 | 공간복잡도 |\n");
419    printf("    |----------------|------------|------------|\n");
420    printf("    | 빌드           | O(n)       | O(n)       |\n");
421    printf("    | 점 업데이트    | O(log n)   | -          |\n");
422    printf("    | 구간 쿼리      | O(log n)   | -          |\n");
423    printf("    | 구간 업데이트  | O(log n)   | O(n)       |\n");
424    printf("    | 동적 트리      | O(log M)   | O(Q log M) |\n");
425
426    printf("\n[6] 응용\n");
427    printf("    - 구간 합/최소/최대 쿼리\n");
428    printf("    - 구간 GCD 쿼리\n");
429    printf("    - 역순 쌍 개수 (Inversion Count)\n");
430    printf("    - K번째 원소 찾기\n");
431    printf("    - 2D 세그먼트 트리 (평면 쿼리)\n");
432
433    printf("\n============================================================\n");
434
435    return 0;
436}