14. Soft Actor-Critic (SAC)

14. Soft Actor-Critic (SAC)

Difficulty: โญโญโญโญ (Advanced)

Learning Objectives

  • Understand maximum entropy reinforcement learning
  • Implement the SAC algorithm for continuous action spaces
  • Learn automatic temperature (alpha) tuning
  • Compare SAC with PPO and TD3
  • Apply SAC to practical continuous control tasks

Table of Contents

  1. Maximum Entropy RL
  2. SAC Algorithm
  3. SAC Implementation
  4. Automatic Temperature Tuning
  5. SAC vs Other Algorithms
  6. Practical Tips
  7. Practice Problems

1. Maximum Entropy RL

1.1 Standard RL vs Maximum Entropy RL

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚              Maximum Entropy Framework                            โ”‚
โ”‚                                                                 โ”‚
โ”‚  Standard RL objective:                                         โ”‚
โ”‚  ฯ€* = argmax_ฯ€  E [ ฮฃ ฮณ^t r_t ]                               โ”‚
โ”‚                  โ†’ maximize expected return only                 โ”‚
โ”‚                                                                 โ”‚
โ”‚  Maximum Entropy RL objective:                                  โ”‚
โ”‚  ฯ€* = argmax_ฯ€  E [ ฮฃ ฮณ^t (r_t + ฮฑ H(ฯ€(ยท|s_t))) ]            โ”‚
โ”‚                  โ†’ maximize return + policy entropy             โ”‚
โ”‚                                                                 โ”‚
โ”‚  Where:                                                         โ”‚
โ”‚  โ€ข H(ฯ€(ยท|s)) = -E[log ฯ€(a|s)] is the policy entropy           โ”‚
โ”‚  โ€ข ฮฑ (temperature) controls exploration-exploitation balance    โ”‚
โ”‚                                                                 โ”‚
โ”‚  Benefits of maximum entropy:                                   โ”‚
โ”‚  1. Encourages exploration (higher entropy = more random)       โ”‚
โ”‚  2. Captures multiple modes (doesn't collapse to one solution)  โ”‚
โ”‚  3. More robust to perturbations                                โ”‚
โ”‚  4. Better transfer and fine-tuning                             โ”‚
โ”‚                                                                 โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

1.2 Soft Bellman Equation

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚              Soft Value Functions                                 โ”‚
โ”‚                                                                 โ”‚
โ”‚  Soft state value:                                              โ”‚
โ”‚  V(s) = E_a~ฯ€ [ Q(s,a) - ฮฑ log ฯ€(a|s) ]                      โ”‚
โ”‚                                                                 โ”‚
โ”‚  Soft Q-value (Bellman equation):                               โ”‚
โ”‚  Q(s,a) = r(s,a) + ฮณ E_s' [ V(s') ]                          โ”‚
โ”‚         = r(s,a) + ฮณ E_s' [ E_a'~ฯ€ [ Q(s',a') - ฮฑ log ฯ€(a'|s') ] ]
โ”‚                                                                 โ”‚
โ”‚  Soft policy improvement:                                       โ”‚
โ”‚  ฯ€_new = argmin_ฯ€  D_KL( ฯ€(ยท|s) || exp(Q(s,ยท)/ฮฑ) / Z(s) )   โ”‚
โ”‚                                                                 โ”‚
โ”‚  In practice: ฯ€ outputs mean and std of Gaussian               โ”‚
โ”‚  a ~ tanh(ฮผ + ฯƒ ยท ฮต),  ฮต ~ N(0, I)                            โ”‚
โ”‚                                                                 โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

2. SAC Algorithm

