1"""
2최소 공통 조상 (LCA - Lowest Common Ancestor)
3LCA and Tree Queries
4
5트리에서 두 노드의 최소 공통 조상을 찾는 알고리즘입니다.
6"""
7
8from typing import List, Tuple, Optional
9from collections import defaultdict, deque
10import math
11
12
13# =============================================================================
14# 1. 기본 LCA (Naive)
15# =============================================================================
16
17def lca_naive(n: int, edges: List[Tuple[int, int]], u: int, v: int) -> int:
18 """
19 기본 LCA (높이 맞추기)
20 시간복잡도: O(n) per query
21 전처리: O(n)
22 """
23 # 트리 구성
24 adj = defaultdict(list)
25 for a, b in edges:
26 adj[a].append(b)
27 adj[b].append(a)
28
29 # 부모와 깊이 계산
30 parent = [-1] * n
31 depth = [0] * n
32
33 def dfs(node: int, par: int, d: int):
34 parent[node] = par
35 depth[node] = d
36 for child in adj[node]:
37 if child != par:
38 dfs(child, node, d + 1)
39
40 dfs(0, -1, 0)
41
42 # 높이 맞추기
43 while depth[u] > depth[v]:
44 u = parent[u]
45 while depth[v] > depth[u]:
46 v = parent[v]
47
48 # 동시에 올라가기
49 while u != v:
50 u = parent[u]
51 v = parent[v]
52
53 return u
54
55
56# =============================================================================
57# 2. Binary Lifting (희소 테이블)
58# =============================================================================
59
60class LCABinaryLifting:
61 """
62 Binary Lifting을 이용한 LCA
63 전처리: O(n log n)
64 쿼리: O(log n)
65 """
66
67 def __init__(self, n: int, edges: List[Tuple[int, int]], root: int = 0):
68 self.n = n
69 self.LOG = max(1, int(math.log2(n)) + 1)
70
71 # 그래프 구성
72 self.adj = defaultdict(list)
73 for a, b in edges:
74 self.adj[a].append(b)
75 self.adj[b].append(a)
76
77 # 전처리
78 self.parent = [[-1] * n for _ in range(self.LOG)]
79 self.depth = [0] * n
80
81 self._preprocess(root)
82
83 def _preprocess(self, root: int):
84 """DFS로 부모/깊이 계산 + 희소 테이블 구성"""
85 stack = [(root, -1, 0)]
86
87 while stack:
88 node, par, d = stack.pop()
89 self.parent[0][node] = par
90 self.depth[node] = d
91
92 for child in self.adj[node]:
93 if child != par:
94 stack.append((child, node, d + 1))
95
96 # 희소 테이블 구성: parent[i][v] = v의 2^i번째 조상
97 for i in range(1, self.LOG):
98 for v in range(self.n):
99 if self.parent[i - 1][v] != -1:
100 self.parent[i][v] = self.parent[i - 1][self.parent[i - 1][v]]
101
102 def query(self, u: int, v: int) -> int:
103 """LCA 쿼리 - O(log n)"""
104 # u가 더 깊도록 조정
105 if self.depth[u] < self.depth[v]:
106 u, v = v, u
107
108 # 높이 맞추기
109 diff = self.depth[u] - self.depth[v]
110 for i in range(self.LOG):
111 if (diff >> i) & 1:
112 u = self.parent[i][u]
113
114 # 같으면 완료
115 if u == v:
116 return u
117
118 # 동시에 올라가기
119 for i in range(self.LOG - 1, -1, -1):
120 if self.parent[i][u] != self.parent[i][v]:
121 u = self.parent[i][u]
122 v = self.parent[i][v]
123
124 return self.parent[0][u]
125
126 def kth_ancestor(self, node: int, k: int) -> int:
127 """k번째 조상 찾기 - O(log n)"""
128 for i in range(self.LOG):
129 if node == -1:
130 break
131 if (k >> i) & 1:
132 node = self.parent[i][node]
133 return node
134
135 def distance(self, u: int, v: int) -> int:
136 """두 노드 사이 거리 - O(log n)"""
137 lca = self.query(u, v)
138 return self.depth[u] + self.depth[v] - 2 * self.depth[lca]
139
140
141# =============================================================================
142# 3. Euler Tour + RMQ (Sparse Table)
143# =============================================================================
144
145class LCAEulerTour:
146 """
147 오일러 경로 + RMQ를 이용한 LCA
148 전처리: O(n log n)
149 쿼리: O(1)
150 """
151
152 def __init__(self, n: int, edges: List[Tuple[int, int]], root: int = 0):
153 self.n = n
154 self.adj = defaultdict(list)
155 for a, b in edges:
156 self.adj[a].append(b)
157 self.adj[b].append(a)
158
159 # 오일러 경로 및 첫 등장 위치
160 self.euler = [] # (깊이, 노드) 쌍
161 self.first = [-1] * n # 각 노드의 첫 등장 인덱스
162
163 self._build_euler_tour(root)
164 self._build_sparse_table()
165
166 def _build_euler_tour(self, root: int):
167 """오일러 경로 구성 - O(n)"""
168 stack = [(root, -1, 0, False)]
169
170 while stack:
171 node, parent, depth, visited = stack.pop()
172
173 self.euler.append((depth, node))
174 if self.first[node] == -1:
175 self.first[node] = len(self.euler) - 1
176
177 if visited:
178 continue
179
180 stack.append((node, parent, depth, True))
181 for child in self.adj[node]:
182 if child != parent:
183 stack.append((child, node, depth + 1, False))
184
185 def _build_sparse_table(self):
186 """희소 테이블 구성 - O(n log n)"""
187 m = len(self.euler)
188 self.LOG = max(1, int(math.log2(m)) + 1)
189
190 # sparse[i][j] = euler[j..j+2^i) 구간의 최솟값 인덱스
191 self.sparse = [[0] * m for _ in range(self.LOG)]
192
193 for j in range(m):
194 self.sparse[0][j] = j
195
196 for i in range(1, self.LOG):
197 length = 1 << i
198 for j in range(m - length + 1):
199 left = self.sparse[i - 1][j]
200 right = self.sparse[i - 1][j + (length >> 1)]
201 if self.euler[left][0] <= self.euler[right][0]:
202 self.sparse[i][j] = left
203 else:
204 self.sparse[i][j] = right
205
206 def _rmq(self, left: int, right: int) -> int:
207 """범위 최소 쿼리 - O(1)"""
208 length = right - left + 1
209 k = int(math.log2(length))
210 left_idx = self.sparse[k][left]
211 right_idx = self.sparse[k][right - (1 << k) + 1]
212 if self.euler[left_idx][0] <= self.euler[right_idx][0]:
213 return left_idx
214 return right_idx
215
216 def query(self, u: int, v: int) -> int:
217 """LCA 쿼리 - O(1)"""
218 left = self.first[u]
219 right = self.first[v]
220 if left > right:
221 left, right = right, left
222 idx = self._rmq(left, right)
223 return self.euler[idx][1]
224
225
226# =============================================================================
227# 4. 트리에서 경로 합/최대/최소
228# =============================================================================
229
230class TreePathQuery:
231 """트리 경로 쿼리 (LCA + 가중치)"""
232
233 def __init__(self, n: int, edges: List[Tuple[int, int, int]], root: int = 0):
234 """edges: [(u, v, weight), ...]"""
235 self.n = n
236 self.LOG = max(1, int(math.log2(n)) + 1)
237
238 self.adj = defaultdict(list)
239 for a, b, w in edges:
240 self.adj[a].append((b, w))
241 self.adj[b].append((a, w))
242
243 self.parent = [[-1] * n for _ in range(self.LOG)]
244 self.depth = [0] * n
245 self.dist_from_root = [0] * n # 루트로부터의 거리
246 self.max_edge = [[0] * n for _ in range(self.LOG)] # 경로상 최대 간선
247
248 self._preprocess(root)
249
250 def _preprocess(self, root: int):
251 stack = [(root, -1, 0, 0)]
252
253 while stack:
254 node, par, d, dist = stack.pop()
255 self.parent[0][node] = par
256 self.depth[node] = d
257 self.dist_from_root[node] = dist
258
259 for child, weight in self.adj[node]:
260 if child != par:
261 self.max_edge[0][child] = weight
262 stack.append((child, node, d + 1, dist + weight))
263
264 # 희소 테이블
265 for i in range(1, self.LOG):
266 for v in range(self.n):
267 if self.parent[i - 1][v] != -1:
268 self.parent[i][v] = self.parent[i - 1][self.parent[i - 1][v]]
269 self.max_edge[i][v] = max(
270 self.max_edge[i - 1][v],
271 self.max_edge[i - 1][self.parent[i - 1][v]] if self.parent[i - 1][v] != -1 else 0
272 )
273
274 def lca(self, u: int, v: int) -> int:
275 """LCA 쿼리"""
276 if self.depth[u] < self.depth[v]:
277 u, v = v, u
278
279 diff = self.depth[u] - self.depth[v]
280 for i in range(self.LOG):
281 if (diff >> i) & 1:
282 u = self.parent[i][u]
283
284 if u == v:
285 return u
286
287 for i in range(self.LOG - 1, -1, -1):
288 if self.parent[i][u] != self.parent[i][v]:
289 u = self.parent[i][u]
290 v = self.parent[i][v]
291
292 return self.parent[0][u]
293
294 def path_distance(self, u: int, v: int) -> int:
295 """경로 거리 합"""
296 ancestor = self.lca(u, v)
297 return self.dist_from_root[u] + self.dist_from_root[v] - 2 * self.dist_from_root[ancestor]
298
299 def path_max_edge(self, u: int, v: int) -> int:
300 """경로상 최대 간선 가중치"""
301 ancestor = self.lca(u, v)
302 result = 0
303
304 # u → lca
305 curr = u
306 diff = self.depth[u] - self.depth[ancestor]
307 for i in range(self.LOG):
308 if (diff >> i) & 1:
309 result = max(result, self.max_edge[i][curr])
310 curr = self.parent[i][curr]
311
312 # v → lca
313 curr = v
314 diff = self.depth[v] - self.depth[ancestor]
315 for i in range(self.LOG):
316 if (diff >> i) & 1:
317 result = max(result, self.max_edge[i][curr])
318 curr = self.parent[i][curr]
319
320 return result
321
322
323# =============================================================================
324# 5. 실전 문제: 트리에서 두 노드 사이 경로
325# =============================================================================
326
327def find_path(n: int, edges: List[Tuple[int, int]], u: int, v: int) -> List[int]:
328 """두 노드 사이의 경로 찾기"""
329 lca_solver = LCABinaryLifting(n, edges)
330 ancestor = lca_solver.query(u, v)
331
332 # u → lca
333 path_u = []
334 curr = u
335 while curr != ancestor:
336 path_u.append(curr)
337 curr = lca_solver.parent[0][curr]
338 path_u.append(ancestor)
339
340 # v → lca (역순)
341 path_v = []
342 curr = v
343 while curr != ancestor:
344 path_v.append(curr)
345 curr = lca_solver.parent[0][curr]
346
347 return path_u + path_v[::-1]
348
349
350# =============================================================================
351# 테스트
352# =============================================================================
353
354def main():
355 print("=" * 60)
356 print("최소 공통 조상 (LCA) 예제")
357 print("=" * 60)
358
359 # 트리 구성
360 # 0
361 # / | \
362 # 1 2 3
363 # / \ |
364 # 4 5 6
365 # /
366 # 7
367
368 n = 8
369 edges = [(0, 1), (0, 2), (0, 3), (1, 4), (1, 5), (3, 6), (4, 7)]
370
371 # 1. 기본 LCA
372 print("\n[1] 기본 LCA (Naive)")
373 lca = lca_naive(n, edges, 7, 5)
374 print(f" LCA(7, 5) = {lca}")
375 lca = lca_naive(n, edges, 7, 6)
376 print(f" LCA(7, 6) = {lca}")
377
378 # 2. Binary Lifting
379 print("\n[2] Binary Lifting")
380 lca_bl = LCABinaryLifting(n, edges)
381 print(f" LCA(7, 5) = {lca_bl.query(7, 5)}")
382 print(f" LCA(7, 6) = {lca_bl.query(7, 6)}")
383 print(f" LCA(4, 6) = {lca_bl.query(4, 6)}")
384 print(f" 거리(7, 5) = {lca_bl.distance(7, 5)}")
385 print(f" 7의 2번째 조상 = {lca_bl.kth_ancestor(7, 2)}")
386
387 # 3. Euler Tour + RMQ
388 print("\n[3] Euler Tour + RMQ (O(1) 쿼리)")
389 lca_euler = LCAEulerTour(n, edges)
390 print(f" LCA(7, 5) = {lca_euler.query(7, 5)}")
391 print(f" LCA(7, 6) = {lca_euler.query(7, 6)}")
392
393 # 4. 가중치 트리 경로 쿼리
394 print("\n[4] 가중치 트리 경로 쿼리")
395 weighted_edges = [
396 (0, 1, 3), (0, 2, 5), (0, 3, 4),
397 (1, 4, 2), (1, 5, 6), (3, 6, 1), (4, 7, 8)
398 ]
399 path_query = TreePathQuery(n, weighted_edges)
400 print(f" 경로 거리(7, 5) = {path_query.path_distance(7, 5)}")
401 print(f" 경로 최대 간선(7, 5) = {path_query.path_max_edge(7, 5)}")
402 print(f" 경로 거리(7, 6) = {path_query.path_distance(7, 6)}")
403
404 # 5. 경로 찾기
405 print("\n[5] 두 노드 사이 경로")
406 path = find_path(n, edges, 7, 6)
407 print(f" 경로(7, 6) = {path}")
408 path = find_path(n, edges, 5, 2)
409 print(f" 경로(5, 2) = {path}")
410
411 # 6. 성능 비교
412 print("\n[6] 복잡도 비교")
413 print(" | 방법 | 전처리 | 쿼리 |")
414 print(" |----------------|------------|---------|")
415 print(" | Naive | O(n) | O(n) |")
416 print(" | Binary Lifting | O(n log n) | O(log n)|")
417 print(" | Euler + RMQ | O(n log n) | O(1) |")
418
419 print("\n" + "=" * 60)
420
421
422if __name__ == "__main__":
423 main()