01_root_finding.py

Download
python 390 lines 10.9 KB
  1"""
  2๊ทผ ์ฐพ๊ธฐ (Root Finding)
  3Numerical Root Finding Methods
  4
  5f(x) = 0 ์„ ๋งŒ์กฑํ•˜๋Š” x๋ฅผ ์ˆ˜์น˜์ ์œผ๋กœ ์ฐพ๋Š” ๋ฐฉ๋ฒ•๋“ค์ž…๋‹ˆ๋‹ค.
  6"""
  7
  8import numpy as np
  9import matplotlib.pyplot as plt
 10from typing import Callable, Tuple, Optional
 11
 12
 13# =============================================================================
 14# 1. ์ด๋ถ„๋ฒ• (Bisection Method)
 15# =============================================================================
 16def bisection(
 17    f: Callable[[float], float],
 18    a: float,
 19    b: float,
 20    tol: float = 1e-10,
 21    max_iter: int = 100
 22) -> Tuple[float, int, list]:
 23    """
 24    ์ด๋ถ„๋ฒ•์œผ๋กœ f(x) = 0์˜ ๊ทผ ์ฐพ๊ธฐ
 25
 26    ์กฐ๊ฑด: f(a)์™€ f(b)์˜ ๋ถ€ํ˜ธ๊ฐ€ ๋‹ฌ๋ผ์•ผ ํ•จ (์ค‘๊ฐ„๊ฐ’ ์ •๋ฆฌ)
 27    ์ˆ˜๋ ด ์†๋„: ์„ ํ˜• (๋งค ๋ฐ˜๋ณต๋งˆ๋‹ค ๊ตฌ๊ฐ„ ์ ˆ๋ฐ˜)
 28
 29    Args:
 30        f: ๋ชฉํ‘œ ํ•จ์ˆ˜
 31        a, b: ์ดˆ๊ธฐ ๊ตฌ๊ฐ„
 32        tol: ํ—ˆ์šฉ ์˜ค์ฐจ
 33        max_iter: ์ตœ๋Œ€ ๋ฐ˜๋ณต ํšŸ์ˆ˜
 34
 35    Returns:
 36        (๊ทผ, ๋ฐ˜๋ณต ํšŸ์ˆ˜, ์ค‘๊ฐ„๊ฐ’ ํžˆ์Šคํ† ๋ฆฌ)
 37    """
 38    if f(a) * f(b) >= 0:
 39        raise ValueError("f(a)์™€ f(b)์˜ ๋ถ€ํ˜ธ๊ฐ€ ๋‹ฌ๋ผ์•ผ ํ•ฉ๋‹ˆ๋‹ค")
 40
 41    history = []
 42
 43    for i in range(max_iter):
 44        c = (a + b) / 2
 45        history.append(c)
 46
 47        if abs(f(c)) < tol or (b - a) / 2 < tol:
 48            return c, i + 1, history
 49
 50        if f(a) * f(c) < 0:
 51            b = c
 52        else:
 53            a = c
 54
 55    return (a + b) / 2, max_iter, history
 56
 57
 58# =============================================================================
 59# 2. ๋‰ดํ„ด-๋žฉ์Šจ ๋ฐฉ๋ฒ• (Newton-Raphson Method)
 60# =============================================================================
 61def newton_raphson(
 62    f: Callable[[float], float],
 63    df: Callable[[float], float],
 64    x0: float,
 65    tol: float = 1e-10,
 66    max_iter: int = 100
 67) -> Tuple[float, int, list]:
 68    """
 69    ๋‰ดํ„ด-๋žฉ์Šจ ๋ฐฉ๋ฒ•์œผ๋กœ f(x) = 0์˜ ๊ทผ ์ฐพ๊ธฐ
 70
 71    x_{n+1} = x_n - f(x_n) / f'(x_n)
 72
 73    ์ˆ˜๋ ด ์†๋„: 2์ฐจ (์ œ๊ณฑ ์ˆ˜๋ ด)
 74    ๋‹จ์ : ๋„ํ•จ์ˆ˜ ํ•„์š”, ์ดˆ๊ธฐ๊ฐ’์— ๋ฏผ๊ฐ
 75
 76    Args:
 77        f: ๋ชฉํ‘œ ํ•จ์ˆ˜
 78        df: ๋„ํ•จ์ˆ˜
 79        x0: ์ดˆ๊ธฐ๊ฐ’
 80        tol: ํ—ˆ์šฉ ์˜ค์ฐจ
 81        max_iter: ์ตœ๋Œ€ ๋ฐ˜๋ณต ํšŸ์ˆ˜
 82
 83    Returns:
 84        (๊ทผ, ๋ฐ˜๋ณต ํšŸ์ˆ˜, ํžˆ์Šคํ† ๋ฆฌ)
 85    """
 86    x = x0
 87    history = [x]
 88
 89    for i in range(max_iter):
 90        fx = f(x)
 91        dfx = df(x)
 92
 93        if abs(dfx) < 1e-15:
 94            raise ValueError("๋„ํ•จ์ˆ˜๊ฐ€ 0์— ๊ฐ€๊นŒ์›€: ๋ฐœ์‚ฐ ์œ„ํ—˜")
 95
 96        x_new = x - fx / dfx
 97        history.append(x_new)
 98
 99        if abs(x_new - x) < tol:
