10. PPO์ TRPO
10. PPO์ TRPO¶
๋์ด๋: โญโญโญโญ (๊ณ ๊ธ)
ํ์ต ๋ชฉํ¶
- ์ ์ฑ ์ ๋ฐ์ดํธ์ ์์ ์ฑ ๋ฌธ์ ์ดํด
- TRPO์ ์ ๋ขฐ ์์ญ ๊ฐ๋ ํ์ต
- PPO์ ํด๋ฆฌํ ๋ฉ์ปค๋์ฆ ์ดํด
- PyTorch๋ก PPO ๊ตฌํ
1. ์ ์ฑ ์ต์ ํ์ ๋ฌธ์ ¶
1.1 ํฐ ์ ๋ฐ์ดํธ์ ์ํ์ฑ¶
์ ์ฑ ๊ฒฝ์ฌ์์ ๋๋ฌด ํฐ ์ ๋ฐ์ดํธ๋ ์ฑ๋ฅ์ ๊ธ๊ฒฉํ ์ ํ์ํฌ ์ ์์ต๋๋ค.
ฮธ_new = ฮธ_old + ฮฑโJ(ฮธ)
๋ฌธ์ : ฮฑ๊ฐ ํฌ๋ฉด ์ ์ฑ
์ด ๊ธ๊ฒฉํ ๋ณํด ํ์ต ๋ถ์์
ํด๊ฒฐ: ์ ์ฑ
๋ณํ๋ฅผ ์ ํ
1.2 ํด๊ฒฐ ๋ฐฉํฅ¶
- TRPO: KL divergence๋ก ์ ๋ขฐ ์์ญ ์ ํ (๋ณต์ก)
- PPO: Clipping์ผ๋ก ๊ฐ๋จํ๊ฒ ์ ํ
2. TRPO (Trust Region Policy Optimization)¶
2.1 ๋ชฉํ ํจ์¶
์ ์ ์ฑ ๊ณผ ์ด์ ์ ์ฑ ์ ๋น์จ์ ์ฌ์ฉ:
$$L^{CPI}(\theta) = \mathbb{E}\left[\frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} A^{\pi_{old}}(s, a)\right]$$
2.2 KL Divergence ์ ์ฝ¶
$$\text{maximize}_\theta \quad L^{CPI}(\theta)$$ $$\text{subject to} \quad \mathbb{E}[D_{KL}(\pi_{\theta_{old}} || \pi_\theta)] \leq \delta$$
2.3 TRPO์ ๋ฌธ์ ์ ¶
- 2์ฐจ ๋ฏธ๋ถ(Hessian) ๊ณ์ฐ ํ์
- Conjugate gradient ์๊ณ ๋ฆฌ์ฆ ํ์
- ๊ตฌํ์ด ๋ณต์กํ๊ณ ๊ณ์ฐ ๋น์ฉ์ด ๋์
3. PPO (Proximal Policy Optimization)¶
3.1 ํต์ฌ ์์ด๋์ด¶
Clipping์ ์ฌ์ฉํ์ฌ ์ ์ฑ ๋น์จ์ ์ ํํฉ๋๋ค.
$$r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}$$
3.2 Clipped ๋ชฉํ ํจ์¶
$$L^{CLIP}(\theta) = \mathbb{E}\left[\min(r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t)\right]$$
def compute_ppo_loss(ratio, advantage, clip_epsilon=0.2):
"""PPO Clipped ์์ค"""
# ํด๋ฆฌํ๋ ๋น์จ
clipped_ratio = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon)
# ๋ ํญ ์ค ์์ ๊ฐ ์ ํ
loss1 = ratio * advantage
loss2 = clipped_ratio * advantage
return -torch.min(loss1, loss2).mean()
3.3 Clipping ์ง๊ด¶
Advantage > 0 (์ข์ ํ๋):
- ratio ์ฆ๊ฐ โ ํ๋ฅ ์ฆ๊ฐ
- ๋จ, ratio > 1+ฮต ์ด์์ ๋ฌด์ (๊ธ๊ฒฉํ ์ฆ๊ฐ ๋ฐฉ์ง)
Advantage < 0 (๋์ ํ๋):
- ratio ๊ฐ์ โ ํ๋ฅ ๊ฐ์
- ๋จ, ratio < 1-ฮต ์ดํ๋ ๋ฌด์ (๊ธ๊ฒฉํ ๊ฐ์ ๋ฐฉ์ง)
4. PPO ์ ์ฒด ๊ตฌํ¶
4.1 PPO ์์ด์ ํธ¶
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class PPONetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=64):
super().__init__()
# Actor
self.actor = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, action_dim),
nn.Softmax(dim=-1)
)
# Critic
self.critic = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1)
)
def forward(self, state):
return self.actor(state), self.critic(state)
def get_action(self, state, action=None):
probs, value = self.forward(state)
dist = torch.distributions.Categorical(probs)
if action is None:
action = dist.sample()
return action, dist.log_prob(action), dist.entropy(), value
class PPOAgent:
def __init__(
self,
state_dim,
action_dim,
lr=3e-4,
gamma=0.99,
gae_lambda=0.95,
clip_epsilon=0.2,
value_coef=0.5,
entropy_coef=0.01,
max_grad_norm=0.5,
update_epochs=10,
batch_size=64
):
self.network = PPONetwork(state_dim, action_dim)
self.optimizer = torch.optim.Adam(self.network.parameters(), lr=lr)
self.gamma = gamma
self.gae_lambda = gae_lambda
self.clip_epsilon = clip_epsilon
self.value_coef = value_coef
self.entropy_coef = entropy_coef
self.max_grad_norm = max_grad_norm
self.update_epochs = update_epochs
self.batch_size = batch_size
def collect_rollouts(self, env, n_steps):
"""๊ฒฝํ ์์ง"""
states, actions, rewards, dones = [], [], [], []
values, log_probs = [], []
state, _ = env.reset()
for _ in range(n_steps):
state_tensor = torch.FloatTensor(state).unsqueeze(0)
with torch.no_grad():
action, log_prob, _, value = self.network.get_action(state_tensor)
next_state, reward, terminated, truncated, _ = env.step(action.item())
done = terminated or truncated
states.append(state)
actions.append(action.item())
rewards.append(reward)
dones.append(done)
values.append(value.item())
log_probs.append(log_prob.item())
state = next_state if not done else env.reset()[0]
# ๋ง์ง๋ง ์ํ์ ๊ฐ์น
with torch.no_grad():
_, _, _, last_value = self.network.get_action(
torch.FloatTensor(state).unsqueeze(0)
)
return {
'states': np.array(states),
'actions': np.array(actions),
'rewards': np.array(rewards),
'dones': np.array(dones),
'values': np.array(values),
'log_probs': np.array(log_probs),
'last_value': last_value.item()
}
def compute_gae(self, rollout):
"""GAE ๊ณ์ฐ"""
rewards = rollout['rewards']
values = rollout['values']
dones = rollout['dones']
last_value = rollout['last_value']
advantages = np.zeros_like(rewards)
last_gae = 0
for t in reversed(range(len(rewards))):
if t == len(rewards) - 1:
next_value = last_value
else:
next_value = values[t + 1]
next_non_terminal = 1.0 - dones[t]
delta = rewards[t] + self.gamma * next_value * next_non_terminal - values[t]
advantages[t] = last_gae = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae
returns = advantages + values
return advantages, returns
def update(self, rollout):
"""PPO ์
๋ฐ์ดํธ"""
advantages, returns = self.compute_gae(rollout)
# ์ ๊ทํ
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# ํ
์ ๋ณํ
states = torch.FloatTensor(rollout['states'])
actions = torch.LongTensor(rollout['actions'])
old_log_probs = torch.FloatTensor(rollout['log_probs'])
returns = torch.FloatTensor(returns)
advantages = torch.FloatTensor(advantages)
# ์ฌ๋ฌ ์ํญ ์
๋ฐ์ดํธ
for _ in range(self.update_epochs):
# ๋ฏธ๋๋ฐฐ์น ์์ฑ
indices = np.random.permutation(len(states))
for start in range(0, len(states), self.batch_size):
end = start + self.batch_size
batch_indices = indices[start:end]
batch_states = states[batch_indices]
batch_actions = actions[batch_indices]
batch_old_log_probs = old_log_probs[batch_indices]
batch_returns = returns[batch_indices]
batch_advantages = advantages[batch_indices]
# ํ์ฌ ์ ์ฑ
์ผ๋ก ํ๊ฐ
_, new_log_probs, entropy, values = self.network.get_action(
batch_states, batch_actions
)
# ๋น์จ ๊ณ์ฐ
ratio = torch.exp(new_log_probs - batch_old_log_probs)
# Clipped ์์ค
surr1 = ratio * batch_advantages
surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * batch_advantages
actor_loss = -torch.min(surr1, surr2).mean()
# ๊ฐ์น ์์ค
critic_loss = F.mse_loss(values.squeeze(), batch_returns)
# ์ํธ๋กํผ ๋ณด๋์ค
entropy_loss = -entropy.mean()
# ์ด ์์ค
loss = actor_loss + self.value_coef * critic_loss + self.entropy_coef * entropy_loss
# ์
๋ฐ์ดํธ
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.network.parameters(), self.max_grad_norm)
self.optimizer.step()
return actor_loss.item(), critic_loss.item()
4.2 PPO ํ์ต ๋ฃจํ¶
import gymnasium as gym
def train_ppo(env_name='CartPole-v1', total_timesteps=100000, n_steps=2048):
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = PPOAgent(state_dim, action_dim)
timesteps = 0
episode_rewards = []
current_episode_reward = 0
while timesteps < total_timesteps:
# ๋กค์์ ์์ง
rollout = agent.collect_rollouts(env, n_steps)
timesteps += n_steps
# ์ํผ์๋ ๋ณด์ ์ถ์
for r, d in zip(rollout['rewards'], rollout['dones']):
current_episode_reward += r
if d:
episode_rewards.append(current_episode_reward)
current_episode_reward = 0
# PPO ์
๋ฐ์ดํธ
actor_loss, critic_loss = agent.update(rollout)
# ๋ก๊น
if len(episode_rewards) > 0 and timesteps % 10000 < n_steps:
avg_reward = np.mean(episode_rewards[-10:]) if len(episode_rewards) >= 10 else np.mean(episode_rewards)
print(f"Timesteps: {timesteps}, Avg Reward: {avg_reward:.2f}")
return agent, episode_rewards
5. PPO ๋ณํ๋ค¶
5.1 PPO-Clip (๊ธฐ๋ณธ)¶
์์์ ๊ตฌํํ ๋ฐฉ์์ ๋๋ค.
5.2 PPO-Penalty¶
KL divergence๋ฅผ ํ๋ํฐ๋ก ์ถ๊ฐ:
def ppo_penalty_loss(ratio, advantage, old_probs, new_probs, beta=0.01):
policy_loss = (ratio * advantage).mean()
kl_div = F.kl_div(new_probs.log(), old_probs, reduction='batchmean')
return -policy_loss + beta * kl_div
5.3 Clipped Value Loss¶
๊ฐ์น ํจ์์๋ ํด๋ฆฌํ ์ ์ฉ:
def clipped_value_loss(values, old_values, returns, clip_epsilon=0.2):
# ํด๋ฆฌํ๋ ๊ฐ์น
clipped_values = old_values + torch.clamp(
values - old_values, -clip_epsilon, clip_epsilon
)
# ๋ ์์ค ์ค ํฐ ๊ฐ
loss1 = (values - returns) ** 2
loss2 = (clipped_values - returns) ** 2
return 0.5 * torch.max(loss1, loss2).mean()
6. ์ฐ์ ํ๋ ๊ณต๊ฐ PPO¶
class ContinuousPPONetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=64):
super().__init__()
# Actor
self.actor_mean = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, action_dim)
)
self.actor_log_std = nn.Parameter(torch.zeros(action_dim))
# Critic
self.critic = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1)
)
def forward(self, state):
mean = self.actor_mean(state)
std = self.actor_log_std.exp()
value = self.critic(state)
return mean, std, value
def get_action(self, state, action=None):
mean, std, value = self.forward(state)
dist = torch.distributions.Normal(mean, std)
if action is None:
action = dist.sample()
log_prob = dist.log_prob(action).sum(-1)
entropy = dist.entropy().sum(-1)
return action, log_prob, entropy, value
7. ํ์ดํผํ๋ผ๋ฏธํฐ ๊ฐ์ด๋¶
7.1 ์ผ๋ฐ์ ์ธ ์ค์ ¶
config = {
# ํ์ต
'lr': 3e-4, # ํ์ต๋ฅ
'gamma': 0.99, # ํ ์ธ์จ
'gae_lambda': 0.95, # GAE lambda
# PPO ํน์
'clip_epsilon': 0.2, # ํด๋ฆฌํ ๋ฒ์
'update_epochs': 10, # ์
๋ฐ์ดํธ ๋ฐ๋ณต
'batch_size': 64, # ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ
# ์์ค ๊ณ์
'value_coef': 0.5, # ๊ฐ์น ์์ค ๊ณ์
'entropy_coef': 0.01, # ์ํธ๋กํผ ๊ณ์
# ๋กค์์
'n_steps': 2048, # ๋กค์์ ๊ธธ์ด
'n_envs': 8, # ๋ณ๋ ฌ ํ๊ฒฝ ์
# ์์ ํ
'max_grad_norm': 0.5, # ๊ทธ๋๋์ธํธ ํด๋ฆฌํ
}
7.2 ํ๊ฒฝ๋ณ ํ๋¶
| ํ๊ฒฝ | lr | n_steps | clip_epsilon |
|---|---|---|---|
| CartPole | 3e-4 | 128 | 0.2 |
| LunarLander | 3e-4 | 2048 | 0.2 |
| Atari | 2.5e-4 | 128 | 0.1 |
| MuJoCo | 3e-4 | 2048 | 0.2 |
8. PPO vs ๋ค๋ฅธ ์๊ณ ๋ฆฌ์ฆ¶
| ์๊ณ ๋ฆฌ์ฆ | ๋ณต์ก๋ | ์ํ ํจ์จ | ์์ ์ฑ |
|---|---|---|---|
| REINFORCE | ๋ฎ์ | ๋ฎ์ | ๋ฎ์ |
| A2C | ์ค๊ฐ | ์ค๊ฐ | ์ค๊ฐ |
| TRPO | ๋์ | ๋์ | ๋์ |
| PPO | ์ค๊ฐ | ๋์ | ๋์ |
| SAC | ์ค๊ฐ | ๋์ | ๋์ |
PPO์ ์ฅ์ : - TRPO ์์ค์ ์ฑ๋ฅ, ๊ตฌํ์ ๊ฐ๋จ - ๋ค์ํ ํ๊ฒฝ์์ ์์ ์ - ํ์ดํผํ๋ผ๋ฏธํฐ ๋ฏผ๊ฐ๋ ๋ฎ์
์์ฝ¶
PPO ํต์ฌ:
L^{CLIP} = E[min(r(ฮธ)A, clip(r(ฮธ), 1-ฮต, 1+ฮต)A)]
r(ฮธ) = ฯ_ฮธ(a|s) / ฯ_ฮธ_old(a|s) # ์ ์ฑ
๋น์จ
ํด๋ฆฌํ ํจ๊ณผ: - ์ ์ฑ ๋ณํ๋ฅผ [1-ฮต, 1+ฮต] ๋ฒ์๋ก ์ ํ - ๊ธ๊ฒฉํ ์ ๋ฐ์ดํธ ๋ฐฉ์ง - ํ์ต ์์ ์ฑ ํ๋ณด
๋ค์ ๋จ๊ณ¶
- 11_Multi_Agent_RL.md - ๋ค์ค ์์ด์ ํธ ๊ฐํํ์ต