2.1 SAC Components

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚              SAC Architecture                                    โ”‚
โ”‚                                                                 โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”           โ”‚
โ”‚  โ”‚                  Actor (Policy)                    โ”‚           โ”‚
โ”‚  โ”‚  ฯ€_ฯ†(a|s): Squashed Gaussian                     โ”‚           โ”‚
โ”‚  โ”‚  Input: state s                                   โ”‚           โ”‚
โ”‚  โ”‚  Output: ฮผ(s), ฯƒ(s) โ†’ a = tanh(ฮผ + ฯƒยทฮต)         โ”‚           โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜           โ”‚
โ”‚                                                                 โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”           โ”‚
โ”‚  โ”‚              Twin Critics (Q1, Q2)                 โ”‚           โ”‚
โ”‚  โ”‚  Q_ฮธ1(s, a), Q_ฮธ2(s, a)                          โ”‚           โ”‚
โ”‚  โ”‚  Input: state s, action a                         โ”‚           โ”‚
โ”‚  โ”‚  Output: Q-value                                  โ”‚           โ”‚
โ”‚  โ”‚  โ†’ Use min(Q1, Q2) to prevent overestimation     โ”‚           โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜           โ”‚
โ”‚                                                                 โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”           โ”‚
โ”‚  โ”‚              Target Networks (Q1', Q2')            โ”‚           โ”‚
โ”‚  โ”‚  Soft update: ฮธ' โ† ฯ„ฮธ + (1-ฯ„)ฮธ'                  โ”‚           โ”‚
โ”‚  โ”‚  Provides stable targets for critic training      โ”‚           โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜           โ”‚
โ”‚                                                                 โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”           โ”‚
โ”‚  โ”‚              Temperature (ฮฑ)                       โ”‚           โ”‚
โ”‚  โ”‚  Controls entropy bonus                           โ”‚           โ”‚
โ”‚  โ”‚  Can be fixed or automatically tuned              โ”‚           โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜           โ”‚
โ”‚                                                                 โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

2.2 SAC Update Rules

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚              SAC Training Steps                                  โ”‚
โ”‚                                                                 โ”‚
โ”‚  For each gradient step:                                        โ”‚
โ”‚                                                                 โ”‚
โ”‚  1. Sample batch (s, a, r, s', done) from replay buffer         โ”‚
โ”‚                                                                 โ”‚
โ”‚  2. Compute target:                                             โ”‚
โ”‚     a' ~ ฯ€_ฯ†(ยท|s')                                              โ”‚
โ”‚     y = r + ฮณ(1-done) ร— [min(Q'โ‚(s',a'), Q'โ‚‚(s',a'))          โ”‚
โ”‚                           - ฮฑ log ฯ€_ฯ†(a'|s')]                   โ”‚
โ”‚                                                                 โ”‚
โ”‚  3. Update Critics (minimize MSE):                              โ”‚
โ”‚     L_Q = E[(Q_ฮธi(s,a) - y)ยฒ]  for i = 1, 2                  โ”‚
โ”‚                                                                 โ”‚
โ”‚  4. Update Actor (maximize):                                    โ”‚
โ”‚     รฃ ~ ฯ€_ฯ†(ยท|s)  (reparameterization trick)                   โ”‚
โ”‚     L_ฯ€ = E[ฮฑ log ฯ€_ฯ†(รฃ|s) - min(Q_ฮธ1(s,รฃ), Q_ฮธ2(s,รฃ))]    โ”‚
โ”‚                                                                 โ”‚
โ”‚  5. Update Temperature (if auto-tuning):                        โ”‚
โ”‚     L_ฮฑ = E[-ฮฑ (log ฯ€_ฯ†(รฃ|s) + H_target)]                    โ”‚
โ”‚                                                                 โ”‚
โ”‚  6. Soft update target networks:                                โ”‚
โ”‚     ฮธ'i โ† ฯ„ ฮธi + (1-ฯ„) ฮธ'i                                    โ”‚
โ”‚                                                                 โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

3. SAC Implementation

3.1 Actor Network (Squashed Gaussian Policy)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal

LOG_STD_MIN = -20
LOG_STD_MAX = 2

