11_model_based_dyna.py

Download
python 300 lines 8.3 KB
  1"""
  2Dyna-Q: Model-Based Reinforcement Learning
  3===========================================
  4
  5Dyna-Q combines model-free Q-learning with model-based planning.
  6The agent learns:
  71. Q-values from real experience (Q-learning)
  82. A model of the environment (transition dynamics)
  93. Uses the model to generate simulated experience for additional learning
 10
 11This demonstrates the sample efficiency benefit of model-based planning.
 12
 13Requirements: gymnasium, numpy, matplotlib
 14"""
 15
 16import gymnasium as gym
 17import numpy as np
 18import matplotlib.pyplot as plt
 19from collections import defaultdict
 20from typing import Tuple, List
 21
 22
 23class DynaQ:
 24    """
 25    Dyna-Q agent with tabular Q-learning and model-based planning.
 26
 27    The agent maintains:
 28    - Q-table: state-action values
 29    - Model: transition dynamics (s, a) -> (s', r)
 30    - Memory: visited state-action pairs for planning
 31    """
 32
 33    def __init__(
 34        self,
 35        n_states: int,
 36        n_actions: int,
 37        learning_rate: float = 0.1,
 38        gamma: float = 0.95,
 39        epsilon: float = 0.1,
 40        n_planning: int = 5
 41    ):
 42        """
 43        Initialize Dyna-Q agent.
 44
 45        Args:
 46            n_states: Number of states
 47            n_actions: Number of actions
 48            learning_rate: Q-learning rate
 49            gamma: Discount factor
 50            epsilon: Exploration rate
 51            n_planning: Number of planning steps per real step
 52        """
 53        self.n_states = n_states
 54        self.n_actions = n_actions
 55        self.lr = learning_rate
 56        self.gamma = gamma
 57        self.epsilon = epsilon
 58        self.n_planning = n_planning
 59
 60        # Q-table: Q[s, a]
 61        self.Q = np.zeros((n_states, n_actions))
 62
 63        # Model: stores (next_state, reward) for each (state, action)
 64        # Using defaultdict to handle unseen state-action pairs
 65        self.model = {}
 66
 67        # Memory: set of visited (state, action) pairs
 68        self.memory = set()
 69
 70    def select_action(self, state: int) -> int:
 71        """
 72        Select action using epsilon-greedy policy.
 73
 74        Args:
 75            state: Current state
 76
 77        Returns:
 78            Selected action
 79        """
 80        if np.random.random() < self.epsilon:
 81            return np.random.randint(self.n_actions)
 82        else:
 83            return np.argmax(self.Q[state])
 84
 85    def update_q(self, state: int, action: int, reward: float, next_state: int):
 86        """
 87        Update Q-value using Q-learning rule.
 88
 89        Args:
 90            state: Current state
 91            action: Action taken
 92            reward: Reward received
 93            next_state: Next state
 94        """
 95        # Q-learning update
 96        td_target = reward + self.gamma * np.max(self.Q[next_state])
 97        td_error = td_target - self.Q[state, action]
 98        self.Q[state, action] += self.lr * td_error
 99
