09_tree_bst.py

Download
python 565 lines 14.8 KB
  1"""
  2ํŠธ๋ฆฌ์™€ ์ด์ง„ ํƒ์ƒ‰ ํŠธ๋ฆฌ (Tree & BST)
  3Tree and Binary Search Tree
  4
  5ํŠธ๋ฆฌ ๊ตฌ์กฐ์™€ BST ์—ฐ์‚ฐ์„ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค.
  6"""
  7
  8from typing import List, Optional, Generator
  9from collections import deque
 10
 11
 12# =============================================================================
 13# 1. ์ด์ง„ ํŠธ๋ฆฌ ๋…ธ๋“œ
 14# =============================================================================
 15
 16class TreeNode:
 17    """์ด์ง„ ํŠธ๋ฆฌ ๋…ธ๋“œ"""
 18
 19    def __init__(self, val: int, left: 'TreeNode' = None, right: 'TreeNode' = None):
 20        self.val = val
 21        self.left = left
 22        self.right = right
 23
 24    def __repr__(self):
 25        return f"TreeNode({self.val})"
 26
 27
 28# =============================================================================
 29# 2. ํŠธ๋ฆฌ ์ˆœํšŒ (Tree Traversal)
 30# =============================================================================
 31
 32def preorder_recursive(root: TreeNode) -> List[int]:
 33    """์ „์œ„ ์ˆœํšŒ (์žฌ๊ท€) - O(n)"""
 34    result = []
 35
 36    def traverse(node):
 37        if not node:
 38            return
 39        result.append(node.val)
 40        traverse(node.left)
 41        traverse(node.right)
 42
 43    traverse(root)
 44    return result
 45
 46
 47def preorder_iterative(root: TreeNode) -> List[int]:
 48    """์ „์œ„ ์ˆœํšŒ (๋ฐ˜๋ณต) - O(n)"""
 49    if not root:
 50        return []
 51
 52    result = []
 53    stack = [root]
 54
 55    while stack:
 56        node = stack.pop()
 57        result.append(node.val)
 58
 59        # ์˜ค๋ฅธ์ชฝ ๋จผ์ € push (์™ผ์ชฝ์ด ๋จผ์ € ์ฒ˜๋ฆฌ๋˜๋„๋ก)
 60        if node.right:
 61            stack.append(node.right)
 62        if node.left:
 63            stack.append(node.left)
 64
 65    return result
 66
 67
 68def inorder_recursive(root: TreeNode) -> List[int]:
 69    """์ค‘์œ„ ์ˆœํšŒ (์žฌ๊ท€) - O(n)"""
 70    result = []
 71
 72    def traverse(node):
 73        if not node:
 74            return
 75        traverse(node.left)
 76        result.append(node.val)
 77        traverse(node.right)
 78
 79    traverse(root)
 80    return result
 81
 82
 83def inorder_iterative(root: TreeNode) -> List[int]:
 84    """์ค‘์œ„ ์ˆœํšŒ (๋ฐ˜๋ณต) - O(n)"""
 85    result = []
 86    stack = []
 87    current = root
 88
 89    while stack or current:
 90        # ์™ผ์ชฝ ๋๊นŒ์ง€ ์ด๋™
 91        while current:
 92            stack.append(current)
 93            current = current.left
 94
 95        current = stack.pop()
 96        result.append(current.val)
 97        current = current.right
 98
 99    return result
