12_sac_continuous.py

Download
python 501 lines 15.2 KB
  1"""
  2Soft Actor-Critic (SAC) for Continuous Control
  3===============================================
  4
  5SAC is a state-of-the-art off-policy RL algorithm for continuous action spaces.
  6Key features:
  71. Maximum Entropy RL: encourages exploration through entropy regularization
  82. Twin Q-networks: reduces overestimation bias
  93. Squashed Gaussian policy: bounded continuous actions
 104. Automatic temperature tuning: adaptive entropy coefficient
 11
 12This simplified implementation demonstrates core SAC concepts on Pendulum-v1.
 13
 14Requirements: torch, gymnasium, numpy, matplotlib
 15"""
 16
 17import gymnasium as gym
 18import numpy as np
 19import torch
 20import torch.nn as nn
 21import torch.nn.functional as F
 22import torch.optim as optim
 23from collections import deque
 24import random
 25import matplotlib.pyplot as plt
 26from typing import Tuple, List
 27
 28
 29# Set random seeds for reproducibility
 30torch.manual_seed(42)
 31np.random.seed(42)
 32random.seed(42)
 33
 34
 35class ReplayBuffer:
 36    """
 37    Experience replay buffer for off-policy learning.
 38    """
 39
 40    def __init__(self, capacity: int = 100000):
 41        """
 42        Initialize replay buffer.
 43
 44        Args:
 45            capacity: Maximum buffer size
 46        """
 47        self.buffer = deque(maxlen=capacity)
 48
 49    def push(self, state, action, reward, next_state, done):
 50        """Store transition in buffer."""
 51        self.buffer.append((state, action, reward, next_state, done))
 52
 53    def sample(self, batch_size: int) -> Tuple:
 54        """
 55        Sample random batch from buffer.
 56
 57        Args:
 58            batch_size: Number of transitions to sample
 59
 60        Returns:
 61            Tuple of batched (states, actions, rewards, next_states, dones)
 62        """
 63        batch = random.sample(self.buffer, batch_size)
 64        states, actions, rewards, next_states, dones = zip(*batch)
 65
 66        return (
 67            np.array(states),
 68            np.array(actions),
 69            np.array(rewards, dtype=np.float32),
 70            np.array(next_states),
 71            np.array(dones, dtype=np.float32)
 72        )
 73
 74    def __len__(self):
 75        return len(self.buffer)
 76
 77
 78class GaussianActor(nn.Module):
 79    """
 80    Gaussian policy network with tanh squashing.
 81    Outputs mean and log_std for a Gaussian distribution.
 82    Actions are sampled and squashed to [-1, 1] range.
 83    """
 84
 85    def __init__(
 86        self,
 87        state_dim: int,
 88        action_dim: int,
 89        hidden_dim: int = 256,
 90        log_std_min: float = -20,
 91        log_std_max: float = 2
 92    ):
 93        """
 94        Initialize actor network.
 95
 96        Args:
 97            state_dim: Observation dimension
 98            action_dim: Action dimension
 99            hidden_dim: Hidden layer size
100            log_std_min: Minimum log standard deviation
101            log_std_max: Maximum log standard deviation
102        """
103        super().__init__()
104
105        self.log_std_min = log_std_min
106        self.log_std_max = log_std_max
107
108        # Shared layers
109        self.fc1 = nn.Linear(state_dim, hidden_dim)
110        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
111
112        # Separate heads for mean and log_std
113        self.mean_head = nn.Linear(hidden_dim, action_dim)
114        self.log_std_head = nn.Linear(hidden_dim, action_dim)
115
116    def forward(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
117        """
118        Forward pass through network.
119
120        Args:
121            state: State tensor
122
123        Returns:
124            (mean, log_std) for Gaussian distribution
125        """
126        x = F.relu(self.fc1(state))
127        x = F.relu(self.fc2(x))
128
129        mean = self.mean_head(x)
130        log_std = self.log_std_head(x)
131        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
132
133        return mean, log_std
134
135    def sample(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
136        """
137        Sample action from policy and compute log probability.
138
139        Args:
140            state: State tensor
141
142        Returns:
143            (action, log_prob) where action is squashed to [-1, 1]
144        """
145        mean, log_std = self.forward(state)
146        std = log_std.exp()
147
148        # Sample from Gaussian
149        normal = torch.distributions.Normal(mean, std)
150        x = normal.rsample()  # Reparameterization trick
151
152        # Squash to [-1, 1] using tanh
153        action = torch.tanh(x)
154
155        # Compute log probability with change of variables correction
156        log_prob = normal.log_prob(x)
157        log_prob -= torch.log(1 - action.pow(2) + 1e-6)  # tanh correction
158        log_prob = log_prob.sum(dim=-1, keepdim=True)
159
160        return action, log_prob
161
162
163class TwinQCritic(nn.Module):
164    """
165    Twin Q-networks to reduce overestimation bias.
166    Both networks have identical architecture but independent parameters.
167    """
168
169    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 256):
170        """
171        Initialize twin Q-networks.
172
173        Args:
174            state_dim: Observation dimension
175            action_dim: Action dimension
176            hidden_dim: Hidden layer size
177        """
178        super().__init__()
179
180        # Q1 network
181        self.q1_fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
182        self.q1_fc2 = nn.Linear(hidden_dim, hidden_dim)
183        self.q1_out = nn.Linear(hidden_dim, 1)
184
185        # Q2 network
186        self.q2_fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
187        self.q2_fc2 = nn.Linear(hidden_dim, hidden_dim)
188        self.q2_out = nn.Linear(hidden_dim, 1)
189
190    def forward(self, state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
191        """
192        Compute Q-values from both networks.
193
194        Args:
195            state: State tensor
196            action: Action tensor
197
198        Returns:
199            (Q1_value, Q2_value)
200        """
201        x = torch.cat([state, action], dim=-1)
202
203        # Q1
204        q1 = F.relu(self.q1_fc1(x))
205        q1 = F.relu(self.q1_fc2(q1))
206        q1 = self.q1_out(q1)
207
208        # Q2
209        q2 = F.relu(self.q2_fc1(x))
210        q2 = F.relu(self.q2_fc2(q2))
211        q2 = self.q2_out(q2)
212
213        return q1, q2
214
215
216class SACAgent:
217    """
218    Soft Actor-Critic agent with automatic temperature tuning.
219    """
220
221    def __init__(
222        self,
223        state_dim: int,
224        action_dim: int,
225        lr: float = 3e-4,
226        gamma: float = 0.99,
227        tau: float = 0.005,
228        alpha: float = 0.2,
229        auto_tune_alpha: bool = True
230    ):
231        """
232        Initialize SAC agent.
233
234        Args:
235            state_dim: Observation dimension
236            action_dim: Action dimension
237            lr: Learning rate
238            gamma: Discount factor
239            tau: Polyak averaging coefficient for target networks
240            alpha: Entropy coefficient (initial value if auto-tuning)
241            auto_tune_alpha: Whether to automatically tune alpha
242        """
243        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
244        self.gamma = gamma
245        self.tau = tau
246        self.auto_tune_alpha = auto_tune_alpha
247
248        # Actor network
249        self.actor = GaussianActor(state_dim, action_dim).to(self.device)
250        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
251
252        # Twin Q-networks
253        self.critic = TwinQCritic(state_dim, action_dim).to(self.device)
254        self.critic_target = TwinQCritic(state_dim, action_dim).to(self.device)
255        self.critic_target.load_state_dict(self.critic.state_dict())
256        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr)
257
258        # Automatic temperature tuning
259        if self.auto_tune_alpha:
260            self.target_entropy = -action_dim  # Heuristic: -dim(A)
261            self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
262            self.alpha_optimizer = optim.Adam([self.log_alpha], lr=lr)
263            self.alpha = self.log_alpha.exp().item()
264        else:
265            self.alpha = alpha
266
267    def select_action(self, state: np.ndarray, deterministic: bool = False) -> np.ndarray:
268        """
269        Select action from policy.
270
271        Args:
272            state: Current state
273            deterministic: If True, return mean action (for evaluation)
274
275        Returns:
276            Selected action
277        """
278        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
279
280        with torch.no_grad():
281            if deterministic:
282                mean, _ = self.actor(state)
283                action = torch.tanh(mean)
284            else:
285                action, _ = self.actor.sample(state)
286
287        return action.cpu().numpy()[0]
288
289    def update(self, batch_size: int, replay_buffer: ReplayBuffer):
290        """
291        Update networks using a batch from replay buffer.
292
293        Args:
294            batch_size: Batch size for sampling
295            replay_buffer: Replay buffer to sample from
296        """
297        # Sample batch
298        states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
299
300        states = torch.FloatTensor(states).to(self.device)
301        actions = torch.FloatTensor(actions).to(self.device)
302        rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
303        next_states = torch.FloatTensor(next_states).to(self.device)
304        dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)
305
306        # ==================== Update Critic ====================
307        with torch.no_grad():
308            # Sample next actions from current policy
309            next_actions, next_log_probs = self.actor.sample(next_states)
310
311            # Compute target Q-values using target networks
312            q1_target, q2_target = self.critic_target(next_states, next_actions)
313            min_q_target = torch.min(q1_target, q2_target)
314
315            # Add entropy term
316            target_q = rewards + (1 - dones) * self.gamma * (min_q_target - self.alpha * next_log_probs)
317
318        # Current Q-values
319        q1, q2 = self.critic(states, actions)
320
321        # Critic loss: MSE between current and target Q-values
322        critic_loss = F.mse_loss(q1, target_q) + F.mse_loss(q2, target_q)
323
324        # Update critic
325        self.critic_optimizer.zero_grad()
326        critic_loss.backward()
327        self.critic_optimizer.step()
328
329        # ==================== Update Actor ====================
330        # Sample actions from current policy
331        new_actions, log_probs = self.actor.sample(states)
332
333        # Compute Q-values for new actions
334        q1_new, q2_new = self.critic(states, new_actions)
335        min_q_new = torch.min(q1_new, q2_new)
336
337        # Actor loss: maximize Q - alpha * log_prob (equivalent to minimize negative)
338        actor_loss = (self.alpha * log_probs - min_q_new).mean()
339
340        # Update actor
341        self.actor_optimizer.zero_grad()
342        actor_loss.backward()
343        self.actor_optimizer.step()
344
345        # ==================== Update Temperature ====================
346        if self.auto_tune_alpha:
347            alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
348
349            self.alpha_optimizer.zero_grad()
350            alpha_loss.backward()
351            self.alpha_optimizer.step()
352
353            self.alpha = self.log_alpha.exp().item()
354
355        # ==================== Update Target Networks ====================
356        # Polyak averaging: target = tau * current + (1 - tau) * target
357        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
358            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
359
360
361def train_sac(
362    env_name: str = 'Pendulum-v1',
363    max_episodes: int = 100,
364    max_steps: int = 200,
365    batch_size: int = 128,
366    buffer_capacity: int = 100000,
367    warmup_steps: int = 1000
368) -> List[float]:
369    """
370    Train SAC agent on continuous control environment.
371
372    Args:
373        env_name: Gymnasium environment name
374        max_episodes: Number of training episodes
375        max_steps: Maximum steps per episode
376        batch_size: Batch size for updates
377        buffer_capacity: Replay buffer capacity
378        warmup_steps: Random exploration steps before training
379
380    Returns:
381        List of episode rewards
382    """
383    # Create environment
384    env = gym.make(env_name)
385
386    state_dim = env.observation_space.shape[0]
387    action_dim = env.action_space.shape[0]
388
389    # Create agent and replay buffer
390    agent = SACAgent(state_dim, action_dim)
391    replay_buffer = ReplayBuffer(buffer_capacity)
392
393    episode_rewards = []
394    total_steps = 0
395
396    print(f"Training SAC on {env_name}...")
397    print(f"State dim: {state_dim}, Action dim: {action_dim}\n")
398
399    # Warmup: collect random transitions
400    state, _ = env.reset()
401    for _ in range(warmup_steps):
402        action = env.action_space.sample()
403        next_state, reward, terminated, truncated, _ = env.step(action)
404        done = terminated or truncated
405
406        replay_buffer.push(state, action, reward, next_state, done)
407
408        if done:
409            state, _ = env.reset()
410        else:
411            state = next_state
412
413    print(f"Warmup complete: {warmup_steps} steps collected\n")
414
415    # Training loop
416    for episode in range(max_episodes):
417        state, _ = env.reset()
418        episode_reward = 0
419
420        for step in range(max_steps):
421            # Select action
422            action = agent.select_action(state)
423
424            # Take action
425            next_state, reward, terminated, truncated, _ = env.step(action)
426            done = terminated or truncated
427
428            # Store transition
429            replay_buffer.push(state, action, reward, next_state, done)
430
431            # Update agent
432            agent.update(batch_size, replay_buffer)
433
434            state = next_state
435            episode_reward += reward
436            total_steps += 1
437
438            if done:
439                break
440
441        episode_rewards.append(episode_reward)
442
443        # Print progress
444        if (episode + 1) % 10 == 0:
445            avg_reward = np.mean(episode_rewards[-10:])
446            print(f"Episode {episode + 1}/{max_episodes}, "
447                  f"Avg Reward (last 10): {avg_reward:.2f}, "
448                  f"Alpha: {agent.alpha:.3f}")
449
450    env.close()
451    return episode_rewards
452
453
454def plot_results(rewards: List[float]):
455    """
456    Plot learning curve.
457
458    Args:
459        rewards: List of episode rewards
460    """
461    plt.figure(figsize=(10, 6))
462    plt.plot(rewards, alpha=0.3, label='Episode Reward')
463
464    # Moving average
465    window = 10
466    if len(rewards) >= window:
467        moving_avg = np.convolve(rewards, np.ones(window)/window, mode='valid')
468        plt.plot(range(window-1, len(rewards)), moving_avg, linewidth=2, label=f'{window}-Episode Moving Average')
469
470    plt.xlabel('Episode', fontsize=12)
471    plt.ylabel('Reward', fontsize=12)
472    plt.title('SAC Training on Pendulum-v1', fontsize=14, fontweight='bold')
473    plt.legend(fontsize=11)
474    plt.grid(True, alpha=0.3)
475    plt.tight_layout()
476
477    plt.savefig('/opt/projects/01_Personal/03_Study/examples/Reinforcement_Learning/sac_training.png', dpi=150)
478    print(f"\nPlot saved to: sac_training.png")
479    plt.show()
480
481
482if __name__ == '__main__':
483    # Train agent
484    rewards = train_sac(
485        env_name='Pendulum-v1',
486        max_episodes=100,
487        max_steps=200,
488        batch_size=128,
489        warmup_steps=1000
490    )
491
492    # Plot results
493    plot_results(rewards)
494
495    # Print final performance
496    print(f"\n{'='*60}")
497    print("Training Complete!")
498    print(f"{'='*60}")
499    print(f"Final 10-episode average reward: {np.mean(rewards[-10:]):.2f}")
500    print(f"Best 10-episode average reward: {max([np.mean(rewards[i:i+10]) for i in range(len(rewards)-10)]):.2f}")