14. Soft Actor-Critic (SAC)
14. Soft Actor-Critic (SAC)¶
λμ΄λ: ββββ (κ³ κΈ)
νμ΅ λͺ©ν¶
- μ΅λ μνΈλ‘νΌ κ°ννμ΅(maximum entropy reinforcement learning) μ΄ν΄
- μ°μ νλ 곡κ°μ μν SAC μκ³ λ¦¬μ¦ κ΅¬ν
- μλ μ¨λ(alpha) νλ νμ΅
- SACμ PPO, TD3 λΉκ΅
- μ€μ©μ μΈ μ°μ μ μ΄ νμ€ν¬μ SAC μ μ©
λͺ©μ°¨¶
- μ΅λ μνΈλ‘νΌ κ°ννμ΅
- SAC μκ³ λ¦¬μ¦
- SAC ꡬν
- μλ μ¨λ νλ
- SAC vs λ€λ₯Έ μκ³ λ¦¬μ¦
- μ€μ©μ ν
- μ°μ΅ λ¬Έμ
1. μ΅λ μνΈλ‘νΌ κ°ννμ΅¶
1.1 νμ€ κ°ννμ΅ vs μ΅λ μνΈλ‘νΌ κ°ννμ΅¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Maximum Entropy Framework β
β β
β Standard RL objective: β
β Ο* = argmax_Ο E [ Ξ£ Ξ³^t r_t ] β
β β maximize expected return only β
β β
β Maximum Entropy RL objective: β
β Ο* = argmax_Ο E [ Ξ£ Ξ³^t (r_t + Ξ± H(Ο(Β·|s_t))) ] β
β β maximize return + policy entropy β
β β
β Where: β
β β’ H(Ο(Β·|s)) = -E[log Ο(a|s)] is the policy entropy β
β β’ Ξ± (temperature) controls exploration-exploitation balance β
β β
β Benefits of maximum entropy: β
β 1. Encourages exploration (higher entropy = more random) β
β 2. Captures multiple modes (doesn't collapse to one solution) β
β 3. More robust to perturbations β
β 4. Better transfer and fine-tuning β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
1.2 μννΈ λ²¨λ§ λ°©μ μ(Soft Bellman Equation)¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Soft Value Functions β
β β
β Soft state value: β
β V(s) = E_a~Ο [ Q(s,a) - Ξ± log Ο(a|s) ] β
β β
β Soft Q-value (Bellman equation): β
β Q(s,a) = r(s,a) + Ξ³ E_s' [ V(s') ] β
β = r(s,a) + Ξ³ E_s' [ E_a'~Ο [ Q(s',a') - Ξ± log Ο(a'|s') ] ]
β β
β Soft policy improvement: β
β Ο_new = argmin_Ο D_KL( Ο(Β·|s) || exp(Q(s,Β·)/Ξ±) / Z(s) ) β
β β
β In practice: Ο outputs mean and std of Gaussian β
β a ~ tanh(ΞΌ + Ο Β· Ξ΅), Ξ΅ ~ N(0, I) β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
2. SAC μκ³ λ¦¬μ¦¶
2.1 SAC κ΅¬μ± μμ¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β SAC Architecture β
β β
β ββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β Actor (Policy) β β
β β Ο_Ο(a|s): Squashed Gaussian β β
β β Input: state s β β
β β Output: ΞΌ(s), Ο(s) β a = tanh(ΞΌ + ΟΒ·Ξ΅) β β
β ββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
β ββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β Twin Critics (Q1, Q2) β β
β β Q_ΞΈ1(s, a), Q_ΞΈ2(s, a) β β
β β Input: state s, action a β β
β β Output: Q-value β β
β β β Use min(Q1, Q2) to prevent overestimation β β
β ββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
β ββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β Target Networks (Q1', Q2') β β
β β Soft update: ΞΈ' β ΟΞΈ + (1-Ο)ΞΈ' β β
β β Provides stable targets for critic training β β
β ββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
β ββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β Temperature (Ξ±) β β
β β Controls entropy bonus β β
β β Can be fixed or automatically tuned β β
β ββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
2.2 SAC μ λ°μ΄νΈ κ·μΉ¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β SAC Training Steps β
β β
β For each gradient step: β
β β
β 1. Sample batch (s, a, r, s', done) from replay buffer β
β β
β 2. Compute target: β
β a' ~ Ο_Ο(Β·|s') β
β y = r + Ξ³(1-done) Γ [min(Q'β(s',a'), Q'β(s',a')) β
β - Ξ± log Ο_Ο(a'|s')] β
β β
β 3. Update Critics (minimize MSE): β
β L_Q = E[(Q_ΞΈi(s,a) - y)Β²] for i = 1, 2 β
β β
β 4. Update Actor (maximize): β
β Γ£ ~ Ο_Ο(Β·|s) (reparameterization trick) β
β L_Ο = E[Ξ± log Ο_Ο(Γ£|s) - min(Q_ΞΈ1(s,Γ£), Q_ΞΈ2(s,Γ£))] β
β β
β 5. Update Temperature (if auto-tuning): β
β L_Ξ± = E[-Ξ± (log Ο_Ο(Γ£|s) + H_target)] β
β β
β 6. Soft update target networks: β
β ΞΈ'i β Ο ΞΈi + (1-Ο) ΞΈ'i β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
3. SAC ꡬν¶
3.1 μ‘ν° λ€νΈμν¬(Squashed Gaussian Policy)¶
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
LOG_STD_MIN = -20
LOG_STD_MAX = 2
class GaussianActor(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
self.mean_head = nn.Linear(hidden_dim, action_dim)
self.log_std_head = nn.Linear(hidden_dim, action_dim)
def forward(self, state):
x = self.net(state)
mean = self.mean_head(x)
log_std = self.log_std_head(x)
log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
return mean, log_std
def sample(self, state):
mean, log_std = self.forward(state)
std = log_std.exp()
normal = Normal(mean, std)
# Reparameterization trick: z = ΞΌ + ΟΒ·Ξ΅
z = normal.rsample()
# Squash through tanh
action = torch.tanh(z)
# Log probability with correction for tanh squashing
log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + 1e-6)
log_prob = log_prob.sum(dim=-1, keepdim=True)
return action, log_prob
def deterministic_action(self, state):
mean, _ = self.forward(state)
return torch.tanh(mean)
3.2 ν¬λ¦¬ν± λ€νΈμν¬¶
class TwinQCritic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super().__init__()
# Q1 network
self.q1 = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
# Q2 network
self.q2 = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, state, action):
x = torch.cat([state, action], dim=-1)
return self.q1(x), self.q2(x)
def q1_forward(self, state, action):
x = torch.cat([state, action], dim=-1)
return self.q1(x)
3.3 SAC μμ΄μ νΈ¶
import copy
class SACAgent:
def __init__(self, state_dim, action_dim, hidden_dim=256,
lr=3e-4, gamma=0.99, tau=0.005, alpha=0.2,
auto_alpha=True, target_entropy=None):
self.gamma = gamma
self.tau = tau
self.auto_alpha = auto_alpha
# Networks
self.actor = GaussianActor(state_dim, action_dim, hidden_dim)
self.critic = TwinQCritic(state_dim, action_dim, hidden_dim)
self.critic_target = copy.deepcopy(self.critic)
# Freeze target parameters
for param in self.critic_target.parameters():
param.requires_grad = False
# Optimizers
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr)
# Temperature (alpha)
if auto_alpha:
self.target_entropy = target_entropy or -action_dim
self.log_alpha = torch.zeros(1, requires_grad=True)
self.alpha = self.log_alpha.exp().item()
self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=lr)
else:
self.alpha = alpha
def select_action(self, state, deterministic=False):
state = torch.FloatTensor(state).unsqueeze(0)
with torch.no_grad():
if deterministic:
action = self.actor.deterministic_action(state)
else:
action, _ = self.actor.sample(state)
return action.squeeze(0).numpy()
def update(self, batch):
states, actions, rewards, next_states, dones = batch
# --- Update Critics ---
with torch.no_grad():
next_actions, next_log_probs = self.actor.sample(next_states)
q1_target, q2_target = self.critic_target(next_states, next_actions)
q_target = torch.min(q1_target, q2_target)
target = rewards + self.gamma * (1 - dones) * \
(q_target - self.alpha * next_log_probs)
q1, q2 = self.critic(states, actions)
critic_loss = F.mse_loss(q1, target) + F.mse_loss(q2, target)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# --- Update Actor ---
new_actions, log_probs = self.actor.sample(states)
q1_new, q2_new = self.critic(states, new_actions)
q_new = torch.min(q1_new, q2_new)
actor_loss = (self.alpha * log_probs - q_new).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# --- Update Temperature ---
if self.auto_alpha:
alpha_loss = -(self.log_alpha * (log_probs.detach() + self.target_entropy)).mean()
self.alpha_optimizer.zero_grad()
alpha_loss.backward()
self.alpha_optimizer.step()
self.alpha = self.log_alpha.exp().item()
# --- Soft Update Target Networks ---
for param, target_param in zip(
self.critic.parameters(), self.critic_target.parameters()
):
target_param.data.copy_(
self.tau * param.data + (1 - self.tau) * target_param.data
)
return {
'critic_loss': critic_loss.item(),
'actor_loss': actor_loss.item(),
'alpha': self.alpha,
'entropy': -log_probs.mean().item()
}
3.4 리νλ μ΄ λ²νΌ¶
import numpy as np
class ReplayBuffer:
def __init__(self, state_dim, action_dim, capacity=1_000_000):
self.capacity = capacity
self.idx = 0
self.size = 0
self.states = np.zeros((capacity, state_dim), dtype=np.float32)
self.actions = np.zeros((capacity, action_dim), dtype=np.float32)
self.rewards = np.zeros((capacity, 1), dtype=np.float32)
self.next_states = np.zeros((capacity, state_dim), dtype=np.float32)
self.dones = np.zeros((capacity, 1), dtype=np.float32)
def add(self, state, action, reward, next_state, done):
self.states[self.idx] = state
self.actions[self.idx] = action
self.rewards[self.idx] = reward
self.next_states[self.idx] = next_state
self.dones[self.idx] = done
self.idx = (self.idx + 1) % self.capacity
self.size = min(self.size + 1, self.capacity)
def sample(self, batch_size):
idxs = np.random.randint(0, self.size, size=batch_size)
return (
torch.FloatTensor(self.states[idxs]),
torch.FloatTensor(self.actions[idxs]),
torch.FloatTensor(self.rewards[idxs]),
torch.FloatTensor(self.next_states[idxs]),
torch.FloatTensor(self.dones[idxs])
)
def __len__(self):
return self.size
3.5 νμ΅ λ£¨ν¶
import gymnasium as gym
def train_sac(env_name='Pendulum-v1', total_steps=100_000,
batch_size=256, start_steps=5000):
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
action_scale = env.action_space.high[0]
agent = SACAgent(state_dim, action_dim)
buffer = ReplayBuffer(state_dim, action_dim)
state, _ = env.reset()
episode_reward = 0
episode_rewards = []
for step in range(total_steps):
# Random actions for initial exploration
if step < start_steps:
action = env.action_space.sample()
else:
action = agent.select_action(state) * action_scale
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
buffer.add(state, action / action_scale, reward, next_state, float(terminated))
state = next_state
episode_reward += reward
if done:
episode_rewards.append(episode_reward)
state, _ = env.reset()
episode_reward = 0
# Update after collecting enough data
if step >= start_steps and len(buffer) >= batch_size:
batch = buffer.sample(batch_size)
metrics = agent.update(batch)
if step % 1000 == 0:
avg_reward = np.mean(episode_rewards[-10:]) if episode_rewards else 0
print(f"Step {step}: avg_reward={avg_reward:.1f}, "
f"alpha={metrics['alpha']:.3f}, "
f"entropy={metrics['entropy']:.3f}")
return agent, episode_rewards
4. μλ μ¨λ νλ¶
4.1 μλ νλμ μ€μμ±¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Temperature (Ξ±) Effect β
β β
β Ξ± too high: Ξ± too low: β
β βββββββββββββββββββββββ βββββββββββββββββββββββ β
β β Entropy dominates β β Return dominates β β
β β β nearly random β β β premature β β
β β β slow learning β β convergence β β
β β β poor performance β β β poor exploration β β
β βββββββββββββββββββββββ βββββββββββββββββββββββ β
β β
β Auto-tuning: β
β ββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β Constraint: H(Ο(Β·|s)) β₯ H_target β β
β β If entropy < target: increase Ξ± (explore more) β β
β β If entropy > target: decrease Ξ± (exploit more) β β
β β H_target = -dim(A) (heuristic for continuous) β β
β ββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
β Typical Ξ± trajectory during training: β
β Ξ± β
β β β
β β β² β
β β β² β
β β β²___ β
β β β²____ β
β β β²_________ β
β β βββββ β
β ββββββββββββββββββββββββββββββββΆ steps β
β (starts high for exploration, decreases as policy converges) β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
4.2 μ¨λ μμ€ ν¨μ¶
# Automatic temperature tuning objective
# L(Ξ±) = E_a~Ο [-Ξ± log Ο(a|s) - Ξ± H_target]
# = E_a~Ο [-Ξ± (log Ο(a|s) + H_target)]
# In the update step:
alpha_loss = -(self.log_alpha * (log_probs.detach() + self.target_entropy)).mean()
# Intuition:
# When entropy (-log_probs) < H_target: log_probs + H_target > 0
# β gradient pushes log_alpha up β Ξ± increases β more entropy encouraged
# When entropy > H_target: log_probs + H_target < 0
# β gradient pushes log_alpha down β Ξ± decreases
5. SAC vs λ€λ₯Έ μκ³ λ¦¬μ¦¶
5.1 λΉκ΅ ν¶
βββββββββββββββββ¬βββββββββββ¬βββββββββββ¬βββββββββββ¬ββββββββββββββββ
β β SAC β PPO β TD3 β DDPG β
βββββββββββββββββΌβββββββββββΌβββββββββββΌβββββββββββΌββββββββββββββββ€
β Policy type β Stochas. β Stochas. β Determin.β Deterministic β
β On/Off policy β Off β On β Off β Off β
β Action space β Contin. β Both β Contin. β Continuous β
β Entropy reg. β Yes β Yes β No β No β
β Twin critics β Yes β No β Yes β No β
β Sample eff. β High β Low β High β Medium β
β Stability β High β High β Medium β Low β
β Hyperparams β Few β Many β Medium β Many β
β Auto-tuning β Ξ± tuning β No β No β No β
βββββββββββββββββ΄βββββββββββ΄βββββββββββ΄βββββββββββ΄ββββββββββββββββ
5.2 SAC μ¬μ© μκΈ°¶
Use SAC when:
β Continuous action spaces (robotics, control)
β Sample efficiency matters (real-world, expensive simulation)
β You want stable training with minimal tuning
β Multi-modal optimal policies exist
Use PPO instead when:
β Discrete action spaces
β On-policy learning is preferred
β Simulation is cheap (can generate many samples)
β Distributed training (PPO scales better)
Use TD3 instead when:
β Deterministic policy is preferred
β Simpler implementation needed
β No entropy regularization wanted
6. μ€μ©μ ν¶
6.1 νμ΄νΌνλΌλ―Έν°¶
ββββββββββββββββββββββββββ¬βββββββββββββββ¬βββββββββββββββββββββββββββ
β Hyperparameter β Default β Notes β
ββββββββββββββββββββββββββΌβββββββββββββββΌβββββββββββββββββββββββββββ€
β Learning rate β 3e-4 β Same for actor & critic β
β Discount (Ξ³) β 0.99 β Standard β
β Soft update (Ο) β 0.005 β Slow target updates β
β Batch size β 256 β Larger is more stable β
β Buffer size β 1M β Large replay buffer β
β Hidden layers β (256, 256) β 2 layers is standard β
β Start steps β 5000-10000 β Random exploration first β
β Target entropy β -dim(A) β Heuristic, works well β
β Gradient steps/env stepβ 1 β 1:1 ratio is standard β
ββββββββββββββββββββββββββ΄βββββββββββββββ΄βββββββββββββββββββββββββββ
6.2 μΌλ°μ μΈ λ¬Έμ μ ν΄κ²° λ°©λ²¶
Issue: Training instability / Q-values diverge
β Check reward scale (normalize if needed)
β Reduce learning rate
β Increase batch size
Issue: Low entropy (premature convergence)
β Enable auto alpha tuning
β Increase initial alpha
β Check action bounds
Issue: Slow learning
β Increase start_steps for better initial exploration
β Try larger networks
β Check reward shaping
Issue: Action values saturating at bounds
β Ensure proper action scaling
β Check tanh squashing implementation
β Verify log_prob correction term
7. μ°μ΅ λ¬Έμ ¶
μ°μ΅ 1: Pendulumμμ SAC¶
Pendulum-v1μμ SACλ₯Ό νμ΅νκ³ νμ΅ κ³‘μ μ 그리μΈμ.
# Train SAC
agent, rewards = train_sac('Pendulum-v1', total_steps=50_000)
# Plot learning curve
import matplotlib.pyplot as plt
window = 10
smoothed = np.convolve(rewards, np.ones(window)/window, mode='valid')
plt.plot(smoothed)
plt.xlabel('Episode')
plt.ylabel('Episode Reward')
plt.title('SAC on Pendulum-v1')
plt.show()
# Expected: converges to ~-200 within 20K steps
μ°μ΅ 2: SAC vs PPO λΉκ΅¶
μ°μ μ μ΄ νμ€ν¬μμ SACμ PPOλ₯Ό λͺ¨λ νμ΅νκ³ μν ν¨μ¨μ±μ λΉκ΅νμΈμ.
# Use HalfCheetah-v4 or Hopper-v4
# Plot reward vs environment steps for both algorithms
# Expected: SAC reaches same performance in ~5x fewer environment steps
# But PPO may have lower wall-clock time per step
μ°μ΅ 3: μ μ μ°κ΅¬(Ablation Study)¶
λ€μ λ³νμΌλ‘ SACλ₯Ό μ€ννκ³ λΉκ΅νμΈμ: 1. κ³ μ alpha = 0.2 (μλ νλ μμ) 2. μλ alpha (κΈ°λ³Έκ°) 3. μνΈλ‘νΌ ν μμ (alpha = 0, TD3μ μ μ¬) 4. λ¨μΌ Q-λ€νΈμν¬ (νΈμ ν¬λ¦¬ν± μμ)
# Expected findings:
# - Auto alpha > fixed alpha (adapts to task)
# - With entropy > without (better exploration)
# - Twin critics > single (prevents overestimation)
μ°μ΅ 4: 컀μ€ν νκ²½¶
컀μ€ν μ°μ μ μ΄ νμ€ν¬μ SACλ₯Ό μ μ©νμΈμ.
# Example: reaching task
import gymnasium as gym
from gymnasium import spaces
class ReachingEnv(gym.Env):
"""2D reaching task: move arm tip to target."""
def __init__(self):
self.observation_space = spaces.Box(-np.inf, np.inf, shape=(4,))
self.action_space = spaces.Box(-1.0, 1.0, shape=(2,))
self.target = np.array([0.5, 0.5])
def reset(self, seed=None):
super().reset(seed=seed)
self.pos = np.random.uniform(-1, 1, size=2)
return np.concatenate([self.pos, self.target]), {}
def step(self, action):
self.pos = np.clip(self.pos + action * 0.1, -1, 1)
dist = np.linalg.norm(self.pos - self.target)
reward = -dist
done = dist < 0.05
return np.concatenate([self.pos, self.target]), reward, done, False, {}
μμ½¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β SAC Key Components β
β β
β 1. Maximum entropy objective: reward + Ξ± Γ entropy β
β 2. Squashed Gaussian policy: a = tanh(ΞΌ + ΟΞ΅) β
β 3. Twin Q-critics: min(Q1, Q2) prevents overestimation β
β 4. Automatic temperature: Ξ± adapts to maintain target entropy β
β 5. Off-policy: high sample efficiency via replay buffer β
β β
β SAC is the go-to algorithm for continuous control tasks β
β due to its stability, sample efficiency, and minimal tuning. β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
μ°Έκ³ λ¬Έν¶
- SAC Paper (v1) β Haarnoja et al. 2018
- SAC Paper (v2, auto-alpha) β Haarnoja et al. 2018
- Spinning Up: SAC
- Stable-Baselines3 SAC
- CleanRL SAC Implementation