16_lca.cpp

Download
cpp 412 lines 11.5 KB
  1/*
  2 * LCA와 트리 쿼리 (LCA and Tree Queries)
  3 * Binary Lifting, Sparse Table, Euler Tour, HLD 기초
  4 *
  5 * 트리에서 최소 공통 조상을 찾는 알고리즘입니다.
  6 */
  7
  8#include <iostream>
  9#include <vector>
 10#include <cmath>
 11#include <algorithm>
 12
 13using namespace std;
 14
 15// =============================================================================
 16// 1. Binary Lifting LCA
 17// =============================================================================
 18
 19class LCABinaryLifting {
 20private:
 21    int n, LOG;
 22    vector<vector<int>> adj;
 23    vector<vector<int>> up;  // up[v][j] = 2^j번째 조상
 24    vector<int> depth;
 25
 26    void dfs(int v, int p, int d) {
 27        depth[v] = d;
 28        up[v][0] = p;
 29
 30        for (int j = 1; j < LOG; j++) {
 31            if (up[v][j-1] != -1) {
 32                up[v][j] = up[up[v][j-1]][j-1];
 33            }
 34        }
 35
 36        for (int u : adj[v]) {
 37            if (u != p) {
 38                dfs(u, v, d + 1);
 39            }
 40        }
 41    }
 42
 43public:
 44    LCABinaryLifting(int n, const vector<vector<int>>& adj, int root = 0)
 45        : n(n), adj(adj) {
 46        LOG = (int)ceil(log2(n + 1)) + 1;
 47        up.assign(n, vector<int>(LOG, -1));
 48        depth.assign(n, 0);
 49        dfs(root, -1, 0);
 50    }
 51
 52    int getDepth(int v) const {
 53        return depth[v];
 54    }
 55
 56    int kthAncestor(int v, int k) {
 57        for (int j = 0; j < LOG && v != -1; j++) {
 58            if ((k >> j) & 1) {
 59                v = up[v][j];
 60            }
 61        }
 62        return v;
 63    }
 64
 65    int lca(int u, int v) {
 66        if (depth[u] < depth[v]) swap(u, v);
 67
 68        // 같은 깊이로 맞추기
 69        u = kthAncestor(u, depth[u] - depth[v]);
 70
 71        if (u == v) return u;
 72
 73        // 함께 올라가기
 74        for (int j = LOG - 1; j >= 0; j--) {
 75            if (up[u][j] != up[v][j]) {
 76                u = up[u][j];
 77                v = up[v][j];
 78            }
 79        }
 80
 81        return up[u][0];
 82    }
 83
 84    int distance(int u, int v) {
 85        return depth[u] + depth[v] - 2 * depth[lca(u, v)];
 86    }
 87};
 88
 89// =============================================================================
 90// 2. Euler Tour + RMQ
 91// =============================================================================
 92
 93class LCAEulerTour {
 94private:
 95    int n;
 96    vector<vector<int>> adj;
 97    vector<int> euler;       // Euler tour
 98    vector<int> first;       // 첫 등장 위치
 99    vector<int> depth;
100    vector<vector<int>> sparse;  // Sparse table
101
102    void dfs(int v, int p, int d) {
103        first[v] = euler.size();
104        euler.push_back(v);
105        depth[v] = d;
106
107        for (int u : adj[v]) {
108            if (u != p) {
109                dfs(u, v, d + 1);
110                euler.push_back(v);
111            }
112        }
113    }
114
115    void buildSparseTable() {
116        int m = euler.size();
117        int LOG = (int)ceil(log2(m + 1)) + 1;
118        sparse.assign(LOG, vector<int>(m));
119
120        for (int i = 0; i < m; i++) {
121            sparse[0][i] = euler[i];
122        }
123
124        for (int j = 1; j < LOG; j++) {
125            for (int i = 0; i + (1 << j) <= m; i++) {
126                int left = sparse[j-1][i];
127                int right = sparse[j-1][i + (1 << (j-1))];
128                sparse[j][i] = (depth[left] < depth[right]) ? left : right;
129            }
130        }
131    }
132
133public:
134    LCAEulerTour(int n, const vector<vector<int>>& adj, int root = 0)
135        : n(n), adj(adj), first(n), depth(n) {
136        dfs(root, -1, 0);
137        buildSparseTable();
138    }
139
140    int lca(int u, int v) {
141        int l = first[u], r = first[v];
142        if (l > r) swap(l, r);
143
144        int len = r - l + 1;
145        int k = (int)log2(len);
146
147        int left = sparse[k][l];
148        int right = sparse[k][r - (1 << k) + 1];
149
150        return (depth[left] < depth[right]) ? left : right;
151    }
152};
153
154// =============================================================================
155// 3. 트리에서 경로 쿼리
156// =============================================================================
157
158class TreePathQuery {
159private:
160    LCABinaryLifting lca;
161    vector<long long> prefixSum;  // 루트에서 각 노드까지의 합
162    vector<int> value;
163
164public:
165    TreePathQuery(int n, const vector<vector<int>>& adj,
166                  const vector<int>& values, int root = 0)
167        : lca(n, adj, root), prefixSum(n, 0), value(values) {
168
169        // DFS로 prefix sum 계산
170        function<void(int, int, long long)> dfs = [&](int v, int p, long long sum) {
171            sum += value[v];
172            prefixSum[v] = sum;
173            for (int u : adj[v]) {
174                if (u != p) dfs(u, v, sum);
175            }
176        };
177        dfs(root, -1, 0);
178    }
179
180    long long pathSum(int u, int v) {
181        int l = lca.lca(u, v);
182        return prefixSum[u] + prefixSum[v] - 2 * prefixSum[l] + value[l];
183    }
184
185    int pathLength(int u, int v) {
186        return lca.distance(u, v);
187    }
188};
189
190// =============================================================================
191// 4. 트리 직경
192// =============================================================================
193
194pair<int, pair<int, int>> treeDiameter(int n, const vector<vector<int>>& adj) {
195    vector<int> dist(n, -1);
196
197    // 첫 번째 BFS
198    auto bfs = [&](int start) -> int {
199        fill(dist.begin(), dist.end(), -1);
200        queue<int> q;
201        q.push(start);
202        dist[start] = 0;
203        int farthest = start;
204
205        while (!q.empty()) {
206            int v = q.front();
207            q.pop();
208
209            for (int u : adj[v]) {
210                if (dist[u] == -1) {
211                    dist[u] = dist[v] + 1;
212                    q.push(u);
213                    if (dist[u] > dist[farthest]) {
214                        farthest = u;
215                    }
216                }
217            }
218        }
219
220        return farthest;
221    };
222
223    int u = bfs(0);      // 가장 먼 점 찾기
224    int v = bfs(u);      // u에서 가장 먼 점 찾기
225
226    return {dist[v], {u, v}};
227}
228
229// =============================================================================
230// 5. 트리의 중심 (Centroid)
231// =============================================================================
232
233int treeCentroid(int n, const vector<vector<int>>& adj) {
234    vector<int> subtreeSize(n);
235
236    function<void(int, int)> calcSize = [&](int v, int p) {
237        subtreeSize[v] = 1;
238        for (int u : adj[v]) {
239            if (u != p) {
240                calcSize(u, v);
241                subtreeSize[v] += subtreeSize[u];
242            }
243        }
244    };
245
246    calcSize(0, -1);
247
248    function<int(int, int)> findCentroid = [&](int v, int p) -> int {
249        for (int u : adj[v]) {
250            if (u != p && subtreeSize[u] > n / 2) {
251                return findCentroid(u, v);
252            }
253        }
254        return v;
255    };
256
257    return findCentroid(0, -1);
258}
259
260// =============================================================================
261// 6. 가중치 LCA
262// =============================================================================
263
264class WeightedLCA {
265private:
266    int n, LOG;
267    vector<vector<pair<int, int>>> adj;  // {neighbor, weight}
268    vector<vector<int>> up;
269    vector<vector<int>> maxWeight;  // 조상 경로의 최대 가중치
270    vector<int> depth;
271
272    void dfs(int v, int p, int d, int w) {
273        depth[v] = d;
274        up[v][0] = p;
275        maxWeight[v][0] = w;
276
277        for (int j = 1; j < LOG; j++) {
278            if (up[v][j-1] != -1) {
279                up[v][j] = up[up[v][j-1]][j-1];
280                maxWeight[v][j] = max(maxWeight[v][j-1],
281                                      maxWeight[up[v][j-1]][j-1]);
282            }
283        }
284
285        for (auto [u, weight] : adj[v]) {
286            if (u != p) {
287                dfs(u, v, d + 1, weight);
288            }
289        }
290    }
291
292public:
293    WeightedLCA(int n, const vector<vector<pair<int, int>>>& adj, int root = 0)
294        : n(n), adj(adj) {
295        LOG = (int)ceil(log2(n + 1)) + 1;
296        up.assign(n, vector<int>(LOG, -1));
297        maxWeight.assign(n, vector<int>(LOG, 0));
298        depth.assign(n, 0);
299        dfs(root, -1, 0, 0);
300    }
301
302    pair<int, int> lcaWithMaxWeight(int u, int v) {
303        int maxW = 0;
304
305        if (depth[u] < depth[v]) swap(u, v);
306        int diff = depth[u] - depth[v];
307
308        for (int j = 0; j < LOG; j++) {
309            if ((diff >> j) & 1) {
310                maxW = max(maxW, maxWeight[u][j]);
311                u = up[u][j];
312            }
313        }
314
315        if (u == v) return {u, maxW};
316
317        for (int j = LOG - 1; j >= 0; j--) {
318            if (up[u][j] != up[v][j]) {
319                maxW = max(maxW, max(maxWeight[u][j], maxWeight[v][j]));
320                u = up[u][j];
321                v = up[v][j];
322            }
323        }
324
325        maxW = max(maxW, max(maxWeight[u][0], maxWeight[v][0]));
326        return {up[u][0], maxW};
327    }
328};
329
330// =============================================================================
331// 테스트
332// =============================================================================
333
334#include <queue>
335
336int main() {
337    cout << "============================================================" << endl;
338    cout << "LCA와 트리 쿼리 예제" << endl;
339    cout << "============================================================" << endl;
340
341    // 테스트 트리
342    //        0
343    //       /|\
344    //      1 2 3
345    //     /|   |
346    //    4 5   6
347    //    |
348    //    7
349
350    int n = 8;
351    vector<vector<int>> adj(n);
352    adj[0] = {1, 2, 3};
353    adj[1] = {0, 4, 5};
354    adj[2] = {0};
355    adj[3] = {0, 6};
356    adj[4] = {1, 7};
357    adj[5] = {1};
358    adj[6] = {3};
359    adj[7] = {4};
360
361    // 1. Binary Lifting LCA
362    cout << "\n[1] Binary Lifting LCA" << endl;
363    LCABinaryLifting lcaBL(n, adj, 0);
364    cout << "    LCA(4, 5) = " << lcaBL.lca(4, 5) << endl;
365    cout << "    LCA(7, 6) = " << lcaBL.lca(7, 6) << endl;
366    cout << "    LCA(7, 5) = " << lcaBL.lca(7, 5) << endl;
367    cout << "    거리(7, 6) = " << lcaBL.distance(7, 6) << endl;
368
369    // 2. K번째 조상
370    cout << "\n[2] K번째 조상" << endl;
371    cout << "    7의 1번째 조상: " << lcaBL.kthAncestor(7, 1) << endl;
372    cout << "    7의 2번째 조상: " << lcaBL.kthAncestor(7, 2) << endl;
373    cout << "    7의 3번째 조상: " << lcaBL.kthAncestor(7, 3) << endl;
374
375    // 3. Euler Tour LCA
376    cout << "\n[3] Euler Tour LCA" << endl;
377    LCAEulerTour lcaET(n, adj, 0);
378    cout << "    LCA(4, 5) = " << lcaET.lca(4, 5) << endl;
379    cout << "    LCA(7, 6) = " << lcaET.lca(7, 6) << endl;
380
381    // 4. 트리 직경
382    cout << "\n[4] 트리 직경" << endl;
383    auto [diameter, endpoints] = treeDiameter(n, adj);
384    cout << "    직경: " << diameter << endl;
385    cout << "    끝점: (" << endpoints.first << ", " << endpoints.second << ")" << endl;
386
387    // 5. 트리 중심
388    cout << "\n[5] 트리 중심" << endl;
389    int centroid = treeCentroid(n, adj);
390    cout << "    중심: " << centroid << endl;
391
392    // 6. 경로 쿼리
393    cout << "\n[6] 경로 쿼리" << endl;
394    vector<int> values = {1, 2, 3, 4, 5, 6, 7, 8};
395    TreePathQuery tpq(n, adj, values, 0);
396    cout << "    노드 값: [1, 2, 3, 4, 5, 6, 7, 8]" << endl;
397    cout << "    경로 합(7, 6): " << tpq.pathSum(7, 6) << endl;
398    cout << "    경로 길이(7, 6): " << tpq.pathLength(7, 6) << endl;
399
400    // 7. 복잡도 요약
401    cout << "\n[7] 복잡도 요약" << endl;
402    cout << "    | 알고리즘        | 전처리      | 쿼리      |" << endl;
403    cout << "    |-----------------|-------------|-----------|" << endl;
404    cout << "    | Binary Lifting  | O(N log N)  | O(log N)  |" << endl;
405    cout << "    | Euler Tour+RMQ  | O(N log N)  | O(1)      |" << endl;
406    cout << "    | Tarjan's Offline| O(N + Q)    | O(1)      |" << endl;
407
408    cout << "\n============================================================" << endl;
409
410    return 0;
411}