16_lca.py

Download
python 424 lines 12.8 KB
  1"""
  2최소 공통 조상 (LCA - Lowest Common Ancestor)
  3LCA and Tree Queries
  4
  5트리에서 두 노드의 최소 공통 조상을 찾는 알고리즘입니다.
  6"""
  7
  8from typing import List, Tuple, Optional
  9from collections import defaultdict, deque
 10import math
 11
 12
 13# =============================================================================
 14# 1. 기본 LCA (Naive)
 15# =============================================================================
 16
 17def lca_naive(n: int, edges: List[Tuple[int, int]], u: int, v: int) -> int:
 18    """
 19    기본 LCA (높이 맞추기)
 20    시간복잡도: O(n) per query
 21    전처리: O(n)
 22    """
 23    # 트리 구성
 24    adj = defaultdict(list)
 25    for a, b in edges:
 26        adj[a].append(b)
 27        adj[b].append(a)
 28
 29    # 부모와 깊이 계산
 30    parent = [-1] * n
 31    depth = [0] * n
 32
 33    def dfs(node: int, par: int, d: int):
 34        parent[node] = par
 35        depth[node] = d
 36        for child in adj[node]:
 37            if child != par:
 38                dfs(child, node, d + 1)
 39
 40    dfs(0, -1, 0)
 41
 42    # 높이 맞추기
 43    while depth[u] > depth[v]:
 44        u = parent[u]
 45    while depth[v] > depth[u]:
 46        v = parent[v]
 47
 48    # 동시에 올라가기
 49    while u != v:
 50        u = parent[u]
 51        v = parent[v]
 52
 53    return u
 54
 55
 56# =============================================================================
 57# 2. Binary Lifting (희소 테이블)
 58# =============================================================================
 59
 60class LCABinaryLifting:
 61    """
 62    Binary Lifting을 이용한 LCA
 63    전처리: O(n log n)
 64    쿼리: O(log n)
 65    """
 66
 67    def __init__(self, n: int, edges: List[Tuple[int, int]], root: int = 0):
 68        self.n = n
 69        self.LOG = max(1, int(math.log2(n)) + 1)
 70
 71        # 그래프 구성
 72        self.adj = defaultdict(list)
 73        for a, b in edges:
 74            self.adj[a].append(b)
 75            self.adj[b].append(a)
 76
 77        # 전처리
 78        self.parent = [[-1] * n for _ in range(self.LOG)]
 79        self.depth = [0] * n
 80
 81        self._preprocess(root)
 82
 83    def _preprocess(self, root: int):
 84        """DFS로 부모/깊이 계산 + 희소 테이블 구성"""
 85        stack = [(root, -1, 0)]
 86
 87        while stack:
 88            node, par, d = stack.pop()
 89            self.parent[0][node] = par
 90            self.depth[node] = d
 91
 92            for child in self.adj[node]:
 93                if child != par:
 94                    stack.append((child, node, d + 1))
 95
 96        # 희소 테이블 구성: parent[i][v] = v의 2^i번째 조상
 97        for i in range(1, self.LOG):
 98            for v in range(self.n):
 99                if self.parent[i - 1][v] != -1:
