10. PPO์™€ TRPO

10. PPO์™€ TRPO

๋‚œ์ด๋„: โญโญโญโญ (๊ณ ๊ธ‰)

ํ•™์Šต ๋ชฉํ‘œ

  • ์ •์ฑ… ์—…๋ฐ์ดํŠธ์˜ ์•ˆ์ •์„ฑ ๋ฌธ์ œ ์ดํ•ด
  • TRPO์˜ ์‹ ๋ขฐ ์˜์—ญ ๊ฐœ๋… ํ•™์Šต
  • PPO์˜ ํด๋ฆฌํ•‘ ๋ฉ”์ปค๋‹ˆ์ฆ˜ ์ดํ•ด
  • PyTorch๋กœ PPO ๊ตฌํ˜„

1. ์ •์ฑ… ์ตœ์ ํ™”์˜ ๋ฌธ์ œ

1.1 ํฐ ์—…๋ฐ์ดํŠธ์˜ ์œ„ํ—˜์„ฑ

์ •์ฑ… ๊ฒฝ์‚ฌ์—์„œ ๋„ˆ๋ฌด ํฐ ์—…๋ฐ์ดํŠธ๋Š” ์„ฑ๋Šฅ์„ ๊ธ‰๊ฒฉํžˆ ์ €ํ•˜์‹œํ‚ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

ฮธ_new = ฮธ_old + ฮฑโˆ‡J(ฮธ)

๋ฌธ์ œ: ฮฑ๊ฐ€ ํฌ๋ฉด ์ •์ฑ…์ด ๊ธ‰๊ฒฉํžˆ ๋ณ€ํ•ด ํ•™์Šต ๋ถˆ์•ˆ์ •
ํ•ด๊ฒฐ: ์ •์ฑ… ๋ณ€ํ™”๋ฅผ ์ œํ•œ

1.2 ํ•ด๊ฒฐ ๋ฐฉํ–ฅ

  • TRPO: KL divergence๋กœ ์‹ ๋ขฐ ์˜์—ญ ์ œํ•œ (๋ณต์žก)
  • PPO: Clipping์œผ๋กœ ๊ฐ„๋‹จํ•˜๊ฒŒ ์ œํ•œ

2. TRPO (Trust Region Policy Optimization)

2.1 ๋ชฉํ‘œ ํ•จ์ˆ˜

์ƒˆ ์ •์ฑ…๊ณผ ์ด์ „ ์ •์ฑ…์˜ ๋น„์œจ์„ ์‚ฌ์šฉ:

$$L^{CPI}(\theta) = \mathbb{E}\left[\frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} A^{\pi_{old}}(s, a)\right]$$

2.2 KL Divergence ์ œ์•ฝ

$$\text{maximize}_\theta \quad L^{CPI}(\theta)$$ $$\text{subject to} \quad \mathbb{E}[D_{KL}(\pi_{\theta_{old}} || \pi_\theta)] \leq \delta$$

2.3 TRPO์˜ ๋ฌธ์ œ์ 

  • 2์ฐจ ๋ฏธ๋ถ„(Hessian) ๊ณ„์‚ฐ ํ•„์š”
  • Conjugate gradient ์•Œ๊ณ ๋ฆฌ์ฆ˜ ํ•„์š”
  • ๊ตฌํ˜„์ด ๋ณต์žกํ•˜๊ณ  ๊ณ„์‚ฐ ๋น„์šฉ์ด ๋†’์Œ

3. PPO (Proximal Policy Optimization)

3.1 ํ•ต์‹ฌ ์•„์ด๋””์–ด

Clipping์„ ์‚ฌ์šฉํ•˜์—ฌ ์ •์ฑ… ๋น„์œจ์„ ์ œํ•œํ•ฉ๋‹ˆ๋‹ค.

$$r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}$$

3.2 Clipped ๋ชฉํ‘œ ํ•จ์ˆ˜

$$L^{CLIP}(\theta) = \mathbb{E}\left[\min(r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t)\right]$$

def compute_ppo_loss(ratio, advantage, clip_epsilon=0.2):
    """PPO Clipped ์†์‹ค"""
    # ํด๋ฆฌํ•‘๋œ ๋น„์œจ
    clipped_ratio = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon)

    # ๋‘ ํ•ญ ์ค‘ ์ž‘์€ ๊ฐ’ ์„ ํƒ
    loss1 = ratio * advantage
    loss2 = clipped_ratio * advantage

    return -torch.min(loss1, loss2).mean()

3.3 Clipping ์ง๊ด€

