13. Model-Based Reinforcement Learning
13. Model-Based Reinforcement Learning¶
Difficulty: ββββ (Advanced)
Learning Objectives¶
- Understand the distinction between model-free and model-based RL
- Implement the Dyna architecture for planning with learned models
- Learn world model approaches (Dreamer, MuZero)
- Apply Model-Based Policy Optimization (MBPO)
- Understand when model-based methods outperform model-free ones
Table of Contents¶
- Model-Free vs Model-Based RL
- Dyna Architecture
- Learning World Models
- Model-Based Policy Optimization (MBPO)
- MuZero: Planning without a Known Model
- Dreamer: World Models for Continuous Control
- Practice Problems
1. Model-Free vs Model-Based RL¶
1.1 Comparison¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Model-Free vs Model-Based RL β
β β
β Model-Free: β
β ββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β Agent ββ(action)βββΆ Environment ββ(s',r)βββΆ Agent β
β β β β
β β β’ Learn value/policy directly from experience β β
β β β’ No explicit model of dynamics β β
β β β’ Examples: DQN, PPO, SAC β β
β β β’ Simple but sample-inefficient β β
β ββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
β Model-Based: β
β ββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β Agent ββ(action)βββΆ Environment ββ(s',r)βββΆ Agent β
β β β β β β
β β β ββββββββββββββββββββ β β β
β β ββββββββββΆβ Learned Model βββββββββββ β β
β β β Ε', rΜ = f(s,a) β β β
β β ββββββββββ¬ββββββββββ β β
β β β β β
β β Planning β β
β β (simulated rollouts) β β
β β β β
β β β’ Learn a model of environment dynamics β β
β β β’ Plan using the learned model β β
β β β’ Examples: Dyna, MBPO, MuZero, Dreamer β β
β β β’ Sample-efficient but model errors accumulate β β
β ββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
1.2 Trade-offs¶
ββββββββββββββββββββββββ¬βββββββββββββββββββ¬ββββββββββββββββββββββββ
β β Model-Free β Model-Based β
ββββββββββββββββββββββββΌβββββββββββββββββββΌββββββββββββββββββββββββ€
β Sample efficiency β Low β High (10-100x fewer) β
β Asymptotic perf. β High β Limited by model err β
β Computation β Low per step β High (planning) β
β Implementation β Simpler β More complex β
β Robustness β More robust β Sensitive to model β
β Best for β Simulation-heavy β Real-world, expensive β
β β environments β interactions β
ββββββββββββββββββββββββ΄βββββββββββββββββββ΄ββββββββββββββββββββββββ
2. Dyna Architecture¶
2.1 Dyna-Q Algorithm¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Dyna-Q Architecture β
β β
β ββββββββββββββββββββββββββββββββββββββββββββββ β
β β Real Experience β β
β β s ββ(a)βββΆ Environment βββΆ s', r β β
β βββββββββββββββββββ¬βββββββββββββββββββββββββββ β
β β β
β ββββββββββΌβββββββββ β
β βΌ βΌ βΌ β
β ββββββββββββ ββββββββ ββββββββββββ β
β β Direct β βModel β β Planning β β
β β RL β βLearn β β (n steps)β β
β β Q-update β β β β β β
β ββββββββββββ ββββββββ ββββββββββββ β
β β β β
β ββββββββββββ¬βββββββββββ β
β βΌ β
β Q-value Table / Network β
β β
β Loop: β
β 1. Act in real environment, observe (s, a, r, s') β
β 2. Direct RL: Update Q(s,a) β
β 3. Model learning: Update model(s,a) β (rΜ, Ε') β
β 4. Planning: Repeat n times: β
β - Sample random (s, a) from experience β
β - Simulate: rΜ, Ε' = model(s, a) β
β - Update Q(s, a) using simulated experience β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
2.2 Dyna-Q Implementation¶
import numpy as np
from collections import defaultdict
class DynaQ:
def __init__(self, n_states, n_actions, alpha=0.1, gamma=0.99,
epsilon=0.1, n_planning=5):
self.Q = np.zeros((n_states, n_actions))
self.model = {} # (s, a) β (r, s')
self.alpha = alpha
self.gamma = gamma
self.epsilon = epsilon
self.n_planning = n_planning
self.visited = [] # track visited (s, a) pairs
def select_action(self, state):
if np.random.random() < self.epsilon:
return np.random.randint(self.Q.shape[1])
return np.argmax(self.Q[state])
def update(self, s, a, r, s_next, done):
# Step 1: Direct RL update
target = r if done else r + self.gamma * np.max(self.Q[s_next])
self.Q[s, a] += self.alpha * (target - self.Q[s, a])
# Step 2: Model learning
self.model[(s, a)] = (r, s_next, done)
if (s, a) not in self.visited:
self.visited.append((s, a))
# Step 3: Planning (n simulated updates)
for _ in range(self.n_planning):
# Sample random previously visited (s, a)
idx = np.random.randint(len(self.visited))
sim_s, sim_a = self.visited[idx]
sim_r, sim_s_next, sim_done = self.model[(sim_s, sim_a)]
# Q-learning update with simulated experience
sim_target = sim_r if sim_done else sim_r + self.gamma * np.max(self.Q[sim_s_next])
self.Q[sim_s, sim_a] += self.alpha * (sim_target - self.Q[sim_s, sim_a])
2.3 Dyna-Q+ (Exploration Bonus)¶
class DynaQPlus(DynaQ):
"""Dyna-Q+ adds exploration bonus for states not visited recently."""
def __init__(self, *args, kappa=0.001, **kwargs):
super().__init__(*args, **kwargs)
self.kappa = kappa
self.last_visit = defaultdict(int) # (s, a) β last time step
self.time_step = 0
def update(self, s, a, r, s_next, done):
self.time_step += 1
self.last_visit[(s, a)] = self.time_step
# Direct RL update (same as Dyna-Q)
target = r if done else r + self.gamma * np.max(self.Q[s_next])
self.Q[s, a] += self.alpha * (target - self.Q[s, a])
# Model learning
self.model[(s, a)] = (r, s_next, done)
if (s, a) not in self.visited:
self.visited.append((s, a))
# Planning with exploration bonus
for _ in range(self.n_planning):
idx = np.random.randint(len(self.visited))
sim_s, sim_a = self.visited[idx]
sim_r, sim_s_next, sim_done = self.model[(sim_s, sim_a)]
# Add bonus for unvisited time
tau = self.time_step - self.last_visit.get((sim_s, sim_a), 0)
bonus = self.kappa * np.sqrt(tau)
sim_target = (sim_r + bonus) if sim_done else \
(sim_r + bonus) + self.gamma * np.max(self.Q[sim_s_next])
self.Q[sim_s, sim_a] += self.alpha * (sim_target - self.Q[sim_s, sim_a])
3. Learning World Models¶
3.1 Neural Network Dynamics Model¶
import torch
import torch.nn as nn
import torch.optim as optim
class DynamicsModel(nn.Module):
"""Predicts next state and reward given current state and action."""
def __init__(self, state_dim, action_dim, hidden_dim=256):
super().__init__()
self.network = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
self.state_head = nn.Linear(hidden_dim, state_dim)
self.reward_head = nn.Linear(hidden_dim, 1)
def forward(self, state, action):
x = torch.cat([state, action], dim=-1)
features = self.network(x)
next_state = state + self.state_head(features) # predict residual
reward = self.reward_head(features)
return next_state, reward.squeeze(-1)
class EnsembleDynamicsModel(nn.Module):
"""Ensemble of dynamics models for uncertainty estimation."""
def __init__(self, state_dim, action_dim, n_models=5, hidden_dim=256):
super().__init__()
self.models = nn.ModuleList([
DynamicsModel(state_dim, action_dim, hidden_dim)
for _ in range(n_models)
])
def forward(self, state, action):
predictions = [model(state, action) for model in self.models]
next_states = torch.stack([p[0] for p in predictions])
rewards = torch.stack([p[1] for p in predictions])
# Mean prediction
mean_next_state = next_states.mean(dim=0)
mean_reward = rewards.mean(dim=0)
# Uncertainty (disagreement between models)
uncertainty = next_states.std(dim=0).mean(dim=-1)
return mean_next_state, mean_reward, uncertainty
3.2 Training the Model¶
class ModelTrainer:
def __init__(self, ensemble, lr=1e-3):
self.ensemble = ensemble
self.optimizers = [
optim.Adam(model.parameters(), lr=lr)
for model in ensemble.models
]
def train(self, replay_buffer, batch_size=256, epochs=5):
for epoch in range(epochs):
for i, (model, optimizer) in enumerate(
zip(self.ensemble.models, self.optimizers)
):
# Each model trained on different bootstrap sample
states, actions, rewards, next_states = \
replay_buffer.sample(batch_size)
pred_next_states, pred_rewards = model(states, actions)
state_loss = nn.MSELoss()(pred_next_states, next_states)
reward_loss = nn.MSELoss()(pred_rewards, rewards)
loss = state_loss + reward_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
3.3 Model Error and Compounding¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Model Error Compounding β
β β
β Real trajectory: sβ β sβ β sβ β sβ β ... β
β β
β Model rollout: sβ β Εβ β Εβ β Εβ β ... β
β β β β β
β small medium LARGE error β
β β
β Error grows exponentially with rollout length! β
β β
β Mitigation strategies: β
β 1. Short rollouts (H = 1-5 steps) β
β 2. Ensemble disagreement as uncertainty β
β 3. Truncate when uncertainty is high β
β 4. Mix real and simulated data β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
4. Model-Based Policy Optimization (MBPO)¶
4.1 MBPO Algorithm¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β MBPO: Model-Based Policy Optimization β
β β
β Key idea: Use learned model for SHORT rollouts only β
β (branched from real states) β
β β
β Algorithm: β
β 1. Collect real data D_env from environment β
β 2. Train ensemble dynamics model on D_env β
β 3. For each real state s in D_env: β
β - Generate k-step model rollout (k = 1~5) β
β - Add simulated transitions to D_model β
β 4. Train SAC policy on D_env βͺ D_model β
β 5. Repeat β
β β
β Benefits: β
β β’ 10-100x more sample efficient than SAC alone β
β β’ Model only used for short rollouts β less error β
β β’ Guaranteed monotonic improvement (under assumptions) β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
4.2 Simplified MBPO Implementation¶
class MBPO:
def __init__(self, env, state_dim, action_dim,
model_rollout_length=1, rollouts_per_step=400):
self.env = env
self.ensemble = EnsembleDynamicsModel(state_dim, action_dim)
self.model_trainer = ModelTrainer(self.ensemble)
# SAC as the model-free backbone
self.policy = SACAgent(state_dim, action_dim)
self.env_buffer = ReplayBuffer(capacity=1_000_000)
self.model_buffer = ReplayBuffer(capacity=1_000_000)
self.rollout_length = model_rollout_length
self.rollouts_per_step = rollouts_per_step
def train(self, total_steps=100_000):
state, _ = self.env.reset()
for step in range(total_steps):
# 1. Real environment interaction
action = self.policy.select_action(state)
next_state, reward, done, truncated, _ = self.env.step(action)
self.env_buffer.add(state, action, reward, next_state, done)
state = next_state if not (done or truncated) else self.env.reset()[0]
# 2. Train dynamics model periodically
if step % 250 == 0 and len(self.env_buffer) > 1000:
self.model_trainer.train(self.env_buffer, epochs=5)
# 3. Generate model rollouts
if len(self.env_buffer) > 1000:
self._generate_model_rollouts()
# 4. Train policy on mixed data
if len(self.env_buffer) > 1000:
# Sample from both buffers
real_batch = self.env_buffer.sample(128)
model_batch = self.model_buffer.sample(128) \
if len(self.model_buffer) > 128 else real_batch
self.policy.update(real_batch, model_batch)
def _generate_model_rollouts(self):
"""Branch short rollouts from real states."""
states = self.env_buffer.sample_states(self.rollouts_per_step)
for state in states:
s = state.clone()
for h in range(self.rollout_length):
a = self.policy.select_action(s)
s_next, r, uncertainty = self.ensemble(
s.unsqueeze(0), a.unsqueeze(0)
)
# Stop rollout if model is uncertain
if uncertainty.item() > 0.5:
break
self.model_buffer.add(
s.numpy(), a.numpy(), r.item(),
s_next.squeeze(0).detach().numpy(), False
)
s = s_next.squeeze(0).detach()
5. MuZero: Planning without a Known Model¶
5.1 MuZero Architecture¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β MuZero: Three Learned Functions β
β β
β 1. Representation: h(observation) β hidden state β
β βββββββββββββ β
β β obs (o_t) ββββΆ h_ΞΈ βββΆ s_0 (hidden state) β
β βββββββββββββ β
β β
β 2. Dynamics: g(s_k, a_k) β s_{k+1}, r_k β
β ββββββββββββββββββββ β
β β s_k + action a_k ββββΆ g_ΞΈ βββΆ s_{k+1}, rΜ_k β
β ββββββββββββββββββββ β
β β
β 3. Prediction: f(s_k) β policy, value β
β βββββββββββββ β
β β s_k ββββΆ f_ΞΈ βββΆ Ο_k, v_k β
β βββββββββββββ β
β β
β Key insight: Model operates in LEARNED hidden state space, β
β not observation space β no need to predict pixels β
β β
β Planning: MCTS (Monte Carlo Tree Search) using learned model β
β β
β Results: β
β β’ Atari: superhuman in 57 games β
β β’ Go/Chess/Shogi: matches AlphaZero without rules β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
5.2 MuZero Planning (MCTS)¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β MCTS in MuZero β
β β
β For each action decision: β
β 1. Run N simulations through the learned model β
β 2. Each simulation: β
β a. SELECT: traverse tree using UCB β
β b. EXPAND: use dynamics model g(s,a) β s', rΜ β
β c. EVALUATE: use prediction model f(s') β Ο, v β
β d. BACKUP: update visit counts and values β
β β
β Tree after 50 simulations: β
β β
β sβ (root = current state) β
β / | \ β
β aβ aβ aβ β
β / | \ β
β sβ sβ sβ N(aβ)=20, N(aβ)=25, N(aβ)=5 β
β / \ / \ | β
β aβ aβ aβ aβ aβ Q(aβ) highest β select aβ β
β ... β
β β
β Final action: proportional to visit count N(a) at root β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
6. Dreamer: World Models for Continuous Control¶
6.1 Dreamer Architecture¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β DreamerV3 Architecture β
β β
β World Model (RSSM β Recurrent State-Space Model): β
β β
β ββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β Deterministic path: β β
β β h_t = f_ΞΈ(h_{t-1}, z_{t-1}, a_{t-1}) β β
β β β β
β β Stochastic path: β β
β β Prior: αΊ_t ~ p_ΞΈ(αΊ_t | h_t) β β
β β Posterior: z_t ~ q_ΞΈ(z_t | h_t, o_t) β β
β β β β
β β Decoder: Γ΄_t = dec_ΞΈ(h_t, z_t) β β
β β Reward: rΜ_t = rew_ΞΈ(h_t, z_t) β β
β ββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
β Actor-Critic (trained entirely in imagination): β
β ββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β Imagine trajectories using world model only β β
β β Actor: a_t ~ Ο_ΞΈ(a_t | h_t, z_t) β β
β β Critic: v_ΞΈ(h_t, z_t) β β
β β No real environment interaction during policy β β
β β training! β β
β ββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
β Results: β
β β’ First single algorithm to master 150+ diverse tasks β
β β’ Atari, DMControl, Minecraft diamond without task-specific β
β tuning β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
6.2 Imagination-Based Policy Learning¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Dreamer Training Loop β
β β
β Outer loop (real environment): β
β 1. Collect experience with current policy β replay buffer β
β 2. Train world model on replay buffer β
β β
β Inner loop (imagination): β
β 3. Sample starting states from replay buffer β
β 4. "Dream" H-step trajectories using world model: β
β sβ β sβ β sβ β ... β s_H (all in latent space) β
β 5. Compute imagined rewards and values β
β 6. Update actor and critic on imagined trajectories β
β β
β Key advantage: Can train policy on 10000s of imagined β
β trajectories per real environment step β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
7. Practice Problems¶
Exercise 1: Dyna-Q on GridWorld¶
Implement Dyna-Q and compare performance with different planning steps (n=0, 5, 50).
# Starter code
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
# Simple GridWorld (or use FrozenLake)
env = gym.make("FrozenLake-v1", is_slippery=False)
results = {}
for n_planning in [0, 5, 50]:
agent = DynaQ(
n_states=env.observation_space.n,
n_actions=env.action_space.n,
n_planning=n_planning
)
episode_rewards = []
for episode in range(500):
state, _ = env.reset()
total_reward = 0
done = False
while not done:
action = agent.select_action(state)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
agent.update(state, action, reward, next_state, done)
state = next_state
total_reward += reward
episode_rewards.append(total_reward)
results[n_planning] = episode_rewards
# Plot: more planning steps β faster learning
for n, rewards in results.items():
# Smooth with moving average
smoothed = np.convolve(rewards, np.ones(20)/20, mode='valid')
plt.plot(smoothed, label=f'n_planning={n}')
plt.xlabel('Episode')
plt.ylabel('Reward (smoothed)')
plt.legend()
plt.title('Dyna-Q: Effect of Planning Steps')
plt.show()
Exercise 2: Neural Dynamics Model¶
Train a neural network dynamics model on CartPole and evaluate prediction accuracy.
# Collect data from random policy
# Train DynamicsModel to predict next state
# Evaluate: 1-step prediction error vs multi-step rollout error
# Show that error grows with rollout length
# Key metrics to plot:
# - 1-step prediction MSE
# - k-step rollout MSE for k = 1, 5, 10, 20
# - Ensemble disagreement correlation with actual error
Exercise 3: Compare Sample Efficiency¶
Compare model-free SAC vs MBPO on a continuous control task.
# Use HalfCheetah-v4 or Pendulum-v1
# Plot: reward vs environment steps
# Expected: MBPO reaches good performance in ~10x fewer steps
# But: MBPO has higher wall-clock time per step (model training + planning)
Summary¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Model-Based RL Landscape β
β β
β Simple Complex β
β ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
β β
β Dyna-Q MBPO MuZero DreamerV3 β
β (tabular) (ensemble + (MCTS + (RSSM + β
β short learned imagination) β
β rollouts) hidden space) β
β β
β ββββββββ ββββββββ ββββββββ ββββββββ β
β βSampleβ βSampleβ βSampleβ βSampleβ β
β βeff: β βeff: β βeff: β βeff: β β
β β Med β β High β βV.Highβ βV.Highβ β
β ββββββββ ββββββββ ββββββββ ββββββββ β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Key Takeaways: - Model-based RL trades computation for sample efficiency - Ensemble models provide uncertainty estimates - Short rollouts mitigate compounding model errors - MuZero: planning in learned latent space (no observation prediction) - Dreamer: entire policy training in imagination
References¶
- Sutton & Barto Ch. 8: "Planning and Learning with Tabular Methods"
- MBPO Paper β Janner et al. 2019
- MuZero Paper β Schrittwieser et al. 2020
- DreamerV3 Paper β Hafner et al. 2023
- Spinning Up: Model-Based Methods