03_dynamic_programming.py

Download
python 481 lines 13.4 KB
  1"""
  2๋™์  ํ”„๋กœ๊ทธ๋ž˜๋ฐ (Dynamic Programming) ๊ตฌํ˜„
  3- Policy Evaluation (์ •์ฑ… ํ‰๊ฐ€)
  4- Policy Improvement (์ •์ฑ… ๊ฐœ์„ )
  5- Policy Iteration (์ •์ฑ… ๋ฐ˜๋ณต)
  6- Value Iteration (๊ฐ€์น˜ ๋ฐ˜๋ณต)
  7"""
  8import numpy as np
  9import matplotlib.pyplot as plt
 10from typing import Dict, List, Tuple
 11from collections import defaultdict
 12
 13
 14class GridWorld:
 15    """๊ฐ„๋‹จํ•œ ๊ทธ๋ฆฌ๋“œ ์›”๋“œ ํ™˜๊ฒฝ"""
 16
 17    def __init__(self, size=4):
 18        self.size = size
 19        self.actions = ['up', 'down', 'left', 'right']
 20        self.n_actions = len(self.actions)
 21
 22    def get_states(self):
 23        """๋ชจ๋“  ์ƒํƒœ ๋ฐ˜ํ™˜"""
 24        return [(i, j) for i in range(self.size) for j in range(self.size)]
 25
 26    def is_terminal(self, state):
 27        """์ข…๋ฃŒ ์ƒํƒœ ํ™•์ธ"""
 28        return state == (0, 0) or state == (self.size-1, self.size-1)
 29
 30    def get_transitions(self, state, action):
 31        """์ „์ด ํ™•๋ฅ  ๋ฐ˜ํ™˜: [(prob, next_state, reward, done)]"""
 32        if self.is_terminal(state):
 33            return [(1.0, state, 0, True)]
 34
 35        deltas = {
 36            'up': (-1, 0),
 37            'down': (1, 0),
 38            'left': (0, -1),
 39            'right': (0, 1)
 40        }
 41        delta = deltas[action]
 42
 43        # ๊ทธ๋ฆฌ๋“œ ๊ฒฝ๊ณ„ ์ฒ˜๋ฆฌ
 44        new_row = max(0, min(self.size-1, state[0] + delta[0]))
 45        new_col = max(0, min(self.size-1, state[1] + delta[1]))
 46        next_state = (new_row, new_col)
 47
 48        # ๋ณด์ƒ: ๊ฐ ์ด๋™๋งˆ๋‹ค -1
 49        reward = -1
 50        done = self.is_terminal(next_state)
 51
 52        return [(1.0, next_state, reward, done)]
 53
 54
 55def create_uniform_policy(grid):
 56    """๊ท ๋“ฑ ๋žœ๋ค ์ •์ฑ… ์ƒ์„ฑ"""
 57    policy = {}
 58    for s in grid.get_states():
 59        policy[s] = {a: 1.0/len(grid.actions) for a in grid.actions}
 60    return policy
 61
 62
 63def policy_evaluation(grid, policy: Dict, gamma: float = 0.9, theta: float = 1e-6):
 64    """
 65    ์ •์ฑ… ํ‰๊ฐ€: ์ฃผ์–ด์ง„ ์ •์ฑ…์˜ ๊ฐ€์น˜ ํ•จ์ˆ˜ ๊ณ„์‚ฐ
 66
 67    Args:
 68        grid: GridWorld ํ™˜๊ฒฝ
 69        policy: ์ •์ฑ… {state: {action: probability}}
 70        gamma: ํ• ์ธ์œจ
 71        theta: ์ˆ˜๋ ด ์ž„๊ณ„๊ฐ’
 72
 73    Returns:
 74        V: ์ƒํƒœ ๊ฐ€์น˜ ํ•จ์ˆ˜ {state: value}
 75    """
 76    # ๊ฐ€์น˜ ํ•จ์ˆ˜ ์ดˆ๊ธฐํ™”
 77    V = {s: 0.0 for s in grid.get_states()}
 78
 79    iteration = 0
 80    while True:
 81        delta = 0  # ์ตœ๋Œ€ ๋ณ€ํ™”๋Ÿ‰ ์ถ”์ 
 82        iteration += 1
 83
 84        # ๋ชจ๋“  ์ƒํƒœ์— ๋Œ€ํ•ด ์—…๋ฐ์ดํŠธ
 85        for s in grid.get_states():
 86            if grid.is_terminal(s):
 87                continue
 88
 89            v = V[s]  # ์ด์ „ ๊ฐ’ ์ €์žฅ
 90            new_v = 0
 91
 92            # ๋ฒจ๋งŒ ๊ธฐ๋Œ€ ๋ฐฉ์ •์‹ ์ ์šฉ
 93            for a in grid.actions:
 94                action_prob = policy[s].get(a, 0)
 95
 96                for prob, next_s, reward, done in grid.get_transitions(s, a):
 97                    if done:
 98                        new_v += action_prob * prob * reward
 99                    else:
