1"""
2๋์ ํ๋ก๊ทธ๋๋ฐ (Dynamic Programming) ๊ตฌํ
3- Policy Evaluation (์ ์ฑ
ํ๊ฐ)
4- Policy Improvement (์ ์ฑ
๊ฐ์ )
5- Policy Iteration (์ ์ฑ
๋ฐ๋ณต)
6- Value Iteration (๊ฐ์น ๋ฐ๋ณต)
7"""
8import numpy as np
9import matplotlib.pyplot as plt
10from typing import Dict, List, Tuple
11from collections import defaultdict
12
13
14class GridWorld:
15 """๊ฐ๋จํ ๊ทธ๋ฆฌ๋ ์๋ ํ๊ฒฝ"""
16
17 def __init__(self, size=4):
18 self.size = size
19 self.actions = ['up', 'down', 'left', 'right']
20 self.n_actions = len(self.actions)
21
22 def get_states(self):
23 """๋ชจ๋ ์ํ ๋ฐํ"""
24 return [(i, j) for i in range(self.size) for j in range(self.size)]
25
26 def is_terminal(self, state):
27 """์ข
๋ฃ ์ํ ํ์ธ"""
28 return state == (0, 0) or state == (self.size-1, self.size-1)
29
30 def get_transitions(self, state, action):
31 """์ ์ด ํ๋ฅ ๋ฐํ: [(prob, next_state, reward, done)]"""
32 if self.is_terminal(state):
33 return [(1.0, state, 0, True)]
34
35 deltas = {
36 'up': (-1, 0),
37 'down': (1, 0),
38 'left': (0, -1),
39 'right': (0, 1)
40 }
41 delta = deltas[action]
42
43 # ๊ทธ๋ฆฌ๋ ๊ฒฝ๊ณ ์ฒ๋ฆฌ
44 new_row = max(0, min(self.size-1, state[0] + delta[0]))
45 new_col = max(0, min(self.size-1, state[1] + delta[1]))
46 next_state = (new_row, new_col)
47
48 # ๋ณด์: ๊ฐ ์ด๋๋ง๋ค -1
49 reward = -1
50 done = self.is_terminal(next_state)
51
52 return [(1.0, next_state, reward, done)]
53
54
55def create_uniform_policy(grid):
56 """๊ท ๋ฑ ๋๋ค ์ ์ฑ
์์ฑ"""
57 policy = {}
58 for s in grid.get_states():
59 policy[s] = {a: 1.0/len(grid.actions) for a in grid.actions}
60 return policy
61
62
63def policy_evaluation(grid, policy: Dict, gamma: float = 0.9, theta: float = 1e-6):
64 """
65 ์ ์ฑ
ํ๊ฐ: ์ฃผ์ด์ง ์ ์ฑ
์ ๊ฐ์น ํจ์ ๊ณ์ฐ
66
67 Args:
68 grid: GridWorld ํ๊ฒฝ
69 policy: ์ ์ฑ
{state: {action: probability}}
70 gamma: ํ ์ธ์จ
71 theta: ์๋ ด ์๊ณ๊ฐ
72
73 Returns:
74 V: ์ํ ๊ฐ์น ํจ์ {state: value}
75 """
76 # ๊ฐ์น ํจ์ ์ด๊ธฐํ
77 V = {s: 0.0 for s in grid.get_states()}
78
79 iteration = 0
80 while True:
81 delta = 0 # ์ต๋ ๋ณํ๋ ์ถ์
82 iteration += 1
83
84 # ๋ชจ๋ ์ํ์ ๋ํด ์
๋ฐ์ดํธ
85 for s in grid.get_states():
86 if grid.is_terminal(s):
87 continue
88
89 v = V[s] # ์ด์ ๊ฐ ์ ์ฅ
90 new_v = 0
91
92 # ๋ฒจ๋ง ๊ธฐ๋ ๋ฐฉ์ ์ ์ ์ฉ
93 for a in grid.actions:
94 action_prob = policy[s].get(a, 0)
95
96 for prob, next_s, reward, done in grid.get_transitions(s, a):
97 if done:
98 new_v += action_prob * prob * reward
99 else:
100 new_v += action_prob * prob * (reward + gamma * V[next_s])
101
102 V[s] = new_v
103 delta = max(delta, abs(v - new_v))
104
105 # ์๋ ด ์ฒดํฌ
106 if delta < theta:
107 print(f"์ ์ฑ
ํ๊ฐ ์๋ ด: {iteration} iterations, delta={delta:.8f}")
108 break
109
110 return V
111
112
113def policy_improvement(grid, V: Dict, gamma: float = 0.9):
114 """
115 ์ ์ฑ
๊ฐ์ : V๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ์์ ์ ์ฑ
์์ฑ
116
117 Args:
118 grid: GridWorld ํ๊ฒฝ
119 V: ํ์ฌ ๊ฐ์น ํจ์
120 gamma: ํ ์ธ์จ
121
122 Returns:
123 new_policy: ๊ฐ์ ๋ ์ ์ฑ
124 policy_stable: ์ ์ฑ
์ด ๋ณํ์ง ์์์ผ๋ฉด True
125 """
126 new_policy = {}
127 policy_stable = True
128
129 for s in grid.get_states():
130 if grid.is_terminal(s):
131 new_policy[s] = {a: 1.0/len(grid.actions) for a in grid.actions}
132 continue
133
134 # ๊ฐ ํ๋์ Q ๊ฐ ๊ณ์ฐ
135 q_values = {}
136 for a in grid.actions:
137 q = 0
138 for prob, next_s, reward, done in grid.get_transitions(s, a):
139 if done:
140 q += prob * reward
141 else:
142 q += prob * (reward + gamma * V[next_s])
143 q_values[a] = q
144
145 # ์ต์ ํ๋ ์ฐพ๊ธฐ
146 best_action = max(q_values, key=q_values.get)
147 best_q = q_values[best_action]
148
149 # ๋๋ฅ ์ธ ํ๋๋ค ์ฐพ๊ธฐ (์์น ์ค์ฐจ ๊ณ ๋ ค)
150 best_actions = [a for a, q in q_values.items()
151 if abs(q - best_q) < 1e-8]
152
153 # ๊ฒฐ์ ์ ์ ์ฑ
์์ฑ
154 new_policy[s] = {a: 0.0 for a in grid.actions}
155 for a in best_actions:
156 new_policy[s][a] = 1.0 / len(best_actions)
157
158 return new_policy, policy_stable
159
160
161def policy_iteration(grid, gamma: float = 0.9, theta: float = 1e-6):
162 """
163 ์ ์ฑ
๋ฐ๋ณต ์๊ณ ๋ฆฌ์ฆ
164
165 Returns:
166 V: ์ต์ ๊ฐ์น ํจ์
167 policy: ์ต์ ์ ์ฑ
168 """
169 # ๊ท ๋ฑ ๋๋ค ์ ์ฑ
์ผ๋ก ์ด๊ธฐํ
170 policy = create_uniform_policy(grid)
171
172 iteration = 0
173 while True:
174 iteration += 1
175 print(f"\n=== ์ ์ฑ
๋ฐ๋ณต {iteration} ===")
176
177 # 1. ์ ์ฑ
ํ๊ฐ
178 V = policy_evaluation(grid, policy, gamma, theta)
179
180 # 2. ์ ์ฑ
๊ฐ์
181 old_policy = policy.copy()
182 policy, _ = policy_improvement(grid, V, gamma)
183
184 # 3. ์ ์ฑ
์์ ์ฑ ์ฒดํฌ
185 policy_stable = True
186 for s in grid.get_states():
187 if grid.is_terminal(s):
188 continue
189
190 old_best = max(old_policy[s], key=old_policy[s].get)
191 new_best = max(policy[s], key=policy[s].get)
192
193 if old_best != new_best:
194 policy_stable = False
195 break
196
197 if policy_stable:
198 print(f"\n์ ์ฑ
๋ฐ๋ณต ์๋ ด! (์ด {iteration} iterations)")
199 break
200
201 return V, policy
202
203
204def value_iteration(grid, gamma: float = 0.9, theta: float = 1e-6):
205 """
206 ๊ฐ์น ๋ฐ๋ณต ์๊ณ ๋ฆฌ์ฆ
207
208 Returns:
209 V: ์ต์ ๊ฐ์น ํจ์
210 policy: ์ต์ ์ ์ฑ
211 """
212 # ๊ฐ์น ํจ์ ์ด๊ธฐํ
213 V = {s: 0.0 for s in grid.get_states()}
214
215 iteration = 0
216 while True:
217 delta = 0
218 iteration += 1
219
220 for s in grid.get_states():
221 if grid.is_terminal(s):
222 continue
223
224 v = V[s]
225
226 # ๋ฒจ๋ง ์ต์ ์ฑ ๋ฐฉ์ ์: max over actions
227 q_values = []
228 for a in grid.actions:
229 q = 0
230 for prob, next_s, reward, done in grid.get_transitions(s, a):
231 if done:
232 q += prob * reward
233 else:
234 q += prob * (reward + gamma * V[next_s])
235 q_values.append(q)
236
237 V[s] = max(q_values)
238 delta = max(delta, abs(v - V[s]))
239
240 if iteration % 10 == 0:
241 print(f"๋ฐ๋ณต {iteration}: delta = {delta:.8f}")
242
243 if delta < theta:
244 print(f"\n๊ฐ์น ๋ฐ๋ณต ์๋ ด: {iteration} iterations")
245 break
246
247 # ์ต์ ์ ์ฑ
์ถ์ถ
248 policy = {}
249 for s in grid.get_states():
250 if grid.is_terminal(s):
251 policy[s] = {a: 1.0/len(grid.actions) for a in grid.actions}
252 continue
253
254 q_values = {}
255 for a in grid.actions:
256 q = 0
257 for prob, next_s, reward, done in grid.get_transitions(s, a):
258 if done:
259 q += prob * reward
260 else:
261 q += prob * (reward + gamma * V[next_s])
262 q_values[a] = q
263
264 best_action = max(q_values, key=q_values.get)
265 policy[s] = {a: 0.0 for a in grid.actions}
266 policy[s][best_action] = 1.0
267
268 return V, policy
269
270
271def print_value_function(grid, V):
272 """๊ฐ์น ํจ์ ์ถ๋ ฅ"""
273 print("\n๊ฐ์น ํจ์:")
274 for i in range(grid.size):
275 row = [f"{V[(i,j)]:7.2f}" for j in range(grid.size)]
276 print(" ".join(row))
277
278
279def print_policy(grid, policy):
280 """์ ์ฑ
์ถ๋ ฅ (ํ์ดํ๋ก)"""
281 print("\n์ต์ ์ ์ฑ
:")
282 arrows = {'up': 'โ', 'down': 'โ', 'left': 'โ', 'right': 'โ'}
283
284 for i in range(grid.size):
285 row = []
286 for j in range(grid.size):
287 s = (i, j)
288 if grid.is_terminal(s):
289 row.append(' * ')
290 else:
291 best_a = max(policy[s], key=policy[s].get)
292 row.append(f' {arrows[best_a]} ')
293 print(" ".join(row))
294
295
296def visualize_value_function(grid, V, title="Value Function"):
297 """๊ฐ์น ํจ์ ์๊ฐํ"""
298 value_grid = np.zeros((grid.size, grid.size))
299 for i in range(grid.size):
300 for j in range(grid.size):
301 value_grid[i, j] = V[(i, j)]
302
303 plt.figure(figsize=(8, 6))
304 plt.imshow(value_grid, cmap='coolwarm', interpolation='nearest')
305 plt.colorbar(label='Value')
306 plt.title(title)
307
308 # ์ซ์ ํ์
309 for i in range(grid.size):
310 for j in range(grid.size):
311 plt.text(j, i, f'{value_grid[i, j]:.1f}',
312 ha='center', va='center', color='black', fontsize=12)
313
314 plt.xticks(range(grid.size))
315 plt.yticks(range(grid.size))
316 plt.grid(False)
317 plt.tight_layout()
318 plt.savefig('value_function.png', dpi=150)
319 print(f"๊ฐ์น ํจ์ ์๊ฐํ ์ ์ฅ: value_function.png")
320
321
322def compare_algorithms():
323 """DP ์๊ณ ๋ฆฌ์ฆ ๋น๊ต"""
324 print("=" * 60)
325 print("๋์ ํ๋ก๊ทธ๋๋ฐ ์๊ณ ๋ฆฌ์ฆ ๋น๊ต")
326 print("=" * 60)
327
328 grid = GridWorld(size=4)
329 gamma = 0.9
330
331 # 1. ์ ์ฑ
ํ๊ฐ (๊ท ๋ฑ ๋๋ค ์ ์ฑ
)
332 print("\n[1] ์ ์ฑ
ํ๊ฐ - ๊ท ๋ฑ ๋๋ค ์ ์ฑ
")
333 print("-" * 60)
334 uniform_policy = create_uniform_policy(grid)
335 V_uniform = policy_evaluation(grid, uniform_policy, gamma)
336 print_value_function(grid, V_uniform)
337
338 # 2. ์ ์ฑ
๋ฐ๋ณต
339 print("\n[2] ์ ์ฑ
๋ฐ๋ณต (Policy Iteration)")
340 print("-" * 60)
341 V_pi, policy_pi = policy_iteration(grid, gamma)
342 print_value_function(grid, V_pi)
343 print_policy(grid, policy_pi)
344
345 # 3. ๊ฐ์น ๋ฐ๋ณต
346 print("\n[3] ๊ฐ์น ๋ฐ๋ณต (Value Iteration)")
347 print("-" * 60)
348 V_vi, policy_vi = value_iteration(grid, gamma)
349 print_value_function(grid, V_vi)
350 print_policy(grid, policy_vi)
351
352 # 4. ๊ฒฐ๊ณผ ๋น๊ต
353 print("\n[4] ๊ฒฐ๊ณผ ๋น๊ต")
354 print("-" * 60)
355 print("์ ์ฑ
๋ฐ๋ณต๊ณผ ๊ฐ์น ๋ฐ๋ณต์ ๊ฐ์น ํจ์ ์ฐจ์ด:")
356 max_diff = 0
357 for s in grid.get_states():
358 diff = abs(V_pi[s] - V_vi[s])
359 max_diff = max(max_diff, diff)
360 print(f"์ต๋ ์ฐจ์ด: {max_diff:.10f}")
361
362 # ์๊ฐํ
363 visualize_value_function(grid, V_pi, "Policy Iteration - Value Function")
364
365 return V_pi, policy_pi, V_vi, policy_vi
366
367
368def frozen_lake_example():
369 """Frozen Lake ํ๊ฒฝ์์ DP ์ ์ฉ"""
370 import gymnasium as gym
371
372 print("\n" + "=" * 60)
373 print("Frozen Lake ์์ ")
374 print("=" * 60)
375
376 # ํ๊ฒฝ ์์ฑ (๋ฏธ๋๋ฌ์ง์ง ์๋ ๋ฒ์ )
377 env = gym.make('FrozenLake-v1', is_slippery=False)
378
379 n_states = env.observation_space.n
380 n_actions = env.action_space.n
381 gamma = 0.99
382 theta = 1e-8
383
384 # P[s][a] = [(prob, next_state, reward, done), ...]
385 P = env.unwrapped.P
386
387 # ๊ฐ์น ๋ฐ๋ณต
388 V = np.zeros(n_states)
389 iteration = 0
390
391 print("\n๊ฐ์น ๋ฐ๋ณต ์์...")
392 while True:
393 delta = 0
394 iteration += 1
395
396 for s in range(n_states):
397 v = V[s]
398
399 # ๊ฐ ํ๋์ ๊ฐ์น ๊ณ์ฐ
400 q_values = []
401 for a in range(n_actions):
402 q = sum(prob * (reward + gamma * V[next_s] * (not done))
403 for prob, next_s, reward, done in P[s][a])
404 q_values.append(q)
405
406 V[s] = max(q_values)
407 delta = max(delta, abs(v - V[s]))
408
409 if delta < theta:
410 print(f"์๋ ด: {iteration} iterations")
411 break
412
413 # ์ต์ ์ ์ฑ
์ถ์ถ
414 policy = np.zeros(n_states, dtype=int)
415 for s in range(n_states):
416 q_values = []
417 for a in range(n_actions):
418 q = sum(prob * (reward + gamma * V[next_s] * (not done))
419 for prob, next_s, reward, done in P[s][a])
420 q_values.append(q)
421 policy[s] = np.argmax(q_values)
422
423 # ๊ฒฐ๊ณผ ์๊ฐํ
424 action_names = ['โ', 'โ', 'โ', 'โ']
425 print("\n์ต์ ์ ์ฑ
(4x4 ๊ทธ๋ฆฌ๋):")
426 print("S: ์์, H: ๊ตฌ๋ฉ, F: ์ผ์, G: ๋ชฉํ")
427 for i in range(4):
428 row = ""
429 for j in range(4):
430 s = i * 4 + j
431 if s == 0:
432 row += " S "
433 elif s in [5, 7, 11, 12]: # ๊ตฌ๋ฉ
434 row += " H "
435 elif s == 15: # ๋ชฉํ
436 row += " G "
437 else:
438 row += f" {action_names[policy[s]]} "
439 print(row)
440
441 print("\n๊ฐ์น ํจ์:")
442 print(V.reshape(4, 4).round(3))
443
444 # ์ ์ฑ
ํ
์คํธ
445 print("\n์ ์ฑ
ํ
์คํธ ์ค...")
446 success = 0
447 n_tests = 100
448
449 for _ in range(n_tests):
450 state, _ = env.reset()
451 done = False
452
453 while not done:
454 action = policy[state]
455 state, reward, terminated, truncated, _ = env.step(action)
456 done = terminated or truncated
457
458 if reward > 0:
459 success += 1
460
461 print(f"์ฑ๊ณต๋ฅ : {success}/{n_tests} = {success/n_tests*100:.1f}%")
462
463 env.close()
464 return V, policy
465
466
467if __name__ == "__main__":
468 # ๊ทธ๋ฆฌ๋ ์๋ ์๊ณ ๋ฆฌ์ฆ ๋น๊ต
469 V_pi, policy_pi, V_vi, policy_vi = compare_algorithms()
470
471 # Frozen Lake ์์
472 try:
473 V_fl, policy_fl = frozen_lake_example()
474 except Exception as e:
475 print(f"\nFrozen Lake ์์ ์คํ ์คํจ: {e}")
476 print("gymnasium ํจํค์ง๊ฐ ์ค์น๋์ด ์๋์ง ํ์ธํ์ธ์: pip install gymnasium")
477
478 print("\n" + "=" * 60)
479 print("๋์ ํ๋ก๊ทธ๋๋ฐ ์์ ์๋ฃ!")
480 print("=" * 60)