class GaussianActor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        self.mean_head = nn.Linear(hidden_dim, action_dim)
        self.log_std_head = nn.Linear(hidden_dim, action_dim)

    def forward(self, state):
        x = self.net(state)
        mean = self.mean_head(x)
        log_std = self.log_std_head(x)
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        return mean, log_std

    def sample(self, state):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal = Normal(mean, std)

        # Reparameterization trick: z = ฮผ + ฯƒยทฮต
        z = normal.rsample()

        # Squash through tanh
        action = torch.tanh(z)

        # Log probability with correction for tanh squashing
        log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + 1e-6)
        log_prob = log_prob.sum(dim=-1, keepdim=True)

        return action, log_prob

    def deterministic_action(self, state):
        mean, _ = self.forward(state)
        return torch.tanh(mean)

3.2 Critic Networks

class TwinQCritic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super().__init__()

        # Q1 network
        self.q1 = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

        # Q2 network
        self.q2 = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        return self.q1(x), self.q2(x)

    def q1_forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        return self.q1(x)

3.3 SAC Agent

import copy

class SACAgent:
    def __init__(self, state_dim, action_dim, hidden_dim=256,
                 lr=3e-4, gamma=0.99, tau=0.005, alpha=0.2,
                 auto_alpha=True, target_entropy=None):

        self.gamma = gamma
        self.tau = tau
        self.auto_alpha = auto_alpha

        # Networks
        self.actor = GaussianActor(state_dim, action_dim, hidden_dim)
        self.critic = TwinQCritic(state_dim, action_dim, hidden_dim)
        self.critic_target = copy.deepcopy(self.critic)

        # Freeze target parameters
        for param in self.critic_target.parameters():
            param.requires_grad = False

        # Optimizers
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr)

        # Temperature (alpha)
        if auto_alpha:
            self.target_entropy = target_entropy or -action_dim
            self.log_alpha = torch.zeros(1, requires_grad=True)
            self.alpha = self.log_alpha.exp().item()
            self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=lr)
        else:
            self.alpha = alpha

    def select_action(self, state, deterministic=False):
        state = torch.FloatTensor(state).unsqueeze(0)
        with torch.no_grad():
            if deterministic:
                action = self.actor.deterministic_action(state)
            else:
                action, _ = self.actor.sample(state)
        return action.squeeze(0).numpy()

    def update(self, batch):
        states, actions, rewards, next_states, dones = batch

        # --- Update Critics ---
        with torch.no_grad():
            next_actions, next_log_probs = self.actor.sample(next_states)
            q1_target, q2_target = self.critic_target(next_states, next_actions)
            q_target = torch.min(q1_target, q2_target)
            target = rewards + self.gamma * (1 - dones) * \
                     (q_target - self.alpha * next_log_probs)

        q1, q2 = self.critic(states, actions)
        critic_loss = F.mse_loss(q1, target) + F.mse_loss(q2, target)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # --- Update Actor ---
        new_actions, log_probs = self.actor.sample(states)
        q1_new, q2_new = self.critic(states, new_actions)
        q_new = torch.min(q1_new, q2_new)

        actor_loss = (self.alpha * log_probs - q_new).mean()

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # --- Update Temperature ---
        if self.auto_alpha:
            alpha_loss = -(self.log_alpha * (log_probs.detach() + self.target_entropy)).mean()

            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            self.alpha = self.log_alpha.exp().item()

        # --- Soft Update Target Networks ---
        for param, target_param in zip(
            self.critic.parameters(), self.critic_target.parameters()
        ):
            target_param.data.copy_(
                self.tau * param.data + (1 - self.tau) * target_param.data
            )

        return {
            'critic_loss': critic_loss.item(),
            'actor_loss': actor_loss.item(),
            'alpha': self.alpha,
            'entropy': -log_probs.mean().item()
        }

3.4 Replay Buffer

import numpy as np

