09_actor_critic.py

Download
python 416 lines 13.2 KB
  1"""
  2Actor-Critic (A2C) κ΅¬ν˜„
  3Actor-Critic μ•„ν‚€ν…μ²˜, Advantage μΆ”μ •, GAE 포함
  4"""
  5import torch
  6import torch.nn as nn
  7import torch.nn.functional as F
  8import numpy as np
  9import gymnasium as gym
 10import matplotlib.pyplot as plt
 11
 12
 13class ActorCriticNetwork(nn.Module):
 14    """Actor-Critic 곡유 λ„€νŠΈμ›Œν¬"""
 15
 16    def __init__(self, state_dim, action_dim, hidden_dim=128):
 17        super().__init__()
 18
 19        # 곡유 νŠΉμ§• μΆ”μΆœ λ ˆμ΄μ–΄
 20        self.shared = nn.Sequential(
 21            nn.Linear(state_dim, hidden_dim),
 22            nn.ReLU()
 23        )
 24
 25        # Actor (μ •μ±… λ„€νŠΈμ›Œν¬)
 26        self.actor = nn.Sequential(
 27            nn.Linear(hidden_dim, hidden_dim),
 28            nn.ReLU(),
 29            nn.Linear(hidden_dim, action_dim)
 30        )
 31
 32        # Critic (κ°€μΉ˜ λ„€νŠΈμ›Œν¬)
 33        self.critic = nn.Sequential(
 34            nn.Linear(hidden_dim, hidden_dim),
 35            nn.ReLU(),
 36            nn.Linear(hidden_dim, 1)
 37        )
 38
 39    def forward(self, state):
 40        """μˆœμ „νŒŒ: μ •μ±…κ³Ό κ°€μΉ˜λ₯Ό λ™μ‹œμ— 좜λ ₯"""
 41        features = self.shared(state)
 42        policy_logits = self.actor(features)
 43        value = self.critic(features)
 44        return policy_logits, value
 45
 46    def get_action(self, state):
 47        """행동 μƒ˜ν”Œλ§"""
 48        policy_logits, value = self.forward(state)
 49        policy = F.softmax(policy_logits, dim=-1)
 50        dist = torch.distributions.Categorical(policy)
 51        action = dist.sample()
 52        log_prob = dist.log_prob(action)
 53        entropy = dist.entropy()
 54        return action.item(), log_prob, value, entropy
 55
 56
 57class A2CAgent:
 58    """A2C (Advantage Actor-Critic) μ—μ΄μ „νŠΈ"""
 59
 60    def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99,
 61                 value_coef=0.5, entropy_coef=0.01):
 62        self.network = ActorCriticNetwork(state_dim, action_dim)
 63        self.optimizer = torch.optim.Adam(self.network.parameters(), lr=lr)
 64
 65        self.gamma = gamma
 66        self.value_coef = value_coef  # Critic 손싀 κ°€μ€‘μΉ˜
 67        self.entropy_coef = entropy_coef  # μ—”νŠΈλ‘œν”Ό λ³΄λ„ˆμŠ€ κ°€μ€‘μΉ˜
 68
 69        # μ—ν”Όμ†Œλ“œ 버퍼
 70        self.reset_buffers()
 71
 72    def reset_buffers(self):
 73        """버퍼 μ΄ˆκΈ°ν™”"""
 74        self.log_probs = []
 75        self.values = []
 76        self.rewards = []
 77        self.dones = []
 78        self.entropies = []
 79
 80    def choose_action(self, state):
 81        """행동 선택"""
 82        state_tensor = torch.FloatTensor(state).unsqueeze(0)
 83        action, log_prob, value, entropy = self.network.get_action(state_tensor)
 84
 85        # 버퍼에 μ €μž₯
 86        self.log_probs.append(log_prob)
 87        self.values.append(value)
 88        self.entropies.append(entropy)
 89
 90        return action
 91
 92    def store_transition(self, reward, done):
 93        """전이 μ €μž₯"""
 94        self.rewards.append(reward)
 95        self.dones.append(done)
 96
 97    def compute_returns(self, next_value):
 98        """n-step returns 계산 (λΆ€νŠΈμŠ€νŠΈλž˜ν•‘)"""
 99        returns = []