100                    self.parent[i][v] = self.parent[i - 1][self.parent[i - 1][v]]
101
102    def query(self, u: int, v: int) -> int:
103        """LCA 쿼리 - O(log n)"""
104        # u가 더 깊도록 조정
105        if self.depth[u] < self.depth[v]:
106            u, v = v, u
107
108        # 높이 맞추기
109        diff = self.depth[u] - self.depth[v]
110        for i in range(self.LOG):
111            if (diff >> i) & 1:
112                u = self.parent[i][u]
113
114        # 같으면 완료
115        if u == v:
116            return u
117
118        # 동시에 올라가기
119        for i in range(self.LOG - 1, -1, -1):
120            if self.parent[i][u] != self.parent[i][v]:
121                u = self.parent[i][u]
122                v = self.parent[i][v]
123
124        return self.parent[0][u]
125
126    def kth_ancestor(self, node: int, k: int) -> int:
127        """k번째 조상 찾기 - O(log n)"""
128        for i in range(self.LOG):
129            if node == -1:
130                break
131            if (k >> i) & 1:
132                node = self.parent[i][node]
133        return node
134
135    def distance(self, u: int, v: int) -> int:
136        """두 노드 사이 거리 - O(log n)"""
137        lca = self.query(u, v)
138        return self.depth[u] + self.depth[v] - 2 * self.depth[lca]
139
140
141# =============================================================================
142# 3. Euler Tour + RMQ (Sparse Table)
143# =============================================================================
144
145class LCAEulerTour:
146    """
147    오일러 경로 + RMQ를 이용한 LCA
148    전처리: O(n log n)
149    쿼리: O(1)
150    """
151
152    def __init__(self, n: int, edges: List[Tuple[int, int]], root: int = 0):
153        self.n = n
154        self.adj = defaultdict(list)
155        for a, b in edges:
156            self.adj[a].append(b)
157            self.adj[b].append(a)
158
159        # 오일러 경로 및 첫 등장 위치
160        self.euler = []  # (깊이, 노드) 쌍
161        self.first = [-1] * n  # 각 노드의 첫 등장 인덱스
162
163        self._build_euler_tour(root)
164        self._build_sparse_table()
165
166    def _build_euler_tour(self, root: int):
167        """오일러 경로 구성 - O(n)"""
168        stack = [(root, -1, 0, False)]
169
170        while stack:
171            node, parent, depth, visited = stack.pop()
172
173            self.euler.append((depth, node))
174            if self.first[node] == -1:
175                self.first[node] = len(self.euler) - 1
176
177            if visited:
178                continue
179
180            stack.append((node, parent, depth, True))
181            for child in self.adj[node]:
182                if child != parent:
183                    stack.append((child, node, depth + 1, False))
184
185    def _build_sparse_table(self):
186        """희소 테이블 구성 - O(n log n)"""
187        m = len(self.euler)
188        self.LOG = max(1, int(math.log2(m)) + 1)
189
190        # sparse[i][j] = euler[j..j+2^i) 구간의 최솟값 인덱스
191        self.sparse = [[0] * m for _ in range(self.LOG)]
192
193        for j in range(m):
194            self.sparse[0][j] = j
195
196        for i in range(1, self.LOG):
197            length = 1 << i
198            for j in range(m - length + 1):
199                left = self.sparse[i - 1][j]
200                right = self.sparse[i - 1][j + (length >> 1)]
201                if self.euler[left][0] <= self.euler[right][0]:
202                    self.sparse[i][j] = left
203                else:
204                    self.sparse[i][j] = right
205
206    def _rmq(self, left: int, right: int) -> int:
207        """범위 최소 쿼리 - O(1)"""
208        length = right - left + 1
209        k = int(math.log2(length))
210        left_idx = self.sparse[k][left]
211        right_idx = self.sparse[k][right - (1 << k) + 1]
212        if self.euler[left_idx][0] <= self.euler[right_idx][0]:
213            return left_idx
214        return right_idx
215
216    def query(self, u: int, v: int) -> int:
217        """LCA 쿼리 - O(1)"""
218        left = self.first[u]
219        right = self.first[v]
220        if left > right:
221            left, right = right, left
222        idx = self._rmq(left, right)
223        return self.euler[idx][1]
224
225
226# =============================================================================
227# 4. 트리에서 경로 합/최대/최소
228# =============================================================================
229
230class TreePathQuery:
231    """트리 경로 쿼리 (LCA + 가중치)"""
232
233    def __init__(self, n: int, edges: List[Tuple[int, int, int]], root: int = 0):
234        """edges: [(u, v, weight), ...]"""
235        self.n = n
236        self.LOG = max(1, int(math.log2(n)) + 1)
237
238        self.adj = defaultdict(list)
239        for a, b, w in edges:
240            self.adj[a].append((b, w))
241            self.adj[b].append((a, w))
242
243        self.parent = [[-1] * n for _ in range(self.LOG)]
244        self.depth = [0] * n
245        self.dist_from_root = [0] * n  # 루트로부터의 거리
246        self.max_edge = [[0] * n for _ in range(self.LOG)]  # 경로상 최대 간선
247
248        self._preprocess(root)
249
250    def _preprocess(self, root: int):
251        stack = [(root, -1, 0, 0)]
252
253        while stack:
254            node, par, d, dist = stack.pop()
255            self.parent[0][node] = par
256            self.depth[node] = d
257            self.dist_from_root[node] = dist
258
259            for child, weight in self.adj[node]:
260                if child != par:
261                    self.max_edge[0][child] = weight
262                    stack.append((child, node, d + 1, dist + weight))
263
264        # 희소 테이블
265        for i in range(1, self.LOG):
266            for v in range(self.n):
267                if self.parent[i - 1][v] != -1:
268                    self.parent[i][v] = self.parent[i - 1][self.parent[i - 1][v]]
269                    self.max_edge[i][v] = max(
270                        self.max_edge[i - 1][v],
271                        self.max_edge[i - 1][self.parent[i - 1][v]] if self.parent[i - 1][v] != -1 else 0
272                    )
273
274    def lca(self, u: int, v: int) -> int:
275        """LCA 쿼리"""
276        if self.depth[u] < self.depth[v]:
277            u, v = v, u
278
279        diff = self.depth[u] - self.depth[v]
280        for i in range(self.LOG):
281            if (diff >> i) & 1:
282                u = self.parent[i][u]
283
284        if u == v:
285            return u
286
287        for i in range(self.LOG - 1, -1, -1):
288            if self.parent[i][u] != self.parent[i][v]:
289                u = self.parent[i][u]
290                v = self.parent[i][v]
291
292        return self.parent[0][u]
293
294    def path_distance(self, u: int, v: int) -> int:
295        """경로 거리 합"""
296        ancestor = self.lca(u, v)
297        return self.dist_from_root[u] + self.dist_from_root[v] - 2 * self.dist_from_root[ancestor]
298
299    def path_max_edge(self, u: int, v: int) -> int:
300        """경로상 최대 간선 가중치"""
301        ancestor = self.lca(u, v)
302        result = 0
303
304        # u → lca
305        curr = u
306        diff = self.depth[u] - self.depth[ancestor]
307        for i in range(self.LOG):
308            if (diff >> i) & 1:
309                result = max(result, self.max_edge[i][curr])
310                curr = self.parent[i][curr]
311
312        # v → lca
313        curr = v
314        diff = self.depth[v] - self.depth[ancestor]
315        for i in range(self.LOG):
316            if (diff >> i) & 1:
317                result = max(result, self.max_edge[i][curr])
318                curr = self.parent[i][curr]
319
320        return result
321
322
323# =============================================================================
324# 5. 실전 문제: 트리에서 두 노드 사이 경로
325# =============================================================================
326
327def find_path(n: int, edges: List[Tuple[int, int]], u: int, v: int) -> List[int]:
328    """두 노드 사이의 경로 찾기"""
329    lca_solver = LCABinaryLifting(n, edges)
330    ancestor = lca_solver.query(u, v)
331
332    # u → lca
333    path_u = []
334    curr = u
335    while curr != ancestor:
336        path_u.append(curr)
337        curr = lca_solver.parent[0][curr]
338    path_u.append(ancestor)
339
340    # v → lca (역순)
341    path_v = []
342    curr = v
343    while curr != ancestor:
344        path_v.append(curr)
345        curr = lca_solver.parent[0][curr]
346
347    return path_u + path_v[::-1]
348
349
350# =============================================================================
351# 테스트
352# =============================================================================
353
354def main():
355    print("=" * 60)
356    print("최소 공통 조상 (LCA) 예제")
357    print("=" * 60)
358
359    # 트리 구성
360    #        0
361    #      / | \
362    #     1  2  3
363    #    / \    |
364    #   4   5   6
365    #  /
366    # 7
367
368    n = 8
369    edges = [(0, 1), (0, 2), (0, 3), (1, 4), (1, 5), (3, 6), (4, 7)]
370
371    # 1. 기본 LCA
372    print("\n[1] 기본 LCA (Naive)")
373    lca = lca_naive(n, edges, 7, 5)
374    print(f"    LCA(7, 5) = {lca}")
375    lca = lca_naive(n, edges, 7, 6)
376    print(f"    LCA(7, 6) = {lca}")
377
378    # 2. Binary Lifting
379    print("\n[2] Binary Lifting")
380    lca_bl = LCABinaryLifting(n, edges)
381    print(f"    LCA(7, 5) = {lca_bl.query(7, 5)}")
382    print(f"    LCA(7, 6) = {lca_bl.query(7, 6)}")
383    print(f"    LCA(4, 6) = {lca_bl.query(4, 6)}")
384    print(f"    거리(7, 5) = {lca_bl.distance(7, 5)}")
385    print(f"    7의 2번째 조상 = {lca_bl.kth_ancestor(7, 2)}")
386
387    # 3. Euler Tour + RMQ
388    print("\n[3] Euler Tour + RMQ (O(1) 쿼리)")
389    lca_euler = LCAEulerTour(n, edges)
390    print(f"    LCA(7, 5) = {lca_euler.query(7, 5)}")
391    print(f"    LCA(7, 6) = {lca_euler.query(7, 6)}")
392
393    # 4. 가중치 트리 경로 쿼리
394    print("\n[4] 가중치 트리 경로 쿼리")
395    weighted_edges = [
396        (0, 1, 3), (0, 2, 5), (0, 3, 4),
397        (1, 4, 2), (1, 5, 6), (3, 6, 1), (4, 7, 8)
398    ]
399    path_query = TreePathQuery(n, weighted_edges)
400    print(f"    경로 거리(7, 5) = {path_query.path_distance(7, 5)}")
401    print(f"    경로 최대 간선(7, 5) = {path_query.path_max_edge(7, 5)}")
402    print(f"    경로 거리(7, 6) = {path_query.path_distance(7, 6)}")
403
404    # 5. 경로 찾기
405    print("\n[5] 두 노드 사이 경로")
406    path = find_path(n, edges, 7, 6)
407    print(f"    경로(7, 6) = {path}")
408    path = find_path(n, edges, 5, 2)
409    print(f"    경로(5, 2) = {path}")
410
411    # 6. 성능 비교
412    print("\n[6] 복잡도 비교")
413    print("    | 방법           | 전처리     | 쿼리    |")
414    print("    |----------------|------------|---------|")
415    print("    | Naive          | O(n)       | O(n)    |")
416    print("    | Binary Lifting | O(n log n) | O(log n)|")
417    print("    | Euler + RMQ    | O(n log n) | O(1)    |")
418
419    print("\n" + "=" * 60)
420
421
422if __name__ == "__main__":
423    main()