1"""
2TD Learning (Temporal Difference Learning) ꡬν
3TD(0), SARSA, Q-Learning, Expected SARSA ν¬ν¨
4"""
5import numpy as np
6import gymnasium as gym
7from collections import defaultdict
8import matplotlib.pyplot as plt
9
10
11class TD0Prediction:
12 """TD(0) μ μ±
νκ° μκ³ λ¦¬μ¦"""
13
14 def __init__(self, alpha=0.1, gamma=0.99):
15 self.V = defaultdict(float)
16 self.alpha = alpha
17 self.gamma = gamma
18
19 def update(self, state, reward, next_state, done):
20 """TD(0) κ°μΉ ν¨μ μ
λ°μ΄νΈ"""
21 if done:
22 td_target = reward
23 else:
24 td_target = reward + self.gamma * self.V[next_state]
25
26 td_error = td_target - self.V[state]
27 self.V[state] += self.alpha * td_error
28 return td_error
29
30 def get_value(self, state):
31 return self.V[state]
32
33
34class SARSA:
35 """SARSA (On-policy TD Control)"""
36
37 def __init__(self, n_actions, alpha=0.5, gamma=0.99, epsilon=0.1):
38 self.Q = defaultdict(lambda: np.zeros(n_actions))
39 self.alpha = alpha
40 self.gamma = gamma
41 self.epsilon = epsilon
42 self.n_actions = n_actions
43
44 def choose_action(self, state):
45 """Ξ΅-greedy μ μ±
"""
46 if np.random.random() < self.epsilon:
47 return np.random.randint(self.n_actions)
48 return np.argmax(self.Q[state])
49
50 def update(self, state, action, reward, next_state, next_action, done):
51 """SARSA μ
λ°μ΄νΈ: Q(s,a) β Q(s,a) + Ξ±[r + Ξ³Q(s',a') - Q(s,a)]"""
52 if done:
53 td_target = reward
54 else:
55 td_target = reward + self.gamma * self.Q[next_state][next_action]
56
57 td_error = td_target - self.Q[state][action]
58 self.Q[state][action] += self.alpha * td_error
59 return td_error
60
61
62class QLearning:
63 """Q-Learning (Off-policy TD Control)"""
64
65 def __init__(self, n_actions, alpha=0.5, gamma=0.99, epsilon=0.1):
66 self.Q = defaultdict(lambda: np.zeros(n_actions))
67 self.alpha = alpha
68 self.gamma = gamma
69 self.epsilon = epsilon
70 self.n_actions = n_actions
71
72 def choose_action(self, state):
73 """Ξ΅-greedy μ μ±
"""
74 if np.random.random() < self.epsilon:
75 return np.random.randint(self.n_actions)
76 return np.argmax(self.Q[state])
77
78 def update(self, state, action, reward, next_state, done):
79 """Q-Learning μ
λ°μ΄νΈ: Q(s,a) β Q(s,a) + Ξ±[r + Ξ³ max Q(s',a') - Q(s,a)]"""
80 if done:
81 td_target = reward
82 else:
83 td_target = reward + self.gamma * np.max(self.Q[next_state])
84
85 td_error = td_target - self.Q[state][action]
86 self.Q[state][action] += self.alpha * td_error
87 return td_error
88
89
90class ExpectedSARSA:
91 """Expected SARSA"""
92
93 def __init__(self, n_actions, alpha=0.5, gamma=0.99, epsilon=0.1):
94 self.Q = defaultdict(lambda: np.zeros(n_actions))
95 self.alpha = alpha
96 self.gamma = gamma
97 self.epsilon = epsilon
98 self.n_actions = n_actions
99
100 def choose_action(self, state):
101 """Ξ΅-greedy μ μ±
"""
102 if np.random.random() < self.epsilon:
103 return np.random.randint(self.n_actions)
104 return np.argmax(self.Q[state])
105
106 def update(self, state, action, reward, next_state, done):
107 """Expected SARSA μ
λ°μ΄νΈ: λ€μ μνμμ μ μ±
μ κΈ°λκ° μ¬μ©"""
108 if done:
109 td_target = reward
110 else:
111 # Ξ΅-greedy μ μ±
νμμ κΈ°λκ° κ³μ°
112 best_action = np.argmax(self.Q[next_state])
113 expected_q = 0.0
114 for a in range(self.n_actions):
115 if a == best_action:
116 prob = 1 - self.epsilon + self.epsilon / self.n_actions
117 else:
118 prob = self.epsilon / self.n_actions
119 expected_q += prob * self.Q[next_state][a]
120
121 td_target = reward + self.gamma * expected_q
122
123 td_error = td_target - self.Q[state][action]
124 self.Q[state][action] += self.alpha * td_error
125 return td_error
126
127
128def train_sarsa(env_name='CliffWalking-v0', n_episodes=500, alpha=0.5, gamma=1.0, epsilon=0.1):
129 """SARSA νμ΅"""
130 env = gym.make(env_name)
131 agent = SARSA(env.action_space.n, alpha=alpha, gamma=gamma, epsilon=epsilon)
132
133 episode_rewards = []
134
135 for episode in range(n_episodes):
136 state, _ = env.reset()
137 action = agent.choose_action(state)
138 total_reward = 0
139 done = False
140
141 while not done:
142 next_state, reward, terminated, truncated, _ = env.step(action)
143 done = terminated or truncated
144
145 next_action = agent.choose_action(next_state)
146 agent.update(state, action, reward, next_state, next_action, done)
147
148 state = next_state
149 action = next_action
150 total_reward += reward
151
152 episode_rewards.append(total_reward)
153
154 if (episode + 1) % 100 == 0:
155 avg = np.mean(episode_rewards[-100:])
156 print(f"SARSA - Episode {episode + 1}: avg_reward = {avg:.1f}")
157
158 env.close()
159 return agent, episode_rewards
160
161
162def train_qlearning(env_name='CliffWalking-v0', n_episodes=500, alpha=0.5, gamma=1.0, epsilon=0.1):
163 """Q-Learning νμ΅"""
164 env = gym.make(env_name)
165 agent = QLearning(env.action_space.n, alpha=alpha, gamma=gamma, epsilon=epsilon)
166
167 episode_rewards = []
168
169 for episode in range(n_episodes):
170 state, _ = env.reset()
171 total_reward = 0
172 done = False
173
174 while not done:
175 action = agent.choose_action(state)
176 next_state, reward, terminated, truncated, _ = env.step(action)
177 done = terminated or truncated
178
179 agent.update(state, action, reward, next_state, done)
180
181 state = next_state
182 total_reward += reward
183
184 episode_rewards.append(total_reward)
185
186 if (episode + 1) % 100 == 0:
187 avg = np.mean(episode_rewards[-100:])
188 print(f"Q-Learning - Episode {episode + 1}: avg_reward = {avg:.1f}")
189
190 env.close()
191 return agent, episode_rewards
192
193
194def compare_td_methods():
195 """TD λ°©λ²λ€ λΉκ΅: SARSA vs Q-Learning vs Expected SARSA"""
196 print("=== CliffWalking νκ²½μμ TD λ°©λ² λΉκ΅ ===\n")
197
198 # SARSA (μμ ν κ²½λ‘ μ νΈ)
199 print("SARSA νμ΅ μ€...")
200 _, sarsa_rewards = train_sarsa(n_episodes=500)
201
202 # Q-Learning (μ΅μ κ²½λ‘ νμ΅, λ μν)
203 print("\nQ-Learning νμ΅ μ€...")
204 _, qlearning_rewards = train_qlearning(n_episodes=500)
205
206 # Expected SARSA
207 print("\nExpected SARSA νμ΅ μ€...")
208 env = gym.make('CliffWalking-v0')
209 expected_sarsa = ExpectedSARSA(env.action_space.n, alpha=0.5, gamma=1.0, epsilon=0.1)
210 expected_rewards = []
211
212 for episode in range(500):
213 state, _ = env.reset()
214 total_reward = 0
215 done = False
216
217 while not done:
218 action = expected_sarsa.choose_action(state)
219 next_state, reward, terminated, truncated, _ = env.step(action)
220 done = terminated or truncated
221
222 expected_sarsa.update(state, action, reward, next_state, done)
223
224 state = next_state
225 total_reward += reward
226
227 expected_rewards.append(total_reward)
228
229 if (episode + 1) % 100 == 0:
230 avg = np.mean(expected_rewards[-100:])
231 print(f"Expected SARSA - Episode {episode + 1}: avg_reward = {avg:.1f}")
232
233 env.close()
234
235 # νμ΅ κ³‘μ μκ°ν
236 plot_comparison(sarsa_rewards, qlearning_rewards, expected_rewards)
237
238 return sarsa_rewards, qlearning_rewards, expected_rewards
239
240
241def plot_comparison(sarsa_rewards, qlearning_rewards, expected_rewards):
242 """νμ΅ κ³‘μ λΉκ΅ μκ°ν"""
243 window = 10
244
245 def smooth(data, window):
246 return np.convolve(data, np.ones(window)/window, mode='valid')
247
248 plt.figure(figsize=(12, 5))
249
250 # μλ³Έ λ°μ΄ν°
251 plt.subplot(1, 2, 1)
252 plt.plot(sarsa_rewards, alpha=0.3, label='SARSA (raw)')
253 plt.plot(qlearning_rewards, alpha=0.3, label='Q-Learning (raw)')
254 plt.plot(expected_rewards, alpha=0.3, label='Expected SARSA (raw)')
255 plt.xlabel('Episode')
256 plt.ylabel('Episode Reward')
257 plt.title('TD Methods Comparison - Raw Data')
258 plt.legend()
259 plt.grid(True, alpha=0.3)
260
261 # νννλ λ°μ΄ν°
262 plt.subplot(1, 2, 2)
263 plt.plot(smooth(sarsa_rewards, window), label='SARSA (smoothed)', linewidth=2)
264 plt.plot(smooth(qlearning_rewards, window), label='Q-Learning (smoothed)', linewidth=2)
265 plt.plot(smooth(expected_rewards, window), label='Expected SARSA (smoothed)', linewidth=2)
266 plt.xlabel('Episode')
267 plt.ylabel('Episode Reward (smoothed)')
268 plt.title(f'TD Methods Comparison - Smoothed (window={window})')
269 plt.legend()
270 plt.grid(True, alpha=0.3)
271
272 plt.tight_layout()
273 plt.savefig('td_methods_comparison.png', dpi=150)
274 print("\nνμ΅ κ³‘μ μ΄ 'td_methods_comparison.png'λ‘ μ μ₯λμμ΅λλ€.")
275
276
277def visualize_policy(agent, env_name='CliffWalking-v0'):
278 """νμ΅λ μ μ±
μκ°ν (CliffWalking μ μ©)"""
279 if env_name != 'CliffWalking-v0':
280 print("μ μ±
μκ°νλ CliffWalking νκ²½λ§ μ§μν©λλ€.")
281 return
282
283 print("\n=== νμ΅λ μ μ±
(4x12 그리λ) ===")
284 arrows = {0: '^', 1: '>', 2: 'v', 3: '<'}
285
286 for row in range(4):
287 line = ""
288 for col in range(12):
289 state = row * 12 + col
290 if state == 36: # μμμ
291 line += " S "
292 elif state == 47: # λͺ©ν
293 line += " G "
294 elif 37 <= state <= 46: # μ λ²½
295 line += " C "
296 else:
297 action = np.argmax(agent.Q[state])
298 line += f" {arrows[action]} "
299 print(line)
300
301 print("\n(S: μμ, G: λͺ©ν, C: μ λ²½, ^>v<: νλ λ°©ν₯)")
302
303
304if __name__ == "__main__":
305 # TD λ°©λ² λΉκ΅
306 sarsa_rewards, qlearning_rewards, expected_rewards = compare_td_methods()
307
308 # SARSA μ μ±
μκ°ν
309 print("\n" + "="*50)
310 env = gym.make('CliffWalking-v0')
311 sarsa_agent = SARSA(env.action_space.n)
312
313 # λ€μ νμ΅ (μκ°νμ©)
314 for episode in range(500):
315 state, _ = env.reset()
316 action = sarsa_agent.choose_action(state)
317 done = False
318
319 while not done:
320 next_state, reward, terminated, truncated, _ = env.step(action)
321 done = terminated or truncated
322 next_action = sarsa_agent.choose_action(next_state)
323 sarsa_agent.update(state, action, reward, next_state, next_action, done)
324 state = next_state
325 action = next_action
326
327 print("\nSARSA νμ΅ μλ£ - μμ ν κ²½λ‘ μ νΈ")
328 visualize_policy(sarsa_agent)
329
330 env.close()