28_advanced_dp.c

Download
c 403 lines 11.3 KB
  1/*
  2 * ๊ณ ๊ธ‰ DP ์ตœ์ ํ™” (Advanced DP Optimization)
  3 * CHT, D&C ์ตœ์ ํ™”, Knuth ์ตœ์ ํ™”, ๋ชจ๋…ธํ† ๋‹‰ ํ
  4 *
  5 * DP์˜ ์‹œ๊ฐ„ ๋ณต์žก๋„๋ฅผ ์ค„์ด๋Š” ๊ณ ๊ธ‰ ๊ธฐ๋ฒ•๋“ค์ž…๋‹ˆ๋‹ค.
  6 */
  7
  8#include <stdio.h>
  9#include <stdlib.h>
 10#include <string.h>
 11#include <limits.h>
 12#include <math.h>
 13
 14#define INF LLONG_MAX
 15#define MAX_N 100001
 16
 17typedef long long ll;
 18
 19/* =============================================================================
 20 * 1. Convex Hull Trick (CHT)
 21 * ============================================================================= */
 22
 23/* ์ง์„  ๊ตฌ์กฐ์ฒด */
 24typedef struct {
 25    ll m;  /* ๊ธฐ์šธ๊ธฐ */
 26    ll b;  /* y์ ˆํŽธ */
 27} Line;
 28
 29/* CHT ๊ตฌ์กฐ์ฒด (๊ธฐ์šธ๊ธฐ ๋‹จ์กฐ ๊ฐ์†Œ ๋ฒ„์ „) */
 30typedef struct {
 31    Line* lines;
 32    int size;
 33    int capacity;
 34} CHT;
 35
 36CHT* cht_create(int capacity) {
 37    CHT* cht = malloc(sizeof(CHT));
 38    cht->lines = malloc(capacity * sizeof(Line));
 39    cht->size = 0;
 40    cht->capacity = capacity;
 41    return cht;
 42}
 43
 44void cht_free(CHT* cht) {
 45    free(cht->lines);
 46    free(cht);
 47}
 48
 49/* ๊ต์ฐจ์  x์ขŒํ‘œ ๋น„๊ต */
 50bool cht_bad(Line l1, Line l2, Line l3) {
 51    /* l2๊ฐ€ ๋ถˆํ•„์š”ํ•œ์ง€ ํ™•์ธ */
 52    /* (l3.b - l1.b) / (l1.m - l3.m) <= (l2.b - l1.b) / (l1.m - l2.m) */
 53    return (l3.b - l1.b) * (l1.m - l2.m) <= (l2.b - l1.b) * (l1.m - l3.m);
 54}
 55
 56/* ์ง์„  ์ถ”๊ฐ€ (๊ธฐ์šธ๊ธฐ ๋‹จ์กฐ ๊ฐ์†Œ) */
 57void cht_add(CHT* cht, ll m, ll b) {
 58    Line new_line = {m, b};
 59
 60    while (cht->size >= 2 &&
 61           cht_bad(cht->lines[cht->size - 2], cht->lines[cht->size - 1], new_line)) {
 62        cht->size--;
 63    }
 64
 65    cht->lines[cht->size++] = new_line;
 66}
 67
 68/* ์ตœ์†Ÿ๊ฐ’ ์ฟผ๋ฆฌ */
 69ll cht_query(CHT* cht, ll x) {
 70    int lo = 0, hi = cht->size - 1;
 71    while (lo < hi) {
 72        int mid = (lo + hi) / 2;
 73        ll y1 = cht->lines[mid].m * x + cht->lines[mid].b;
 74        ll y2 = cht->lines[mid + 1].m * x + cht->lines[mid + 1].b;
 75        if (y1 > y2) lo = mid + 1;
 76        else hi = mid;
 77    }
 78    return cht->lines[lo].m * x + cht->lines[lo].b;
 79}
 80
 81/* =============================================================================
 82 * 2. Li Chao Tree
 83 * ============================================================================= */
 84
 85typedef struct LCNode {
 86    ll m, b;
 87    struct LCNode* left;
 88    struct LCNode* right;
 89} LCNode;
 90
 91LCNode* lc_create(void) {
 92    LCNode* node = malloc(sizeof(LCNode));
 93    node->m = 0;
 94    node->b = INF;
 95    node->left = NULL;
 96    node->right = NULL;
 97    return node;
 98}
 99