100            return x_new, i + 1, history
101
102        x = x_new
103
104    return x, max_iter, history
105
106
107# =============================================================================
108# 3. ํ• ์„ ๋ฒ• (Secant Method)
109# =============================================================================
110def secant(
111    f: Callable[[float], float],
112    x0: float,
113    x1: float,
114    tol: float = 1e-10,
115    max_iter: int = 100
116) -> Tuple[float, int, list]:
117    """
118    ํ• ์„ ๋ฒ•์œผ๋กœ f(x) = 0์˜ ๊ทผ ์ฐพ๊ธฐ
119
120    ๋‰ดํ„ด๋ฒ•์˜ ๋„ํ•จ์ˆ˜๋ฅผ ์ฐจ๋ถ„์œผ๋กœ ๊ทผ์‚ฌ:
121    x_{n+1} = x_n - f(x_n) * (x_n - x_{n-1}) / (f(x_n) - f(x_{n-1}))
122
123    ์ˆ˜๋ ด ์†๋„: ์•ฝ 1.618์ฐจ (ํ™ฉ๊ธˆ๋น„)
124    ์žฅ์ : ๋„ํ•จ์ˆ˜ ๋ถˆํ•„์š”
125
126    Args:
127        f: ๋ชฉํ‘œ ํ•จ์ˆ˜
128        x0, x1: ๋‘ ์ดˆ๊ธฐ๊ฐ’
129        tol: ํ—ˆ์šฉ ์˜ค์ฐจ
130        max_iter: ์ตœ๋Œ€ ๋ฐ˜๋ณต ํšŸ์ˆ˜
131
132    Returns:
133        (๊ทผ, ๋ฐ˜๋ณต ํšŸ์ˆ˜, ํžˆ์Šคํ† ๋ฆฌ)
134    """
135    history = [x0, x1]
136
137    for i in range(max_iter):
138        f0, f1 = f(x0), f(x1)
139
140        if abs(f1 - f0) < 1e-15:
141            raise ValueError("๋ถ„๋ชจ๊ฐ€ 0์— ๊ฐ€๊นŒ์›€")
142
143        x2 = x1 - f1 * (x1 - x0) / (f1 - f0)
144        history.append(x2)
145
146        if abs(x2 - x1) < tol:
147            return x2, i + 1, history
148
149        x0, x1 = x1, x2
150
151    return x1, max_iter, history
152
153
154# =============================================================================
155# 4. ๊ณ ์ •์  ๋ฐ˜๋ณต๋ฒ• (Fixed-Point Iteration)
156# =============================================================================
157def fixed_point(
158    g: Callable[[float], float],
159    x0: float,
160    tol: float = 1e-10,
161    max_iter: int = 100
162) -> Tuple[float, int, list]:
163    """
164    ๊ณ ์ •์  ๋ฐ˜๋ณต: x = g(x)๋ฅผ ๋งŒ์กฑํ•˜๋Š” x ์ฐพ๊ธฐ
165
166    f(x) = 0 ์„ x = g(x) ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜
167    ์˜ˆ: xยฒ - 2 = 0  โ†’  x = 2/x ๋˜๋Š” x = (x + 2/x)/2
168
169    ์ˆ˜๋ ด ์กฐ๊ฑด: |g'(x*)| < 1
170
171    Args:
172        g: ๋ฐ˜๋ณต ํ•จ์ˆ˜
173        x0: ์ดˆ๊ธฐ๊ฐ’
174        tol: ํ—ˆ์šฉ ์˜ค์ฐจ
175        max_iter: ์ตœ๋Œ€ ๋ฐ˜๋ณต ํšŸ์ˆ˜
176
177    Returns:
178        (๊ณ ์ •์ , ๋ฐ˜๋ณต ํšŸ์ˆ˜, ํžˆ์Šคํ† ๋ฆฌ)
179    """
180    x = x0
181    history = [x]
182
183    for i in range(max_iter):
184        x_new = g(x)
185        history.append(x_new)
186
187        if abs(x_new - x) < tol:
188            return x_new, i + 1, history
189
190        x = x_new
191
192    return x, max_iter, history
193
194
195# =============================================================================
196# 5. Brent's Method (scipy์™€ ๋น„๊ต)
197# =============================================================================
198def brents_method(
199    f: Callable[[float], float],
200    a: float,
201    b: float,
202    tol: float = 1e-10,
203    max_iter: int = 100
204) -> Tuple[float, int]:
205    """
206    Brent's Method (๊ฐ„์†Œํ™” ๋ฒ„์ „)
207    ์ด๋ถ„๋ฒ•, ํ• ์„ ๋ฒ•, ์—ญ2์ฐจ๋ณด๊ฐ„์„ ์กฐํ•ฉํ•œ ๋ฐฉ๋ฒ•
208
209    ์‹ค๋ฌด์—์„œ๋Š” scipy.optimize.brentq ์‚ฌ์šฉ ๊ถŒ์žฅ
210    """
211    fa, fb = f(a), f(b)
212    if fa * fb >= 0:
213        raise ValueError("f(a)์™€ f(b)์˜ ๋ถ€ํ˜ธ๊ฐ€ ๋‹ฌ๋ผ์•ผ ํ•ฉ๋‹ˆ๋‹ค")
214
215    if abs(fa) < abs(fb):
216        a, b = b, a
217        fa, fb = fb, fa
218
219    c, fc = a, fa
220    d = c  # d ์ดˆ๊ธฐํ™” (์ฒซ iteration์—์„œ ์‚ฌ์šฉ๋จ)
221    mflag = True
222
223    for i in range(max_iter):
224        if fa != fc and fb != fc:
225            # ์—ญ2์ฐจ ๋ณด๊ฐ„
226            s = (a * fb * fc / ((fa - fb) * (fa - fc)) +
227                 b * fa * fc / ((fb - fa) * (fb - fc)) +
228                 c * fa * fb / ((fc - fa) * (fc - fb)))
229        else:
230            # ํ• ์„ ๋ฒ•
231            s = b - fb * (b - a) / (fb - fa)
232
233        # ์กฐ๊ฑด ์ฒดํฌ ํ›„ ์ด๋ถ„๋ฒ•์œผ๋กœ ๋Œ€์ฒด
234        conditions = [
235            not ((3 * a + b) / 4 <= s <= b or b <= s <= (3 * a + b) / 4),
236            mflag and abs(s - b) >= abs(b - c) / 2,
237            not mflag and abs(s - b) >= abs(c - d) / 2,
238        ]
239
240        if any(conditions):
241            s = (a + b) / 2
242            mflag = True
243        else:
244            mflag = False
245
246        fs = f(s)
247        d, c, fc = c, b, fb
248
249        if fa * fs < 0:
250            b, fb = s, fs
251        else:
252            a, fa = s, fs
253
254        if abs(fa) < abs(fb):
255            a, b = b, a
256            fa, fb = fb, fa
257
258        if abs(b - a) < tol or abs(fb) < tol:
259            return b, i + 1
260
261    return b, max_iter
262
263
264# =============================================================================
265# ์‹œ๊ฐํ™”
266# =============================================================================
267def plot_convergence(f, methods_data, x_range, title="๊ทผ ์ฐพ๊ธฐ ์ˆ˜๋ ด ๋น„๊ต"):
268    """์ˆ˜๋ ด ๊ณผ์ • ์‹œ๊ฐํ™”"""
269    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
270
271    # ์™ผ์ชฝ: ํ•จ์ˆ˜์™€ ๊ทผ
272    x = np.linspace(x_range[0], x_range[1], 500)
273    y = [f(xi) for xi in x]
274    axes[0].plot(x, y, 'b-', label='f(x)')
275    axes[0].axhline(y=0, color='k', linestyle='-', linewidth=0.5)
276
277    colors = plt.cm.tab10.colors
278    for i, (name, root, _, history) in enumerate(methods_data):
279        axes[0].scatter([root], [0], s=100, color=colors[i], zorder=5, label=f'{name}: x={root:.6f}')
280
281    axes[0].set_xlabel('x')
282    axes[0].set_ylabel('f(x)')
283    axes[0].set_title('ํ•จ์ˆ˜์™€ ๊ทผ')
284    axes[0].legend()
285    axes[0].grid(True, alpha=0.3)
286
287    # ์˜ค๋ฅธ์ชฝ: ์ˆ˜๋ ด ์†๋„
288    for i, (name, root, iters, history) in enumerate(methods_data):
289        if history:
290            errors = [abs(h - root) for h in history]
291            errors = [e if e > 1e-16 else 1e-16 for e in errors]
292            axes[1].semilogy(errors, 'o-', color=colors[i], label=f'{name} ({iters}ํšŒ)')
293
294    axes[1].set_xlabel('๋ฐ˜๋ณต ํšŸ์ˆ˜')
295    axes[1].set_ylabel('์˜ค์ฐจ (log scale)')
296    axes[1].set_title('์ˆ˜๋ ด ์†๋„ ๋น„๊ต')
297    axes[1].legend()
298    axes[1].grid(True, alpha=0.3)
299
300    plt.suptitle(title)
301    plt.tight_layout()
302    plt.savefig('/opt/projects/01_Personal/03_Study/Numerical_Simulation/examples/root_finding.png', dpi=150)
303    plt.close()
304    print("    ๊ทธ๋ž˜ํ”„ ์ €์žฅ: root_finding.png")
305
306
307# =============================================================================
308# ํ…Œ์ŠคํŠธ
309# =============================================================================
310def main():
311    print("=" * 60)
312    print("๊ทผ ์ฐพ๊ธฐ (Root Finding) ์˜ˆ์ œ")
313    print("=" * 60)
314
315    # ์˜ˆ์ œ 1: f(x) = xยณ - x - 2 = 0 (๊ทผ โ‰ˆ 1.5214)
316    print("\n[์˜ˆ์ œ 1] f(x) = xยณ - x - 2 = 0")
317    print("-" * 40)
318
319    f = lambda x: x**3 - x - 2
320    df = lambda x: 3*x**2 - 1
321
322    methods_data = []
323
324    # ์ด๋ถ„๋ฒ•
325    root, iters, hist = bisection(f, 1, 2)
326    methods_data.append(("Bisection", root, iters, hist))
327    print(f"์ด๋ถ„๋ฒ•:     ๊ทผ = {root:.10f}, ๋ฐ˜๋ณต = {iters}")
328
329    # ๋‰ดํ„ด-๋žฉ์Šจ
330    root, iters, hist = newton_raphson(f, df, 1.5)
331    methods_data.append(("Newton", root, iters, hist))
332    print(f"๋‰ดํ„ด-๋žฉ์Šจ: ๊ทผ = {root:.10f}, ๋ฐ˜๋ณต = {iters}")
333
334    # ํ• ์„ ๋ฒ•
335    root, iters, hist = secant(f, 1, 2)
336    methods_data.append(("Secant", root, iters, hist))
337    print(f"ํ• ์„ ๋ฒ•:     ๊ทผ = {root:.10f}, ๋ฐ˜๋ณต = {iters}")
338
339    # ์‹œ๊ฐํ™”
340    try:
341        plot_convergence(f, methods_data, (0, 3), "f(x) = xยณ - x - 2")
342    except Exception as e:
343        print(f"    ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ ์‹คํŒจ: {e}")
344
345    # ์˜ˆ์ œ 2: โˆš2 ๊ตฌํ•˜๊ธฐ (xยฒ - 2 = 0)
346    print("\n[์˜ˆ์ œ 2] โˆš2 ๊ตฌํ•˜๊ธฐ (xยฒ - 2 = 0)")
347    print("-" * 40)
348
349    f2 = lambda x: x**2 - 2
350    df2 = lambda x: 2*x
351    g = lambda x: (x + 2/x) / 2  # ๋ฐ”๋นŒ๋กœ๋‹ˆ์•„ ๋ฐฉ๋ฒ•
352
353    root, iters, _ = newton_raphson(f2, df2, 1.0)
354    print(f"๋‰ดํ„ด-๋žฉ์Šจ:   โˆš2 = {root:.15f}, ๋ฐ˜๋ณต = {iters}")
355
356    root, iters, _ = fixed_point(g, 1.0)
357    print(f"๊ณ ์ •์  ๋ฐ˜๋ณต: โˆš2 = {root:.15f}, ๋ฐ˜๋ณต = {iters}")
358
359    print(f"์‹ค์ œ โˆš2:        {np.sqrt(2):.15f}")
360
361    # ์˜ˆ์ œ 3: cos(x) = x ๊ณ ์ •์ 
362    print("\n[์˜ˆ์ œ 3] cos(x) = x (Dottie Number)")
363    print("-" * 40)
364
365    g_cos = lambda x: np.cos(x)
366    root, iters, _ = fixed_point(g_cos, 0.5)
367    print(f"๊ณ ์ •์  x = cos(x): {root:.10f}, ๋ฐ˜๋ณต = {iters}")
368
369    print("\n" + "=" * 60)
370    print("๊ทผ ์ฐพ๊ธฐ ๋ฐฉ๋ฒ• ๋น„๊ต")
371    print("=" * 60)
372    print("""
373    | ๋ฐฉ๋ฒ•        | ์ˆ˜๋ ด ์†๋„ | ์žฅ์                 | ๋‹จ์                 |
374    |------------|----------|---------------------|---------------------|
375    | ์ด๋ถ„๋ฒ•      | ์„ ํ˜•     | ํ•ญ์ƒ ์ˆ˜๋ ด, ์•ˆ์ •์     | ๋А๋ฆผ, ๊ตฌ๊ฐ„ ํ•„์š”      |
376    | ๋‰ดํ„ด-๋žฉ์Šจ   | 2์ฐจ      | ๋งค์šฐ ๋น ๋ฆ„           | ๋„ํ•จ์ˆ˜ ํ•„์š”, ๋ฐœ์‚ฐ ๊ฐ€๋Šฅ|
377    | ํ• ์„ ๋ฒ•      | ~1.618์ฐจ | ๋„ํ•จ์ˆ˜ ๋ถˆํ•„์š”        | ๋‰ดํ„ด๋ณด๋‹ค ๋А๋ฆผ        |
378    | ๊ณ ์ •์ ๋ฐ˜๋ณต  | ์„ ํ˜•~2์ฐจ | ๊ฐ„๋‹จ                | ์ˆ˜๋ ด ์กฐ๊ฑด ํ™•์ธ ํ•„์š”   |
379    | Brent      | ์กฐํ•ฉ     | ์•ˆ์ • + ๋น ๋ฆ„         | ๊ตฌํ˜„ ๋ณต์žก           |
380
381    ์‹ค๋ฌด ๊ถŒ์žฅ:
382    - scipy.optimize.brentq: ์•ˆ์ •์ ์ธ ๊ทผ ์ฐพ๊ธฐ
383    - scipy.optimize.newton: ๋‰ดํ„ด-๋žฉ์Šจ/ํ• ์„ ๋ฒ•
384    - scipy.optimize.fsolve: ๋‹ค๋ณ€์ˆ˜ ๋ฐฉ์ •์‹
385    """)
386
387
388if __name__ == "__main__":
389    main()