24_fenwick_tree.py

Download
python 415 lines 12.6 KB
  1"""
  2ํŽœ์œ… ํŠธ๋ฆฌ (Fenwick Tree / Binary Indexed Tree)
  3Fenwick Tree (BIT)
  4
  5๊ตฌ๊ฐ„ ํ•ฉ๊ณผ ์  ์—…๋ฐ์ดํŠธ๋ฅผ ํšจ์œจ์ ์œผ๋กœ ์ฒ˜๋ฆฌํ•˜๋Š” ์ž๋ฃŒ๊ตฌ์กฐ์ž…๋‹ˆ๋‹ค.
  6"""
  7
  8from typing import List
  9
 10
 11# =============================================================================
 12# 1. ๊ธฐ๋ณธ ํŽœ์œ… ํŠธ๋ฆฌ (๊ตฌ๊ฐ„ ํ•ฉ)
 13# =============================================================================
 14
 15class FenwickTree:
 16    """
 17    ํŽœ์œ… ํŠธ๋ฆฌ (Binary Indexed Tree)
 18    - ์  ์—…๋ฐ์ดํŠธ: O(log n)
 19    - ์ ‘๋‘์‚ฌ ํ•ฉ: O(log n)
 20    - ๊ตฌ๊ฐ„ ํ•ฉ: O(log n)
 21    - ๊ณต๊ฐ„: O(n)
 22    """
 23
 24    def __init__(self, n: int):
 25        """ํฌ๊ธฐ n์˜ ๋นˆ ํŽœ์œ… ํŠธ๋ฆฌ ์ƒ์„ฑ"""
 26        self.n = n
 27        self.tree = [0] * (n + 1)  # 1-indexed
 28
 29    @classmethod
 30    def from_array(cls, arr: List[int]) -> 'FenwickTree':
 31        """๋ฐฐ์—ด๋กœ๋ถ€ํ„ฐ ํŽœ์œ… ํŠธ๋ฆฌ ์ƒ์„ฑ - O(n)"""
 32        ft = cls(len(arr))
 33
 34        # ํšจ์œจ์ ์ธ ๊ตฌ์„ฑ (O(n))
 35        for i, val in enumerate(arr):
 36            ft.tree[i + 1] += val
 37            parent = i + 1 + (ft._lowbit(i + 1))
 38            if parent <= ft.n:
 39                ft.tree[parent] += ft.tree[i + 1]
 40
 41        return ft
 42
 43    def _lowbit(self, x: int) -> int:
 44        """์ตœํ•˜์œ„ ๋น„ํŠธ (x & -x)"""
 45        return x & (-x)
 46
 47    def update(self, idx: int, delta: int):
 48        """idx ์œ„์น˜์— delta ๋”ํ•˜๊ธฐ (0-indexed) - O(log n)"""
 49        idx += 1  # 1-indexed๋กœ ๋ณ€ํ™˜
 50
 51        while idx <= self.n:
 52            self.tree[idx] += delta
 53            idx += self._lowbit(idx)
 54
 55    def prefix_sum(self, idx: int) -> int:
 56        """[0, idx] ๊ตฌ๊ฐ„ ํ•ฉ (0-indexed) - O(log n)"""
 57        idx += 1
 58        result = 0
 59
 60        while idx > 0:
 61            result += self.tree[idx]
 62            idx -= self._lowbit(idx)
 63
 64        return result
 65
 66    def range_sum(self, left: int, right: int) -> int:
 67        """[left, right] ๊ตฌ๊ฐ„ ํ•ฉ (0-indexed) - O(log n)"""
 68        if left == 0:
 69            return self.prefix_sum(right)
 70        return self.prefix_sum(right) - self.prefix_sum(left - 1)
 71
 72    def get(self, idx: int) -> int:
 73        """idx ์œ„์น˜์˜ ๊ฐ’ (0-indexed)"""
 74        return self.range_sum(idx, idx)
 75
 76    def set(self, idx: int, val: int):
 77        """idx ์œ„์น˜์˜ ๊ฐ’์„ val๋กœ ์„ค์ •"""
 78        current = self.get(idx)
 79        self.update(idx, val - current)
 80
 81
 82# =============================================================================
 83# 2. ๊ตฌ๊ฐ„ ์—…๋ฐ์ดํŠธ + ์  ์ฟผ๋ฆฌ ํŽœ์œ… ํŠธ๋ฆฌ
 84# =============================================================================
 85
 86class FenwickTreeRangeUpdate:
 87    """
 88    ๊ตฌ๊ฐ„ ์—…๋ฐ์ดํŠธ + ์  ์ฟผ๋ฆฌ
 89    ์ฐจ๋ถ„ ๋ฐฐ์—ด ๊ธฐ๋ฒ• ํ™œ์šฉ
 90    """
 91
 92    def __init__(self, n: int):
 93        self.n = n
 94        self.tree = [0] * (n + 1)
 95
 96    def _lowbit(self, x: int) -> int:
 97        return x & (-x)
 98
 99    def _update(self, idx: int, delta: int):
