09. Actor-Critic Methods

09. Actor-Critic Methods

Difficulty: ⭐⭐⭐⭐ (Advanced)

Learning Objectives

  • Understand Actor-Critic architecture
  • Learn Advantage function and GAE
  • Compare A2C and A3C algorithms
  • Implement Actor-Critic with PyTorch

1. Actor-Critic Overview

1.1 Core Idea

Actor: Learns policy π(a|s;θ) Critic: Learns value function V(s;w)

Actor-Critic = Policy Gradient + TD Learning

1.2 REINFORCE vs Actor-Critic

REINFORCE Actor-Critic
Update after episode ends Update every step
Uses actual return G Uses TD target
High variance Lower variance, some bias

2. Advantage Function

2.1 Definition

$$A(s, a) = Q(s, a) - V(s)$$

Meaning: How much better is this action compared to average

2.2 TD Error as Advantage

δ_t = r_t + γV(s_{t+1}) - V(s_t)

E[δ_t | s_t, a_t] = Q(s_t, a_t) - V(s_t) = A(s_t, a_t)

TD Error is an unbiased estimator of Advantage.

def compute_advantage(rewards, values, next_values, dones, gamma=0.99):
    """Compute 1-step Advantage"""
    advantages = []
    for r, v, nv, d in zip(rewards, values, next_values, dones):
        if d:
            advantage = r - v
        else:
            advantage = r + gamma * nv - v
        advantages.append(advantage)
    return advantages

3. A2C (Advantage Actor-Critic)

3.1 Network Architecture

import torch
import torch.nn as nn
import torch.nn.functional as F

class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super().__init__()

        # Shared feature extraction
        self.shared = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU()
        )

        # Actor (policy)
        self.actor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )

        # Critic (value)
        self.critic = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, state):
        features = self.shared(state)
        policy = F.softmax(self.actor(features), dim=-1)
        value = self.critic(features)
        return policy, value

    def get_action(self, state):
        policy, value = self.forward(state)
        dist = torch.distributions.Categorical(policy)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action.item(), log_prob, value

3.2 A2C Agent

class A2CAgent:
    def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99,
                 value_coef=0.5, entropy_coef=0.01):
        self.network = ActorCritic(state_dim, action_dim)
        self.optimizer = torch.optim.Adam(self.network.parameters(), lr=lr)

        self.gamma = gamma
        self.value_coef = value_coef
        self.entropy_coef = entropy_coef

        # Episode buffer
        self.log_probs = []
        self.values = []
        self.rewards = []
        self.dones = []
        self.entropies = []

    def choose_action(self, state):
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        policy, value = self.network(state_tensor)

        dist = torch.distributions.Categorical(policy)
        action = dist.sample()

        self.log_probs.append(dist.log_prob(action))
        self.values.append(value)
        self.entropies.append(dist.entropy())

        return action.item()

    def store(self, reward, done):
        self.rewards.append(reward)
        self.dones.append(done)

    def compute_returns(self, next_value):
        """Compute n-step returns"""
        returns = []
        R = next_value

        for r, d in zip(reversed(self.rewards), reversed(self.dones)):
            if d:
                R = 0
            R = r + self.gamma * R
            returns.insert(0, R)

        return torch.tensor(returns)

    def update(self, next_state):
        # Next state value (bootstrapping)
        with torch.no_grad():
            _, next_value = self.network(
                torch.FloatTensor(next_state).unsqueeze(0)
            )
            next_value = next_value.item()

        returns = self.compute_returns(next_value)
        values = torch.cat(self.values).squeeze()
        log_probs = torch.stack(self.log_probs)
        entropies = torch.stack(self.entropies)

        # Advantage
        advantages = returns - values.detach()

        # Compute losses
        actor_loss = -(log_probs * advantages).mean()
        critic_loss = F.mse_loss(values, returns)
        entropy_loss = -entropies.mean()

        total_loss = (actor_loss +
                     self.value_coef * critic_loss +
                     self.entropy_coef * entropy_loss)

        # Update
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.network.parameters(), 0.5)
        self.optimizer.step()

        # Reset buffer
        self.log_probs = []
        self.values = []
        self.rewards = []
        self.dones = []
        self.entropies = []

        return actor_loss.item(), critic_loss.item()

