1"""
2Soft Actor-Critic (SAC) for Continuous Control
3===============================================
4
5SAC is a state-of-the-art off-policy RL algorithm for continuous action spaces.
6Key features:
71. Maximum Entropy RL: encourages exploration through entropy regularization
82. Twin Q-networks: reduces overestimation bias
93. Squashed Gaussian policy: bounded continuous actions
104. Automatic temperature tuning: adaptive entropy coefficient
11
12This simplified implementation demonstrates core SAC concepts on Pendulum-v1.
13
14Requirements: torch, gymnasium, numpy, matplotlib
15"""
16
17import gymnasium as gym
18import numpy as np
19import torch
20import torch.nn as nn
21import torch.nn.functional as F
22import torch.optim as optim
23from collections import deque
24import random
25import matplotlib.pyplot as plt
26from typing import Tuple, List
27
28
29# Set random seeds for reproducibility
30torch.manual_seed(42)
31np.random.seed(42)
32random.seed(42)
33
34
35class ReplayBuffer:
36 """
37 Experience replay buffer for off-policy learning.
38 """
39
40 def __init__(self, capacity: int = 100000):
41 """
42 Initialize replay buffer.
43
44 Args:
45 capacity: Maximum buffer size
46 """
47 self.buffer = deque(maxlen=capacity)
48
49 def push(self, state, action, reward, next_state, done):
50 """Store transition in buffer."""
51 self.buffer.append((state, action, reward, next_state, done))
52
53 def sample(self, batch_size: int) -> Tuple:
54 """
55 Sample random batch from buffer.
56
57 Args:
58 batch_size: Number of transitions to sample
59
60 Returns:
61 Tuple of batched (states, actions, rewards, next_states, dones)
62 """
63 batch = random.sample(self.buffer, batch_size)
64 states, actions, rewards, next_states, dones = zip(*batch)
65
66 return (
67 np.array(states),
68 np.array(actions),
69 np.array(rewards, dtype=np.float32),
70 np.array(next_states),
71 np.array(dones, dtype=np.float32)
72 )
73
74 def __len__(self):
75 return len(self.buffer)
76
77
78class GaussianActor(nn.Module):
79 """
80 Gaussian policy network with tanh squashing.
81 Outputs mean and log_std for a Gaussian distribution.
82 Actions are sampled and squashed to [-1, 1] range.
83 """
84
85 def __init__(
86 self,
87 state_dim: int,
88 action_dim: int,
89 hidden_dim: int = 256,
90 log_std_min: float = -20,
91 log_std_max: float = 2
92 ):
93 """
94 Initialize actor network.
95
96 Args:
97 state_dim: Observation dimension
98 action_dim: Action dimension
99 hidden_dim: Hidden layer size
100 log_std_min: Minimum log standard deviation
101 log_std_max: Maximum log standard deviation
102 """
103 super().__init__()
104
105 self.log_std_min = log_std_min
106 self.log_std_max = log_std_max
107
108 # Shared layers
109 self.fc1 = nn.Linear(state_dim, hidden_dim)
110 self.fc2 = nn.Linear(hidden_dim, hidden_dim)
111
112 # Separate heads for mean and log_std
113 self.mean_head = nn.Linear(hidden_dim, action_dim)
114 self.log_std_head = nn.Linear(hidden_dim, action_dim)
115
116 def forward(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
117 """
118 Forward pass through network.
119
120 Args:
121 state: State tensor
122
123 Returns:
124 (mean, log_std) for Gaussian distribution
125 """
126 x = F.relu(self.fc1(state))
127 x = F.relu(self.fc2(x))
128
129 mean = self.mean_head(x)
130 log_std = self.log_std_head(x)
131 log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
132
133 return mean, log_std
134
135 def sample(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
136 """
137 Sample action from policy and compute log probability.
138
139 Args:
140 state: State tensor
141
142 Returns:
143 (action, log_prob) where action is squashed to [-1, 1]
144 """
145 mean, log_std = self.forward(state)
146 std = log_std.exp()
147
148 # Sample from Gaussian
149 normal = torch.distributions.Normal(mean, std)
150 x = normal.rsample() # Reparameterization trick
151
152 # Squash to [-1, 1] using tanh
153 action = torch.tanh(x)
154
155 # Compute log probability with change of variables correction
156 log_prob = normal.log_prob(x)
157 log_prob -= torch.log(1 - action.pow(2) + 1e-6) # tanh correction
158 log_prob = log_prob.sum(dim=-1, keepdim=True)
159
160 return action, log_prob
161
162
163class TwinQCritic(nn.Module):
164 """
165 Twin Q-networks to reduce overestimation bias.
166 Both networks have identical architecture but independent parameters.
167 """
168
169 def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 256):
170 """
171 Initialize twin Q-networks.
172
173 Args:
174 state_dim: Observation dimension
175 action_dim: Action dimension
176 hidden_dim: Hidden layer size
177 """
178 super().__init__()
179
180 # Q1 network
181 self.q1_fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
182 self.q1_fc2 = nn.Linear(hidden_dim, hidden_dim)
183 self.q1_out = nn.Linear(hidden_dim, 1)
184
185 # Q2 network
186 self.q2_fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
187 self.q2_fc2 = nn.Linear(hidden_dim, hidden_dim)
188 self.q2_out = nn.Linear(hidden_dim, 1)
189
190 def forward(self, state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
191 """
192 Compute Q-values from both networks.
193
194 Args:
195 state: State tensor
196 action: Action tensor
197
198 Returns:
199 (Q1_value, Q2_value)
200 """
201 x = torch.cat([state, action], dim=-1)
202
203 # Q1
204 q1 = F.relu(self.q1_fc1(x))
205 q1 = F.relu(self.q1_fc2(q1))
206 q1 = self.q1_out(q1)
207
208 # Q2
209 q2 = F.relu(self.q2_fc1(x))
210 q2 = F.relu(self.q2_fc2(q2))
211 q2 = self.q2_out(q2)
212
213 return q1, q2
214
215
216class SACAgent:
217 """
218 Soft Actor-Critic agent with automatic temperature tuning.
219 """
220
221 def __init__(
222 self,
223 state_dim: int,
224 action_dim: int,
225 lr: float = 3e-4,
226 gamma: float = 0.99,
227 tau: float = 0.005,
228 alpha: float = 0.2,
229 auto_tune_alpha: bool = True
230 ):
231 """
232 Initialize SAC agent.
233
234 Args:
235 state_dim: Observation dimension
236 action_dim: Action dimension
237 lr: Learning rate
238 gamma: Discount factor
239 tau: Polyak averaging coefficient for target networks
240 alpha: Entropy coefficient (initial value if auto-tuning)
241 auto_tune_alpha: Whether to automatically tune alpha
242 """
243 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
244 self.gamma = gamma
245 self.tau = tau
246 self.auto_tune_alpha = auto_tune_alpha
247
248 # Actor network
249 self.actor = GaussianActor(state_dim, action_dim).to(self.device)
250 self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
251
252 # Twin Q-networks
253 self.critic = TwinQCritic(state_dim, action_dim).to(self.device)
254 self.critic_target = TwinQCritic(state_dim, action_dim).to(self.device)
255 self.critic_target.load_state_dict(self.critic.state_dict())
256 self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr)
257
258 # Automatic temperature tuning
259 if self.auto_tune_alpha:
260 self.target_entropy = -action_dim # Heuristic: -dim(A)
261 self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
262 self.alpha_optimizer = optim.Adam([self.log_alpha], lr=lr)
263 self.alpha = self.log_alpha.exp().item()
264 else:
265 self.alpha = alpha
266
267 def select_action(self, state: np.ndarray, deterministic: bool = False) -> np.ndarray:
268 """
269 Select action from policy.
270
271 Args:
272 state: Current state
273 deterministic: If True, return mean action (for evaluation)
274
275 Returns:
276 Selected action
277 """
278 state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
279
280 with torch.no_grad():
281 if deterministic:
282 mean, _ = self.actor(state)
283 action = torch.tanh(mean)
284 else:
285 action, _ = self.actor.sample(state)
286
287 return action.cpu().numpy()[0]
288
289 def update(self, batch_size: int, replay_buffer: ReplayBuffer):
290 """
291 Update networks using a batch from replay buffer.
292
293 Args:
294 batch_size: Batch size for sampling
295 replay_buffer: Replay buffer to sample from
296 """
297 # Sample batch
298 states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
299
300 states = torch.FloatTensor(states).to(self.device)
301 actions = torch.FloatTensor(actions).to(self.device)
302 rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
303 next_states = torch.FloatTensor(next_states).to(self.device)
304 dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)
305
306 # ==================== Update Critic ====================
307 with torch.no_grad():
308 # Sample next actions from current policy
309 next_actions, next_log_probs = self.actor.sample(next_states)
310
311 # Compute target Q-values using target networks
312 q1_target, q2_target = self.critic_target(next_states, next_actions)
313 min_q_target = torch.min(q1_target, q2_target)
314
315 # Add entropy term
316 target_q = rewards + (1 - dones) * self.gamma * (min_q_target - self.alpha * next_log_probs)
317
318 # Current Q-values
319 q1, q2 = self.critic(states, actions)
320
321 # Critic loss: MSE between current and target Q-values
322 critic_loss = F.mse_loss(q1, target_q) + F.mse_loss(q2, target_q)
323
324 # Update critic
325 self.critic_optimizer.zero_grad()
326 critic_loss.backward()
327 self.critic_optimizer.step()
328
329 # ==================== Update Actor ====================
330 # Sample actions from current policy
331 new_actions, log_probs = self.actor.sample(states)
332
333 # Compute Q-values for new actions
334 q1_new, q2_new = self.critic(states, new_actions)
335 min_q_new = torch.min(q1_new, q2_new)
336
337 # Actor loss: maximize Q - alpha * log_prob (equivalent to minimize negative)
338 actor_loss = (self.alpha * log_probs - min_q_new).mean()
339
340 # Update actor
341 self.actor_optimizer.zero_grad()
342 actor_loss.backward()
343 self.actor_optimizer.step()
344
345 # ==================== Update Temperature ====================
346 if self.auto_tune_alpha:
347 alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
348
349 self.alpha_optimizer.zero_grad()
350 alpha_loss.backward()
351 self.alpha_optimizer.step()
352
353 self.alpha = self.log_alpha.exp().item()
354
355 # ==================== Update Target Networks ====================
356 # Polyak averaging: target = tau * current + (1 - tau) * target
357 for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
358 target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
359
360
361def train_sac(
362 env_name: str = 'Pendulum-v1',
363 max_episodes: int = 100,
364 max_steps: int = 200,
365 batch_size: int = 128,
366 buffer_capacity: int = 100000,
367 warmup_steps: int = 1000
368) -> List[float]:
369 """
370 Train SAC agent on continuous control environment.
371
372 Args:
373 env_name: Gymnasium environment name
374 max_episodes: Number of training episodes
375 max_steps: Maximum steps per episode
376 batch_size: Batch size for updates
377 buffer_capacity: Replay buffer capacity
378 warmup_steps: Random exploration steps before training
379
380 Returns:
381 List of episode rewards
382 """
383 # Create environment
384 env = gym.make(env_name)
385
386 state_dim = env.observation_space.shape[0]
387 action_dim = env.action_space.shape[0]
388
389 # Create agent and replay buffer
390 agent = SACAgent(state_dim, action_dim)
391 replay_buffer = ReplayBuffer(buffer_capacity)
392
393 episode_rewards = []
394 total_steps = 0
395
396 print(f"Training SAC on {env_name}...")
397 print(f"State dim: {state_dim}, Action dim: {action_dim}\n")
398
399 # Warmup: collect random transitions
400 state, _ = env.reset()
401 for _ in range(warmup_steps):
402 action = env.action_space.sample()
403 next_state, reward, terminated, truncated, _ = env.step(action)
404 done = terminated or truncated
405
406 replay_buffer.push(state, action, reward, next_state, done)
407
408 if done:
409 state, _ = env.reset()
410 else:
411 state = next_state
412
413 print(f"Warmup complete: {warmup_steps} steps collected\n")
414
415 # Training loop
416 for episode in range(max_episodes):
417 state, _ = env.reset()
418 episode_reward = 0
419
420 for step in range(max_steps):
421 # Select action
422 action = agent.select_action(state)
423
424 # Take action
425 next_state, reward, terminated, truncated, _ = env.step(action)
426 done = terminated or truncated
427
428 # Store transition
429 replay_buffer.push(state, action, reward, next_state, done)
430
431 # Update agent
432 agent.update(batch_size, replay_buffer)
433
434 state = next_state
435 episode_reward += reward
436 total_steps += 1
437
438 if done:
439 break
440
441 episode_rewards.append(episode_reward)
442
443 # Print progress
444 if (episode + 1) % 10 == 0:
445 avg_reward = np.mean(episode_rewards[-10:])
446 print(f"Episode {episode + 1}/{max_episodes}, "
447 f"Avg Reward (last 10): {avg_reward:.2f}, "
448 f"Alpha: {agent.alpha:.3f}")
449
450 env.close()
451 return episode_rewards
452
453
454def plot_results(rewards: List[float]):
455 """
456 Plot learning curve.
457
458 Args:
459 rewards: List of episode rewards
460 """
461 plt.figure(figsize=(10, 6))
462 plt.plot(rewards, alpha=0.3, label='Episode Reward')
463
464 # Moving average
465 window = 10
466 if len(rewards) >= window:
467 moving_avg = np.convolve(rewards, np.ones(window)/window, mode='valid')
468 plt.plot(range(window-1, len(rewards)), moving_avg, linewidth=2, label=f'{window}-Episode Moving Average')
469
470 plt.xlabel('Episode', fontsize=12)
471 plt.ylabel('Reward', fontsize=12)
472 plt.title('SAC Training on Pendulum-v1', fontsize=14, fontweight='bold')
473 plt.legend(fontsize=11)
474 plt.grid(True, alpha=0.3)
475 plt.tight_layout()
476
477 plt.savefig('/opt/projects/01_Personal/03_Study/examples/Reinforcement_Learning/sac_training.png', dpi=150)
478 print(f"\nPlot saved to: sac_training.png")
479 plt.show()
480
481
482if __name__ == '__main__':
483 # Train agent
484 rewards = train_sac(
485 env_name='Pendulum-v1',
486 max_episodes=100,
487 max_steps=200,
488 batch_size=128,
489 warmup_steps=1000
490 )
491
492 # Plot results
493 plot_results(rewards)
494
495 # Print final performance
496 print(f"\n{'='*60}")
497 print("Training Complete!")
498 print(f"{'='*60}")
499 print(f"Final 10-episode average reward: {np.mean(rewards[-10:]):.2f}")
500 print(f"Best 10-episode average reward: {max([np.mean(rewards[i:i+10]) for i in range(len(rewards)-10)]):.2f}")