1"""
2Actor-Critic (A2C) ꡬν
3Actor-Critic μν€ν
μ², Advantage μΆμ , GAE ν¬ν¨
4"""
5import torch
6import torch.nn as nn
7import torch.nn.functional as F
8import numpy as np
9import gymnasium as gym
10import matplotlib.pyplot as plt
11
12
13class ActorCriticNetwork(nn.Module):
14 """Actor-Critic 곡μ λ€νΈμν¬"""
15
16 def __init__(self, state_dim, action_dim, hidden_dim=128):
17 super().__init__()
18
19 # 곡μ νΉμ§ μΆμΆ λ μ΄μ΄
20 self.shared = nn.Sequential(
21 nn.Linear(state_dim, hidden_dim),
22 nn.ReLU()
23 )
24
25 # Actor (μ μ±
λ€νΈμν¬)
26 self.actor = nn.Sequential(
27 nn.Linear(hidden_dim, hidden_dim),
28 nn.ReLU(),
29 nn.Linear(hidden_dim, action_dim)
30 )
31
32 # Critic (κ°μΉ λ€νΈμν¬)
33 self.critic = nn.Sequential(
34 nn.Linear(hidden_dim, hidden_dim),
35 nn.ReLU(),
36 nn.Linear(hidden_dim, 1)
37 )
38
39 def forward(self, state):
40 """μμ ν: μ μ±
κ³Ό κ°μΉλ₯Ό λμμ μΆλ ₯"""
41 features = self.shared(state)
42 policy_logits = self.actor(features)
43 value = self.critic(features)
44 return policy_logits, value
45
46 def get_action(self, state):
47 """νλ μνλ§"""
48 policy_logits, value = self.forward(state)
49 policy = F.softmax(policy_logits, dim=-1)
50 dist = torch.distributions.Categorical(policy)
51 action = dist.sample()
52 log_prob = dist.log_prob(action)
53 entropy = dist.entropy()
54 return action.item(), log_prob, value, entropy
55
56
57class A2CAgent:
58 """A2C (Advantage Actor-Critic) μμ΄μ νΈ"""
59
60 def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99,
61 value_coef=0.5, entropy_coef=0.01):
62 self.network = ActorCriticNetwork(state_dim, action_dim)
63 self.optimizer = torch.optim.Adam(self.network.parameters(), lr=lr)
64
65 self.gamma = gamma
66 self.value_coef = value_coef # Critic μμ€ κ°μ€μΉ
67 self.entropy_coef = entropy_coef # μνΈλ‘νΌ λ³΄λμ€ κ°μ€μΉ
68
69 # μνΌμλ λ²νΌ
70 self.reset_buffers()
71
72 def reset_buffers(self):
73 """λ²νΌ μ΄κΈ°ν"""
74 self.log_probs = []
75 self.values = []
76 self.rewards = []
77 self.dones = []
78 self.entropies = []
79
80 def choose_action(self, state):
81 """νλ μ ν"""
82 state_tensor = torch.FloatTensor(state).unsqueeze(0)
83 action, log_prob, value, entropy = self.network.get_action(state_tensor)
84
85 # λ²νΌμ μ μ₯
86 self.log_probs.append(log_prob)
87 self.values.append(value)
88 self.entropies.append(entropy)
89
90 return action
91
92 def store_transition(self, reward, done):
93 """μ μ΄ μ μ₯"""
94 self.rewards.append(reward)
95 self.dones.append(done)
96
97 def compute_returns(self, next_value):
98 """n-step returns κ³μ° (λΆνΈμ€νΈλν)"""
99 returns = []
100 R = next_value
101
102 for reward, done in zip(reversed(self.rewards), reversed(self.dones)):
103 if done:
104 R = 0
105 R = reward + self.gamma * R
106 returns.insert(0, R)
107
108 return torch.tensor(returns, dtype=torch.float32)
109
110 def update(self, next_state):
111 """A2C μ
λ°μ΄νΈ"""
112 if len(self.rewards) == 0:
113 return 0, 0
114
115 # λ€μ μνμ κ°μΉ (λΆνΈμ€νΈλνμ©)
116 with torch.no_grad():
117 state_tensor = torch.FloatTensor(next_state).unsqueeze(0)
118 _, next_value = self.network(state_tensor)
119 next_value = next_value.item()
120
121 # Returns κ³μ°
122 returns = self.compute_returns(next_value)
123 values = torch.cat(self.values).squeeze()
124 log_probs = torch.stack(self.log_probs)
125 entropies = torch.stack(self.entropies)
126
127 # Advantage κ³μ°: A(s,a) = Q(s,a) - V(s) β R - V(s)
128 advantages = returns - values.detach()
129
130 # μμ€ κ³μ°
131 actor_loss = -(log_probs * advantages).mean() # Policy gradient
132 critic_loss = F.mse_loss(values, returns) # Value function loss
133 entropy_loss = -entropies.mean() # νμ μ₯λ €
134
135 total_loss = (actor_loss +
136 self.value_coef * critic_loss +
137 self.entropy_coef * entropy_loss)
138
139 # κ·ΈλλμΈνΈ μ
λ°μ΄νΈ
140 self.optimizer.zero_grad()
141 total_loss.backward()
142 torch.nn.utils.clip_grad_norm_(self.network.parameters(), max_norm=0.5)
143 self.optimizer.step()
144
145 # λ²νΌ μ΄κΈ°ν
146 self.reset_buffers()
147
148 return actor_loss.item(), critic_loss.item()
149
150
151class A2CWithGAE(A2CAgent):
152 """GAE (Generalized Advantage Estimation)λ₯Ό μ¬μ©νλ A2C"""
153
154 def __init__(self, *args, gae_lambda=0.95, **kwargs):
155 super().__init__(*args, **kwargs)
156 self.gae_lambda = gae_lambda
157
158 def compute_gae(self, next_value):
159 """GAEλ₯Ό μ¬μ©ν Advantage κ³μ°"""
160 values = torch.cat(self.values).squeeze().tolist()
161 values.append(next_value) # λ§μ§λ§μ λΆνΈμ€νΈλ© κ°μΉ μΆκ°
162
163 advantages = []
164 gae = 0
165
166 # μλ°©ν₯μΌλ‘ GAE κ³μ°
167 for t in reversed(range(len(self.rewards))):
168 if self.dones[t]:
169 delta = self.rewards[t] - values[t]
170 gae = delta
171 else:
172 delta = self.rewards[t] + self.gamma * values[t + 1] - values[t]
173 gae = delta + self.gamma * self.gae_lambda * gae
174
175 advantages.insert(0, gae)
176
177 advantages = torch.tensor(advantages, dtype=torch.float32)
178 returns = advantages + torch.tensor(values[:-1], dtype=torch.float32)
179
180 return advantages, returns
181
182 def update(self, next_state):
183 """GAEλ₯Ό μ¬μ©ν μ
λ°μ΄νΈ"""
184 if len(self.rewards) == 0:
185 return 0, 0
186
187 # λ€μ μνμ κ°μΉ
188 with torch.no_grad():
189 state_tensor = torch.FloatTensor(next_state).unsqueeze(0)
190 _, next_value = self.network(state_tensor)
191 next_value = next_value.item()
192
193 # GAEλ‘ advantageμ returns κ³μ°
194 advantages, returns = self.compute_gae(next_value)
195
196 values = torch.cat(self.values).squeeze()
197 log_probs = torch.stack(self.log_probs)
198 entropies = torch.stack(self.entropies)
199
200 # μμ€ κ³μ°
201 actor_loss = -(log_probs * advantages.detach()).mean()
202 critic_loss = F.mse_loss(values, returns)
203 entropy_loss = -entropies.mean()
204
205 total_loss = (actor_loss +
206 self.value_coef * critic_loss +
207 self.entropy_coef * entropy_loss)
208
209 # μ
λ°μ΄νΈ
210 self.optimizer.zero_grad()
211 total_loss.backward()
212 torch.nn.utils.clip_grad_norm_(self.network.parameters(), max_norm=0.5)
213 self.optimizer.step()
214
215 self.reset_buffers()
216
217 return actor_loss.item(), critic_loss.item()
218
219
220def train_a2c(env_name='CartPole-v1', n_episodes=1000, n_steps=5, use_gae=False):
221 """A2C νμ΅"""
222 env = gym.make(env_name)
223 state_dim = env.observation_space.shape[0]
224 action_dim = env.action_space.n
225
226 # μμ΄μ νΈ μμ±
227 if use_gae:
228 agent = A2CWithGAE(state_dim, action_dim, lr=3e-4, gamma=0.99,
229 value_coef=0.5, entropy_coef=0.01, gae_lambda=0.95)
230 method_name = "A2C with GAE"
231 else:
232 agent = A2CAgent(state_dim, action_dim, lr=3e-4, gamma=0.99,
233 value_coef=0.5, entropy_coef=0.01)
234 method_name = "A2C"
235
236 print(f"=== {method_name} νμ΅ μμ ({env_name}) ===\n")
237
238 scores = []
239 actor_losses = []
240 critic_losses = []
241
242 for episode in range(n_episodes):
243 state, _ = env.reset()
244 total_reward = 0
245 step_count = 0
246 done = False
247
248 while not done:
249 action = agent.choose_action(state)
250 next_state, reward, terminated, truncated, _ = env.step(action)
251 done = terminated or truncated
252
253 agent.store_transition(reward, done)
254 state = next_state
255 total_reward += reward
256 step_count += 1
257
258 # n-step μ
λ°μ΄νΈ λλ μνΌμλ μ’
λ£ μ μ
λ°μ΄νΈ
259 if step_count % n_steps == 0 or done:
260 actor_loss, critic_loss = agent.update(next_state)
261 actor_losses.append(actor_loss)
262 critic_losses.append(critic_loss)
263
264 scores.append(total_reward)
265
266 if (episode + 1) % 50 == 0:
267 avg_score = np.mean(scores[-50:])
268 avg_actor_loss = np.mean(actor_losses[-50:]) if actor_losses else 0
269 avg_critic_loss = np.mean(critic_losses[-50:]) if critic_losses else 0
270 print(f"Episode {episode + 1:4d} | "
271 f"Avg Score: {avg_score:7.2f} | "
272 f"Actor Loss: {avg_actor_loss:.4f} | "
273 f"Critic Loss: {avg_critic_loss:.4f}")
274
275 # CartPole ν΄κ²° 쑰건: μ°μ 100 μνΌμλ νκ· 475 μ΄μ
276 if len(scores) >= 100 and np.mean(scores[-100:]) >= 475:
277 print(f"\nνκ²½ ν΄κ²°! ({episode + 1} μνΌμλ)")
278 break
279
280 env.close()
281 return agent, scores, actor_losses, critic_losses
282
283
284def compare_a2c_with_reinforce():
285 """A2Cμ REINFORCE λΉκ΅"""
286 print("=== A2C vs REINFORCE λΉκ΅ ===\n")
287
288 # A2C νμ΅
289 _, a2c_scores, _, _ = train_a2c('CartPole-v1', n_episodes=500, use_gae=False)
290
291 # A2C with GAE νμ΅
292 print("\n" + "="*60 + "\n")
293 _, a2c_gae_scores, _, _ = train_a2c('CartPole-v1', n_episodes=500, use_gae=True)
294
295 # νμ΅ κ³‘μ λΉκ΅ μκ°ν
296 plot_comparison(a2c_scores, a2c_gae_scores)
297
298 return a2c_scores, a2c_gae_scores
299
300
301def plot_comparison(a2c_scores, a2c_gae_scores):
302 """νμ΅ κ³‘μ λΉκ΅ μκ°ν"""
303 window = 10
304
305 def smooth(data, window):
306 if len(data) < window:
307 return data
308 return np.convolve(data, np.ones(window)/window, mode='valid')
309
310 plt.figure(figsize=(14, 5))
311
312 # μλ³Έ λ°μ΄ν°
313 plt.subplot(1, 2, 1)
314 plt.plot(a2c_scores, alpha=0.3, label='A2C (raw)', color='blue')
315 plt.plot(a2c_gae_scores, alpha=0.3, label='A2C+GAE (raw)', color='green')
316 plt.axhline(y=475, color='red', linestyle='--', linewidth=1, label='Solved threshold')
317 plt.xlabel('Episode')
318 plt.ylabel('Episode Reward')
319 plt.title('A2C vs A2C+GAE - Raw Data')
320 plt.legend()
321 plt.grid(True, alpha=0.3)
322
323 # νννλ λ°μ΄ν°
324 plt.subplot(1, 2, 2)
325 plt.plot(smooth(a2c_scores, window), label='A2C (smoothed)', linewidth=2, color='blue')
326 plt.plot(smooth(a2c_gae_scores, window), label='A2C+GAE (smoothed)', linewidth=2, color='green')
327 plt.axhline(y=475, color='red', linestyle='--', linewidth=1, label='Solved threshold')
328 plt.xlabel('Episode')
329 plt.ylabel('Episode Reward (smoothed)')
330 plt.title(f'A2C vs A2C+GAE - Smoothed (window={window})')
331 plt.legend()
332 plt.grid(True, alpha=0.3)
333
334 plt.tight_layout()
335 plt.savefig('a2c_comparison.png', dpi=150)
336 print("\nνμ΅ κ³‘μ μ΄ 'a2c_comparison.png'λ‘ μ μ₯λμμ΅λλ€.")
337
338
339def train_lunarlander():
340 """LunarLander νκ²½μμ A2C νμ΅"""
341 try:
342 env = gym.make('LunarLander-v2')
343 except:
344 print("LunarLander-v2 νκ²½μ μ°Ύμ μ μμ΅λλ€.")
345 print("μ€μΉ: pip install gymnasium[box2d]")
346 return None, None
347
348 state_dim = env.observation_space.shape[0]
349 action_dim = env.action_space.n
350
351 # GAEλ₯Ό μ¬μ©νλ A2C
352 agent = A2CWithGAE(
353 state_dim, action_dim,
354 lr=7e-4, gamma=0.99,
355 value_coef=0.5, entropy_coef=0.01,
356 gae_lambda=0.95
357 )
358
359 print("=== LunarLander A2C νμ΅ μμ ===\n")
360
361 scores = []
362 n_steps = 5
363 n_episodes = 2000
364
365 for episode in range(n_episodes):
366 state, _ = env.reset()
367 total_reward = 0
368 steps = 0
369
370 while True:
371 action = agent.choose_action(state)
372 next_state, reward, terminated, truncated, _ = env.step(action)
373 done = terminated or truncated
374
375 agent.store_transition(reward, done)
376 state = next_state
377 total_reward += reward
378 steps += 1
379
380 # n-step μ
λ°μ΄νΈ λλ μ’
λ£ μ
381 if steps % n_steps == 0 or done:
382 agent.update(next_state)
383
384 if done:
385 break
386
387 scores.append(total_reward)
388
389 if (episode + 1) % 100 == 0:
390 avg = np.mean(scores[-100:])
391 print(f"Episode {episode + 1:4d} | Avg Score: {avg:.2f}")
392
393 # LunarLander ν΄κ²° 쑰건: νκ· 200 μ΄μ
394 if avg >= 200:
395 print(f"\nνκ²½ ν΄κ²°! ({episode + 1} μνΌμλ)")
396 break
397
398 env.close()
399 return agent, scores
400
401
402if __name__ == "__main__":
403 # 1. CartPoleμμ A2C vs A2C+GAE λΉκ΅
404 a2c_scores, a2c_gae_scores = compare_a2c_with_reinforce()
405
406 # 2. LunarLander νμ΅ (μ νμ )
407 print("\n" + "="*60)
408 print("LunarLander νμ΅μ μμνλ €λ©΄ μ£Όμμ ν΄μ νμΈμ:")
409 print("# agent, scores = train_lunarlander()")
410
411 # νμ΅ κ²°κ³Ό μμ½
412 print("\n" + "="*60)
413 print("νμ΅ μλ£!")
414 print(f"A2C μ΅μ’
100 μνΌμλ νκ· : {np.mean(a2c_scores[-100:]):.2f}")
415 print(f"A2C+GAE μ΅μ’
100 μνΌμλ νκ· : {np.mean(a2c_gae_scores[-100:]):.2f}")