class ReplayBuffer:
    def __init__(self, state_dim, action_dim, capacity=1_000_000):
        self.capacity = capacity
        self.idx = 0
        self.size = 0

        self.states = np.zeros((capacity, state_dim), dtype=np.float32)
        self.actions = np.zeros((capacity, action_dim), dtype=np.float32)
        self.rewards = np.zeros((capacity, 1), dtype=np.float32)
        self.next_states = np.zeros((capacity, state_dim), dtype=np.float32)
        self.dones = np.zeros((capacity, 1), dtype=np.float32)

    def add(self, state, action, reward, next_state, done):
        self.states[self.idx] = state
        self.actions[self.idx] = action
        self.rewards[self.idx] = reward
        self.next_states[self.idx] = next_state
        self.dones[self.idx] = done

        self.idx = (self.idx + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def sample(self, batch_size):
        idxs = np.random.randint(0, self.size, size=batch_size)
        return (
            torch.FloatTensor(self.states[idxs]),
            torch.FloatTensor(self.actions[idxs]),
            torch.FloatTensor(self.rewards[idxs]),
            torch.FloatTensor(self.next_states[idxs]),
            torch.FloatTensor(self.dones[idxs])
        )

    def __len__(self):
        return self.size

3.5 Training Loop

import gymnasium as gym

def train_sac(env_name='Pendulum-v1', total_steps=100_000,
              batch_size=256, start_steps=5000):
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    action_scale = env.action_space.high[0]

    agent = SACAgent(state_dim, action_dim)
    buffer = ReplayBuffer(state_dim, action_dim)

    state, _ = env.reset()
    episode_reward = 0
    episode_rewards = []

    for step in range(total_steps):
        # Random actions for initial exploration
        if step < start_steps:
            action = env.action_space.sample()
        else:
            action = agent.select_action(state) * action_scale

        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        buffer.add(state, action / action_scale, reward, next_state, float(terminated))

        state = next_state
        episode_reward += reward

        if done:
            episode_rewards.append(episode_reward)
            state, _ = env.reset()
            episode_reward = 0

        # Update after collecting enough data
        if step >= start_steps and len(buffer) >= batch_size:
            batch = buffer.sample(batch_size)
            metrics = agent.update(batch)

            if step % 1000 == 0:
                avg_reward = np.mean(episode_rewards[-10:]) if episode_rewards else 0
                print(f"Step {step}: avg_reward={avg_reward:.1f}, "
                      f"alpha={metrics['alpha']:.3f}, "
                      f"entropy={metrics['entropy']:.3f}")

    return agent, episode_rewards

4. Automatic Temperature Tuning

4.1 Why Auto-Tuning Matters

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚              Temperature (ฮฑ) Effect                              โ”‚
โ”‚                                                                 โ”‚
โ”‚  ฮฑ too high:                     ฮฑ too low:                     โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”         โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”       โ”‚
โ”‚  โ”‚ Entropy dominates    โ”‚         โ”‚ Return dominates     โ”‚       โ”‚
โ”‚  โ”‚ โ†’ nearly random     โ”‚         โ”‚ โ†’ premature          โ”‚       โ”‚
โ”‚  โ”‚ โ†’ slow learning     โ”‚         โ”‚   convergence        โ”‚       โ”‚
โ”‚  โ”‚ โ†’ poor performance  โ”‚         โ”‚ โ†’ poor exploration   โ”‚       โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜         โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜       โ”‚
โ”‚                                                                 โ”‚
โ”‚  Auto-tuning:                                                   โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”           โ”‚
โ”‚  โ”‚ Constraint: H(ฯ€(ยท|s)) โ‰ฅ H_target               โ”‚           โ”‚
โ”‚  โ”‚ If entropy < target: increase ฮฑ (explore more)   โ”‚           โ”‚
โ”‚  โ”‚ If entropy > target: decrease ฮฑ (exploit more)   โ”‚           โ”‚
โ”‚  โ”‚ H_target = -dim(A) (heuristic for continuous)    โ”‚           โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜           โ”‚
โ”‚                                                                 โ”‚
โ”‚  Typical ฮฑ trajectory during training:                          โ”‚
โ”‚  ฮฑ                                                              โ”‚
โ”‚  โ”‚                                                              โ”‚
โ”‚  โ”‚  โ•ฒ                                                           โ”‚
โ”‚  โ”‚   โ•ฒ                                                          โ”‚
โ”‚  โ”‚    โ•ฒ___                                                      โ”‚
โ”‚  โ”‚        โ•ฒ____                                                 โ”‚
โ”‚  โ”‚             โ•ฒ_________                                       โ”‚
โ”‚  โ”‚                       โ”€โ”€โ”€โ”€โ”€                                  โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ–ถ steps                        โ”‚
โ”‚  (starts high for exploration, decreases as policy converges)   โ”‚
โ”‚                                                                 โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

4.2 Temperature Loss

# Automatic temperature tuning objective
# L(ฮฑ) = E_a~ฯ€ [-ฮฑ log ฯ€(a|s) - ฮฑ H_target]
# = E_a~ฯ€ [-ฮฑ (log ฯ€(a|s) + H_target)]

# In the update step:
alpha_loss = -(self.log_alpha * (log_probs.detach() + self.target_entropy)).mean()

# Intuition:
# When entropy (-log_probs) < H_target: log_probs + H_target > 0
# โ†’ gradient pushes log_alpha up โ†’ ฮฑ increases โ†’ more entropy encouraged
# When entropy > H_target: log_probs + H_target < 0
# โ†’ gradient pushes log_alpha down โ†’ ฮฑ decreases

5. SAC vs Other Algorithms

5.1 Comparison Table

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚               โ”‚ SAC      โ”‚ PPO      โ”‚ TD3      โ”‚ DDPG          โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ Policy type   โ”‚ Stochas. โ”‚ Stochas. โ”‚ Determin.โ”‚ Deterministic โ”‚
โ”‚ On/Off policy โ”‚ Off      โ”‚ On       โ”‚ Off      โ”‚ Off           โ”‚
โ”‚ Action space  โ”‚ Contin.  โ”‚ Both     โ”‚ Contin.  โ”‚ Continuous    โ”‚
โ”‚ Entropy reg.  โ”‚ Yes      โ”‚ Yes      โ”‚ No       โ”‚ No            โ”‚
โ”‚ Twin critics  โ”‚ Yes      โ”‚ No       โ”‚ Yes      โ”‚ No            โ”‚
โ”‚ Sample eff.   โ”‚ High     โ”‚ Low      โ”‚ High     โ”‚ Medium        โ”‚
โ”‚ Stability     โ”‚ High     โ”‚ High     โ”‚ Medium   โ”‚ Low           โ”‚
โ”‚ Hyperparams   โ”‚ Few      โ”‚ Many     โ”‚ Medium   โ”‚ Many          โ”‚
โ”‚ Auto-tuning   โ”‚ ฮฑ tuning โ”‚ No       โ”‚ No       โ”‚ No            โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

5.2 When to Use SAC

Use SAC when:
โœ“ Continuous action spaces (robotics, control)
โœ“ Sample efficiency matters (real-world, expensive simulation)
โœ“ You want stable training with minimal tuning
โœ“ Multi-modal optimal policies exist

Use PPO instead when:
โœ“ Discrete action spaces
โœ“ On-policy learning is preferred
โœ“ Simulation is cheap (can generate many samples)
โœ“ Distributed training (PPO scales better)

Use TD3 instead when:
โœ“ Deterministic policy is preferred
โœ“ Simpler implementation needed
โœ“ No entropy regularization wanted

6. Practical Tips

6.1 Hyperparameters

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚ Hyperparameter         โ”‚ Default      โ”‚ Notes                    โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ Learning rate          โ”‚ 3e-4         โ”‚ Same for actor & critic  โ”‚
โ”‚ Discount (ฮณ)           โ”‚ 0.99         โ”‚ Standard                 โ”‚
โ”‚ Soft update (ฯ„)        โ”‚ 0.005        โ”‚ Slow target updates      โ”‚
โ”‚ Batch size             โ”‚ 256          โ”‚ Larger is more stable    โ”‚
โ”‚ Buffer size            โ”‚ 1M           โ”‚ Large replay buffer      โ”‚
โ”‚ Hidden layers          โ”‚ (256, 256)   โ”‚ 2 layers is standard     โ”‚
โ”‚ Start steps            โ”‚ 5000-10000   โ”‚ Random exploration first โ”‚
โ”‚ Target entropy         โ”‚ -dim(A)      โ”‚ Heuristic, works well    โ”‚
โ”‚ Gradient steps/env stepโ”‚ 1            โ”‚ 1:1 ratio is standard    โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

6.2 Common Issues and Solutions

Issue: Training instability / Q-values diverge
โ†’ Check reward scale (normalize if needed)
โ†’ Reduce learning rate
โ†’ Increase batch size

Issue: Low entropy (premature convergence)
โ†’ Enable auto alpha tuning
โ†’ Increase initial alpha
โ†’ Check action bounds

Issue: Slow learning
โ†’ Increase start_steps for better initial exploration
โ†’ Try larger networks
โ†’ Check reward shaping

Issue: Action values saturating at bounds
โ†’ Ensure proper action scaling
โ†’ Check tanh squashing implementation
โ†’ Verify log_prob correction term

7. Practice Problems

Exercise 1: SAC on Pendulum

Train SAC on Pendulum-v1 and plot the learning curve.

# Train SAC
agent, rewards = train_sac('Pendulum-v1', total_steps=50_000)

# Plot learning curve
import matplotlib.pyplot as plt
window = 10
smoothed = np.convolve(rewards, np.ones(window)/window, mode='valid')
plt.plot(smoothed)
plt.xlabel('Episode')
plt.ylabel('Episode Reward')
plt.title('SAC on Pendulum-v1')
plt.show()

# Expected: converges to ~-200 within 20K steps

Exercise 2: Compare SAC vs PPO

Train both SAC and PPO on a continuous control task and compare sample efficiency.

# Use HalfCheetah-v4 or Hopper-v4
# Plot reward vs environment steps for both algorithms
# Expected: SAC reaches same performance in ~5x fewer environment steps
# But PPO may have lower wall-clock time per step

Exercise 3: Ablation Study

Run SAC with these variations and compare: 1. Fixed alpha = 0.2 (no auto-tuning) 2. Auto alpha (default) 3. No entropy term (alpha = 0, equivalent to TD3-like) 4. Single Q-network (no twin critics)

# Expected findings:
# - Auto alpha > fixed alpha (adapts to task)
# - With entropy > without (better exploration)
# - Twin critics > single (prevents overestimation)

Exercise 4: Custom Environment

Apply SAC to a custom continuous control task.

# Example: reaching task
import gymnasium as gym
from gymnasium import spaces

class ReachingEnv(gym.Env):
    """2D reaching task: move arm tip to target."""

    def __init__(self):
        self.observation_space = spaces.Box(-np.inf, np.inf, shape=(4,))
        self.action_space = spaces.Box(-1.0, 1.0, shape=(2,))
        self.target = np.array([0.5, 0.5])

    def reset(self, seed=None):
        super().reset(seed=seed)
        self.pos = np.random.uniform(-1, 1, size=2)
        return np.concatenate([self.pos, self.target]), {}

    def step(self, action):
        self.pos = np.clip(self.pos + action * 0.1, -1, 1)
        dist = np.linalg.norm(self.pos - self.target)
        reward = -dist
        done = dist < 0.05
        return np.concatenate([self.pos, self.target]), reward, done, False, {}

Summary

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚              SAC Key Components                                  โ”‚
โ”‚                                                                 โ”‚
โ”‚  1. Maximum entropy objective: reward + ฮฑ ร— entropy             โ”‚
โ”‚  2. Squashed Gaussian policy: a = tanh(ฮผ + ฯƒฮต)                 โ”‚
โ”‚  3. Twin Q-critics: min(Q1, Q2) prevents overestimation        โ”‚
โ”‚  4. Automatic temperature: ฮฑ adapts to maintain target entropy  โ”‚
โ”‚  5. Off-policy: high sample efficiency via replay buffer        โ”‚
โ”‚                                                                 โ”‚
โ”‚  SAC is the go-to algorithm for continuous control tasks         โ”‚
โ”‚  due to its stability, sample efficiency, and minimal tuning.   โ”‚
โ”‚                                                                 โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

References

to navigate between lessons