10_ppo.py

Download
python 158 lines 5.4 KB
  1"""
  2PPO (Proximal Policy Optimization) 구현
  3"""
  4import torch
  5import torch.nn as nn
  6import torch.nn.functional as F
  7import numpy as np
  8import gymnasium as gym
  9
 10
 11class ActorCritic(nn.Module):
 12    def __init__(self, state_dim, action_dim, hidden_dim=64):
 13        super().__init__()
 14        self.shared = nn.Sequential(
 15            nn.Linear(state_dim, hidden_dim),
 16            nn.Tanh(),
 17            nn.Linear(hidden_dim, hidden_dim),
 18            nn.Tanh()
 19        )
 20        self.actor = nn.Linear(hidden_dim, action_dim)
 21        self.critic = nn.Linear(hidden_dim, 1)
 22
 23    def forward(self, state):
 24        features = self.shared(state)
 25        return F.softmax(self.actor(features), dim=-1), self.critic(features)
 26
 27    def get_action_and_value(self, state, action=None):
 28        probs, value = self.forward(state)
 29        dist = torch.distributions.Categorical(probs)
 30        if action is None:
 31            action = dist.sample()
 32        return action, dist.log_prob(action), dist.entropy(), value.squeeze(-1)
 33
 34
 35class PPO:
 36    def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99,
 37                 gae_lambda=0.95, clip_epsilon=0.2, epochs=10, batch_size=64):
 38        self.network = ActorCritic(state_dim, action_dim)
 39        self.optimizer = torch.optim.Adam(self.network.parameters(), lr=lr)
 40
 41        self.gamma = gamma
 42        self.gae_lambda = gae_lambda
 43        self.clip_epsilon = clip_epsilon
 44        self.epochs = epochs
 45        self.batch_size = batch_size
 46
 47    def collect_rollout(self, env, n_steps):
 48        obs_buf, act_buf, rew_buf, done_buf, val_buf, logp_buf = [], [], [], [], [], []
 49        obs, _ = env.reset()
 50
 51        for _ in range(n_steps):
 52            obs_tensor = torch.FloatTensor(obs).unsqueeze(0)
 53            with torch.no_grad():
 54                action, logp, _, value = self.network.get_action_and_value(obs_tensor)
 55
 56            next_obs, reward, terminated, truncated, _ = env.step(action.item())
 57            done = terminated or truncated
 58
 59            obs_buf.append(obs)
 60            act_buf.append(action.item())
 61            rew_buf.append(reward)
 62            done_buf.append(done)
 63            val_buf.append(value.item())
 64            logp_buf.append(logp.item())
 65
 66            obs = next_obs if not done else env.reset()[0]
 67
 68        with torch.no_grad():
 69            _, _, _, last_value = self.network.get_action_and_value(
 70                torch.FloatTensor(obs).unsqueeze(0)
 71            )
 72
 73        return {
 74            'obs': np.array(obs_buf), 'actions': np.array(act_buf),
 75            'rewards': np.array(rew_buf), 'dones': np.array(done_buf),
 76            'values': np.array(val_buf), 'log_probs': np.array(logp_buf),
 77            'last_value': last_value.item()
 78        }
 79
 80    def compute_gae(self, rollout):
 81        rewards, values, dones = rollout['rewards'], rollout['values'], rollout['dones']
 82        advantages = np.zeros_like(rewards)
 83        last_gae = 0
 84
 85        for t in reversed(range(len(rewards))):
 86            next_val = rollout['last_value'] if t == len(rewards) - 1 else values[t + 1]
 87            delta = rewards[t] + self.gamma * next_val * (1 - dones[t]) - values[t]
 88            advantages[t] = last_gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * last_gae
 89
 90        return advantages, advantages + values
 91
 92    def update(self, rollout):
 93        advantages, returns = self.compute_gae(rollout)
 94        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
 95
 96        obs = torch.FloatTensor(rollout['obs'])
 97        actions = torch.LongTensor(rollout['actions'])
 98        old_logp = torch.FloatTensor(rollout['log_probs'])
 99        advantages = torch.FloatTensor(advantages)
100        returns = torch.FloatTensor(returns)
101
102        for _ in range(self.epochs):
103            indices = np.random.permutation(len(obs))
104            for start in range(0, len(obs), self.batch_size):
105                idx = indices[start:start + self.batch_size]
106
107                _, new_logp, entropy, values = self.network.get_action_and_value(
108                    obs[idx], actions[idx]
109                )
110
111                ratio = torch.exp(new_logp - old_logp[idx])
112                surr1 = ratio * advantages[idx]
113                surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages[idx]
114
115                actor_loss = -torch.min(surr1, surr2).mean()
116                critic_loss = F.mse_loss(values, returns[idx])
117                loss = actor_loss + 0.5 * critic_loss - 0.01 * entropy.mean()
118
119                self.optimizer.zero_grad()
120                loss.backward()
121                nn.utils.clip_grad_norm_(self.network.parameters(), 0.5)
122                self.optimizer.step()
123
124
125def train():
126    env = gym.make('CartPole-v1')
127    agent = PPO(
128        state_dim=env.observation_space.shape[0],
129        action_dim=env.action_space.n
130    )
131
132    n_steps = 128
133    timesteps = 0
134    episode_rewards = []
135    current_reward = 0
136
137    while timesteps < 50000:
138        rollout = agent.collect_rollout(env, n_steps)
139        timesteps += n_steps
140
141        for r, d in zip(rollout['rewards'], rollout['dones']):
142            current_reward += r
143            if d:
144                episode_rewards.append(current_reward)
145                current_reward = 0
146
147        agent.update(rollout)
148
149        if len(episode_rewards) > 0 and timesteps % 5000 < n_steps:
150            print(f"Timesteps: {timesteps}, Avg: {np.mean(episode_rewards[-10:]):.2f}")
151
152    env.close()
153    return agent, episode_rewards
154
155
156if __name__ == "__main__":
157    agent, rewards = train()