1#!/usr/bin/env python3
2"""
3Robot Control with MuJoCo and Soft Actor-Critic (SAC)
4
5MuJoCo (Multi-Joint dynamics with Contact) is the industry-standard physics
6engine for robotics research. It provides accurate simulation of rigid body
7dynamics, contacts, and actuator models — making it the benchmark environment
8for continuous-control RL research (locomotion, manipulation, whole-body control).
9
10Key concepts demonstrated:
11 1. Continuous action spaces: robot joints require real-valued torques, not
12 discrete choices. Gaussian policies parameterize mean and std for each
13 action dimension.
14 2. Soft Actor-Critic (SAC): off-policy Actor-Critic that adds an entropy bonus
15 to the reward. Maximizing H(π) alongside return drives exploration and
16 prevents premature convergence to suboptimal deterministic policies.
17 3. Replay buffer: off-policy algorithms reuse past experience, dramatically
18 improving sample efficiency over on-policy methods (PPO, REINFORCE).
19 4. Twin Q-networks: two separate critics take the minimum of their predictions,
20 reducing the systematic overestimation that destabilises single-critic SAC.
21
22Target environment: HalfCheetah-v4 (or Ant-v4) from Gymnasium.
23Fallback: a lightweight custom continuous pendulum when MuJoCo is absent.
24
25Requirements (full): torch gymnasium[mujoco] numpy
26Requirements (fallback): torch numpy
27"""
28
29import numpy as np
30import torch
31import torch.nn as nn
32import torch.nn.functional as F
33import torch.optim as optim
34from collections import deque
35from typing import Tuple, Optional
36import random
37import sys
38
39# ---------------------------------------------------------------------------
40# Hyperparameters
41# ---------------------------------------------------------------------------
42BUFFER_CAPACITY = 100_000 # Replay buffer size
43BATCH_SIZE = 256 # Minibatch size for SAC updates
44HIDDEN_DIM = 256 # Hidden units in all networks
45GAMMA = 0.99 # Discount factor
46TAU = 0.005 # Soft target-network update rate
47ACTOR_LR = 3e-4 # Actor learning rate
48CRITIC_LR = 3e-4 # Critic learning rate
49ALPHA_LR = 3e-4 # Entropy coefficient learning rate
50LOG_STD_MIN = -20 # Clamp for numerical stability
51LOG_STD_MAX = 2 # Clamp to prevent too-wide distributions
52UPDATE_AFTER = 1_000 # Warm-up steps before first gradient update
53UPDATE_EVERY = 50 # Gradient steps per environment step interval
54TRAINING_EPISODES= 10 # Episodes to run (demo; increase for real training)
55MAX_STEPS = 300 # Max steps per episode
56SEED = 42
57
58# ---------------------------------------------------------------------------
59# MuJoCo availability check
60# ---------------------------------------------------------------------------
61try:
62 import gymnasium as gym
63 env_test = gym.make("HalfCheetah-v4")
64 env_test.close()
65 MUJOCO_AVAILABLE = True
66 TARGET_ENV = "HalfCheetah-v4"
67except Exception:
68 MUJOCO_AVAILABLE = False
69 print(
70 "MuJoCo / Gymnasium[mujoco] not found.\n"
71 "To install: pip install gymnasium[mujoco]\n"
72 " pip install mujoco\n"
73 "Falling back to a custom continuous pendulum environment.\n"
74 )
75
76
77# ---------------------------------------------------------------------------
78# Replay Buffer
79# ---------------------------------------------------------------------------
80class ReplayBuffer:
81 """
82 Circular experience replay buffer for off-policy learning.
83
84 Off-policy algorithms (SAC, DQN, TD3) can reuse transitions collected
85 under any policy. This breaks temporal correlations in the data stream
86 and makes gradient updates more statistically efficient.
87 """
88
89 def __init__(self, capacity: int = BUFFER_CAPACITY):
90 self.buffer = deque(maxlen=capacity)
91
92 def push(
93 self,
94 state: np.ndarray,
95 action: np.ndarray,
96 reward: float,
97 next_state: np.ndarray,
98 done: bool,
99 ) -> None:
100 self.buffer.append((state, action, reward, next_state, done))
101
102 def sample(self, batch_size: int) -> Tuple:
103 batch = random.sample(self.buffer, batch_size)
104 states, actions, rewards, next_states, dones = map(np.array, zip(*batch))
105 return (
106 torch.FloatTensor(states),
107 torch.FloatTensor(actions),
108 torch.FloatTensor(rewards).unsqueeze(1),
109 torch.FloatTensor(next_states),
110 torch.FloatTensor(dones).unsqueeze(1),
111 )
112
113 def __len__(self) -> int:
114 return len(self.buffer)
115
116
117# ---------------------------------------------------------------------------
118# Networks
119# ---------------------------------------------------------------------------
120class GaussianActor(nn.Module):
121 """
122 Stochastic actor for continuous action spaces.
123
124 Outputs a squashed Gaussian policy:
125 1. Two hidden layers produce mean and log_std.
126 2. An action sample z ~ N(mean, std) is passed through tanh to bound
127 it to (-1, 1) — matching most MuJoCo action spaces.
128 3. The log-probability is corrected for the tanh squashing:
129 log π(a|s) = log N(z|s) − Σ log(1 − tanh²(z))
130
131 This squashing trick (Haarnoja et al., 2018) avoids actions that would
132 saturate actuator limits while keeping the policy differentiable.
133 """
134
135 def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = HIDDEN_DIM):
136 super().__init__()
137 self.net = nn.Sequential(
138 nn.Linear(state_dim, hidden_dim), nn.ReLU(),
139 nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
140 )
141 self.mean_layer = nn.Linear(hidden_dim, action_dim)
142 self.log_std_layer = nn.Linear(hidden_dim, action_dim)
143
144 def forward(
145 self, state: torch.Tensor
146 ) -> Tuple[torch.Tensor, torch.Tensor]:
147 """Return (action, log_prob) with tanh squashing applied."""
148 features = self.net(state)
149 mean = self.mean_layer(features)
150 log_std = self.log_std_layer(features).clamp(LOG_STD_MIN, LOG_STD_MAX)
151 std = log_std.exp()
152
153 dist = torch.distributions.Normal(mean, std)
154 z = dist.rsample() # reparameterisation for backprop
155 action = torch.tanh(z)
156
157 # Squashing correction
158 log_prob = dist.log_prob(z) - torch.log(1 - action.pow(2) + 1e-6)
159 log_prob = log_prob.sum(dim=-1, keepdim=True)
160
161 return action, log_prob
162
163
164class TwinQNetwork(nn.Module):
165 """
166 Twin Q-networks (critic) for SAC.
167
168 Two independent Q-functions share no weights. During the TD target
169 computation the *minimum* of Q1 and Q2 is used, which counters the
170 positive bias that arises when a single network both selects and
171 evaluates actions (Fujimoto et al., 2018).
172 """
173
174 def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = HIDDEN_DIM):
175 super().__init__()
176 def _mlp():
177 return nn.Sequential(
178 nn.Linear(state_dim + action_dim, hidden_dim), nn.ReLU(),
179 nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
180 nn.Linear(hidden_dim, 1),
181 )
182 self.q1 = _mlp()
183 self.q2 = _mlp()
184
185 def forward(
186 self, state: torch.Tensor, action: torch.Tensor
187 ) -> Tuple[torch.Tensor, torch.Tensor]:
188 sa = torch.cat([state, action], dim=-1)
189 return self.q1(sa), self.q2(sa)
190
191
192# ---------------------------------------------------------------------------
193# SAC Agent
194# ---------------------------------------------------------------------------
195class SACAgent:
196 """
197 Soft Actor-Critic (SAC) — Haarnoja et al., 2018/2019.
198
199 SAC maximises a temperature-weighted entropy-augmented objective:
200 J(π) = Σ_t E[ r(s,a) + α · H(π(·|s)) ]
201
202 Key advantages over on-policy algorithms (PPO) for robotics:
203 • Off-policy: far more sample efficient (critical for physical robots).
204 • Entropy regularisation: automatic exploration without ε-greedy schedules.
205 • Automatic temperature α: no manual tuning of the exploration weight.
206 """
207
208 def __init__(self, state_dim: int, action_dim: int, action_scale: float = 1.0):
209 self.action_scale = action_scale
210 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
211
212 # Actor
213 self.actor = GaussianActor(state_dim, action_dim).to(self.device)
214 self.actor_optim = optim.Adam(self.actor.parameters(), lr=ACTOR_LR)
215
216 # Twin critics + target critics
217 self.critic = TwinQNetwork(state_dim, action_dim).to(self.device)
218 self.critic_target = TwinQNetwork(state_dim, action_dim).to(self.device)
219 self.critic_target.load_state_dict(self.critic.state_dict())
220 self.critic_optim = optim.Adam(self.critic.parameters(), lr=CRITIC_LR)
221
222 # Entropy temperature α (learned automatically)
223 self.target_entropy = -float(action_dim) # heuristic: -|A|
224 self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
225 self.alpha_optim = optim.Adam([self.log_alpha], lr=ALPHA_LR)
226
227 self.replay_buffer = ReplayBuffer()
228 self.total_steps = 0
229
230 @property
231 def alpha(self) -> torch.Tensor:
232 return self.log_alpha.exp()
233
234 # ------------------------------------------------------------------
235 def select_action(self, state: np.ndarray, deterministic: bool = False) -> np.ndarray:
236 """Sample action from actor; use mean for deterministic evaluation."""
237 state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
238 with torch.no_grad():
239 action, _ = self.actor(state_t)
240 if deterministic:
241 # Use mean (no rsample noise)
242 features = self.actor.net(state_t)
243 action = torch.tanh(self.actor.mean_layer(features))
244 return (action.cpu().numpy()[0] * self.action_scale)
245
246 # ------------------------------------------------------------------
247 def _soft_update(self) -> None:
248 """Polyak averaging: θ_target ← τ·θ + (1−τ)·θ_target."""
249 for p, pt in zip(self.critic.parameters(), self.critic_target.parameters()):
250 pt.data.copy_(TAU * p.data + (1 - TAU) * pt.data)
251
252 # ------------------------------------------------------------------
253 def update(self) -> Optional[dict]:
254 """One gradient step for critic, actor, and temperature."""
255 if len(self.replay_buffer) < BATCH_SIZE:
256 return None
257
258 states, actions, rewards, next_states, dones = self.replay_buffer.sample(BATCH_SIZE)
259 states, actions, rewards, next_states, dones = (
260 x.to(self.device) for x in (states, actions, rewards, next_states, dones)
261 )
262
263 # ---- Critic update -------------------------------------------
264 with torch.no_grad():
265 next_actions, next_log_probs = self.actor(next_states)
266 q1_next, q2_next = self.critic_target(next_states, next_actions)
267 q_next = torch.min(q1_next, q2_next) - self.alpha * next_log_probs
268 q_target = rewards + GAMMA * (1 - dones) * q_next
269
270 q1, q2 = self.critic(states, actions)
271 critic_loss = F.mse_loss(q1, q_target) + F.mse_loss(q2, q_target)
272
273 self.critic_optim.zero_grad()
274 critic_loss.backward()
275 self.critic_optim.step()
276
277 # ---- Actor update --------------------------------------------
278 new_actions, log_probs = self.actor(states)
279 q1_new, q2_new = self.critic(states, new_actions)
280 q_new = torch.min(q1_new, q2_new)
281 actor_loss = (self.alpha.detach() * log_probs - q_new).mean()
282
283 self.actor_optim.zero_grad()
284 actor_loss.backward()
285 self.actor_optim.step()
286
287 # ---- Temperature (α) update ----------------------------------
288 alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
289
290 self.alpha_optim.zero_grad()
291 alpha_loss.backward()
292 self.alpha_optim.step()
293
294 self._soft_update()
295
296 return {
297 "critic_loss": critic_loss.item(),
298 "actor_loss": actor_loss.item(),
299 "alpha": self.alpha.item(),
300 }
301
302 # ------------------------------------------------------------------
303 def step_and_maybe_update(
304 self,
305 state: np.ndarray,
306 action: np.ndarray,
307 reward: float,
308 next_state: np.ndarray,
309 done: bool,
310 ) -> Optional[dict]:
311 """Store transition; trigger gradient updates on schedule."""
312 self.replay_buffer.push(state, action, reward, next_state, done)
313 self.total_steps += 1
314
315 if self.total_steps < UPDATE_AFTER:
316 return None
317
318 if self.total_steps % UPDATE_EVERY == 0:
319 info = None
320 for _ in range(UPDATE_EVERY):
321 info = self.update()
322 return info
323 return None
324
325
326# ---------------------------------------------------------------------------
327# Fallback environment (lightweight continuous pendulum)
328# ---------------------------------------------------------------------------
329class ContinuousPendulumEnv:
330 """
331 Simple undamped pendulum with continuous torque input.
332
333 State: [cos θ, sin θ, θ̇] (3-dim, matching Gymnasium Pendulum-v1)
334 Action: [τ] ∈ [-1, 1]
335 Reward: -(θ² + 0.1·θ̇² + 0.001·τ²) — penalise distance from upright
336 """
337
338 def __init__(self):
339 self.max_torque = 2.0
340 self.max_speed = 8.0
341 self.g, self.m, self.l = 10.0, 1.0, 1.0
342 self.dt = 0.05
343 self.observation_space_shape = (3,)
344 self.action_space_shape = (1,)
345 self.theta = self.theta_dot = 0.0
346
347 def reset(self) -> np.ndarray:
348 self.theta = np.random.uniform(-np.pi, np.pi)
349 self.theta_dot = np.random.uniform(-1.0, 1.0)
350 return self._obs()
351
352 def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool]:
353 torque = float(np.clip(action[0], -1, 1)) * self.max_torque
354 theta_dot_new = (
355 self.theta_dot
356 + (-3 * self.g / (2 * self.l) * np.sin(self.theta + np.pi)
357 + 3.0 / (self.m * self.l ** 2) * torque)
358 * self.dt
359 )
360 self.theta_dot = np.clip(theta_dot_new, -self.max_speed, self.max_speed)
361 self.theta += self.theta_dot * self.dt
362
363 reward = -(
364 self.angle_normalize(self.theta) ** 2
365 + 0.1 * self.theta_dot ** 2
366 + 0.001 * torque ** 2
367 )
368 return self._obs(), reward, False # pendulum never terminates
369
370 def _obs(self) -> np.ndarray:
371 return np.array(
372 [np.cos(self.theta), np.sin(self.theta), self.theta_dot],
373 dtype=np.float32,
374 )
375
376 @staticmethod
377 def angle_normalize(x: float) -> float:
378 return ((x + np.pi) % (2 * np.pi)) - np.pi
379
380
381# ---------------------------------------------------------------------------
382# Training loop
383# ---------------------------------------------------------------------------
384def train(episodes: int = TRAINING_EPISODES) -> None:
385 random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
386
387 if MUJOCO_AVAILABLE:
388 import gymnasium as gym
389 env = gym.make(TARGET_ENV)
390 state_dim = env.observation_space.shape[0]
391 action_dim = env.action_space.shape[0]
392 action_scale = float(env.action_space.high[0])
393 print(f"Environment : {TARGET_ENV}")
394 print(f" obs dim : {state_dim} | action dim : {action_dim}")
395 print(f" action scale : ±{action_scale}")
396 else:
397 env = ContinuousPendulumEnv()
398 state_dim = env.observation_space_shape[0]
399 action_dim = env.action_space_shape[0]
400 action_scale = 1.0
401 print("Environment : ContinuousPendulum (fallback)")
402
403 agent = SACAgent(state_dim, action_dim, action_scale)
404 print(f"Device : {agent.device}\n")
405 print(f"{'Episode':>8} {'Steps':>7} {'Reward':>10} {'Alpha':>7}")
406 print("-" * 45)
407
408 for ep in range(1, episodes + 1):
409 if MUJOCO_AVAILABLE:
410 state, _ = env.reset(seed=SEED + ep)
411 else:
412 state = env.reset()
413
414 ep_reward = 0.0
415 last_info: Optional[dict] = None
416
417 for _ in range(MAX_STEPS):
418 # Random exploration during warm-up; policy afterwards
419 if agent.total_steps < UPDATE_AFTER:
420 if MUJOCO_AVAILABLE:
421 action = env.action_space.sample()
422 else:
423 action = np.random.uniform(-1, 1, size=(action_dim,)).astype(np.float32)
424 else:
425 action = agent.select_action(state)
426
427 if MUJOCO_AVAILABLE:
428 next_state, reward, terminated, truncated, _ = env.step(action)
429 done = terminated or truncated
430 else:
431 next_state, reward, done = env.step(action)
432
433 info = agent.step_and_maybe_update(state, action, reward, next_state, done)
434 if info:
435 last_info = info
436
437 ep_reward += reward
438 state = next_state
439 if done:
440 break
441
442 alpha_val = last_info["alpha"] if last_info else float(agent.alpha)
443 print(
444 f"{ep:>8} {agent.total_steps:>7} "
445 f"{ep_reward:>10.2f} {alpha_val:>7.4f}"
446 )
447
448 if MUJOCO_AVAILABLE:
449 env.close()
450 print("\nDone. Increase TRAINING_EPISODES / MAX_STEPS for meaningful learning.")
451
452
453# ---------------------------------------------------------------------------
454# Entry point
455# ---------------------------------------------------------------------------
456if __name__ == "__main__":
457 import argparse
458
459 parser = argparse.ArgumentParser(description="SAC robot control with MuJoCo")
460 parser.add_argument(
461 "--episodes", type=int, default=TRAINING_EPISODES,
462 help="Number of training episodes"
463 )
464 args = parser.parse_args()
465
466 train(episodes=args.episodes)