1"""
2PPO (Proximal Policy Optimization) 구현
3"""
4import torch
5import torch.nn as nn
6import torch.nn.functional as F
7import numpy as np
8import gymnasium as gym
9
10
11class ActorCritic(nn.Module):
12 def __init__(self, state_dim, action_dim, hidden_dim=64):
13 super().__init__()
14 self.shared = nn.Sequential(
15 nn.Linear(state_dim, hidden_dim),
16 nn.Tanh(),
17 nn.Linear(hidden_dim, hidden_dim),
18 nn.Tanh()
19 )
20 self.actor = nn.Linear(hidden_dim, action_dim)
21 self.critic = nn.Linear(hidden_dim, 1)
22
23 def forward(self, state):
24 features = self.shared(state)
25 return F.softmax(self.actor(features), dim=-1), self.critic(features)
26
27 def get_action_and_value(self, state, action=None):
28 probs, value = self.forward(state)
29 dist = torch.distributions.Categorical(probs)
30 if action is None:
31 action = dist.sample()
32 return action, dist.log_prob(action), dist.entropy(), value.squeeze(-1)
33
34
35class PPO:
36 def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99,
37 gae_lambda=0.95, clip_epsilon=0.2, epochs=10, batch_size=64):
38 self.network = ActorCritic(state_dim, action_dim)
39 self.optimizer = torch.optim.Adam(self.network.parameters(), lr=lr)
40
41 self.gamma = gamma
42 self.gae_lambda = gae_lambda
43 self.clip_epsilon = clip_epsilon
44 self.epochs = epochs
45 self.batch_size = batch_size
46
47 def collect_rollout(self, env, n_steps):
48 obs_buf, act_buf, rew_buf, done_buf, val_buf, logp_buf = [], [], [], [], [], []
49 obs, _ = env.reset()
50
51 for _ in range(n_steps):
52 obs_tensor = torch.FloatTensor(obs).unsqueeze(0)
53 with torch.no_grad():
54 action, logp, _, value = self.network.get_action_and_value(obs_tensor)
55
56 next_obs, reward, terminated, truncated, _ = env.step(action.item())
57 done = terminated or truncated
58
59 obs_buf.append(obs)
60 act_buf.append(action.item())
61 rew_buf.append(reward)
62 done_buf.append(done)
63 val_buf.append(value.item())
64 logp_buf.append(logp.item())
65
66 obs = next_obs if not done else env.reset()[0]
67
68 with torch.no_grad():
69 _, _, _, last_value = self.network.get_action_and_value(
70 torch.FloatTensor(obs).unsqueeze(0)
71 )
72
73 return {
74 'obs': np.array(obs_buf), 'actions': np.array(act_buf),
75 'rewards': np.array(rew_buf), 'dones': np.array(done_buf),
76 'values': np.array(val_buf), 'log_probs': np.array(logp_buf),
77 'last_value': last_value.item()
78 }
79
80 def compute_gae(self, rollout):
81 rewards, values, dones = rollout['rewards'], rollout['values'], rollout['dones']
82 advantages = np.zeros_like(rewards)
83 last_gae = 0
84
85 for t in reversed(range(len(rewards))):
86 next_val = rollout['last_value'] if t == len(rewards) - 1 else values[t + 1]
87 delta = rewards[t] + self.gamma * next_val * (1 - dones[t]) - values[t]
88 advantages[t] = last_gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * last_gae
89
90 return advantages, advantages + values
91
92 def update(self, rollout):
93 advantages, returns = self.compute_gae(rollout)
94 advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
95
96 obs = torch.FloatTensor(rollout['obs'])
97 actions = torch.LongTensor(rollout['actions'])
98 old_logp = torch.FloatTensor(rollout['log_probs'])
99 advantages = torch.FloatTensor(advantages)
100 returns = torch.FloatTensor(returns)
101
102 for _ in range(self.epochs):
103 indices = np.random.permutation(len(obs))
104 for start in range(0, len(obs), self.batch_size):
105 idx = indices[start:start + self.batch_size]
106
107 _, new_logp, entropy, values = self.network.get_action_and_value(
108 obs[idx], actions[idx]
109 )
110
111 ratio = torch.exp(new_logp - old_logp[idx])
112 surr1 = ratio * advantages[idx]
113 surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages[idx]
114
115 actor_loss = -torch.min(surr1, surr2).mean()
116 critic_loss = F.mse_loss(values, returns[idx])
117 loss = actor_loss + 0.5 * critic_loss - 0.01 * entropy.mean()
118
119 self.optimizer.zero_grad()
120 loss.backward()
121 nn.utils.clip_grad_norm_(self.network.parameters(), 0.5)
122 self.optimizer.step()
123
124
125def train():
126 env = gym.make('CartPole-v1')
127 agent = PPO(
128 state_dim=env.observation_space.shape[0],
129 action_dim=env.action_space.n
130 )
131
132 n_steps = 128
133 timesteps = 0
134 episode_rewards = []
135 current_reward = 0
136
137 while timesteps < 50000:
138 rollout = agent.collect_rollout(env, n_steps)
139 timesteps += n_steps
140
141 for r, d in zip(rollout['rewards'], rollout['dones']):
142 current_reward += r
143 if d:
144 episode_rewards.append(current_reward)
145 current_reward = 0
146
147 agent.update(rollout)
148
149 if len(episode_rewards) > 0 and timesteps % 5000 < n_steps:
150 print(f"Timesteps: {timesteps}, Avg: {np.mean(episode_rewards[-10:]):.2f}")
151
152 env.close()
153 return agent, episode_rewards
154
155
156if __name__ == "__main__":
157 agent, rewards = train()