13. Model-Based Reinforcement Learning

13. Model-Based Reinforcement Learning

Difficulty: ⭐⭐⭐⭐ (Advanced)

Learning Objectives

  • Understand the distinction between model-free and model-based RL
  • Implement the Dyna architecture for planning with learned models
  • Learn world model approaches (Dreamer, MuZero)
  • Apply Model-Based Policy Optimization (MBPO)
  • Understand when model-based methods outperform model-free ones

Table of Contents

  1. Model-Free vs Model-Based RL
  2. Dyna Architecture
  3. Learning World Models
  4. Model-Based Policy Optimization (MBPO)
  5. MuZero: Planning without a Known Model
  6. Dreamer: World Models for Continuous Control
  7. Practice Problems

1. Model-Free vs Model-Based RL

1.1 Comparison

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚              Model-Free vs Model-Based RL                        β”‚
β”‚                                                                 β”‚
β”‚  Model-Free:                                                    β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”             β”‚
β”‚  β”‚ Agent ──(action)──▢ Environment ──(s',r)──▢ Agent           β”‚
β”‚  β”‚                                                β”‚             β”‚
β”‚  β”‚ β€’ Learn value/policy directly from experience  β”‚             β”‚
β”‚  β”‚ β€’ No explicit model of dynamics                β”‚             β”‚
β”‚  β”‚ β€’ Examples: DQN, PPO, SAC                      β”‚             β”‚
β”‚  β”‚ β€’ Simple but sample-inefficient                β”‚             β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜             β”‚
β”‚                                                                 β”‚
β”‚  Model-Based:                                                   β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”             β”‚
β”‚  β”‚ Agent ──(action)──▢ Environment ──(s',r)──▢ Agent           β”‚
β”‚  β”‚   β”‚                                      β”‚     β”‚             β”‚
β”‚  β”‚   β”‚         β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”         β”‚     β”‚             β”‚
β”‚  β”‚   └────────▢│  Learned Model   β”‚β—€β”€β”€β”€β”€β”€β”€β”€β”€β”˜     β”‚             β”‚
β”‚  β”‚             β”‚  ŝ', rΜ‚ = f(s,a) β”‚               β”‚             β”‚
β”‚  β”‚             β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜               β”‚             β”‚
β”‚  β”‚                      β”‚                         β”‚             β”‚
β”‚  β”‚                 Planning                        β”‚             β”‚
β”‚  β”‚            (simulated rollouts)                 β”‚             β”‚
β”‚  β”‚                                                β”‚             β”‚
β”‚  β”‚ β€’ Learn a model of environment dynamics         β”‚             β”‚
β”‚  β”‚ β€’ Plan using the learned model                  β”‚             β”‚
β”‚  β”‚ β€’ Examples: Dyna, MBPO, MuZero, Dreamer        β”‚             β”‚
β”‚  β”‚ β€’ Sample-efficient but model errors accumulate  β”‚             β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜             β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

1.2 Trade-offs

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                      β”‚ Model-Free       β”‚ Model-Based           β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Sample efficiency    β”‚ Low              β”‚ High (10-100x fewer)  β”‚
β”‚ Asymptotic perf.     β”‚ High             β”‚ Limited by model err  β”‚
β”‚ Computation          β”‚ Low per step     β”‚ High (planning)       β”‚
β”‚ Implementation       β”‚ Simpler          β”‚ More complex          β”‚
β”‚ Robustness           β”‚ More robust      β”‚ Sensitive to model    β”‚
β”‚ Best for             β”‚ Simulation-heavy β”‚ Real-world, expensive β”‚
β”‚                      β”‚ environments     β”‚ interactions          β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

2. Dyna Architecture

2.1 Dyna-Q Algorithm

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚              Dyna-Q Architecture                                 β”‚
β”‚                                                                 β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                 β”‚
β”‚  β”‚          Real Experience                    β”‚                 β”‚
β”‚  β”‚  s ──(a)──▢ Environment ──▢ s', r          β”‚                 β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                 β”‚
β”‚                    β”‚                                            β”‚
β”‚           β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”                                   β”‚
β”‚           β–Ό        β–Ό        β–Ό                                   β”‚
β”‚     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                         β”‚
β”‚     β”‚ Direct   β”‚ β”‚Model β”‚ β”‚ Planning β”‚                         β”‚
β”‚     β”‚ RL       β”‚ β”‚Learn β”‚ β”‚ (n steps)β”‚                         β”‚
β”‚     β”‚ Q-update β”‚ β”‚      β”‚ β”‚          β”‚                         β”‚
β”‚     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                         β”‚
β”‚           β”‚                     β”‚                               β”‚
β”‚           β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                               β”‚
β”‚                      β–Ό                                          β”‚
β”‚               Q-value Table / Network                           β”‚
β”‚                                                                 β”‚
β”‚  Loop:                                                          β”‚
β”‚  1. Act in real environment, observe (s, a, r, s')              β”‚
β”‚  2. Direct RL: Update Q(s,a)                                    β”‚
β”‚  3. Model learning: Update model(s,a) β†’ (rΜ‚, ŝ')               β”‚
β”‚  4. Planning: Repeat n times:                                   β”‚
β”‚     - Sample random (s, a) from experience                      β”‚
β”‚     - Simulate: rΜ‚, ŝ' = model(s, a)                            β”‚
β”‚     - Update Q(s, a) using simulated experience                 β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

2.2 Dyna-Q Implementation

import numpy as np
from collections import defaultdict

class DynaQ:
    def __init__(self, n_states, n_actions, alpha=0.1, gamma=0.99,
                 epsilon=0.1, n_planning=5):
        self.Q = np.zeros((n_states, n_actions))
        self.model = {}  # (s, a) β†’ (r, s')
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.n_planning = n_planning
        self.visited = []  # track visited (s, a) pairs

    def select_action(self, state):
        if np.random.random() < self.epsilon:
            return np.random.randint(self.Q.shape[1])
        return np.argmax(self.Q[state])

    def update(self, s, a, r, s_next, done):
        # Step 1: Direct RL update
        target = r if done else r + self.gamma * np.max(self.Q[s_next])
        self.Q[s, a] += self.alpha * (target - self.Q[s, a])

        # Step 2: Model learning
        self.model[(s, a)] = (r, s_next, done)
        if (s, a) not in self.visited:
            self.visited.append((s, a))

        # Step 3: Planning (n simulated updates)
        for _ in range(self.n_planning):
            # Sample random previously visited (s, a)
            idx = np.random.randint(len(self.visited))
            sim_s, sim_a = self.visited[idx]
            sim_r, sim_s_next, sim_done = self.model[(sim_s, sim_a)]

            # Q-learning update with simulated experience
            sim_target = sim_r if sim_done else sim_r + self.gamma * np.max(self.Q[sim_s_next])
            self.Q[sim_s, sim_a] += self.alpha * (sim_target - self.Q[sim_s, sim_a])

2.3 Dyna-Q+ (Exploration Bonus)

class DynaQPlus(DynaQ):
    """Dyna-Q+ adds exploration bonus for states not visited recently."""

    def __init__(self, *args, kappa=0.001, **kwargs):
        super().__init__(*args, **kwargs)
        self.kappa = kappa
        self.last_visit = defaultdict(int)  # (s, a) β†’ last time step
        self.time_step = 0

    def update(self, s, a, r, s_next, done):
        self.time_step += 1
        self.last_visit[(s, a)] = self.time_step

        # Direct RL update (same as Dyna-Q)
        target = r if done else r + self.gamma * np.max(self.Q[s_next])
        self.Q[s, a] += self.alpha * (target - self.Q[s, a])

        # Model learning
        self.model[(s, a)] = (r, s_next, done)
        if (s, a) not in self.visited:
            self.visited.append((s, a))

        # Planning with exploration bonus
        for _ in range(self.n_planning):
            idx = np.random.randint(len(self.visited))
            sim_s, sim_a = self.visited[idx]
            sim_r, sim_s_next, sim_done = self.model[(sim_s, sim_a)]

            # Add bonus for unvisited time
            tau = self.time_step - self.last_visit.get((sim_s, sim_a), 0)
            bonus = self.kappa * np.sqrt(tau)

            sim_target = (sim_r + bonus) if sim_done else \
                         (sim_r + bonus) + self.gamma * np.max(self.Q[sim_s_next])
            self.Q[sim_s, sim_a] += self.alpha * (sim_target - self.Q[sim_s, sim_a])

3. Learning World Models

3.1 Neural Network Dynamics Model

import torch
import torch.nn as nn
import torch.optim as optim

class DynamicsModel(nn.Module):
    """Predicts next state and reward given current state and action."""

    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        self.state_head = nn.Linear(hidden_dim, state_dim)
        self.reward_head = nn.Linear(hidden_dim, 1)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        features = self.network(x)
        next_state = state + self.state_head(features)  # predict residual
        reward = self.reward_head(features)
        return next_state, reward.squeeze(-1)


class EnsembleDynamicsModel(nn.Module):
    """Ensemble of dynamics models for uncertainty estimation."""

    def __init__(self, state_dim, action_dim, n_models=5, hidden_dim=256):
        super().__init__()
        self.models = nn.ModuleList([
            DynamicsModel(state_dim, action_dim, hidden_dim)
            for _ in range(n_models)
        ])

    def forward(self, state, action):
        predictions = [model(state, action) for model in self.models]
        next_states = torch.stack([p[0] for p in predictions])
        rewards = torch.stack([p[1] for p in predictions])

        # Mean prediction
        mean_next_state = next_states.mean(dim=0)
        mean_reward = rewards.mean(dim=0)

        # Uncertainty (disagreement between models)
        uncertainty = next_states.std(dim=0).mean(dim=-1)

        return mean_next_state, mean_reward, uncertainty

3.2 Training the Model

class ModelTrainer:
    def __init__(self, ensemble, lr=1e-3):
        self.ensemble = ensemble
        self.optimizers = [
            optim.Adam(model.parameters(), lr=lr)
            for model in ensemble.models
        ]

    def train(self, replay_buffer, batch_size=256, epochs=5):
        for epoch in range(epochs):
            for i, (model, optimizer) in enumerate(
                zip(self.ensemble.models, self.optimizers)
            ):
                # Each model trained on different bootstrap sample
                states, actions, rewards, next_states = \
                    replay_buffer.sample(batch_size)

                pred_next_states, pred_rewards = model(states, actions)

                state_loss = nn.MSELoss()(pred_next_states, next_states)
                reward_loss = nn.MSELoss()(pred_rewards, rewards)
                loss = state_loss + reward_loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

3.3 Model Error and Compounding

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚              Model Error Compounding                             β”‚
β”‚                                                                 β”‚
β”‚  Real trajectory:     sβ‚€ β†’ s₁ β†’ sβ‚‚ β†’ s₃ β†’ ...                β”‚
β”‚                                                                 β”‚
β”‚  Model rollout:       sβ‚€ β†’ ŝ₁ β†’ ŝ₂ β†’ ŝ₃ β†’ ...               β”‚
β”‚                            ↑     ↑     ↑                       β”‚
β”‚                          small  medium  LARGE error             β”‚
β”‚                                                                 β”‚
β”‚  Error grows exponentially with rollout length!                 β”‚
β”‚                                                                 β”‚
β”‚  Mitigation strategies:                                         β”‚
β”‚  1. Short rollouts (H = 1-5 steps)                              β”‚
β”‚  2. Ensemble disagreement as uncertainty                        β”‚
β”‚  3. Truncate when uncertainty is high                           β”‚
β”‚  4. Mix real and simulated data                                 β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

4. Model-Based Policy Optimization (MBPO)

4.1 MBPO Algorithm

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚              MBPO: Model-Based Policy Optimization               β”‚
β”‚                                                                 β”‚
β”‚  Key idea: Use learned model for SHORT rollouts only            β”‚
β”‚            (branched from real states)                           β”‚
β”‚                                                                 β”‚
β”‚  Algorithm:                                                     β”‚
β”‚  1. Collect real data D_env from environment                    β”‚
β”‚  2. Train ensemble dynamics model on D_env                      β”‚
β”‚  3. For each real state s in D_env:                             β”‚
β”‚     - Generate k-step model rollout (k = 1~5)                  β”‚
β”‚     - Add simulated transitions to D_model                      β”‚
β”‚  4. Train SAC policy on D_env βˆͺ D_model                        β”‚
β”‚  5. Repeat                                                      β”‚
β”‚                                                                 β”‚
β”‚  Benefits:                                                      β”‚
β”‚  β€’ 10-100x more sample efficient than SAC alone                 β”‚
β”‚  β€’ Model only used for short rollouts β†’ less error              β”‚
β”‚  β€’ Guaranteed monotonic improvement (under assumptions)         β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

4.2 Simplified MBPO Implementation

class MBPO:
    def __init__(self, env, state_dim, action_dim,
                 model_rollout_length=1, rollouts_per_step=400):
        self.env = env
        self.ensemble = EnsembleDynamicsModel(state_dim, action_dim)
        self.model_trainer = ModelTrainer(self.ensemble)

        # SAC as the model-free backbone
        self.policy = SACAgent(state_dim, action_dim)

        self.env_buffer = ReplayBuffer(capacity=1_000_000)
        self.model_buffer = ReplayBuffer(capacity=1_000_000)

        self.rollout_length = model_rollout_length
        self.rollouts_per_step = rollouts_per_step

    def train(self, total_steps=100_000):
        state, _ = self.env.reset()

        for step in range(total_steps):
            # 1. Real environment interaction
            action = self.policy.select_action(state)
            next_state, reward, done, truncated, _ = self.env.step(action)
            self.env_buffer.add(state, action, reward, next_state, done)
            state = next_state if not (done or truncated) else self.env.reset()[0]

            # 2. Train dynamics model periodically
            if step % 250 == 0 and len(self.env_buffer) > 1000:
                self.model_trainer.train(self.env_buffer, epochs=5)

            # 3. Generate model rollouts
            if len(self.env_buffer) > 1000:
                self._generate_model_rollouts()

            # 4. Train policy on mixed data
            if len(self.env_buffer) > 1000:
                # Sample from both buffers
                real_batch = self.env_buffer.sample(128)
                model_batch = self.model_buffer.sample(128) \
                    if len(self.model_buffer) > 128 else real_batch
                self.policy.update(real_batch, model_batch)

    def _generate_model_rollouts(self):
        """Branch short rollouts from real states."""
        states = self.env_buffer.sample_states(self.rollouts_per_step)

        for state in states:
            s = state.clone()
            for h in range(self.rollout_length):
                a = self.policy.select_action(s)
                s_next, r, uncertainty = self.ensemble(
                    s.unsqueeze(0), a.unsqueeze(0)
                )

                # Stop rollout if model is uncertain
                if uncertainty.item() > 0.5:
                    break

                self.model_buffer.add(
                    s.numpy(), a.numpy(), r.item(),
                    s_next.squeeze(0).detach().numpy(), False
                )
                s = s_next.squeeze(0).detach()

5. MuZero: Planning without a Known Model

5.1 MuZero Architecture

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚              MuZero: Three Learned Functions                     β”‚
β”‚                                                                 β”‚
β”‚  1. Representation: h(observation) β†’ hidden state               β”‚
β”‚     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                                               β”‚
β”‚     β”‚ obs (o_t) │──▢ h_ΞΈ ──▢ s_0 (hidden state)               β”‚
β”‚     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                                               β”‚
β”‚                                                                 β”‚
β”‚  2. Dynamics: g(s_k, a_k) β†’ s_{k+1}, r_k                      β”‚
β”‚     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                                        β”‚
β”‚     β”‚ s_k + action a_k │──▢ g_ΞΈ ──▢ s_{k+1}, rΜ‚_k             β”‚
β”‚     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                                        β”‚
β”‚                                                                 β”‚
β”‚  3. Prediction: f(s_k) β†’ policy, value                          β”‚
β”‚     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                                               β”‚
β”‚     β”‚   s_k     │──▢ f_ΞΈ ──▢ Ο€_k, v_k                         β”‚
β”‚     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                                               β”‚
β”‚                                                                 β”‚
β”‚  Key insight: Model operates in LEARNED hidden state space,     β”‚
β”‚  not observation space β†’ no need to predict pixels              β”‚
β”‚                                                                 β”‚
β”‚  Planning: MCTS (Monte Carlo Tree Search) using learned model   β”‚
β”‚                                                                 β”‚
β”‚  Results:                                                       β”‚
β”‚  β€’ Atari: superhuman in 57 games                                β”‚
β”‚  β€’ Go/Chess/Shogi: matches AlphaZero without rules              β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

5.2 MuZero Planning (MCTS)

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚              MCTS in MuZero                                      β”‚
β”‚                                                                 β”‚
β”‚  For each action decision:                                      β”‚
β”‚  1. Run N simulations through the learned model                 β”‚
β”‚  2. Each simulation:                                            β”‚
β”‚     a. SELECT: traverse tree using UCB                          β”‚
β”‚     b. EXPAND: use dynamics model g(s,a) β†’ s', rΜ‚              β”‚
β”‚     c. EVALUATE: use prediction model f(s') β†’ Ο€, v             β”‚
β”‚     d. BACKUP: update visit counts and values                   β”‚
β”‚                                                                 β”‚
β”‚  Tree after 50 simulations:                                     β”‚
β”‚                                                                 β”‚
β”‚             sβ‚€ (root = current state)                           β”‚
β”‚            / | \                                                β”‚
β”‚          aβ‚€  a₁  aβ‚‚                                            β”‚
β”‚         /    |     \                                            β”‚
β”‚       s₁    sβ‚‚    s₃        N(aβ‚€)=20, N(a₁)=25, N(aβ‚‚)=5     β”‚
β”‚      / \    / \    |                                            β”‚
β”‚    aβ‚€  a₁ aβ‚€  aβ‚‚  a₁       Q(a₁) highest β†’ select a₁        β”‚
β”‚    ...                                                          β”‚
β”‚                                                                 β”‚
β”‚  Final action: proportional to visit count N(a) at root         β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

6. Dreamer: World Models for Continuous Control

6.1 Dreamer Architecture

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚              DreamerV3 Architecture                               β”‚
β”‚                                                                 β”‚
β”‚  World Model (RSSM β€” Recurrent State-Space Model):              β”‚
β”‚                                                                 β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”           β”‚
β”‚  β”‚  Deterministic path:                              β”‚           β”‚
β”‚  β”‚  h_t = f_ΞΈ(h_{t-1}, z_{t-1}, a_{t-1})           β”‚           β”‚
β”‚  β”‚                                                   β”‚           β”‚
β”‚  β”‚  Stochastic path:                                 β”‚           β”‚
β”‚  β”‚  Prior:     αΊ‘_t ~ p_ΞΈ(αΊ‘_t | h_t)                β”‚           β”‚
β”‚  β”‚  Posterior: z_t ~ q_ΞΈ(z_t | h_t, o_t)            β”‚           β”‚
β”‚  β”‚                                                   β”‚           β”‚
β”‚  β”‚  Decoder:   Γ΄_t = dec_ΞΈ(h_t, z_t)                β”‚           β”‚
β”‚  β”‚  Reward:    rΜ‚_t = rew_ΞΈ(h_t, z_t)                β”‚           β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜           β”‚
β”‚                                                                 β”‚
β”‚  Actor-Critic (trained entirely in imagination):                β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”           β”‚
β”‚  β”‚  Imagine trajectories using world model only      β”‚           β”‚
β”‚  β”‚  Actor:  a_t ~ Ο€_ΞΈ(a_t | h_t, z_t)              β”‚           β”‚
β”‚  β”‚  Critic: v_ΞΈ(h_t, z_t)                           β”‚           β”‚
β”‚  β”‚  No real environment interaction during policy    β”‚           β”‚
β”‚  β”‚  training!                                        β”‚           β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜           β”‚
β”‚                                                                 β”‚
β”‚  Results:                                                       β”‚
β”‚  β€’ First single algorithm to master 150+ diverse tasks          β”‚
β”‚  β€’ Atari, DMControl, Minecraft diamond without task-specific    β”‚
β”‚    tuning                                                       β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

6.2 Imagination-Based Policy Learning

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚              Dreamer Training Loop                               β”‚
β”‚                                                                 β”‚
β”‚  Outer loop (real environment):                                 β”‚
β”‚  1. Collect experience with current policy β†’ replay buffer      β”‚
β”‚  2. Train world model on replay buffer                          β”‚
β”‚                                                                 β”‚
β”‚  Inner loop (imagination):                                      β”‚
β”‚  3. Sample starting states from replay buffer                   β”‚
β”‚  4. "Dream" H-step trajectories using world model:              β”‚
β”‚     sβ‚€ β†’ s₁ β†’ sβ‚‚ β†’ ... β†’ s_H  (all in latent space)          β”‚
β”‚  5. Compute imagined rewards and values                         β”‚
β”‚  6. Update actor and critic on imagined trajectories            β”‚
β”‚                                                                 β”‚
β”‚  Key advantage: Can train policy on 10000s of imagined          β”‚
β”‚  trajectories per real environment step                         β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

7. Practice Problems

Exercise 1: Dyna-Q on GridWorld

Implement Dyna-Q and compare performance with different planning steps (n=0, 5, 50).

# Starter code
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt

# Simple GridWorld (or use FrozenLake)
env = gym.make("FrozenLake-v1", is_slippery=False)

results = {}
for n_planning in [0, 5, 50]:
    agent = DynaQ(
        n_states=env.observation_space.n,
        n_actions=env.action_space.n,
        n_planning=n_planning
    )

    episode_rewards = []
    for episode in range(500):
        state, _ = env.reset()
        total_reward = 0
        done = False

        while not done:
            action = agent.select_action(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            agent.update(state, action, reward, next_state, done)
            state = next_state
            total_reward += reward

        episode_rewards.append(total_reward)
    results[n_planning] = episode_rewards

# Plot: more planning steps β†’ faster learning
for n, rewards in results.items():
    # Smooth with moving average
    smoothed = np.convolve(rewards, np.ones(20)/20, mode='valid')
    plt.plot(smoothed, label=f'n_planning={n}')
plt.xlabel('Episode')
plt.ylabel('Reward (smoothed)')
plt.legend()
plt.title('Dyna-Q: Effect of Planning Steps')
plt.show()

Exercise 2: Neural Dynamics Model

Train a neural network dynamics model on CartPole and evaluate prediction accuracy.

# Collect data from random policy
# Train DynamicsModel to predict next state
# Evaluate: 1-step prediction error vs multi-step rollout error
# Show that error grows with rollout length

# Key metrics to plot:
# - 1-step prediction MSE
# - k-step rollout MSE for k = 1, 5, 10, 20
# - Ensemble disagreement correlation with actual error

Exercise 3: Compare Sample Efficiency

Compare model-free SAC vs MBPO on a continuous control task.

# Use HalfCheetah-v4 or Pendulum-v1
# Plot: reward vs environment steps
# Expected: MBPO reaches good performance in ~10x fewer steps
# But: MBPO has higher wall-clock time per step (model training + planning)

Summary

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚              Model-Based RL Landscape                            β”‚
β”‚                                                                 β”‚
β”‚  Simple                                                 Complex β”‚
β”‚  ←──────────────────────────────────────────────────────────→   β”‚
β”‚                                                                 β”‚
β”‚  Dyna-Q       MBPO          MuZero         DreamerV3            β”‚
β”‚  (tabular)    (ensemble +   (MCTS +        (RSSM +              β”‚
β”‚               short         learned        imagination)         β”‚
β”‚               rollouts)     hidden space)                       β”‚
β”‚                                                                 β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β”     β”Œβ”€β”€β”€β”€β”€β”€β”       β”Œβ”€β”€β”€β”€β”€β”€β”              β”‚
β”‚  β”‚Sampleβ”‚    β”‚Sampleβ”‚     β”‚Sampleβ”‚       β”‚Sampleβ”‚              β”‚
β”‚  β”‚eff:  β”‚    β”‚eff:  β”‚     β”‚eff:  β”‚       β”‚eff:  β”‚              β”‚
β”‚  β”‚ Med  β”‚    β”‚ High β”‚     β”‚V.Highβ”‚       β”‚V.Highβ”‚              β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”˜     β””β”€β”€β”€β”€β”€β”€β”˜       β””β”€β”€β”€β”€β”€β”€β”˜              β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Key Takeaways: - Model-based RL trades computation for sample efficiency - Ensemble models provide uncertainty estimates - Short rollouts mitigate compounding model errors - MuZero: planning in learned latent space (no observation prediction) - Dreamer: entire policy training in imagination


References

to navigate between lessons