100        idx += 1
101        while idx <= self.n:
102            self.tree[idx] += delta
103            idx += self._lowbit(idx)
104
105    def update_range(self, left: int, right: int, delta: int):
106        """[left, right] ๊ตฌ๊ฐ„์— delta ๋”ํ•˜๊ธฐ"""
107        self._update(left, delta)
108        if right + 1 < self.n:
109            self._update(right + 1, -delta)
110
111    def query(self, idx: int) -> int:
112        """idx ์œ„์น˜์˜ ๊ฐ’"""
113        idx += 1
114        result = 0
115        while idx > 0:
116            result += self.tree[idx]
117            idx -= self._lowbit(idx)
118        return result
119
120
121# =============================================================================
122# 3. ๊ตฌ๊ฐ„ ์—…๋ฐ์ดํŠธ + ๊ตฌ๊ฐ„ ์ฟผ๋ฆฌ ํŽœ์œ… ํŠธ๋ฆฌ
123# =============================================================================
124
125class FenwickTreeRangeUpdateRangeQuery:
126    """
127    ๊ตฌ๊ฐ„ ์—…๋ฐ์ดํŠธ + ๊ตฌ๊ฐ„ ์ฟผ๋ฆฌ
128    ๋‘ ๊ฐœ์˜ BIT ์‚ฌ์šฉ
129    """
130
131    def __init__(self, n: int):
132        self.n = n
133        self.tree1 = [0] * (n + 1)  # B[i]
134        self.tree2 = [0] * (n + 1)  # B[i] * i
135
136    def _lowbit(self, x: int) -> int:
137        return x & (-x)
138
139    def _update(self, tree: List[int], idx: int, delta: int):
140        while idx <= self.n:
141            tree[idx] += delta
142            idx += self._lowbit(idx)
143
144    def _prefix_sum(self, tree: List[int], idx: int) -> int:
145        result = 0
146        while idx > 0:
147            result += tree[idx]
148            idx -= self._lowbit(idx)
149        return result
150
151    def update_range(self, left: int, right: int, delta: int):
152        """[left, right] ๊ตฌ๊ฐ„์— delta ๋”ํ•˜๊ธฐ (1-indexed)"""
153        self._update(self.tree1, left, delta)
154        self._update(self.tree1, right + 1, -delta)
155        self._update(self.tree2, left, delta * (left - 1))
156        self._update(self.tree2, right + 1, -delta * right)
157
158    def prefix_sum(self, idx: int) -> int:
159        """[1, idx] ๊ตฌ๊ฐ„ ํ•ฉ (1-indexed)"""
160        return self._prefix_sum(self.tree1, idx) * idx - self._prefix_sum(self.tree2, idx)
161
162    def range_sum(self, left: int, right: int) -> int:
163        """[left, right] ๊ตฌ๊ฐ„ ํ•ฉ (1-indexed)"""
164        return self.prefix_sum(right) - self.prefix_sum(left - 1)
165
166
167# =============================================================================
168# 4. 2D ํŽœ์œ… ํŠธ๋ฆฌ
169# =============================================================================
170
171class FenwickTree2D:
172    """
173    2D ํŽœ์œ… ํŠธ๋ฆฌ
174    - ์  ์—…๋ฐ์ดํŠธ: O(log n * log m)
175    - ์ง์‚ฌ๊ฐํ˜• ํ•ฉ: O(log n * log m)
176    """
177
178    def __init__(self, n: int, m: int):
179        self.n = n
180        self.m = m
181        self.tree = [[0] * (m + 1) for _ in range(n + 1)]
182
183    def _lowbit(self, x: int) -> int:
184        return x & (-x)
185
186    def update(self, x: int, y: int, delta: int):
187        """(x, y)์— delta ๋”ํ•˜๊ธฐ (0-indexed)"""
188        x += 1
189        while x <= self.n:
190            y_idx = y + 1
191            while y_idx <= self.m:
192                self.tree[x][y_idx] += delta
193                y_idx += self._lowbit(y_idx)
194            x += self._lowbit(x)
195
196    def prefix_sum(self, x: int, y: int) -> int:
197        """(0,0) ~ (x,y) ํ•ฉ (0-indexed)"""
198        x += 1
199        result = 0
200        while x > 0:
201            y_idx = y + 1
202            while y_idx > 0:
203                result += self.tree[x][y_idx]
204                y_idx -= self._lowbit(y_idx)
205            x -= self._lowbit(x)
206        return result
207
208    def range_sum(self, x1: int, y1: int, x2: int, y2: int) -> int:
209        """(x1,y1) ~ (x2,y2) ์ง์‚ฌ๊ฐํ˜• ํ•ฉ (0-indexed)"""
210        result = self.prefix_sum(x2, y2)
211        if x1 > 0:
212            result -= self.prefix_sum(x1 - 1, y2)
213        if y1 > 0:
214            result -= self.prefix_sum(x2, y1 - 1)
215        if x1 > 0 and y1 > 0:
216            result += self.prefix_sum(x1 - 1, y1 - 1)
217        return result
218
219
220# =============================================================================
221# 5. ์—ญ์ˆœ ์Œ ๊ฐœ์ˆ˜ (Inversion Count)
222# =============================================================================
223
224def count_inversions(arr: List[int]) -> int:
225    """
226    ์—ญ์ˆœ ์Œ ๊ฐœ์ˆ˜ (ํŽœ์œ… ํŠธ๋ฆฌ ํ™œ์šฉ)
227    ์‹œ๊ฐ„๋ณต์žก๋„: O(n log n)
228    """
229    if not arr:
230        return 0
231
232    # ์ขŒํ‘œ ์••์ถ•
233    sorted_arr = sorted(set(arr))
234    rank = {v: i for i, v in enumerate(sorted_arr)}
235    n = len(sorted_arr)
236
237    ft = FenwickTree(n)
238    inversions = 0
239
240    for val in arr:
241        r = rank[val]
242        # r๋ณด๋‹ค ํฐ ์ธ๋ฑ์Šค์˜ ๊ฐœ์ˆ˜ (์ด๋ฏธ ์‚ฝ์ž…๋œ ๊ฒƒ ์ค‘)
243        inversions += ft.prefix_sum(n - 1) - ft.prefix_sum(r)
244        ft.update(r, 1)
245
246    return inversions
247
248
249# =============================================================================
250# 6. K๋ฒˆ์งธ ์›์†Œ ์ฐพ๊ธฐ
251# =============================================================================
252
253class FenwickTreeKth:
254    """K๋ฒˆ์งธ ์›์†Œ ์ฐพ๊ธฐ๋ฅผ ์ง€์›ํ•˜๋Š” ํŽœ์œ… ํŠธ๋ฆฌ"""
255
256    def __init__(self, n: int):
257        self.n = n
258        self.tree = [0] * (n + 1)
259
260    def _lowbit(self, x: int) -> int:
261        return x & (-x)
262
263    def update(self, idx: int, delta: int):
264        """idx์— delta ๋”ํ•˜๊ธฐ (1-indexed)"""
265        while idx <= self.n:
266            self.tree[idx] += delta
267            idx += self._lowbit(idx)
268
269    def prefix_sum(self, idx: int) -> int:
270        """[1, idx] ํ•ฉ"""
271        result = 0
272        while idx > 0:
273            result += self.tree[idx]
274            idx -= self._lowbit(idx)
275        return result
276
277    def find_kth(self, k: int) -> int:
278        """
279        k๋ฒˆ์งธ ์›์†Œ์˜ ์ธ๋ฑ์Šค ์ฐพ๊ธฐ (1-indexed)
280        prefix_sum(idx) >= k์ธ ์ตœ์†Œ idx
281        ์‹œ๊ฐ„๋ณต์žก๋„: O(log n)
282        """
283        idx = 0
284        bit_mask = 1
285
286        while bit_mask <= self.n:
287            bit_mask <<= 1
288        bit_mask >>= 1
289
290        while bit_mask > 0:
291            next_idx = idx + bit_mask
292            if next_idx <= self.n and self.tree[next_idx] < k:
293                idx = next_idx
294                k -= self.tree[idx]
295            bit_mask >>= 1
296
297        return idx + 1
298
299
300# =============================================================================
301# 7. ์‘์šฉ: ๊ตฌ๊ฐ„์—์„œ K๋ณด๋‹ค ์ž‘์€ ์›์†Œ ๊ฐœ์ˆ˜
302# =============================================================================
303
304def count_smaller_in_range(arr: List[int], queries: List[tuple]) -> List[int]:
305    """
306    ์ฟผ๋ฆฌ: (left, right, k) - arr[left:right+1]์—์„œ k๋ณด๋‹ค ์ž‘์€ ์›์†Œ ๊ฐœ์ˆ˜
307    ์˜คํ”„๋ผ์ธ ์ฟผ๋ฆฌ + ํŽœ์œ… ํŠธ๋ฆฌ
308    ์‹œ๊ฐ„๋ณต์žก๋„: O((n + q) log n)
309    """
310    n = len(arr)
311    q = len(queries)
312
313    # ์ขŒํ‘œ ์••์ถ•
314    all_vals = sorted(set(arr) | set(k for _, _, k in queries))
315    val_to_idx = {v: i + 1 for i, v in enumerate(all_vals)}
316    m = len(all_vals)
317
318    # (๊ฐ’, ์ธ๋ฑ์Šค, ํƒ€์ž…) ์ด๋ฒคํŠธ ์ƒ์„ฑ
319    events = []
320    for i, val in enumerate(arr):
321        events.append((val_to_idx[val], i, 'arr', None))
322
323    for qi, (left, right, k) in enumerate(queries):
324        k_idx = val_to_idx.get(k, m + 1)
325        events.append((k_idx, right, 'query_end', (qi, left, right)))
326
327    events.sort()
328
329    # ๊ฒฐ๊ณผ
330    results = [0] * q
331    ft = FenwickTree(n)
332
333    for val_idx, pos, event_type, data in events:
334        if event_type == 'arr':
335            ft.update(pos, 1)
336        else:
337            qi, left, right = data
338            results[qi] = ft.range_sum(left, right)
339
340    return results
341
342
343# =============================================================================
344# ํ…Œ์ŠคํŠธ
345# =============================================================================
346
347def main():
348    print("=" * 60)
349    print("ํŽœ์œ… ํŠธ๋ฆฌ (Fenwick Tree / BIT) ์˜ˆ์ œ")
350    print("=" * 60)
351
352    # 1. ๊ธฐ๋ณธ ํŽœ์œ… ํŠธ๋ฆฌ
353    print("\n[1] ๊ธฐ๋ณธ ํŽœ์œ… ํŠธ๋ฆฌ")
354    arr = [1, 3, 5, 7, 9, 11]
355    ft = FenwickTree.from_array(arr)
356    print(f"    ๋ฐฐ์—ด: {arr}")
357    print(f"    prefix_sum(3): {ft.prefix_sum(3)}")  # 1+3+5+7=16
358    print(f"    range_sum(1, 4): {ft.range_sum(1, 4)}")  # 3+5+7+9=24
359    ft.update(2, 5)  # 5 โ†’ 10
360    print(f"    update(2, +5) ํ›„ range_sum(1, 4): {ft.range_sum(1, 4)}")  # 3+10+7+9=29
361
362    # 2. ๊ตฌ๊ฐ„ ์—…๋ฐ์ดํŠธ + ์  ์ฟผ๋ฆฌ
363    print("\n[2] ๊ตฌ๊ฐ„ ์—…๋ฐ์ดํŠธ + ์  ์ฟผ๋ฆฌ")
364    ft_ru = FenwickTreeRangeUpdate(6)
365    ft_ru.update_range(1, 3, 5)  # [1,3]์— 5 ๋”ํ•˜๊ธฐ
366    ft_ru.update_range(2, 4, 3)  # [2,4]์— 3 ๋”ํ•˜๊ธฐ
367    print(f"    update_range(1, 3, +5), update_range(2, 4, +3)")
368    for i in range(6):
369        print(f"    query({i}): {ft_ru.query(i)}")
370
371    # 3. 2D ํŽœ์œ… ํŠธ๋ฆฌ
372    print("\n[3] 2D ํŽœ์œ… ํŠธ๋ฆฌ")
373    ft2d = FenwickTree2D(3, 3)
374    # ํ–‰๋ ฌ ์ฑ„์šฐ๊ธฐ
375    matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
376    for i in range(3):
377        for j in range(3):
378            ft2d.update(i, j, matrix[i][j])
379    print(f"    ํ–‰๋ ฌ: {matrix}")
380    print(f"    range_sum(0,0,1,1): {ft2d.range_sum(0, 0, 1, 1)}")  # 1+2+4+5=12
381    print(f"    range_sum(1,1,2,2): {ft2d.range_sum(1, 1, 2, 2)}")  # 5+6+8+9=28
382
383    # 4. ์—ญ์ˆœ ์Œ ๊ฐœ์ˆ˜
384    print("\n[4] ์—ญ์ˆœ ์Œ ๊ฐœ์ˆ˜")
385    arr = [2, 4, 1, 3, 5]
386    inv = count_inversions(arr)
387    print(f"    ๋ฐฐ์—ด: {arr}")
388    print(f"    ์—ญ์ˆœ ์Œ: {inv}")  # (2,1), (4,1), (4,3) = 3
389
390    # 5. K๋ฒˆ์งธ ์›์†Œ
391    print("\n[5] K๋ฒˆ์งธ ์›์†Œ")
392    ft_kth = FenwickTreeKth(10)
393    for val in [3, 5, 7, 1, 9]:
394        ft_kth.update(val, 1)
395    print(f"    ์‚ฝ์ž…๋œ ์›์†Œ: [3, 5, 7, 1, 9]")
396    for k in [1, 2, 3, 4, 5]:
397        print(f"    {k}๋ฒˆ์งธ ์›์†Œ: {ft_kth.find_kth(k)}")
398
399    # 6. ํŽœ์œ… ํŠธ๋ฆฌ vs ์„ธ๊ทธ๋จผํŠธ ํŠธ๋ฆฌ ๋น„๊ต
400    print("\n[6] ํŽœ์œ… ํŠธ๋ฆฌ vs ์„ธ๊ทธ๋จผํŠธ ํŠธ๋ฆฌ")
401    print("    | ํŠน์„ฑ         | ํŽœ์œ… ํŠธ๋ฆฌ | ์„ธ๊ทธ๋จผํŠธ ํŠธ๋ฆฌ |")
402    print("    |--------------|-----------|---------------|")
403    print("    | ๊ณต๊ฐ„         | O(n)      | O(4n)         |")
404    print("    | ๊ตฌํ˜„ ๋‚œ์ด๋„  | ์‰ฌ์›€      | ๋ณดํ†ต          |")
405    print("    | ์  ์—…๋ฐ์ดํŠธ  | O(log n)  | O(log n)      |")
406    print("    | ๊ตฌ๊ฐ„ ์ฟผ๋ฆฌ    | O(log n)  | O(log n)      |")
407    print("    | ๊ตฌ๊ฐ„ ์—…๋ฐ์ดํŠธ| 2๊ฐœ BIT   | Lazy          |")
408    print("    | ์ง€์› ์—ฐ์‚ฐ    | ๊ฐ€์—ญ๋งŒ    | ์ž„์˜          |")
409
410    print("\n" + "=" * 60)
411
412
413if __name__ == "__main__":
414    main()