100
101
102def postorder_recursive(root: TreeNode) -> List[int]:
103    """ํ›„์œ„ ์ˆœํšŒ (์žฌ๊ท€) - O(n)"""
104    result = []
105
106    def traverse(node):
107        if not node:
108            return
109        traverse(node.left)
110        traverse(node.right)
111        result.append(node.val)
112
113    traverse(root)
114    return result
115
116
117def postorder_iterative(root: TreeNode) -> List[int]:
118    """ํ›„์œ„ ์ˆœํšŒ (๋ฐ˜๋ณต) - O(n)"""
119    if not root:
120        return []
121
122    result = []
123    stack = [root]
124
125    while stack:
126        node = stack.pop()
127        result.append(node.val)
128
129        if node.left:
130            stack.append(node.left)
131        if node.right:
132            stack.append(node.right)
133
134    return result[::-1]  # ์—ญ์ˆœ
135
136
137def level_order(root: TreeNode) -> List[List[int]]:
138    """๋ ˆ๋ฒจ ์ˆœํšŒ (BFS) - O(n)"""
139    if not root:
140        return []
141
142    result = []
143    queue = deque([root])
144
145    while queue:
146        level_size = len(queue)
147        level = []
148
149        for _ in range(level_size):
150            node = queue.popleft()
151            level.append(node.val)
152
153            if node.left:
154                queue.append(node.left)
155            if node.right:
156                queue.append(node.right)
157
158        result.append(level)
159
160    return result
161
162
163# =============================================================================
164# 3. ์ด์ง„ ํƒ์ƒ‰ ํŠธ๋ฆฌ (BST)
165# =============================================================================
166
167class BST:
168    """์ด์ง„ ํƒ์ƒ‰ ํŠธ๋ฆฌ"""
169
170    def __init__(self):
171        self.root: Optional[TreeNode] = None
172
173    def insert(self, val: int) -> None:
174        """๋…ธ๋“œ ์‚ฝ์ž… - ํ‰๊ท  O(log n), ์ตœ์•… O(n)"""
175        if not self.root:
176            self.root = TreeNode(val)
177            return
178
179        current = self.root
180        while True:
181            if val < current.val:
182                if current.left is None:
183                    current.left = TreeNode(val)
184                    return
185                current = current.left
186            else:
187                if current.right is None:
188                    current.right = TreeNode(val)
189                    return
190                current = current.right
191
192    def search(self, val: int) -> Optional[TreeNode]:
193        """๋…ธ๋“œ ๊ฒ€์ƒ‰ - ํ‰๊ท  O(log n), ์ตœ์•… O(n)"""
194        current = self.root
195
196        while current:
197            if val == current.val:
198                return current
199            elif val < current.val:
200                current = current.left
201            else:
202                current = current.right
203
204        return None
205
206    def delete(self, val: int) -> bool:
207        """๋…ธ๋“œ ์‚ญ์ œ - ํ‰๊ท  O(log n), ์ตœ์•… O(n)"""
208
209        def find_min(node: TreeNode) -> TreeNode:
210            while node.left:
211                node = node.left
212            return node
213
214        def delete_recursive(node: TreeNode, val: int) -> Optional[TreeNode]:
215            if not node:
216                return None
217
218            if val < node.val:
219                node.left = delete_recursive(node.left, val)
220            elif val > node.val:
221                node.right = delete_recursive(node.right, val)
222            else:
223                # ์‚ญ์ œํ•  ๋…ธ๋“œ ๋ฐœ๊ฒฌ
224
225                # Case 1: ๋ฆฌํ”„ ๋…ธ๋“œ
226                if not node.left and not node.right:
227                    return None
228
229                # Case 2: ์ž์‹์ด ํ•˜๋‚˜
230                if not node.left:
231                    return node.right
232                if not node.right:
233                    return node.left
234
235                # Case 3: ์ž์‹์ด ๋‘˜ - ํ›„๊ณ„์ž(์˜ค๋ฅธ์ชฝ ์„œ๋ธŒํŠธ๋ฆฌ์˜ ์ตœ์†Œ๊ฐ’)๋กœ ๋Œ€์ฒด
236                successor = find_min(node.right)
237                node.val = successor.val
238                node.right = delete_recursive(node.right, successor.val)
239
240            return node
241
242        old_root = self.root
243        self.root = delete_recursive(self.root, val)
244        return old_root != self.root or (self.root and old_root.val != val if old_root else False)
245
246    def inorder(self) -> List[int]:
247        """์ค‘์œ„ ์ˆœํšŒ (์ •๋ ฌ๋œ ์ˆœ์„œ)"""
248        return inorder_recursive(self.root)
249
250    def find_min(self) -> Optional[int]:
251        """์ตœ์†Ÿ๊ฐ’ ์ฐพ๊ธฐ - O(h)"""
252        if not self.root:
253            return None
254
255        current = self.root
256        while current.left:
257            current = current.left
258        return current.val
259
260    def find_max(self) -> Optional[int]:
261        """์ตœ๋Œ“๊ฐ’ ์ฐพ๊ธฐ - O(h)"""
262        if not self.root:
263            return None
264
265        current = self.root
266        while current.right:
267            current = current.right
268        return current.val
269
270
271# =============================================================================
272# 4. ํŠธ๋ฆฌ ์†์„ฑ ๊ฒ€์‚ฌ
273# =============================================================================
274
275def tree_height(root: TreeNode) -> int:
276    """ํŠธ๋ฆฌ ๋†’์ด ๊ณ„์‚ฐ - O(n)"""
277    if not root:
278        return -1  # ๋นˆ ํŠธ๋ฆฌ๋Š” ๋†’์ด -1, ๋…ธ๋“œ 1๊ฐœ๋Š” ๋†’์ด 0
279
280    return 1 + max(tree_height(root.left), tree_height(root.right))
281
282
283def is_balanced(root: TreeNode) -> bool:
284    """๊ท ํ˜• ํŠธ๋ฆฌ ๊ฒ€์‚ฌ - O(n)"""
285
286    def check(node: TreeNode) -> int:
287        if not node:
288            return 0
289
290        left_height = check(node.left)
291        if left_height == -1:
292            return -1
293
294        right_height = check(node.right)
295        if right_height == -1:
296            return -1
297
298        if abs(left_height - right_height) > 1:
299            return -1
300
301        return 1 + max(left_height, right_height)
302
303    return check(root) != -1
304
305
306def is_valid_bst(root: TreeNode) -> bool:
307    """์œ ํšจํ•œ BST ๊ฒ€์‚ฌ - O(n)"""
308
309    def validate(node: TreeNode, min_val: float, max_val: float) -> bool:
310        if not node:
311            return True
312
313        if node.val <= min_val or node.val >= max_val:
314            return False
315
316        return (validate(node.left, min_val, node.val) and
317                validate(node.right, node.val, max_val))
318
319    return validate(root, float('-inf'), float('inf'))
320
321
322def count_nodes(root: TreeNode) -> int:
323    """๋…ธ๋“œ ๊ฐœ์ˆ˜ - O(n)"""
324    if not root:
325        return 0
326    return 1 + count_nodes(root.left) + count_nodes(root.right)
327
328
329# =============================================================================
330# 5. ํŠธ๋ฆฌ ๋ณ€ํ™˜/๊ตฌ์„ฑ
331# =============================================================================
332
333def build_tree_from_list(values: List[Optional[int]]) -> Optional[TreeNode]:
334    """๋ ˆ๋ฒจ ์ˆœ์„œ ๋ฆฌ์ŠคํŠธ๋กœ ํŠธ๋ฆฌ ๊ตฌ์„ฑ - O(n)"""
335    if not values or values[0] is None:
336        return None
337
338    root = TreeNode(values[0])
339    queue = deque([root])
340    i = 1
341
342    while queue and i < len(values):
343        node = queue.popleft()
344
345        if i < len(values) and values[i] is not None:
346            node.left = TreeNode(values[i])
347            queue.append(node.left)
348        i += 1
349
350        if i < len(values) and values[i] is not None:
351            node.right = TreeNode(values[i])
352            queue.append(node.right)
353        i += 1
354
355    return root
356
357
358def sorted_array_to_bst(nums: List[int]) -> Optional[TreeNode]:
359    """์ •๋ ฌ๋œ ๋ฐฐ์—ด๋กœ ๊ท ํ˜• BST ๊ตฌ์„ฑ - O(n)"""
360    if not nums:
361        return None
362
363    def build(left: int, right: int) -> Optional[TreeNode]:
364        if left > right:
365            return None
366
367        mid = (left + right) // 2
368        node = TreeNode(nums[mid])
369        node.left = build(left, mid - 1)
370        node.right = build(mid + 1, right)
371        return node
372
373    return build(0, len(nums) - 1)
374
375
376def invert_tree(root: TreeNode) -> TreeNode:
377    """ํŠธ๋ฆฌ ์ขŒ์šฐ ๋ฐ˜์ „ - O(n)"""
378    if not root:
379        return None
380
381    root.left, root.right = invert_tree(root.right), invert_tree(root.left)
382    return root
383
384
385# =============================================================================
386# 6. ์‹ค์ „ ๋ฌธ์ œ
387# =============================================================================
388
389def lowest_common_ancestor(root: TreeNode, p: int, q: int) -> Optional[TreeNode]:
390    """BST์—์„œ ์ตœ์†Œ ๊ณตํ†ต ์กฐ์ƒ (LCA) - O(h)"""
391    current = root
392
393    while current:
394        if p < current.val and q < current.val:
395            current = current.left
396        elif p > current.val and q > current.val:
397            current = current.right
398        else:
399            return current
400
401    return None
402
403
404def kth_smallest(root: TreeNode, k: int) -> int:
405    """BST์—์„œ k๋ฒˆ์งธ ์ž‘์€ ๊ฐ’ - O(h + k)"""
406    stack = []
407    current = root
408    count = 0
409
410    while stack or current:
411        while current:
412            stack.append(current)
413            current = current.left
414
415        current = stack.pop()
416        count += 1
417
418        if count == k:
419            return current.val
420
421        current = current.right
422
423    return -1
424
425
426def path_sum(root: TreeNode, target: int) -> bool:
427    """๋ฃจํŠธ~๋ฆฌํ”„ ๊ฒฝ๋กœ ํ•ฉ ํ™•์ธ - O(n)"""
428    if not root:
429        return False
430
431    if not root.left and not root.right:
432        return root.val == target
433
434    remaining = target - root.val
435    return path_sum(root.left, remaining) or path_sum(root.right, remaining)
436
437
438def serialize(root: TreeNode) -> str:
439    """ํŠธ๋ฆฌ ์ง๋ ฌํ™” - O(n)"""
440    if not root:
441        return "[]"
442
443    result = []
444    queue = deque([root])
445
446    while queue:
447        node = queue.popleft()
448        if node:
449            result.append(str(node.val))
450            queue.append(node.left)
451            queue.append(node.right)
452        else:
453            result.append("null")
454
455    # ๋์˜ null ์ œ๊ฑฐ
456    while result and result[-1] == "null":
457        result.pop()
458
459    return "[" + ",".join(result) + "]"
460
461
462# =============================================================================
463# ์œ ํ‹ธ๋ฆฌํ‹ฐ: ํŠธ๋ฆฌ ์‹œ๊ฐํ™”
464# =============================================================================
465
466def print_tree(root: TreeNode, prefix: str = "", is_left: bool = True) -> None:
467    """ํŠธ๋ฆฌ ASCII ์ถœ๋ ฅ"""
468    if not root:
469        return
470
471    print(prefix + ("โ”œโ”€โ”€ " if is_left else "โ””โ”€โ”€ ") + str(root.val))
472
473    children = []
474    if root.left:
475        children.append((root.left, True))
476    if root.right:
477        children.append((root.right, False))
478
479    for i, (child, is_left_child) in enumerate(children):
480        extension = "โ”‚   " if is_left and i < len(children) - 1 else "    "
481        print_tree(child, prefix + extension, is_left_child)
482
483
484# =============================================================================
485# ํ…Œ์ŠคํŠธ
486# =============================================================================
487
488def main():
489    print("=" * 60)
490    print("ํŠธ๋ฆฌ์™€ ์ด์ง„ ํƒ์ƒ‰ ํŠธ๋ฆฌ (Tree & BST) ์˜ˆ์ œ")
491    print("=" * 60)
492
493    # 1. ํŠธ๋ฆฌ ๊ตฌ์„ฑ
494    print("\n[1] ํŠธ๋ฆฌ ๊ตฌ์„ฑ")
495    #       4
496    #      / \
497    #     2   6
498    #    / \ / \
499    #   1  3 5  7
500    root = build_tree_from_list([4, 2, 6, 1, 3, 5, 7])
501    print("    ๋ ˆ๋ฒจ ์ˆœ์„œ: [4, 2, 6, 1, 3, 5, 7]")
502    print("    ํŠธ๋ฆฌ ๊ตฌ์กฐ:")
503    print_tree(root, "    ")
504
505    # 2. ํŠธ๋ฆฌ ์ˆœํšŒ
506    print("\n[2] ํŠธ๋ฆฌ ์ˆœํšŒ")
507    print(f"    ์ „์œ„ (Preorder):  {preorder_recursive(root)}")
508    print(f"    ์ค‘์œ„ (Inorder):   {inorder_recursive(root)}")
509    print(f"    ํ›„์œ„ (Postorder): {postorder_recursive(root)}")
510    print(f"    ๋ ˆ๋ฒจ (Level):     {level_order(root)}")
511
512    # 3. ํŠธ๋ฆฌ ์†์„ฑ
513    print("\n[3] ํŠธ๋ฆฌ ์†์„ฑ")
514    print(f"    ๋†’์ด: {tree_height(root)}")
515    print(f"    ๋…ธ๋“œ ์ˆ˜: {count_nodes(root)}")
516    print(f"    ๊ท ํ˜• ํŠธ๋ฆฌ: {is_balanced(root)}")
517    print(f"    ์œ ํšจํ•œ BST: {is_valid_bst(root)}")
518
519    # 4. BST ์—ฐ์‚ฐ
520    print("\n[4] BST ์—ฐ์‚ฐ")
521    bst = BST()
522    for val in [5, 3, 7, 1, 4, 6, 8]:
523        bst.insert(val)
524    print(f"    ์‚ฝ์ž…: [5, 3, 7, 1, 4, 6, 8]")
525    print(f"    ์ค‘์œ„ ์ˆœํšŒ: {bst.inorder()}")
526    print(f"    ๊ฒ€์ƒ‰ 4: {bst.search(4)}")
527    print(f"    ์ตœ์†Ÿ๊ฐ’: {bst.find_min()}, ์ตœ๋Œ“๊ฐ’: {bst.find_max()}")
528
529    bst.delete(3)
530    print(f"    ์‚ญ์ œ 3 ํ›„: {bst.inorder()}")
531
532    # 5. ์ •๋ ฌ ๋ฐฐ์—ด โ†’ ๊ท ํ˜• BST
533    print("\n[5] ์ •๋ ฌ ๋ฐฐ์—ด โ†’ ๊ท ํ˜• BST")
534    arr = [1, 2, 3, 4, 5, 6, 7]
535    balanced_bst = sorted_array_to_bst(arr)
536    print(f"    ์ž…๋ ฅ: {arr}")
537    print(f"    ๋ ˆ๋ฒจ ์ˆœํšŒ: {level_order(balanced_bst)}")
538
539    # 6. LCA
540    print("\n[6] ์ตœ์†Œ ๊ณตํ†ต ์กฐ์ƒ (LCA)")
541    lca = lowest_common_ancestor(root, 1, 3)
542    print(f"    ๋…ธ๋“œ 1, 3์˜ LCA: {lca.val if lca else None}")
543    lca = lowest_common_ancestor(root, 1, 6)
544    print(f"    ๋…ธ๋“œ 1, 6์˜ LCA: {lca.val if lca else None}")
545
546    # 7. k๋ฒˆ์งธ ์ž‘์€ ๊ฐ’
547    print("\n[7] k๋ฒˆ์งธ ์ž‘์€ ๊ฐ’")
548    for k in [1, 3, 5]:
549        print(f"    {k}๋ฒˆ์งธ ์ž‘์€ ๊ฐ’: {kth_smallest(root, k)}")
550
551    # 8. ๊ฒฝ๋กœ ํ•ฉ
552    print("\n[8] ๋ฃจํŠธ~๋ฆฌํ”„ ๊ฒฝ๋กœ ํ•ฉ")
553    print(f"    ํ•ฉ 7 (4โ†’2โ†’1): {path_sum(root, 7)}")
554    print(f"    ํ•ฉ 10 (4โ†’6): {path_sum(root, 10)}")
555
556    # 9. ์ง๋ ฌํ™”
557    print("\n[9] ํŠธ๋ฆฌ ์ง๋ ฌํ™”")
558    print(f"    {serialize(root)}")
559
560    print("\n" + "=" * 60)
561
562
563if __name__ == "__main__":
564    main()