Advantage > 0 (์ข‹์€ ํ–‰๋™):
- ratio ์ฆ๊ฐ€ โ†’ ํ™•๋ฅ  ์ฆ๊ฐ€
- ๋‹จ, ratio > 1+ฮต ์ด์ƒ์€ ๋ฌด์‹œ (๊ธ‰๊ฒฉํ•œ ์ฆ๊ฐ€ ๋ฐฉ์ง€)

Advantage < 0 (๋‚˜์œ ํ–‰๋™):
- ratio ๊ฐ์†Œ โ†’ ํ™•๋ฅ  ๊ฐ์†Œ
- ๋‹จ, ratio < 1-ฮต ์ดํ•˜๋Š” ๋ฌด์‹œ (๊ธ‰๊ฒฉํ•œ ๊ฐ์†Œ ๋ฐฉ์ง€)

4. PPO ์ „์ฒด ๊ตฌํ˜„

4.1 PPO ์—์ด์ „ํŠธ

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class PPONetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super().__init__()

        # Actor
        self.actor = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)
        )

        # Critic
        self.critic = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, state):
        return self.actor(state), self.critic(state)

    def get_action(self, state, action=None):
        probs, value = self.forward(state)
        dist = torch.distributions.Categorical(probs)

        if action is None:
            action = dist.sample()

        return action, dist.log_prob(action), dist.entropy(), value


class PPOAgent:
    def __init__(
        self,
        state_dim,
        action_dim,
        lr=3e-4,
        gamma=0.99,
        gae_lambda=0.95,
        clip_epsilon=0.2,
        value_coef=0.5,
        entropy_coef=0.01,
        max_grad_norm=0.5,
        update_epochs=10,
        batch_size=64
    ):
        self.network = PPONetwork(state_dim, action_dim)
        self.optimizer = torch.optim.Adam(self.network.parameters(), lr=lr)

        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.clip_epsilon = clip_epsilon
        self.value_coef = value_coef
        self.entropy_coef = entropy_coef
        self.max_grad_norm = max_grad_norm
        self.update_epochs = update_epochs
        self.batch_size = batch_size

    def collect_rollouts(self, env, n_steps):
        """๊ฒฝํ—˜ ์ˆ˜์ง‘"""
        states, actions, rewards, dones = [], [], [], []
        values, log_probs = [], []

        state, _ = env.reset()

        for _ in range(n_steps):
            state_tensor = torch.FloatTensor(state).unsqueeze(0)

            with torch.no_grad():
                action, log_prob, _, value = self.network.get_action(state_tensor)

            next_state, reward, terminated, truncated, _ = env.step(action.item())
            done = terminated or truncated

            states.append(state)
            actions.append(action.item())
            rewards.append(reward)
            dones.append(done)
            values.append(value.item())
            log_probs.append(log_prob.item())

            state = next_state if not done else env.reset()[0]

        # ๋งˆ์ง€๋ง‰ ์ƒํƒœ์˜ ๊ฐ€์น˜
        with torch.no_grad():
            _, _, _, last_value = self.network.get_action(
                torch.FloatTensor(state).unsqueeze(0)
            )

        return {
            'states': np.array(states),
            'actions': np.array(actions),
            'rewards': np.array(rewards),
            'dones': np.array(dones),
            'values': np.array(values),
            'log_probs': np.array(log_probs),
            'last_value': last_value.item()
        }

    def compute_gae(self, rollout):
        """GAE ๊ณ„์‚ฐ"""
        rewards = rollout['rewards']
        values = rollout['values']
        dones = rollout['dones']
        last_value = rollout['last_value']

        advantages = np.zeros_like(rewards)
        last_gae = 0

        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_value = last_value
            else:
                next_value = values[t + 1]

            next_non_terminal = 1.0 - dones[t]
            delta = rewards[t] + self.gamma * next_value * next_non_terminal - values[t]
            advantages[t] = last_gae = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae

        returns = advantages + values
        return advantages, returns

    def update(self, rollout):
        """PPO ์—…๋ฐ์ดํŠธ"""
        advantages, returns = self.compute_gae(rollout)

        # ์ •๊ทœํ™”
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # ํ…์„œ ๋ณ€ํ™˜
        states = torch.FloatTensor(rollout['states'])
        actions = torch.LongTensor(rollout['actions'])
        old_log_probs = torch.FloatTensor(rollout['log_probs'])
        returns = torch.FloatTensor(returns)
        advantages = torch.FloatTensor(advantages)

        # ์—ฌ๋Ÿฌ ์—ํญ ์—…๋ฐ์ดํŠธ
        for _ in range(self.update_epochs):
            # ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ์ƒ์„ฑ
            indices = np.random.permutation(len(states))

            for start in range(0, len(states), self.batch_size):
                end = start + self.batch_size
                batch_indices = indices[start:end]

                batch_states = states[batch_indices]
                batch_actions = actions[batch_indices]
                batch_old_log_probs = old_log_probs[batch_indices]
                batch_returns = returns[batch_indices]
                batch_advantages = advantages[batch_indices]

                # ํ˜„์žฌ ์ •์ฑ…์œผ๋กœ ํ‰๊ฐ€
                _, new_log_probs, entropy, values = self.network.get_action(
                    batch_states, batch_actions
                )

                # ๋น„์œจ ๊ณ„์‚ฐ
                ratio = torch.exp(new_log_probs - batch_old_log_probs)

                # Clipped ์†์‹ค
                surr1 = ratio * batch_advantages
                surr2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * batch_advantages
                actor_loss = -torch.min(surr1, surr2).mean()

                # ๊ฐ€์น˜ ์†์‹ค
                critic_loss = F.mse_loss(values.squeeze(), batch_returns)

                # ์—”ํŠธ๋กœํ”ผ ๋ณด๋„ˆ์Šค
                entropy_loss = -entropy.mean()

                # ์ด ์†์‹ค
                loss = actor_loss + self.value_coef * critic_loss + self.entropy_coef * entropy_loss

                # ์—…๋ฐ์ดํŠธ
                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.network.parameters(), self.max_grad_norm)
                self.optimizer.step()

        return actor_loss.item(), critic_loss.item()