100                        new_v += action_prob * prob * (reward + gamma * V[next_s])
101
102            V[s] = new_v
103            delta = max(delta, abs(v - new_v))
104
105        # ์ˆ˜๋ ด ์ฒดํฌ
106        if delta < theta:
107            print(f"์ •์ฑ… ํ‰๊ฐ€ ์ˆ˜๋ ด: {iteration} iterations, delta={delta:.8f}")
108            break
109
110    return V
111
112
113def policy_improvement(grid, V: Dict, gamma: float = 0.9):
114    """
115    ์ •์ฑ… ๊ฐœ์„ : V๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ํƒ์š•์  ์ •์ฑ… ์ƒ์„ฑ
116
117    Args:
118        grid: GridWorld ํ™˜๊ฒฝ
119        V: ํ˜„์žฌ ๊ฐ€์น˜ ํ•จ์ˆ˜
120        gamma: ํ• ์ธ์œจ
121
122    Returns:
123        new_policy: ๊ฐœ์„ ๋œ ์ •์ฑ…
124        policy_stable: ์ •์ฑ…์ด ๋ณ€ํ•˜์ง€ ์•Š์•˜์œผ๋ฉด True
125    """
126    new_policy = {}
127    policy_stable = True
128
129    for s in grid.get_states():
130        if grid.is_terminal(s):
131            new_policy[s] = {a: 1.0/len(grid.actions) for a in grid.actions}
132            continue
133
134        # ๊ฐ ํ–‰๋™์˜ Q ๊ฐ’ ๊ณ„์‚ฐ
135        q_values = {}
136        for a in grid.actions:
137            q = 0
138            for prob, next_s, reward, done in grid.get_transitions(s, a):
139                if done:
140                    q += prob * reward
141                else:
142                    q += prob * (reward + gamma * V[next_s])
143            q_values[a] = q
144
145        # ์ตœ์  ํ–‰๋™ ์ฐพ๊ธฐ
146        best_action = max(q_values, key=q_values.get)
147        best_q = q_values[best_action]
148
149        # ๋™๋ฅ ์ธ ํ–‰๋™๋“ค ์ฐพ๊ธฐ (์ˆ˜์น˜ ์˜ค์ฐจ ๊ณ ๋ ค)
150        best_actions = [a for a, q in q_values.items()
151                        if abs(q - best_q) < 1e-8]
152
153        # ๊ฒฐ์ •์  ์ •์ฑ… ์ƒ์„ฑ
154        new_policy[s] = {a: 0.0 for a in grid.actions}
155        for a in best_actions:
156            new_policy[s][a] = 1.0 / len(best_actions)
157
158    return new_policy, policy_stable
159
160
161def policy_iteration(grid, gamma: float = 0.9, theta: float = 1e-6):
162    """
163    ์ •์ฑ… ๋ฐ˜๋ณต ์•Œ๊ณ ๋ฆฌ์ฆ˜
164
165    Returns:
166        V: ์ตœ์  ๊ฐ€์น˜ ํ•จ์ˆ˜
167        policy: ์ตœ์  ์ •์ฑ…
168    """
169    # ๊ท ๋“ฑ ๋žœ๋ค ์ •์ฑ…์œผ๋กœ ์ดˆ๊ธฐํ™”
170    policy = create_uniform_policy(grid)
171
172    iteration = 0
173    while True:
174        iteration += 1
175        print(f"\n=== ์ •์ฑ… ๋ฐ˜๋ณต {iteration} ===")
176
177        # 1. ์ •์ฑ… ํ‰๊ฐ€
178        V = policy_evaluation(grid, policy, gamma, theta)
179
180        # 2. ์ •์ฑ… ๊ฐœ์„ 
181        old_policy = policy.copy()
182        policy, _ = policy_improvement(grid, V, gamma)
183
184        # 3. ์ •์ฑ… ์•ˆ์ •์„ฑ ์ฒดํฌ
185        policy_stable = True
186        for s in grid.get_states():
187            if grid.is_terminal(s):
188                continue
189
190            old_best = max(old_policy[s], key=old_policy[s].get)
191            new_best = max(policy[s], key=policy[s].get)
192
193            if old_best != new_best:
194                policy_stable = False
195                break
196
197        if policy_stable:
198            print(f"\n์ •์ฑ… ๋ฐ˜๋ณต ์ˆ˜๋ ด! (์ด {iteration} iterations)")
199            break
200
201    return V, policy
202
203
204def value_iteration(grid, gamma: float = 0.9, theta: float = 1e-6):
205    """
206    ๊ฐ€์น˜ ๋ฐ˜๋ณต ์•Œ๊ณ ๋ฆฌ์ฆ˜
207
208    Returns:
209        V: ์ตœ์  ๊ฐ€์น˜ ํ•จ์ˆ˜
210        policy: ์ตœ์  ์ •์ฑ…
211    """
212    # ๊ฐ€์น˜ ํ•จ์ˆ˜ ์ดˆ๊ธฐํ™”
213    V = {s: 0.0 for s in grid.get_states()}
214
215    iteration = 0
216    while True:
217        delta = 0
218        iteration += 1
219
220        for s in grid.get_states():
221            if grid.is_terminal(s):
222                continue
223
224            v = V[s]
225
226            # ๋ฒจ๋งŒ ์ตœ์ ์„ฑ ๋ฐฉ์ •์‹: max over actions
227            q_values = []
228            for a in grid.actions:
229                q = 0
230                for prob, next_s, reward, done in grid.get_transitions(s, a):
231                    if done:
232                        q += prob * reward
233                    else:
234                        q += prob * (reward + gamma * V[next_s])
235                q_values.append(q)
236
237            V[s] = max(q_values)
238            delta = max(delta, abs(v - V[s]))
239
240        if iteration % 10 == 0:
241            print(f"๋ฐ˜๋ณต {iteration}: delta = {delta:.8f}")
242
243        if delta < theta:
244            print(f"\n๊ฐ€์น˜ ๋ฐ˜๋ณต ์ˆ˜๋ ด: {iteration} iterations")
245            break
246
247    # ์ตœ์  ์ •์ฑ… ์ถ”์ถœ
248    policy = {}
249    for s in grid.get_states():
250        if grid.is_terminal(s):
251            policy[s] = {a: 1.0/len(grid.actions) for a in grid.actions}
252            continue
253
254        q_values = {}
255        for a in grid.actions:
256            q = 0
257            for prob, next_s, reward, done in grid.get_transitions(s, a):
258                if done:
259                    q += prob * reward
260                else:
261                    q += prob * (reward + gamma * V[next_s])
262            q_values[a] = q
263
264        best_action = max(q_values, key=q_values.get)
265        policy[s] = {a: 0.0 for a in grid.actions}
266        policy[s][best_action] = 1.0
267
268    return V, policy
269
270
271def print_value_function(grid, V):
272    """๊ฐ€์น˜ ํ•จ์ˆ˜ ์ถœ๋ ฅ"""
273    print("\n๊ฐ€์น˜ ํ•จ์ˆ˜:")
274    for i in range(grid.size):
275        row = [f"{V[(i,j)]:7.2f}" for j in range(grid.size)]
276        print(" ".join(row))
277
278
279def print_policy(grid, policy):
280    """์ •์ฑ… ์ถœ๋ ฅ (ํ™”์‚ดํ‘œ๋กœ)"""
281    print("\n์ตœ์  ์ •์ฑ…:")
282    arrows = {'up': 'โ†‘', 'down': 'โ†“', 'left': 'โ†', 'right': 'โ†’'}
283
284    for i in range(grid.size):
285        row = []
286        for j in range(grid.size):
287            s = (i, j)
288            if grid.is_terminal(s):
289                row.append('  *  ')
290            else:
291                best_a = max(policy[s], key=policy[s].get)
292                row.append(f'  {arrows[best_a]}  ')
293        print(" ".join(row))
294
295
296def visualize_value_function(grid, V, title="Value Function"):
297    """๊ฐ€์น˜ ํ•จ์ˆ˜ ์‹œ๊ฐํ™”"""
298    value_grid = np.zeros((grid.size, grid.size))
299    for i in range(grid.size):
300        for j in range(grid.size):
301            value_grid[i, j] = V[(i, j)]
302
303    plt.figure(figsize=(8, 6))
304    plt.imshow(value_grid, cmap='coolwarm', interpolation='nearest')
305    plt.colorbar(label='Value')
306    plt.title(title)
307
308    # ์ˆซ์ž ํ‘œ์‹œ
309    for i in range(grid.size):
310        for j in range(grid.size):
311            plt.text(j, i, f'{value_grid[i, j]:.1f}',
312                    ha='center', va='center', color='black', fontsize=12)
313
314    plt.xticks(range(grid.size))
315    plt.yticks(range(grid.size))
316    plt.grid(False)
317    plt.tight_layout()
318    plt.savefig('value_function.png', dpi=150)
319    print(f"๊ฐ€์น˜ ํ•จ์ˆ˜ ์‹œ๊ฐํ™” ์ €์žฅ: value_function.png")
320
321
322def compare_algorithms():
323    """DP ์•Œ๊ณ ๋ฆฌ์ฆ˜ ๋น„๊ต"""
324    print("=" * 60)
325    print("๋™์  ํ”„๋กœ๊ทธ๋ž˜๋ฐ ์•Œ๊ณ ๋ฆฌ์ฆ˜ ๋น„๊ต")
326    print("=" * 60)
327
328    grid = GridWorld(size=4)
329    gamma = 0.9
330
331    # 1. ์ •์ฑ… ํ‰๊ฐ€ (๊ท ๋“ฑ ๋žœ๋ค ์ •์ฑ…)
332    print("\n[1] ์ •์ฑ… ํ‰๊ฐ€ - ๊ท ๋“ฑ ๋žœ๋ค ์ •์ฑ…")
333    print("-" * 60)
334    uniform_policy = create_uniform_policy(grid)
335    V_uniform = policy_evaluation(grid, uniform_policy, gamma)
336    print_value_function(grid, V_uniform)
337
338    # 2. ์ •์ฑ… ๋ฐ˜๋ณต
339    print("\n[2] ์ •์ฑ… ๋ฐ˜๋ณต (Policy Iteration)")
340    print("-" * 60)
341    V_pi, policy_pi = policy_iteration(grid, gamma)
342    print_value_function(grid, V_pi)
343    print_policy(grid, policy_pi)
344
345    # 3. ๊ฐ€์น˜ ๋ฐ˜๋ณต
346    print("\n[3] ๊ฐ€์น˜ ๋ฐ˜๋ณต (Value Iteration)")
347    print("-" * 60)
348    V_vi, policy_vi = value_iteration(grid, gamma)
349    print_value_function(grid, V_vi)
350    print_policy(grid, policy_vi)
351
352    # 4. ๊ฒฐ๊ณผ ๋น„๊ต
353    print("\n[4] ๊ฒฐ๊ณผ ๋น„๊ต")
354    print("-" * 60)
355    print("์ •์ฑ… ๋ฐ˜๋ณต๊ณผ ๊ฐ€์น˜ ๋ฐ˜๋ณต์˜ ๊ฐ€์น˜ ํ•จ์ˆ˜ ์ฐจ์ด:")
356    max_diff = 0
357    for s in grid.get_states():
358        diff = abs(V_pi[s] - V_vi[s])
359        max_diff = max(max_diff, diff)
360    print(f"์ตœ๋Œ€ ์ฐจ์ด: {max_diff:.10f}")
361
362    # ์‹œ๊ฐํ™”
363    visualize_value_function(grid, V_pi, "Policy Iteration - Value Function")
364
365    return V_pi, policy_pi, V_vi, policy_vi
366
367
368def frozen_lake_example():
369    """Frozen Lake ํ™˜๊ฒฝ์—์„œ DP ์ ์šฉ"""
370    import gymnasium as gym
371
372    print("\n" + "=" * 60)
373    print("Frozen Lake ์˜ˆ์ œ")
374    print("=" * 60)
375
376    # ํ™˜๊ฒฝ ์ƒ์„ฑ (๋ฏธ๋„๋Ÿฌ์ง€์ง€ ์•Š๋Š” ๋ฒ„์ „)
377    env = gym.make('FrozenLake-v1', is_slippery=False)
378
379    n_states = env.observation_space.n
380    n_actions = env.action_space.n
381    gamma = 0.99
382    theta = 1e-8
383
384    # P[s][a] = [(prob, next_state, reward, done), ...]
385    P = env.unwrapped.P
386
387    # ๊ฐ€์น˜ ๋ฐ˜๋ณต
388    V = np.zeros(n_states)
389    iteration = 0
390
391    print("\n๊ฐ€์น˜ ๋ฐ˜๋ณต ์‹œ์ž‘...")
392    while True:
393        delta = 0
394        iteration += 1
395
396        for s in range(n_states):
397            v = V[s]
398
399            # ๊ฐ ํ–‰๋™์˜ ๊ฐ€์น˜ ๊ณ„์‚ฐ
400            q_values = []
401            for a in range(n_actions):
402                q = sum(prob * (reward + gamma * V[next_s] * (not done))
403                       for prob, next_s, reward, done in P[s][a])
404                q_values.append(q)
405
406            V[s] = max(q_values)
407            delta = max(delta, abs(v - V[s]))
408
409        if delta < theta:
410            print(f"์ˆ˜๋ ด: {iteration} iterations")
411            break
412
413    # ์ตœ์  ์ •์ฑ… ์ถ”์ถœ
414    policy = np.zeros(n_states, dtype=int)
415    for s in range(n_states):
416        q_values = []
417        for a in range(n_actions):
418            q = sum(prob * (reward + gamma * V[next_s] * (not done))
419                   for prob, next_s, reward, done in P[s][a])
420            q_values.append(q)
421        policy[s] = np.argmax(q_values)
422
423    # ๊ฒฐ๊ณผ ์‹œ๊ฐํ™”
424    action_names = ['โ†', 'โ†“', 'โ†’', 'โ†‘']
425    print("\n์ตœ์  ์ •์ฑ… (4x4 ๊ทธ๋ฆฌ๋“œ):")
426    print("S: ์‹œ์ž‘, H: ๊ตฌ๋ฉ, F: ์–ผ์Œ, G: ๋ชฉํ‘œ")
427    for i in range(4):
428        row = ""
429        for j in range(4):
430            s = i * 4 + j
431            if s == 0:
432                row += "  S  "
433            elif s in [5, 7, 11, 12]:  # ๊ตฌ๋ฉ
434                row += "  H  "
435            elif s == 15:  # ๋ชฉํ‘œ
436                row += "  G  "
437            else:
438                row += f"  {action_names[policy[s]]}  "
439        print(row)
440
441    print("\n๊ฐ€์น˜ ํ•จ์ˆ˜:")
442    print(V.reshape(4, 4).round(3))
443
444    # ์ •์ฑ… ํ…Œ์ŠคํŠธ
445    print("\n์ •์ฑ… ํ…Œ์ŠคํŠธ ์ค‘...")
446    success = 0
447    n_tests = 100
448
449    for _ in range(n_tests):
450        state, _ = env.reset()
451        done = False
452
453        while not done:
454            action = policy[state]
455            state, reward, terminated, truncated, _ = env.step(action)
456            done = terminated or truncated
457
458        if reward > 0:
459            success += 1
460
461    print(f"์„ฑ๊ณต๋ฅ : {success}/{n_tests} = {success/n_tests*100:.1f}%")
462
463    env.close()
464    return V, policy
465
466
467if __name__ == "__main__":
468    # ๊ทธ๋ฆฌ๋“œ ์›”๋“œ ์•Œ๊ณ ๋ฆฌ์ฆ˜ ๋น„๊ต
469    V_pi, policy_pi, V_vi, policy_vi = compare_algorithms()
470
471    # Frozen Lake ์˜ˆ์ œ
472    try:
473        V_fl, policy_fl = frozen_lake_example()
474    except Exception as e:
475        print(f"\nFrozen Lake ์˜ˆ์ œ ์‹คํ–‰ ์‹คํŒจ: {e}")
476        print("gymnasium ํŒจํ‚ค์ง€๊ฐ€ ์„ค์น˜๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”: pip install gymnasium")
477
478    print("\n" + "=" * 60)
479    print("๋™์  ํ”„๋กœ๊ทธ๋ž˜๋ฐ ์˜ˆ์ œ ์™„๋ฃŒ!")
480    print("=" * 60)