1/*
2 * LCA와 트리 쿼리 (LCA and Tree Queries)
3 * Binary Lifting, Sparse Table, Euler Tour, HLD 기초
4 *
5 * 트리에서 최소 공통 조상을 찾는 알고리즘입니다.
6 */
7
8#include <iostream>
9#include <vector>
10#include <cmath>
11#include <algorithm>
12
13using namespace std;
14
15// =============================================================================
16// 1. Binary Lifting LCA
17// =============================================================================
18
19class LCABinaryLifting {
20private:
21 int n, LOG;
22 vector<vector<int>> adj;
23 vector<vector<int>> up; // up[v][j] = 2^j번째 조상
24 vector<int> depth;
25
26 void dfs(int v, int p, int d) {
27 depth[v] = d;
28 up[v][0] = p;
29
30 for (int j = 1; j < LOG; j++) {
31 if (up[v][j-1] != -1) {
32 up[v][j] = up[up[v][j-1]][j-1];
33 }
34 }
35
36 for (int u : adj[v]) {
37 if (u != p) {
38 dfs(u, v, d + 1);
39 }
40 }
41 }
42
43public:
44 LCABinaryLifting(int n, const vector<vector<int>>& adj, int root = 0)
45 : n(n), adj(adj) {
46 LOG = (int)ceil(log2(n + 1)) + 1;
47 up.assign(n, vector<int>(LOG, -1));
48 depth.assign(n, 0);
49 dfs(root, -1, 0);
50 }
51
52 int getDepth(int v) const {
53 return depth[v];
54 }
55
56 int kthAncestor(int v, int k) {
57 for (int j = 0; j < LOG && v != -1; j++) {
58 if ((k >> j) & 1) {
59 v = up[v][j];
60 }
61 }
62 return v;
63 }
64
65 int lca(int u, int v) {
66 if (depth[u] < depth[v]) swap(u, v);
67
68 // 같은 깊이로 맞추기
69 u = kthAncestor(u, depth[u] - depth[v]);
70
71 if (u == v) return u;
72
73 // 함께 올라가기
74 for (int j = LOG - 1; j >= 0; j--) {
75 if (up[u][j] != up[v][j]) {
76 u = up[u][j];
77 v = up[v][j];
78 }
79 }
80
81 return up[u][0];
82 }
83
84 int distance(int u, int v) {
85 return depth[u] + depth[v] - 2 * depth[lca(u, v)];
86 }
87};
88
89// =============================================================================
90// 2. Euler Tour + RMQ
91// =============================================================================
92
93class LCAEulerTour {
94private:
95 int n;
96 vector<vector<int>> adj;
97 vector<int> euler; // Euler tour
98 vector<int> first; // 첫 등장 위치
99 vector<int> depth;
100 vector<vector<int>> sparse; // Sparse table
101
102 void dfs(int v, int p, int d) {
103 first[v] = euler.size();
104 euler.push_back(v);
105 depth[v] = d;
106
107 for (int u : adj[v]) {
108 if (u != p) {
109 dfs(u, v, d + 1);
110 euler.push_back(v);
111 }
112 }
113 }
114
115 void buildSparseTable() {
116 int m = euler.size();
117 int LOG = (int)ceil(log2(m + 1)) + 1;
118 sparse.assign(LOG, vector<int>(m));
119
120 for (int i = 0; i < m; i++) {
121 sparse[0][i] = euler[i];
122 }
123
124 for (int j = 1; j < LOG; j++) {
125 for (int i = 0; i + (1 << j) <= m; i++) {
126 int left = sparse[j-1][i];
127 int right = sparse[j-1][i + (1 << (j-1))];
128 sparse[j][i] = (depth[left] < depth[right]) ? left : right;
129 }
130 }
131 }
132
133public:
134 LCAEulerTour(int n, const vector<vector<int>>& adj, int root = 0)
135 : n(n), adj(adj), first(n), depth(n) {
136 dfs(root, -1, 0);
137 buildSparseTable();
138 }
139
140 int lca(int u, int v) {
141 int l = first[u], r = first[v];
142 if (l > r) swap(l, r);
143
144 int len = r - l + 1;
145 int k = (int)log2(len);
146
147 int left = sparse[k][l];
148 int right = sparse[k][r - (1 << k) + 1];
149
150 return (depth[left] < depth[right]) ? left : right;
151 }
152};
153
154// =============================================================================
155// 3. 트리에서 경로 쿼리
156// =============================================================================
157
158class TreePathQuery {
159private:
160 LCABinaryLifting lca;
161 vector<long long> prefixSum; // 루트에서 각 노드까지의 합
162 vector<int> value;
163
164public:
165 TreePathQuery(int n, const vector<vector<int>>& adj,
166 const vector<int>& values, int root = 0)
167 : lca(n, adj, root), prefixSum(n, 0), value(values) {
168
169 // DFS로 prefix sum 계산
170 function<void(int, int, long long)> dfs = [&](int v, int p, long long sum) {
171 sum += value[v];
172 prefixSum[v] = sum;
173 for (int u : adj[v]) {
174 if (u != p) dfs(u, v, sum);
175 }
176 };
177 dfs(root, -1, 0);
178 }
179
180 long long pathSum(int u, int v) {
181 int l = lca.lca(u, v);
182 return prefixSum[u] + prefixSum[v] - 2 * prefixSum[l] + value[l];
183 }
184
185 int pathLength(int u, int v) {
186 return lca.distance(u, v);
187 }
188};
189
190// =============================================================================
191// 4. 트리 직경
192// =============================================================================
193
194pair<int, pair<int, int>> treeDiameter(int n, const vector<vector<int>>& adj) {
195 vector<int> dist(n, -1);
196
197 // 첫 번째 BFS
198 auto bfs = [&](int start) -> int {
199 fill(dist.begin(), dist.end(), -1);
200 queue<int> q;
201 q.push(start);
202 dist[start] = 0;
203 int farthest = start;
204
205 while (!q.empty()) {
206 int v = q.front();
207 q.pop();
208
209 for (int u : adj[v]) {
210 if (dist[u] == -1) {
211 dist[u] = dist[v] + 1;
212 q.push(u);
213 if (dist[u] > dist[farthest]) {
214 farthest = u;
215 }
216 }
217 }
218 }
219
220 return farthest;
221 };
222
223 int u = bfs(0); // 가장 먼 점 찾기
224 int v = bfs(u); // u에서 가장 먼 점 찾기
225
226 return {dist[v], {u, v}};
227}
228
229// =============================================================================
230// 5. 트리의 중심 (Centroid)
231// =============================================================================
232
233int treeCentroid(int n, const vector<vector<int>>& adj) {
234 vector<int> subtreeSize(n);
235
236 function<void(int, int)> calcSize = [&](int v, int p) {
237 subtreeSize[v] = 1;
238 for (int u : adj[v]) {
239 if (u != p) {
240 calcSize(u, v);
241 subtreeSize[v] += subtreeSize[u];
242 }
243 }
244 };
245
246 calcSize(0, -1);
247
248 function<int(int, int)> findCentroid = [&](int v, int p) -> int {
249 for (int u : adj[v]) {
250 if (u != p && subtreeSize[u] > n / 2) {
251 return findCentroid(u, v);
252 }
253 }
254 return v;
255 };
256
257 return findCentroid(0, -1);
258}
259
260// =============================================================================
261// 6. 가중치 LCA
262// =============================================================================
263
264class WeightedLCA {
265private:
266 int n, LOG;
267 vector<vector<pair<int, int>>> adj; // {neighbor, weight}
268 vector<vector<int>> up;
269 vector<vector<int>> maxWeight; // 조상 경로의 최대 가중치
270 vector<int> depth;
271
272 void dfs(int v, int p, int d, int w) {
273 depth[v] = d;
274 up[v][0] = p;
275 maxWeight[v][0] = w;
276
277 for (int j = 1; j < LOG; j++) {
278 if (up[v][j-1] != -1) {
279 up[v][j] = up[up[v][j-1]][j-1];
280 maxWeight[v][j] = max(maxWeight[v][j-1],
281 maxWeight[up[v][j-1]][j-1]);
282 }
283 }
284
285 for (auto [u, weight] : adj[v]) {
286 if (u != p) {
287 dfs(u, v, d + 1, weight);
288 }
289 }
290 }
291
292public:
293 WeightedLCA(int n, const vector<vector<pair<int, int>>>& adj, int root = 0)
294 : n(n), adj(adj) {
295 LOG = (int)ceil(log2(n + 1)) + 1;
296 up.assign(n, vector<int>(LOG, -1));
297 maxWeight.assign(n, vector<int>(LOG, 0));
298 depth.assign(n, 0);
299 dfs(root, -1, 0, 0);
300 }
301
302 pair<int, int> lcaWithMaxWeight(int u, int v) {
303 int maxW = 0;
304
305 if (depth[u] < depth[v]) swap(u, v);
306 int diff = depth[u] - depth[v];
307
308 for (int j = 0; j < LOG; j++) {
309 if ((diff >> j) & 1) {
310 maxW = max(maxW, maxWeight[u][j]);
311 u = up[u][j];
312 }
313 }
314
315 if (u == v) return {u, maxW};
316
317 for (int j = LOG - 1; j >= 0; j--) {
318 if (up[u][j] != up[v][j]) {
319 maxW = max(maxW, max(maxWeight[u][j], maxWeight[v][j]));
320 u = up[u][j];
321 v = up[v][j];
322 }
323 }
324
325 maxW = max(maxW, max(maxWeight[u][0], maxWeight[v][0]));
326 return {up[u][0], maxW};
327 }
328};
329
330// =============================================================================
331// 테스트
332// =============================================================================
333
334#include <queue>
335
336int main() {
337 cout << "============================================================" << endl;
338 cout << "LCA와 트리 쿼리 예제" << endl;
339 cout << "============================================================" << endl;
340
341 // 테스트 트리
342 // 0
343 // /|\
344 // 1 2 3
345 // /| |
346 // 4 5 6
347 // |
348 // 7
349
350 int n = 8;
351 vector<vector<int>> adj(n);
352 adj[0] = {1, 2, 3};
353 adj[1] = {0, 4, 5};
354 adj[2] = {0};
355 adj[3] = {0, 6};
356 adj[4] = {1, 7};
357 adj[5] = {1};
358 adj[6] = {3};
359 adj[7] = {4};
360
361 // 1. Binary Lifting LCA
362 cout << "\n[1] Binary Lifting LCA" << endl;
363 LCABinaryLifting lcaBL(n, adj, 0);
364 cout << " LCA(4, 5) = " << lcaBL.lca(4, 5) << endl;
365 cout << " LCA(7, 6) = " << lcaBL.lca(7, 6) << endl;
366 cout << " LCA(7, 5) = " << lcaBL.lca(7, 5) << endl;
367 cout << " 거리(7, 6) = " << lcaBL.distance(7, 6) << endl;
368
369 // 2. K번째 조상
370 cout << "\n[2] K번째 조상" << endl;
371 cout << " 7의 1번째 조상: " << lcaBL.kthAncestor(7, 1) << endl;
372 cout << " 7의 2번째 조상: " << lcaBL.kthAncestor(7, 2) << endl;
373 cout << " 7의 3번째 조상: " << lcaBL.kthAncestor(7, 3) << endl;
374
375 // 3. Euler Tour LCA
376 cout << "\n[3] Euler Tour LCA" << endl;
377 LCAEulerTour lcaET(n, adj, 0);
378 cout << " LCA(4, 5) = " << lcaET.lca(4, 5) << endl;
379 cout << " LCA(7, 6) = " << lcaET.lca(7, 6) << endl;
380
381 // 4. 트리 직경
382 cout << "\n[4] 트리 직경" << endl;
383 auto [diameter, endpoints] = treeDiameter(n, adj);
384 cout << " 직경: " << diameter << endl;
385 cout << " 끝점: (" << endpoints.first << ", " << endpoints.second << ")" << endl;
386
387 // 5. 트리 중심
388 cout << "\n[5] 트리 중심" << endl;
389 int centroid = treeCentroid(n, adj);
390 cout << " 중심: " << centroid << endl;
391
392 // 6. 경로 쿼리
393 cout << "\n[6] 경로 쿼리" << endl;
394 vector<int> values = {1, 2, 3, 4, 5, 6, 7, 8};
395 TreePathQuery tpq(n, adj, values, 0);
396 cout << " 노드 값: [1, 2, 3, 4, 5, 6, 7, 8]" << endl;
397 cout << " 경로 합(7, 6): " << tpq.pathSum(7, 6) << endl;
398 cout << " 경로 길이(7, 6): " << tpq.pathLength(7, 6) << endl;
399
400 // 7. 복잡도 요약
401 cout << "\n[7] 복잡도 요약" << endl;
402 cout << " | 알고리즘 | 전처리 | 쿼리 |" << endl;
403 cout << " |-----------------|-------------|-----------|" << endl;
404 cout << " | Binary Lifting | O(N log N) | O(log N) |" << endl;
405 cout << " | Euler Tour+RMQ | O(N log N) | O(1) |" << endl;
406 cout << " | Tarjan's Offline| O(N + Q) | O(1) |" << endl;
407
408 cout << "\n============================================================" << endl;
409
410 return 0;
411}