4.2 PPO ํ•™์Šต ๋ฃจํ”„

import gymnasium as gym

def train_ppo(env_name='CartPole-v1', total_timesteps=100000, n_steps=2048):
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    agent = PPOAgent(state_dim, action_dim)

    timesteps = 0
    episode_rewards = []
    current_episode_reward = 0

    while timesteps < total_timesteps:
        # ๋กค์•„์›ƒ ์ˆ˜์ง‘
        rollout = agent.collect_rollouts(env, n_steps)
        timesteps += n_steps

        # ์—ํ”ผ์†Œ๋“œ ๋ณด์ƒ ์ถ”์ 
        for r, d in zip(rollout['rewards'], rollout['dones']):
            current_episode_reward += r
            if d:
                episode_rewards.append(current_episode_reward)
                current_episode_reward = 0

        # PPO ์—…๋ฐ์ดํŠธ
        actor_loss, critic_loss = agent.update(rollout)

        # ๋กœ๊น…
        if len(episode_rewards) > 0 and timesteps % 10000 < n_steps:
            avg_reward = np.mean(episode_rewards[-10:]) if len(episode_rewards) >= 10 else np.mean(episode_rewards)
            print(f"Timesteps: {timesteps}, Avg Reward: {avg_reward:.2f}")

    return agent, episode_rewards

5. PPO ๋ณ€ํ˜•๋“ค

5.1 PPO-Clip (๊ธฐ๋ณธ)

์œ„์—์„œ ๊ตฌํ˜„ํ•œ ๋ฐฉ์‹์ž…๋‹ˆ๋‹ค.

5.2 PPO-Penalty

KL divergence๋ฅผ ํŽ˜๋„ํ‹ฐ๋กœ ์ถ”๊ฐ€:

def ppo_penalty_loss(ratio, advantage, old_probs, new_probs, beta=0.01):
    policy_loss = (ratio * advantage).mean()

    kl_div = F.kl_div(new_probs.log(), old_probs, reduction='batchmean')

    return -policy_loss + beta * kl_div

5.3 Clipped Value Loss

๊ฐ€์น˜ ํ•จ์ˆ˜์—๋„ ํด๋ฆฌํ•‘ ์ ์šฉ:

def clipped_value_loss(values, old_values, returns, clip_epsilon=0.2):
    # ํด๋ฆฌํ•‘๋œ ๊ฐ€์น˜
    clipped_values = old_values + torch.clamp(
        values - old_values, -clip_epsilon, clip_epsilon
    )

    # ๋‘ ์†์‹ค ์ค‘ ํฐ ๊ฐ’
    loss1 = (values - returns) ** 2
    loss2 = (clipped_values - returns) ** 2

    return 0.5 * torch.max(loss1, loss2).mean()

6. ์—ฐ์† ํ–‰๋™ ๊ณต๊ฐ„ PPO

class ContinuousPPONetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super().__init__()

        # Actor
        self.actor_mean = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, action_dim)
        )
        self.actor_log_std = nn.Parameter(torch.zeros(action_dim))

        # Critic
        self.critic = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, state):
        mean = self.actor_mean(state)
        std = self.actor_log_std.exp()
        value = self.critic(state)
        return mean, std, value

    def get_action(self, state, action=None):
        mean, std, value = self.forward(state)
        dist = torch.distributions.Normal(mean, std)

        if action is None:
            action = dist.sample()

        log_prob = dist.log_prob(action).sum(-1)
        entropy = dist.entropy().sum(-1)

        return action, log_prob, entropy, value

7. ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ๊ฐ€์ด๋“œ

7.1 ์ผ๋ฐ˜์ ์ธ ์„ค์ •

config = {
    # ํ•™์Šต
    'lr': 3e-4,                  # ํ•™์Šต๋ฅ 
    'gamma': 0.99,               # ํ• ์ธ์œจ
    'gae_lambda': 0.95,          # GAE lambda

    # PPO ํŠน์ •
    'clip_epsilon': 0.2,         # ํด๋ฆฌํ•‘ ๋ฒ”์œ„
    'update_epochs': 10,         # ์—…๋ฐ์ดํŠธ ๋ฐ˜๋ณต
    'batch_size': 64,            # ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ํฌ๊ธฐ

    # ์†์‹ค ๊ณ„์ˆ˜
    'value_coef': 0.5,           # ๊ฐ€์น˜ ์†์‹ค ๊ณ„์ˆ˜
    'entropy_coef': 0.01,        # ์—”ํŠธ๋กœํ”ผ ๊ณ„์ˆ˜

    # ๋กค์•„์›ƒ
    'n_steps': 2048,             # ๋กค์•„์›ƒ ๊ธธ์ด
    'n_envs': 8,                 # ๋ณ‘๋ ฌ ํ™˜๊ฒฝ ์ˆ˜

    # ์•ˆ์ •ํ™”
    'max_grad_norm': 0.5,        # ๊ทธ๋ž˜๋””์–ธํŠธ ํด๋ฆฌํ•‘
}

7.2 ํ™˜๊ฒฝ๋ณ„ ํŠœ๋‹

ํ™˜๊ฒฝ lr n_steps clip_epsilon
CartPole 3e-4 128 0.2
LunarLander 3e-4 2048 0.2
Atari 2.5e-4 128 0.1
MuJoCo 3e-4 2048 0.2

8. PPO vs ๋‹ค๋ฅธ ์•Œ๊ณ ๋ฆฌ์ฆ˜

์•Œ๊ณ ๋ฆฌ์ฆ˜ ๋ณต์žก๋„ ์ƒ˜ํ”Œ ํšจ์œจ ์•ˆ์ •์„ฑ
REINFORCE ๋‚ฎ์Œ ๋‚ฎ์Œ ๋‚ฎ์Œ
A2C ์ค‘๊ฐ„ ์ค‘๊ฐ„ ์ค‘๊ฐ„
TRPO ๋†’์Œ ๋†’์Œ ๋†’์Œ
PPO ์ค‘๊ฐ„ ๋†’์Œ ๋†’์Œ
SAC ์ค‘๊ฐ„ ๋†’์Œ ๋†’์Œ

PPO์˜ ์žฅ์ : - TRPO ์ˆ˜์ค€์˜ ์„ฑ๋Šฅ, ๊ตฌํ˜„์€ ๊ฐ„๋‹จ - ๋‹ค์–‘ํ•œ ํ™˜๊ฒฝ์—์„œ ์•ˆ์ •์  - ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ๋ฏผ๊ฐ๋„ ๋‚ฎ์Œ


์š”์•ฝ

PPO ํ•ต์‹ฌ:

L^{CLIP} = E[min(r(ฮธ)A, clip(r(ฮธ), 1-ฮต, 1+ฮต)A)]

r(ฮธ) = ฯ€_ฮธ(a|s) / ฯ€_ฮธ_old(a|s)  # ์ •์ฑ… ๋น„์œจ

ํด๋ฆฌํ•‘ ํšจ๊ณผ: - ์ •์ฑ… ๋ณ€ํ™”๋ฅผ [1-ฮต, 1+ฮต] ๋ฒ”์œ„๋กœ ์ œํ•œ - ๊ธ‰๊ฒฉํ•œ ์—…๋ฐ์ดํŠธ ๋ฐฉ์ง€ - ํ•™์Šต ์•ˆ์ •์„ฑ ํ™•๋ณด


๋‹ค์Œ ๋‹จ๊ณ„

  • 11_Multi_Agent_RL.md - ๋‹ค์ค‘ ์—์ด์ „ํŠธ ๊ฐ•ํ™”ํ•™์Šต
to navigate between lessons