100ll lc_eval(ll m, ll b, ll x) {
101    return m * x + b;
102}
103
104void lc_insert(LCNode* node, ll lo, ll hi, ll m, ll b) {
105    if (lo == hi) {
106        if (lc_eval(m, b, lo) < lc_eval(node->m, node->b, lo)) {
107            node->m = m;
108            node->b = b;
109        }
110        return;
111    }
112
113    ll mid = (lo + hi) / 2;
114    bool left_better = lc_eval(m, b, lo) < lc_eval(node->m, node->b, lo);
115    bool mid_better = lc_eval(m, b, mid) < lc_eval(node->m, node->b, mid);
116
117    if (mid_better) {
118        ll tmp_m = node->m, tmp_b = node->b;
119        node->m = m; node->b = b;
120        m = tmp_m; b = tmp_b;
121    }
122
123    if (left_better != mid_better) {
124        if (!node->left) node->left = lc_create();
125        lc_insert(node->left, lo, mid, m, b);
126    } else {
127        if (!node->right) node->right = lc_create();
128        lc_insert(node->right, mid + 1, hi, m, b);
129    }
130}
131
132ll lc_query(LCNode* node, ll lo, ll hi, ll x) {
133    if (!node) return INF;
134
135    ll result = lc_eval(node->m, node->b, x);
136    if (lo == hi) return result;
137
138    ll mid = (lo + hi) / 2;
139    if (x <= mid) {
140        ll left_val = lc_query(node->left, lo, mid, x);
141        if (left_val < result) result = left_val;
142    } else {
143        ll right_val = lc_query(node->right, mid + 1, hi, x);
144        if (right_val < result) result = right_val;
145    }
146    return result;
147}
148
149/* =============================================================================
150 * 3. D&C ์ตœ์ ํ™”
151 * ============================================================================= */
152
153/* ์กฐ๊ฑด: opt[i][j] <= opt[i][j+1]
154 * ์‹œ๊ฐ„๋ณต์žก๋„: O(kn log n) */
155
156ll** cost;  /* cost[i][j]: i๋ถ€ํ„ฐ j๊นŒ์ง€์˜ ๋น„์šฉ */
157ll** dp_dc;
158int n_dc;
159
160void compute_dp(int k, int lo, int hi, int opt_lo, int opt_hi) {
161    if (lo > hi) return;
162
163    int mid = (lo + hi) / 2;
164    int opt = opt_lo;
165    dp_dc[k][mid] = INF;
166
167    for (int i = opt_lo; i <= opt_hi && i < mid; i++) {
168        ll val = dp_dc[k - 1][i] + cost[i + 1][mid];
169        if (val < dp_dc[k][mid]) {
170            dp_dc[k][mid] = val;
171            opt = i;
172        }
173    }
174
175    compute_dp(k, lo, mid - 1, opt_lo, opt);
176    compute_dp(k, mid + 1, hi, opt, opt_hi);
177}
178
179ll divide_conquer_opt(int n, int k, ll** cost_matrix) {
180    cost = cost_matrix;
181    n_dc = n;
182
183    dp_dc = malloc((k + 1) * sizeof(ll*));
184    for (int i = 0; i <= k; i++) {
185        dp_dc[i] = malloc((n + 1) * sizeof(ll));
186        for (int j = 0; j <= n; j++) {
187            dp_dc[i][j] = INF;
188        }
189    }
190
191    dp_dc[0][0] = 0;
192
193    /* ์ฒซ ๋ฒˆ์งธ ๊ทธ๋ฃน */
194    for (int j = 1; j <= n; j++) {
195        dp_dc[1][j] = cost[1][j];
196    }
197
198    for (int i = 2; i <= k; i++) {
199        compute_dp(i, 1, n, 0, n - 1);
200    }
201
202    ll result = dp_dc[k][n];
203
204    for (int i = 0; i <= k; i++) free(dp_dc[i]);
205    free(dp_dc);
206
207    return result;
208}
209
210/* =============================================================================
211 * 4. Knuth ์ตœ์ ํ™”
212 * ============================================================================= */
213
214/* ์กฐ๊ฑด: opt[i][j-1] <= opt[i][j] <= opt[i+1][j]
215 * ์‹œ๊ฐ„๋ณต์žก๋„: O(nยฒ) */
216
217ll knuth_opt(int n, ll** cost_matrix) {
218    ll** dp = malloc((n + 2) * sizeof(ll*));
219    int** opt = malloc((n + 2) * sizeof(int*));
220
221    for (int i = 0; i <= n + 1; i++) {
222        dp[i] = calloc(n + 2, sizeof(ll));
223        opt[i] = calloc(n + 2, sizeof(int));
224    }
225
226    /* ๊ธฐ์ € ์กฐ๊ฑด */
227    for (int i = 1; i <= n; i++) {
228        opt[i][i] = i;
229    }
230
231    /* ๊ธธ์ด ์ˆœ์œผ๋กœ ๊ณ„์‚ฐ */
232    for (int len = 2; len <= n; len++) {
233        for (int i = 1; i + len - 1 <= n; i++) {
234            int j = i + len - 1;
235            dp[i][j] = INF;
236
237            int lo = opt[i][j - 1];
238            int hi = opt[i + 1][j];
239            if (lo < i) lo = i;
240            if (hi > j) hi = j;
241
242            for (int k = lo; k <= hi; k++) {
243                ll val = dp[i][k - 1] + dp[k + 1][j] + cost_matrix[i][j];
244                if (val < dp[i][j]) {
245                    dp[i][j] = val;
246                    opt[i][j] = k;
247                }
248            }
249        }
250    }
251
252    ll result = dp[1][n];
253
254    for (int i = 0; i <= n + 1; i++) {
255        free(dp[i]);
256        free(opt[i]);
257    }
258    free(dp);
259    free(opt);
260
261    return result;
262}
263
264/* =============================================================================
265 * 5. ๋ชจ๋…ธํ† ๋‹‰ ํ ์ตœ์ ํ™”
266 * ============================================================================= */
267
268/* dp[i] = min(dp[j] + C[j]) for j in [i-k, i-1]
269 * ์Šฌ๋ผ์ด๋”ฉ ์œˆ๋„์šฐ ์ตœ์†Ÿ๊ฐ’ ํ™œ์šฉ */
270
271ll monotonic_queue_dp(int n, int k, ll arr[]) {
272    ll* dp = malloc((n + 1) * sizeof(ll));
273    int* deque = malloc((n + 1) * sizeof(int));
274    int front = 0, rear = 0;
275
276    dp[0] = 0;
277    deque[rear++] = 0;
278
279    for (int i = 1; i <= n; i++) {
280        /* ์œˆ๋„์šฐ ๋ฒ—์–ด๋‚œ ์›์†Œ ์ œ๊ฑฐ */
281        while (front < rear && deque[front] < i - k) {
282            front++;
283        }
284
285        /* ์ตœ์†Ÿ๊ฐ’ ์‚ฌ์šฉ */
286        dp[i] = dp[deque[front]] + arr[i];
287
288        /* ์ƒˆ ์›์†Œ ์ถ”๊ฐ€ (๋ชจ๋…ธํ† ๋‹‰ ์œ ์ง€) */
289        while (front < rear && dp[deque[rear - 1]] >= dp[i]) {
290            rear--;
291        }
292        deque[rear++] = i;
293    }
294
295    ll result = dp[n];
296    free(dp);
297    free(deque);
298    return result;
299}
300
301/* =============================================================================
302 * 6. SOS DP (Sum over Subsets)
303 * ============================================================================= */
304
305/* f[mask] = sum of a[submask] for all submask of mask */
306void sos_dp(ll a[], int n) {
307    int size = 1 << n;
308
309    for (int i = 0; i < n; i++) {
310        for (int mask = 0; mask < size; mask++) {
311            if (mask & (1 << i)) {
312                a[mask] += a[mask ^ (1 << i)];
313            }
314        }
315    }
316}
317
318/* =============================================================================
319 * ํ…Œ์ŠคํŠธ
320 * ============================================================================= */
321
322int main(void) {
323    printf("============================================================\n");
324    printf("๊ณ ๊ธ‰ DP ์ตœ์ ํ™” ์˜ˆ์ œ\n");
325    printf("============================================================\n");
326
327    /* 1. Convex Hull Trick */
328    printf("\n[1] Convex Hull Trick\n");
329    CHT* cht = cht_create(100);
330
331    /* ์ง์„ ๋“ค: y = -2x + 4, y = -1x + 3, y = -0.5x + 2 */
332    cht_add(cht, -2, 4);
333    cht_add(cht, -1, 3);
334    cht_add(cht, 0, 2);
335
336    printf("    ์ง์„ : y = -2x + 4, y = -x + 3, y = 2\n");
337    printf("    x=0์—์„œ ์ตœ์†Ÿ๊ฐ’: %lld\n", cht_query(cht, 0));
338    printf("    x=1์—์„œ ์ตœ์†Ÿ๊ฐ’: %lld\n", cht_query(cht, 1));
339    printf("    x=2์—์„œ ์ตœ์†Ÿ๊ฐ’: %lld\n", cht_query(cht, 2));
340    printf("    x=3์—์„œ ์ตœ์†Ÿ๊ฐ’: %lld\n", cht_query(cht, 3));
341    cht_free(cht);
342
343    /* 2. Li Chao Tree */
344    printf("\n[2] Li Chao Tree\n");
345    LCNode* root = lc_create();
346    ll lo = 0, hi = 100;
347
348    lc_insert(root, lo, hi, -2, 10);
349    lc_insert(root, lo, hi, 1, 0);
350    lc_insert(root, lo, hi, -1, 8);
351
352    printf("    ์ง์„ : y = -2x + 10, y = x, y = -x + 8\n");
353    printf("    x=0์—์„œ ์ตœ์†Ÿ๊ฐ’: %lld\n", lc_query(root, lo, hi, 0));
354    printf("    x=3์—์„œ ์ตœ์†Ÿ๊ฐ’: %lld\n", lc_query(root, lo, hi, 3));
355    printf("    x=5์—์„œ ์ตœ์†Ÿ๊ฐ’: %lld\n", lc_query(root, lo, hi, 5));
356
357    /* 3. ๋ชจ๋…ธํ† ๋‹‰ ํ DP */
358    printf("\n[3] ๋ชจ๋…ธํ† ๋‹‰ ํ DP\n");
359    ll arr[] = {0, 1, 3, 2, 4, 1, 5};
360    int n = 6, k = 3;
361    printf("    ๋ฐฐ์—ด: [1, 3, 2, 4, 1, 5]\n");
362    printf("    ์œˆ๋„์šฐ ํฌ๊ธฐ k = 3\n");
363    printf("    ์ตœ์†Œ ๋น„์šฉ: %lld\n", monotonic_queue_dp(n, k, arr));
364
365    /* 4. SOS DP */
366    printf("\n[4] SOS DP (Sum over Subsets)\n");
367    ll sos_arr[] = {1, 2, 3, 4, 5, 6, 7, 8};  /* 2^3 = 8 */
368    printf("    ์ดˆ๊ธฐ ๋ฐฐ์—ด: [1, 2, 3, 4, 5, 6, 7, 8]\n");
369    sos_dp(sos_arr, 3);
370    printf("    SOS ๊ฒฐ๊ณผ:\n");
371    for (int mask = 0; mask < 8; mask++) {
372        printf("      f[%d%d%d] = %lld\n",
373               (mask >> 2) & 1, (mask >> 1) & 1, mask & 1, sos_arr[mask]);
374    }
375
376    /* 5. ๋ณต์žก๋„ ๋น„๊ต */
377    printf("\n[5] ๋ณต์žก๋„ ๋น„๊ต\n");
378    printf("    | ๊ธฐ๋ฒ•              | ์›๋ž˜ ๋ณต์žก๋„ | ์ตœ์ ํ™” ํ›„    |\n");
379    printf("    |-------------------|-------------|-------------|\n");
380    printf("    | CHT               | O(nยฒ)       | O(n log n)  |\n");
381    printf("    | Li Chao Tree      | O(nยฒ)       | O(n log C)  |\n");
382    printf("    | D&C ์ตœ์ ํ™”        | O(knยฒ)      | O(kn log n) |\n");
383    printf("    | Knuth ์ตœ์ ํ™”      | O(nยณ)       | O(nยฒ)       |\n");
384    printf("    | ๋ชจ๋…ธํ† ๋‹‰ ํ       | O(nk)       | O(n)        |\n");
385    printf("    | SOS DP            | O(3^n)      | O(n ร— 2^n)  |\n");
386
387    /* 6. ์ ์šฉ ์กฐ๊ฑด */
388    printf("\n[6] ์ ์šฉ ์กฐ๊ฑด\n");
389    printf("    CHT:\n");
390    printf("      - dp[i] = min(dp[j] + a[j] ร— b[i]) ํ˜•ํƒœ\n");
391    printf("      - a[j] ๋˜๋Š” b[i]๊ฐ€ ๋‹จ์กฐ\n");
392    printf("    D&C ์ตœ์ ํ™”:\n");
393    printf("      - opt[i][j] <= opt[i][j+1]\n");
394    printf("      - ๋น„์šฉ ํ•จ์ˆ˜๊ฐ€ Quadrangle Inequality ๋งŒ์กฑ\n");
395    printf("    Knuth ์ตœ์ ํ™”:\n");
396    printf("      - opt[i][j-1] <= opt[i][j] <= opt[i+1][j]\n");
397    printf("      - ๊ตฌ๊ฐ„ DP์—์„œ ์ฃผ๋กœ ์‚ฌ์šฉ\n");
398
399    printf("\n============================================================\n");
400
401    return 0;
402}