3.3 A2C Training

import gymnasium as gym
import numpy as np

def train_a2c(env_name='CartPole-v1', n_episodes=1000, n_steps=5):
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    agent = A2CAgent(state_dim, action_dim)
    scores = []

    for episode in range(n_episodes):
        state, _ = env.reset()
        total_reward = 0
        done = False
        step_count = 0

        while not done:
            action = agent.choose_action(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            agent.store(reward, done)
            state = next_state
            total_reward += reward
            step_count += 1

            # n-step update
            if step_count % n_steps == 0 or done:
                agent.update(next_state)

        scores.append(total_reward)

        if (episode + 1) % 100 == 0:
            print(f"Episode {episode + 1}, Avg: {np.mean(scores[-100:]):.2f}")

    return agent, scores

4. GAE (Generalized Advantage Estimation)

4.1 n-step Returns Tradeoff

n Bias Variance
1 (TD) High Low
∞ (MC) Low High

4.2 GAE Formula

Exponentially-weighted average of all n-step advantages:

$$A^{GAE}\_t = \sum\_{l=0}^{\infty} (\gamma \lambda)^l \delta\_{t+l}$$

where δ_t = r_t + γV(s_{t+1}) - V(s_t)

def compute_gae(rewards, values, next_values, dones, gamma=0.99, lam=0.95):
    """Generalized Advantage Estimation"""
    advantages = []
    gae = 0

    for t in reversed(range(len(rewards))):
        if dones[t]:
            delta = rewards[t] - values[t]
            gae = delta
        else:
            delta = rewards[t] + gamma * next_values[t] - values[t]
            gae = delta + gamma * lam * gae

        advantages.insert(0, gae)

    return torch.tensor(advantages)

4.3 A2C with GAE

class A2CWithGAE(A2CAgent):
    def __init__(self, *args, gae_lambda=0.95, **kwargs):
        super().__init__(*args, **kwargs)
        self.gae_lambda = gae_lambda
        self.next_values = []

    def compute_gae_returns(self):
        """GAE-based advantage and returns"""
        values = torch.cat(self.values).squeeze().tolist()
        next_vals = self.next_values

        advantages = []
        gae = 0

        for t in reversed(range(len(self.rewards))):
            if self.dones[t]:
                delta = self.rewards[t] - values[t]
                gae = delta
            else:
                delta = self.rewards[t] + self.gamma * next_vals[t] - values[t]
                gae = delta + self.gamma * self.gae_lambda * gae

            advantages.insert(0, gae)

        advantages = torch.tensor(advantages)
        returns = advantages + torch.tensor(values)

        return advantages, returns

5. A3C (Asynchronous Advantage Actor-Critic)

5.1 Core Idea

Multiple workers interact with environments in parallel and asynchronously update gradients.

┌─────────────────────────────────────┐
│          Global Network             │
│         (Shared Parameters)          │
└──────────┬──────────┬───────────────┘
           │          │
     ┌─────┴────┐  ┌──┴─────┐
     │ Worker 1 │  │ Worker 2│  ...
     │   Env 1  │  │  Env 2  │
     └──────────┘  └─────────┘

5.2 Pseudocode

# Worker behavior
def worker(global_network, optimizer, env):
    local_network = copy(global_network)

    while True:
        # Collect experience with local network
        trajectory = collect_trajectory(local_network, env)

        # Compute gradients
        loss = compute_loss(trajectory)
        gradients = compute_gradients(loss, local_network)

        # Asynchronous update
        apply_gradients(optimizer, global_network, gradients)

        # Sync local network
        local_network.load_state_dict(global_network.state_dict())

5.3 A2C vs A3C

A2C A3C
Synchronous update Asynchronous update
Batch processing Stream processing
More stable Faster (parallel)
GPU efficient CPU efficient

Current Recommendation: A2C is more commonly used as it's more efficient on GPUs


6. Continuous Action Space Actor-Critic

class ContinuousActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super().__init__()

        # Shared network
        self.shared = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

        # Actor: mean and standard deviation
        self.actor_mean = nn.Linear(hidden_dim, action_dim)
        self.actor_log_std = nn.Parameter(torch.zeros(action_dim))

        # Critic
        self.critic = nn.Linear(hidden_dim, 1)

    def forward(self, state):
        features = self.shared(state)
        mean = self.actor_mean(features)
        std = self.actor_log_std.exp()
        value = self.critic(features)
        return mean, std, value

    def get_action(self, state, deterministic=False):
        mean, std, value = self.forward(state)

        if deterministic:
            action = mean
        else:
            dist = torch.distributions.Normal(mean, std)
            action = dist.sample()

        return action, value

    def evaluate(self, state, action):
        mean, std, value = self.forward(state)
        dist = torch.distributions.Normal(mean, std)

        log_prob = dist.log_prob(action).sum(-1, keepdim=True)
        entropy = dist.entropy().sum(-1, keepdim=True)

        return value, log_prob, entropy

7. Learning Stabilization Techniques

7.1 Gradient Clipping

# Gradient norm clipping
torch.nn.utils.clip_grad_norm_(network.parameters(), max_norm=0.5)

# Gradient value clipping
torch.nn.utils.clip_grad_value_(network.parameters(), clip_value=1.0)

7.2 Learning Rate Scheduling

from torch.optim.lr_scheduler import LinearLR

scheduler = LinearLR(
    optimizer,
    start_factor=1.0,
    end_factor=0.1,
    total_iters=total_timesteps
)

7.3 Reward Normalization

class RunningMeanStd:
    def __init__(self):
        self.mean = 0
        self.var = 1
        self.count = 1e-4

    def update(self, x):
        batch_mean = np.mean(x)
        batch_var = np.var(x)
        batch_count = len(x)

        delta = batch_mean - self.mean
        total_count = self.count + batch_count

        self.mean += delta * batch_count / total_count
        self.var = (self.var * self.count + batch_var * batch_count +
                   delta**2 * self.count * batch_count / total_count) / total_count
        self.count = total_count

    def normalize(self, x):
        return (x - self.mean) / (np.sqrt(self.var) + 1e-8)

8. Practice: LunarLander

def train_lunarlander():
    env = gym.make('LunarLander-v2')
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    agent = A2CAgent(
        state_dim, action_dim,
        lr=7e-4, gamma=0.99,
        value_coef=0.5, entropy_coef=0.01
    )

    scores = []
    n_steps = 5

    for episode in range(2000):
        state, _ = env.reset()
        total_reward = 0
        steps = 0

        while True:
            action = agent.choose_action(state)
            next_state, reward, done, truncated, _ = env.step(action)

            agent.store(reward, done or truncated)
            state = next_state
            total_reward += reward
            steps += 1

            if steps % n_steps == 0 or done or truncated:
                agent.update(next_state)

            if done or truncated:
                break

        scores.append(total_reward)

        if (episode + 1) % 100 == 0:
            avg = np.mean(scores[-100:])
            print(f"Episode {episode + 1}, Avg: {avg:.2f}")

            if avg >= 200:
                print("Solved!")
                break

    return agent, scores

Summary

Component Role Learning Target
Actor Policy θ (policy parameters)
Critic Value evaluation w (value parameters)
Advantage Action quality measure A = Q - V ≈ δ

Loss Function:

L = L_actor + c1 * L_critic + c2 * L_entropy
L_actor = -log π(a|s) * A
L_critic = (V - target)²
L_entropy = -Σ π log π

Next Steps

to navigate between lessons