100    def update_model(self, state: int, action: int, reward: float, next_state: int):
101        """
102        Update model with new transition.
103
104        Args:
105            state: Current state
106            action: Action taken
107            reward: Reward received
108            next_state: Next state
109        """
110        # Store transition in model
111        self.model[(state, action)] = (next_state, reward)
112        # Add to memory for planning
113        self.memory.add((state, action))
114
115    def plan(self):
116        """
117        Perform model-based planning by sampling from model.
118
119        For n_planning steps:
120        1. Sample a previously visited (s, a) pair
121        2. Use model to get (s', r)
122        3. Update Q-value with simulated experience
123        """
124        for _ in range(self.n_planning):
125            if not self.memory:
126                break
127
128            # Sample random state-action from memory
129            state, action = list(self.memory)[np.random.randint(len(self.memory))]
130
131            # Get predicted next state and reward from model
132            next_state, reward = self.model[(state, action)]
133
134            # Update Q-value with simulated experience
135            self.update_q(state, action, reward, next_state)
136
137    def learn(
138        self,
139        state: int,
140        action: int,
141        reward: float,
142        next_state: int,
143        done: bool
144    ):
145        """
146        Main learning step: Q-learning + model update + planning.
147
148        Args:
149            state: Current state
150            action: Action taken
151            reward: Reward received
152            next_state: Next state
153            done: Episode termination flag
154        """
155        # (a) Q-learning from real experience
156        self.update_q(state, action, reward, next_state)
157
158        # (b) Update model
159        self.update_model(state, action, reward, next_state)
160
161        # (c) Model-based planning
162        self.plan()
163
164
165def train_dyna_q(
166    env: gym.Env,
167    agent: DynaQ,
168    n_episodes: int = 500
169) -> List[float]:
170    """
171    Train Dyna-Q agent on environment.
172
173    Args:
174        env: Gymnasium environment
175        agent: DynaQ agent
176        n_episodes: Number of training episodes
177
178    Returns:
179        List of episode rewards
180    """
181    episode_rewards = []
182
183    for episode in range(n_episodes):
184        state, _ = env.reset()
185        episode_reward = 0
186        done = False
187
188        while not done:
189            # Select action
190            action = agent.select_action(state)
191
192            # Take action
193            next_state, reward, terminated, truncated, _ = env.step(action)
194            done = terminated or truncated
195
196            # Learn from transition
197            agent.learn(state, action, reward, next_state, done)
198
199            state = next_state
200            episode_reward += reward
201
202        episode_rewards.append(episode_reward)
203
204        # Print progress
205        if (episode + 1) % 100 == 0:
206            avg_reward = np.mean(episode_rewards[-100:])
207            print(f"Episode {episode + 1}/{n_episodes}, "
208                  f"Avg Reward (last 100): {avg_reward:.2f}, "
209                  f"Planning steps: {agent.n_planning}")
210
211    return episode_rewards
212
213
214def moving_average(data: List[float], window: int = 50) -> np.ndarray:
215    """
216    Compute moving average for smoothing learning curves.
217
218    Args:
219        data: Input data
220        window: Window size
221
222    Returns:
223        Smoothed data
224    """
225    cumsum = np.cumsum(np.insert(data, 0, 0))
226    return (cumsum[window:] - cumsum[:-window]) / window
227
228
229def compare_planning_steps():
230    """
231    Compare Dyna-Q performance with different numbers of planning steps.
232    Demonstrates that more planning leads to faster learning.
233    """
234    # FrozenLake-v1 (4x4 grid, deterministic)
235    env = gym.make('FrozenLake-v1', is_slippery=False)
236
237    n_states = env.observation_space.n
238    n_actions = env.action_space.n
239    n_episodes = 500
240
241    # Test different planning steps
242    planning_configs = [0, 5, 50]  # 0 = pure Q-learning
243    results = {}
244
245    print("Training agents with different planning steps...\n")
246
247    for n_planning in planning_configs:
248        print(f"\n{'='*60}")
249        print(f"Training with n_planning = {n_planning}")
250        print(f"{'='*60}")
251
252        # Create agent
253        agent = DynaQ(
254            n_states=n_states,
255            n_actions=n_actions,
256            learning_rate=0.1,
257            gamma=0.95,
258            epsilon=0.1,
259            n_planning=n_planning
260        )
261
262        # Train agent
263        rewards = train_dyna_q(env, agent, n_episodes)
264        results[n_planning] = rewards
265
266    env.close()
267
268    # Plot learning curves
269    plt.figure(figsize=(12, 6))
270
271    for n_planning, rewards in results.items():
272        smoothed = moving_average(rewards, window=50)
273        label = f'n_planning={n_planning}'
274        if n_planning == 0:
275            label += ' (Pure Q-learning)'
276        plt.plot(smoothed, label=label, linewidth=2)
277
278    plt.xlabel('Episode', fontsize=12)
279    plt.ylabel('Average Reward (50-episode window)', fontsize=12)
280    plt.title('Dyna-Q: Effect of Planning Steps on Learning Speed', fontsize=14, fontweight='bold')
281    plt.legend(fontsize=11)
282    plt.grid(True, alpha=0.3)
283    plt.tight_layout()
284
285    # Print final performance
286    print(f"\n{'='*60}")
287    print("Final Performance (last 100 episodes):")
288    print(f"{'='*60}")
289    for n_planning, rewards in results.items():
290        avg_reward = np.mean(rewards[-100:])
291        print(f"n_planning={n_planning:2d}: {avg_reward:.3f}")
292
293    plt.savefig('/opt/projects/01_Personal/03_Study/examples/Reinforcement_Learning/dyna_q_comparison.png', dpi=150)
294    print(f"\nPlot saved to: dyna_q_comparison.png")
295    plt.show()
296
297
298if __name__ == '__main__':
299    compare_planning_steps()