100        R = next_value
101
102        for reward, done in zip(reversed(self.rewards), reversed(self.dones)):
103            if done:
104                R = 0
105            R = reward + self.gamma * R
106            returns.insert(0, R)
107
108        return torch.tensor(returns, dtype=torch.float32)
109
110    def update(self, next_state):
111        """A2C μ—…λ°μ΄νŠΈ"""
112        if len(self.rewards) == 0:
113            return 0, 0
114
115        # λ‹€μŒ μƒνƒœμ˜ κ°€μΉ˜ (λΆ€νŠΈμŠ€νŠΈλž˜ν•‘μš©)
116        with torch.no_grad():
117            state_tensor = torch.FloatTensor(next_state).unsqueeze(0)
118            _, next_value = self.network(state_tensor)
119            next_value = next_value.item()
120
121        # Returns 계산
122        returns = self.compute_returns(next_value)
123        values = torch.cat(self.values).squeeze()
124        log_probs = torch.stack(self.log_probs)
125        entropies = torch.stack(self.entropies)
126
127        # Advantage 계산: A(s,a) = Q(s,a) - V(s) β‰ˆ R - V(s)
128        advantages = returns - values.detach()
129
130        # 손싀 계산
131        actor_loss = -(log_probs * advantages).mean()  # Policy gradient
132        critic_loss = F.mse_loss(values, returns)  # Value function loss
133        entropy_loss = -entropies.mean()  # 탐색 μž₯λ €
134
135        total_loss = (actor_loss +
136                      self.value_coef * critic_loss +
137                      self.entropy_coef * entropy_loss)
138
139        # κ·Έλž˜λ””μ–ΈνŠΈ μ—…λ°μ΄νŠΈ
140        self.optimizer.zero_grad()
141        total_loss.backward()
142        torch.nn.utils.clip_grad_norm_(self.network.parameters(), max_norm=0.5)
143        self.optimizer.step()
144
145        # 버퍼 μ΄ˆκΈ°ν™”
146        self.reset_buffers()
147
148        return actor_loss.item(), critic_loss.item()
149
150
151class A2CWithGAE(A2CAgent):
152    """GAE (Generalized Advantage Estimation)λ₯Ό μ‚¬μš©ν•˜λŠ” A2C"""
153
154    def __init__(self, *args, gae_lambda=0.95, **kwargs):
155        super().__init__(*args, **kwargs)
156        self.gae_lambda = gae_lambda
157
158    def compute_gae(self, next_value):
159        """GAEλ₯Ό μ‚¬μš©ν•œ Advantage 계산"""
160        values = torch.cat(self.values).squeeze().tolist()
161        values.append(next_value)  # λ§ˆμ§€λ§‰μ— λΆ€νŠΈμŠ€νŠΈλž© κ°€μΉ˜ μΆ”κ°€
162
163        advantages = []
164        gae = 0
165
166        # μ—­λ°©ν–₯으둜 GAE 계산
167        for t in reversed(range(len(self.rewards))):
168            if self.dones[t]:
169                delta = self.rewards[t] - values[t]
170                gae = delta
171            else:
172                delta = self.rewards[t] + self.gamma * values[t + 1] - values[t]
173                gae = delta + self.gamma * self.gae_lambda * gae
174
175            advantages.insert(0, gae)
176
177        advantages = torch.tensor(advantages, dtype=torch.float32)
178        returns = advantages + torch.tensor(values[:-1], dtype=torch.float32)
179
180        return advantages, returns
181
182    def update(self, next_state):
183        """GAEλ₯Ό μ‚¬μš©ν•œ μ—…λ°μ΄νŠΈ"""
184        if len(self.rewards) == 0:
185            return 0, 0
186
187        # λ‹€μŒ μƒνƒœμ˜ κ°€μΉ˜
188        with torch.no_grad():
189            state_tensor = torch.FloatTensor(next_state).unsqueeze(0)
190            _, next_value = self.network(state_tensor)
191            next_value = next_value.item()
192
193        # GAE둜 advantage와 returns 계산
194        advantages, returns = self.compute_gae(next_value)
195
196        values = torch.cat(self.values).squeeze()
197        log_probs = torch.stack(self.log_probs)
198        entropies = torch.stack(self.entropies)
199
200        # 손싀 계산
201        actor_loss = -(log_probs * advantages.detach()).mean()
202        critic_loss = F.mse_loss(values, returns)
203        entropy_loss = -entropies.mean()
204
205        total_loss = (actor_loss +
206                      self.value_coef * critic_loss +
207                      self.entropy_coef * entropy_loss)
208
209        # μ—…λ°μ΄νŠΈ
210        self.optimizer.zero_grad()
211        total_loss.backward()
212        torch.nn.utils.clip_grad_norm_(self.network.parameters(), max_norm=0.5)
213        self.optimizer.step()
214
215        self.reset_buffers()
216
217        return actor_loss.item(), critic_loss.item()
218
219
220def train_a2c(env_name='CartPole-v1', n_episodes=1000, n_steps=5, use_gae=False):
221    """A2C ν•™μŠ΅"""
222    env = gym.make(env_name)
223    state_dim = env.observation_space.shape[0]
224    action_dim = env.action_space.n
225
226    # μ—μ΄μ „νŠΈ 생성
227    if use_gae:
228        agent = A2CWithGAE(state_dim, action_dim, lr=3e-4, gamma=0.99,
229                          value_coef=0.5, entropy_coef=0.01, gae_lambda=0.95)
230        method_name = "A2C with GAE"
231    else:
232        agent = A2CAgent(state_dim, action_dim, lr=3e-4, gamma=0.99,
233                        value_coef=0.5, entropy_coef=0.01)
234        method_name = "A2C"
235
236    print(f"=== {method_name} ν•™μŠ΅ μ‹œμž‘ ({env_name}) ===\n")
237
238    scores = []
239    actor_losses = []
240    critic_losses = []
241
242    for episode in range(n_episodes):
243        state, _ = env.reset()
244        total_reward = 0
245        step_count = 0
246        done = False
247
248        while not done:
249            action = agent.choose_action(state)
250            next_state, reward, terminated, truncated, _ = env.step(action)
251            done = terminated or truncated
252
253            agent.store_transition(reward, done)
254            state = next_state
255            total_reward += reward
256            step_count += 1
257
258            # n-step μ—…λ°μ΄νŠΈ λ˜λŠ” μ—ν”Όμ†Œλ“œ μ’…λ£Œ μ‹œ μ—…λ°μ΄νŠΈ
259            if step_count % n_steps == 0 or done:
260                actor_loss, critic_loss = agent.update(next_state)
261                actor_losses.append(actor_loss)
262                critic_losses.append(critic_loss)
263
264        scores.append(total_reward)
265
266        if (episode + 1) % 50 == 0:
267            avg_score = np.mean(scores[-50:])
268            avg_actor_loss = np.mean(actor_losses[-50:]) if actor_losses else 0
269            avg_critic_loss = np.mean(critic_losses[-50:]) if critic_losses else 0
270            print(f"Episode {episode + 1:4d} | "
271                  f"Avg Score: {avg_score:7.2f} | "
272                  f"Actor Loss: {avg_actor_loss:.4f} | "
273                  f"Critic Loss: {avg_critic_loss:.4f}")
274
275        # CartPole ν•΄κ²° 쑰건: 연속 100 μ—ν”Όμ†Œλ“œ 평균 475 이상
276        if len(scores) >= 100 and np.mean(scores[-100:]) >= 475:
277            print(f"\nν™˜κ²½ ν•΄κ²°! ({episode + 1} μ—ν”Όμ†Œλ“œ)")
278            break
279
280    env.close()
281    return agent, scores, actor_losses, critic_losses
282
283
284def compare_a2c_with_reinforce():
285    """A2C와 REINFORCE 비ꡐ"""
286    print("=== A2C vs REINFORCE 비ꡐ ===\n")
287
288    # A2C ν•™μŠ΅
289    _, a2c_scores, _, _ = train_a2c('CartPole-v1', n_episodes=500, use_gae=False)
290
291    # A2C with GAE ν•™μŠ΅
292    print("\n" + "="*60 + "\n")
293    _, a2c_gae_scores, _, _ = train_a2c('CartPole-v1', n_episodes=500, use_gae=True)
294
295    # ν•™μŠ΅ 곑선 비ꡐ μ‹œκ°ν™”
296    plot_comparison(a2c_scores, a2c_gae_scores)
297
298    return a2c_scores, a2c_gae_scores
299
300
301def plot_comparison(a2c_scores, a2c_gae_scores):
302    """ν•™μŠ΅ 곑선 비ꡐ μ‹œκ°ν™”"""
303    window = 10
304
305    def smooth(data, window):
306        if len(data) < window:
307            return data
308        return np.convolve(data, np.ones(window)/window, mode='valid')
309
310    plt.figure(figsize=(14, 5))
311
312    # 원본 데이터
313    plt.subplot(1, 2, 1)
314    plt.plot(a2c_scores, alpha=0.3, label='A2C (raw)', color='blue')
315    plt.plot(a2c_gae_scores, alpha=0.3, label='A2C+GAE (raw)', color='green')
316    plt.axhline(y=475, color='red', linestyle='--', linewidth=1, label='Solved threshold')
317    plt.xlabel('Episode')
318    plt.ylabel('Episode Reward')
319    plt.title('A2C vs A2C+GAE - Raw Data')
320    plt.legend()
321    plt.grid(True, alpha=0.3)
322
323    # ν‰ν™œν™”λœ 데이터
324    plt.subplot(1, 2, 2)
325    plt.plot(smooth(a2c_scores, window), label='A2C (smoothed)', linewidth=2, color='blue')
326    plt.plot(smooth(a2c_gae_scores, window), label='A2C+GAE (smoothed)', linewidth=2, color='green')
327    plt.axhline(y=475, color='red', linestyle='--', linewidth=1, label='Solved threshold')
328    plt.xlabel('Episode')
329    plt.ylabel('Episode Reward (smoothed)')
330    plt.title(f'A2C vs A2C+GAE - Smoothed (window={window})')
331    plt.legend()
332    plt.grid(True, alpha=0.3)
333
334    plt.tight_layout()
335    plt.savefig('a2c_comparison.png', dpi=150)
336    print("\nν•™μŠ΅ 곑선이 'a2c_comparison.png'둜 μ €μž₯λ˜μ—ˆμŠ΅λ‹ˆλ‹€.")
337
338
339def train_lunarlander():
340    """LunarLander ν™˜κ²½μ—μ„œ A2C ν•™μŠ΅"""
341    try:
342        env = gym.make('LunarLander-v2')
343    except:
344        print("LunarLander-v2 ν™˜κ²½μ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€.")
345        print("μ„€μΉ˜: pip install gymnasium[box2d]")
346        return None, None
347
348    state_dim = env.observation_space.shape[0]
349    action_dim = env.action_space.n
350
351    # GAEλ₯Ό μ‚¬μš©ν•˜λŠ” A2C
352    agent = A2CWithGAE(
353        state_dim, action_dim,
354        lr=7e-4, gamma=0.99,
355        value_coef=0.5, entropy_coef=0.01,
356        gae_lambda=0.95
357    )
358
359    print("=== LunarLander A2C ν•™μŠ΅ μ‹œμž‘ ===\n")
360
361    scores = []
362    n_steps = 5
363    n_episodes = 2000
364
365    for episode in range(n_episodes):
366        state, _ = env.reset()
367        total_reward = 0
368        steps = 0
369
370        while True:
371            action = agent.choose_action(state)
372            next_state, reward, terminated, truncated, _ = env.step(action)
373            done = terminated or truncated
374
375            agent.store_transition(reward, done)
376            state = next_state
377            total_reward += reward
378            steps += 1
379
380            # n-step μ—…λ°μ΄νŠΈ λ˜λŠ” μ’…λ£Œ μ‹œ
381            if steps % n_steps == 0 or done:
382                agent.update(next_state)
383
384            if done:
385                break
386
387        scores.append(total_reward)
388
389        if (episode + 1) % 100 == 0:
390            avg = np.mean(scores[-100:])
391            print(f"Episode {episode + 1:4d} | Avg Score: {avg:.2f}")
392
393            # LunarLander ν•΄κ²° 쑰건: 평균 200 이상
394            if avg >= 200:
395                print(f"\nν™˜κ²½ ν•΄κ²°! ({episode + 1} μ—ν”Όμ†Œλ“œ)")
396                break
397
398    env.close()
399    return agent, scores
400
401
402if __name__ == "__main__":
403    # 1. CartPoleμ—μ„œ A2C vs A2C+GAE 비ꡐ
404    a2c_scores, a2c_gae_scores = compare_a2c_with_reinforce()
405
406    # 2. LunarLander ν•™μŠ΅ (선택적)
407    print("\n" + "="*60)
408    print("LunarLander ν•™μŠ΅μ„ μ‹œμž‘ν•˜λ €λ©΄ 주석을 ν•΄μ œν•˜μ„Έμš”:")
409    print("# agent, scores = train_lunarlander()")
410
411    # ν•™μŠ΅ κ²°κ³Ό μš”μ•½
412    print("\n" + "="*60)
413    print("ν•™μŠ΅ μ™„λ£Œ!")
414    print(f"A2C μ΅œμ’… 100 μ—ν”Όμ†Œλ“œ 평균: {np.mean(a2c_scores[-100:]):.2f}")
415    print(f"A2C+GAE μ΅œμ’… 100 μ—ν”Όμ†Œλ“œ 평균: {np.mean(a2c_gae_scores[-100:]):.2f}")