09. Actor-Critic ๋ฐฉ๋ฒ๋ก
09. Actor-Critic ๋ฐฉ๋ฒ๋ก ¶
๋์ด๋: โญโญโญโญ (๊ณ ๊ธ)
ํ์ต ๋ชฉํ¶
- Actor-Critic ์ํคํ ์ฒ ์ดํด
- Advantage ํจ์์ GAE ํ์ต
- A2C์ A3C ์๊ณ ๋ฆฌ์ฆ ๋น๊ต
- PyTorch๋ก Actor-Critic ๊ตฌํ
1. Actor-Critic ๊ฐ์¶
1.1 ํต์ฌ ์์ด๋์ด¶
Actor: ์ ์ฑ ฯ(a|s;ฮธ)๋ฅผ ํ์ต Critic: ๊ฐ์น ํจ์ V(s;w)๋ฅผ ํ์ต
Actor-Critic = Policy Gradient + TD Learning
1.2 REINFORCE vs Actor-Critic¶
| REINFORCE | Actor-Critic |
|---|---|
| ์ํผ์๋ ์ข ๋ฃ ํ ์ ๋ฐ์ดํธ | ๋งค ์คํ ์ ๋ฐ์ดํธ |
| ์ค์ ๋ฆฌํด G ์ฌ์ฉ | TD Target ์ฌ์ฉ |
| ๋์ ๋ถ์ฐ | ๋ฎ์ ๋ถ์ฐ, ์ฝ๊ฐ์ ํธํฅ |
2. Advantage ํจ์¶
2.1 ์ ์¶
$$A(s, a) = Q(s, a) - V(s)$$
์๋ฏธ: ํ๊ท ๋ณด๋ค ์ผ๋ง๋ ์ข์ ํ๋์ธ๊ฐ
2.2 TD Error๋ฅผ Advantage๋ก ์ฌ์ฉ¶
ฮด_t = r_t + ฮณV(s_{t+1}) - V(s_t)
E[ฮด_t | s_t, a_t] = Q(s_t, a_t) - V(s_t) = A(s_t, a_t)
TD Error๋ Advantage์ ๋ถํธ ์ถ์ ๋์ ๋๋ค.
def compute_advantage(rewards, values, next_values, dones, gamma=0.99):
"""1-step Advantage ๊ณ์ฐ"""
advantages = []
for r, v, nv, d in zip(rewards, values, next_values, dones):
if d:
advantage = r - v
else:
advantage = r + gamma * nv - v
advantages.append(advantage)
return advantages
3. A2C (Advantage Actor-Critic)¶
3.1 ๋คํธ์ํฌ ๊ตฌ์กฐ¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class ActorCritic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=128):
super().__init__()
# ๊ณต์ ํน์ง ์ถ์ถ
self.shared = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU()
)
# Actor (์ ์ฑ
)
self.actor = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim)
)
# Critic (๊ฐ์น)
self.critic = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, state):
features = self.shared(state)
policy = F.softmax(self.actor(features), dim=-1)
value = self.critic(features)
return policy, value
def get_action(self, state):
policy, value = self.forward(state)
dist = torch.distributions.Categorical(policy)
action = dist.sample()
log_prob = dist.log_prob(action)
return action.item(), log_prob, value
3.2 A2C ์์ด์ ํธ¶
class A2CAgent:
def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99,
value_coef=0.5, entropy_coef=0.01):
self.network = ActorCritic(state_dim, action_dim)
self.optimizer = torch.optim.Adam(self.network.parameters(), lr=lr)
self.gamma = gamma
self.value_coef = value_coef
self.entropy_coef = entropy_coef
# ์ํผ์๋ ๋ฒํผ
self.log_probs = []
self.values = []
self.rewards = []
self.dones = []
self.entropies = []
def choose_action(self, state):
state_tensor = torch.FloatTensor(state).unsqueeze(0)
policy, value = self.network(state_tensor)
dist = torch.distributions.Categorical(policy)
action = dist.sample()
self.log_probs.append(dist.log_prob(action))
self.values.append(value)
self.entropies.append(dist.entropy())
return action.item()
def store(self, reward, done):
self.rewards.append(reward)
self.dones.append(done)
def compute_returns(self, next_value):
"""n-step returns ๊ณ์ฐ"""
returns = []
R = next_value
for r, d in zip(reversed(self.rewards), reversed(self.dones)):
if d:
R = 0
R = r + self.gamma * R
returns.insert(0, R)
return torch.tensor(returns)
def update(self, next_state):
# ๋ค์ ์ํ์ ๊ฐ์น (๋ถํธ์คํธ๋ํ)
with torch.no_grad():
_, next_value = self.network(
torch.FloatTensor(next_state).unsqueeze(0)
)
next_value = next_value.item()
returns = self.compute_returns(next_value)
values = torch.cat(self.values).squeeze()
log_probs = torch.stack(self.log_probs)
entropies = torch.stack(self.entropies)
# Advantage
advantages = returns - values.detach()
# ์์ค ๊ณ์ฐ
actor_loss = -(log_probs * advantages).mean()
critic_loss = F.mse_loss(values, returns)
entropy_loss = -entropies.mean()
total_loss = (actor_loss +
self.value_coef * critic_loss +
self.entropy_coef * entropy_loss)
# ์
๋ฐ์ดํธ
self.optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 0.5)
self.optimizer.step()
# ๋ฒํผ ์ด๊ธฐํ
self.log_probs = []
self.values = []
self.rewards = []
self.dones = []
self.entropies = []
return actor_loss.item(), critic_loss.item()
3.3 A2C ํ์ต¶
import gymnasium as gym
import numpy as np
def train_a2c(env_name='CartPole-v1', n_episodes=1000, n_steps=5):
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = A2CAgent(state_dim, action_dim)
scores = []
for episode in range(n_episodes):
state, _ = env.reset()
total_reward = 0
done = False
step_count = 0
while not done:
action = agent.choose_action(state)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
agent.store(reward, done)
state = next_state
total_reward += reward
step_count += 1
# n-step ์
๋ฐ์ดํธ
if step_count % n_steps == 0 or done:
agent.update(next_state)
scores.append(total_reward)
if (episode + 1) % 100 == 0:
print(f"Episode {episode + 1}, Avg: {np.mean(scores[-100:]):.2f}")
return agent, scores
4. GAE (Generalized Advantage Estimation)¶
4.1 n-step Returns ํธ๋ ์ด๋์คํ¶
| n | ํธํฅ | ๋ถ์ฐ |
|---|---|---|
| 1 (TD) | ๋์ | ๋ฎ์ |
| โ (MC) | ๋ฎ์ | ๋์ |
4.2 GAE ๊ณต์¶
๋ชจ๋ n-step advantage๋ฅผ ๊ธฐํ๊ธ์์ ์ผ๋ก ๊ฐ์ค ํ๊ท :
$$A^{GAE}\_t = \sum\_{l=0}^{\infty} (\gamma \lambda)^l \delta\_{t+l}$$
์ฌ๊ธฐ์ ฮด_t = r_t + ฮณV(s_{t+1}) - V(s_t)
def compute_gae(rewards, values, next_values, dones, gamma=0.99, lam=0.95):
"""Generalized Advantage Estimation"""
advantages = []
gae = 0
for t in reversed(range(len(rewards))):
if dones[t]:
delta = rewards[t] - values[t]
gae = delta
else:
delta = rewards[t] + gamma * next_values[t] - values[t]
gae = delta + gamma * lam * gae
advantages.insert(0, gae)
return torch.tensor(advantages)
4.3 GAE ์ ์ฉ A2C¶
class A2CWithGAE(A2CAgent):
def __init__(self, *args, gae_lambda=0.95, **kwargs):
super().__init__(*args, **kwargs)
self.gae_lambda = gae_lambda
self.next_values = []
def compute_gae_returns(self):
"""GAE ๊ธฐ๋ฐ advantage์ returns"""
values = torch.cat(self.values).squeeze().tolist()
next_vals = self.next_values
advantages = []
gae = 0
for t in reversed(range(len(self.rewards))):
if self.dones[t]:
delta = self.rewards[t] - values[t]
gae = delta
else:
delta = self.rewards[t] + self.gamma * next_vals[t] - values[t]
gae = delta + self.gamma * self.gae_lambda * gae
advantages.insert(0, gae)
advantages = torch.tensor(advantages)
returns = advantages + torch.tensor(values)
return advantages, returns
5. A3C (Asynchronous Advantage Actor-Critic)¶
5.1 ํต์ฌ ์์ด๋์ด¶
์ฌ๋ฌ ์์ปค๊ฐ ๋ณ๋ ฌ๋ก ํ๊ฒฝ๊ณผ ์ํธ์์ฉํ๋ฉฐ ๋น๋๊ธฐ์ ์ผ๋ก ๊ทธ๋๋์ธํธ๋ฅผ ์ ๋ฐ์ดํธํฉ๋๋ค.
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Global Network โ
โ (Shared Parameters) โ
โโโโโโโโโโโโฌโโโโโโโโโโโฌโโโโโโโโโโโโโโโโ
โ โ
โโโโโโโดโโโโโ โโโโดโโโโโโ
โ Worker 1 โ โ Worker 2โ ...
โ Env 1 โ โ Env 2 โ
โโโโโโโโโโโโ โโโโโโโโโโโ
5.2 ์์ฌ ์ฝ๋¶
# ๊ฐ ์์ปค์ ๋์
def worker(global_network, optimizer, env):
local_network = copy(global_network)
while True:
# ๋ก์ปฌ ๋คํธ์ํฌ๋ก ๊ฒฝํ ์์ง
trajectory = collect_trajectory(local_network, env)
# ๊ทธ๋๋์ธํธ ๊ณ์ฐ
loss = compute_loss(trajectory)
gradients = compute_gradients(loss, local_network)
# ๋น๋๊ธฐ ์
๋ฐ์ดํธ
apply_gradients(optimizer, global_network, gradients)
# ๋ก์ปฌ ๋คํธ์ํฌ ๋๊ธฐํ
local_network.load_state_dict(global_network.state_dict())
5.3 A2C vs A3C¶
| A2C | A3C |
|---|---|
| ๋๊ธฐ ์ ๋ฐ์ดํธ | ๋น๋๊ธฐ ์ ๋ฐ์ดํธ |
| ๋ฐฐ์น ์ฒ๋ฆฌ | ์คํธ๋ฆผ ์ฒ๋ฆฌ |
| ๋ ์์ ์ | ๋ ๋น ๋ฆ (๋ณ๋ ฌ) |
| GPU ํจ์จ์ | CPU ํจ์จ์ |
ํ์ฌ ๊ถ์ฅ: A2C๊ฐ GPU์์ ๋ ํจ์จ์ ์ด๋ฏ๋ก ๋ง์ด ์ฌ์ฉ๋จ
6. ์ฐ์ ํ๋ ๊ณต๊ฐ Actor-Critic¶
class ContinuousActorCritic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super().__init__()
# ๊ณต์ ๋คํธ์ํฌ
self.shared = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
# Actor: ํ๊ท ๊ณผ ํ์คํธ์ฐจ
self.actor_mean = nn.Linear(hidden_dim, action_dim)
self.actor_log_std = nn.Parameter(torch.zeros(action_dim))
# Critic
self.critic = nn.Linear(hidden_dim, 1)
def forward(self, state):
features = self.shared(state)
mean = self.actor_mean(features)
std = self.actor_log_std.exp()
value = self.critic(features)
return mean, std, value
def get_action(self, state, deterministic=False):
mean, std, value = self.forward(state)
if deterministic:
action = mean
else:
dist = torch.distributions.Normal(mean, std)
action = dist.sample()
return action, value
def evaluate(self, state, action):
mean, std, value = self.forward(state)
dist = torch.distributions.Normal(mean, std)
log_prob = dist.log_prob(action).sum(-1, keepdim=True)
entropy = dist.entropy().sum(-1, keepdim=True)
return value, log_prob, entropy
7. ํ์ต ์์ ํ ๊ธฐ๋ฒ¶
7.1 ๊ทธ๋๋์ธํธ ํด๋ฆฌํ¶
# ๊ทธ๋๋์ธํธ ๋
ธ๋ฆ ํด๋ฆฌํ
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=0.5)
# ๊ทธ๋๋์ธํธ ๊ฐ ํด๋ฆฌํ
torch.nn.utils.clip_grad_value_(network.parameters(), clip_value=1.0)
7.2 ํ์ต๋ฅ ์ค์ผ์ค๋ง¶
from torch.optim.lr_scheduler import LinearLR
scheduler = LinearLR(
optimizer,
start_factor=1.0,
end_factor=0.1,
total_iters=total_timesteps
)
7.3 ๋ณด์ ์ ๊ทํ¶
class RunningMeanStd:
def __init__(self):
self.mean = 0
self.var = 1
self.count = 1e-4
def update(self, x):
batch_mean = np.mean(x)
batch_var = np.var(x)
batch_count = len(x)
delta = batch_mean - self.mean
total_count = self.count + batch_count
self.mean += delta * batch_count / total_count
self.var = (self.var * self.count + batch_var * batch_count +
delta**2 * self.count * batch_count / total_count) / total_count
self.count = total_count
def normalize(self, x):
return (x - self.mean) / (np.sqrt(self.var) + 1e-8)
8. ์ค์ต: LunarLander¶
def train_lunarlander():
env = gym.make('LunarLander-v2')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = A2CAgent(
state_dim, action_dim,
lr=7e-4, gamma=0.99,
value_coef=0.5, entropy_coef=0.01
)
scores = []
n_steps = 5
for episode in range(2000):
state, _ = env.reset()
total_reward = 0
steps = 0
while True:
action = agent.choose_action(state)
next_state, reward, done, truncated, _ = env.step(action)
agent.store(reward, done or truncated)
state = next_state
total_reward += reward
steps += 1
if steps % n_steps == 0 or done or truncated:
agent.update(next_state)
if done or truncated:
break
scores.append(total_reward)
if (episode + 1) % 100 == 0:
avg = np.mean(scores[-100:])
print(f"Episode {episode + 1}, Avg: {avg:.2f}")
if avg >= 200:
print("Solved!")
break
return agent, scores
์์ฝ¶
| ๊ตฌ์ฑ์์ | ์ญํ | ํ์ต ๋์ |
|---|---|---|
| Actor | ์ ์ฑ | ฮธ (์ ์ฑ ํ๋ผ๋ฏธํฐ) |
| Critic | ๊ฐ์น ํ๊ฐ | w (๊ฐ์น ํ๋ผ๋ฏธํฐ) |
| Advantage | ํ๋ ํ์ง ์ธก์ | A = Q - V โ ฮด |
์์ค ํจ์:
L = L_actor + c1 * L_critic + c2 * L_entropy
L_actor = -log ฯ(a|s) * A
L_critic = (V - target)ยฒ
L_entropy = -ฮฃ ฯ log ฯ
๋ค์ ๋จ๊ณ¶
- 10_PPO_TRPO.md - ์ ๋ขฐ ์์ญ ์ ์ฑ ์ต์ ํ