14_robot_mujoco.py

Download
python 467 lines 17.8 KB
  1#!/usr/bin/env python3
  2"""
  3Robot Control with MuJoCo and Soft Actor-Critic (SAC)
  4
  5MuJoCo (Multi-Joint dynamics with Contact) is the industry-standard physics
  6engine for robotics research. It provides accurate simulation of rigid body
  7dynamics, contacts, and actuator models — making it the benchmark environment
  8for continuous-control RL research (locomotion, manipulation, whole-body control).
  9
 10Key concepts demonstrated:
 11  1. Continuous action spaces: robot joints require real-valued torques, not
 12     discrete choices. Gaussian policies parameterize mean and std for each
 13     action dimension.
 14  2. Soft Actor-Critic (SAC): off-policy Actor-Critic that adds an entropy bonus
 15     to the reward. Maximizing H(π) alongside return drives exploration and
 16     prevents premature convergence to suboptimal deterministic policies.
 17  3. Replay buffer: off-policy algorithms reuse past experience, dramatically
 18     improving sample efficiency over on-policy methods (PPO, REINFORCE).
 19  4. Twin Q-networks: two separate critics take the minimum of their predictions,
 20     reducing the systematic overestimation that destabilises single-critic SAC.
 21
 22Target environment: HalfCheetah-v4 (or Ant-v4) from Gymnasium.
 23Fallback: a lightweight custom continuous pendulum when MuJoCo is absent.
 24
 25Requirements (full):  torch  gymnasium[mujoco]  numpy
 26Requirements (fallback):  torch  numpy
 27"""
 28
 29import numpy as np
 30import torch
 31import torch.nn as nn
 32import torch.nn.functional as F
 33import torch.optim as optim
 34from collections import deque
 35from typing import Tuple, Optional
 36import random
 37import sys
 38
 39# ---------------------------------------------------------------------------
 40# Hyperparameters
 41# ---------------------------------------------------------------------------
 42BUFFER_CAPACITY  = 100_000   # Replay buffer size
 43BATCH_SIZE       = 256        # Minibatch size for SAC updates
 44HIDDEN_DIM       = 256        # Hidden units in all networks
 45GAMMA            = 0.99       # Discount factor
 46TAU              = 0.005      # Soft target-network update rate
 47ACTOR_LR         = 3e-4       # Actor learning rate
 48CRITIC_LR        = 3e-4       # Critic learning rate
 49ALPHA_LR         = 3e-4       # Entropy coefficient learning rate
 50LOG_STD_MIN      = -20        # Clamp for numerical stability
 51LOG_STD_MAX      = 2          # Clamp to prevent too-wide distributions
 52UPDATE_AFTER     = 1_000      # Warm-up steps before first gradient update
 53UPDATE_EVERY     = 50         # Gradient steps per environment step interval
 54TRAINING_EPISODES= 10         # Episodes to run (demo; increase for real training)
 55MAX_STEPS        = 300        # Max steps per episode
 56SEED             = 42
 57
 58# ---------------------------------------------------------------------------
 59# MuJoCo availability check
 60# ---------------------------------------------------------------------------
 61try:
 62    import gymnasium as gym
 63    env_test = gym.make("HalfCheetah-v4")
 64    env_test.close()
 65    MUJOCO_AVAILABLE = True
 66    TARGET_ENV = "HalfCheetah-v4"
 67except Exception:
 68    MUJOCO_AVAILABLE = False
 69    print(
 70        "MuJoCo / Gymnasium[mujoco] not found.\n"
 71        "To install:  pip install gymnasium[mujoco]\n"
 72        "             pip install mujoco\n"
 73        "Falling back to a custom continuous pendulum environment.\n"
 74    )
 75
 76
 77# ---------------------------------------------------------------------------
 78# Replay Buffer
 79# ---------------------------------------------------------------------------
 80class ReplayBuffer:
 81    """
 82    Circular experience replay buffer for off-policy learning.
 83
 84    Off-policy algorithms (SAC, DQN, TD3) can reuse transitions collected
 85    under any policy.  This breaks temporal correlations in the data stream
 86    and makes gradient updates more statistically efficient.
 87    """
 88
 89    def __init__(self, capacity: int = BUFFER_CAPACITY):
 90        self.buffer = deque(maxlen=capacity)
 91
 92    def push(
 93        self,
 94        state: np.ndarray,
 95        action: np.ndarray,
 96        reward: float,
 97        next_state: np.ndarray,
 98        done: bool,
 99    ) -> None:
100        self.buffer.append((state, action, reward, next_state, done))
101
102    def sample(self, batch_size: int) -> Tuple:
103        batch = random.sample(self.buffer, batch_size)
104        states, actions, rewards, next_states, dones = map(np.array, zip(*batch))
105        return (
106            torch.FloatTensor(states),
107            torch.FloatTensor(actions),
108            torch.FloatTensor(rewards).unsqueeze(1),
109            torch.FloatTensor(next_states),
110            torch.FloatTensor(dones).unsqueeze(1),
111        )
112
113    def __len__(self) -> int:
114        return len(self.buffer)
115
116
117# ---------------------------------------------------------------------------
118# Networks
119# ---------------------------------------------------------------------------
120class GaussianActor(nn.Module):
121    """
122    Stochastic actor for continuous action spaces.
123
124    Outputs a squashed Gaussian policy:
125      1. Two hidden layers produce mean and log_std.
126      2. An action sample z ~ N(mean, std) is passed through tanh to bound
127         it to (-1, 1) — matching most MuJoCo action spaces.
128      3. The log-probability is corrected for the tanh squashing:
129            log π(a|s) = log N(z|s) − Σ log(1 − tanh²(z))
130
131    This squashing trick (Haarnoja et al., 2018) avoids actions that would
132    saturate actuator limits while keeping the policy differentiable.
133    """
134
135    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = HIDDEN_DIM):
136        super().__init__()
137        self.net = nn.Sequential(
138            nn.Linear(state_dim, hidden_dim), nn.ReLU(),
139            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
140        )
141        self.mean_layer    = nn.Linear(hidden_dim, action_dim)
142        self.log_std_layer = nn.Linear(hidden_dim, action_dim)
143
144    def forward(
145        self, state: torch.Tensor
146    ) -> Tuple[torch.Tensor, torch.Tensor]:
147        """Return (action, log_prob) with tanh squashing applied."""
148        features = self.net(state)
149        mean     = self.mean_layer(features)
150        log_std  = self.log_std_layer(features).clamp(LOG_STD_MIN, LOG_STD_MAX)
151        std      = log_std.exp()
152
153        dist   = torch.distributions.Normal(mean, std)
154        z      = dist.rsample()           # reparameterisation for backprop
155        action = torch.tanh(z)
156
157        # Squashing correction
158        log_prob = dist.log_prob(z) - torch.log(1 - action.pow(2) + 1e-6)
159        log_prob = log_prob.sum(dim=-1, keepdim=True)
160
161        return action, log_prob
162
163
164class TwinQNetwork(nn.Module):
165    """
166    Twin Q-networks (critic) for SAC.
167
168    Two independent Q-functions share no weights.  During the TD target
169    computation the *minimum* of Q1 and Q2 is used, which counters the
170    positive bias that arises when a single network both selects and
171    evaluates actions (Fujimoto et al., 2018).
172    """
173
174    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = HIDDEN_DIM):
175        super().__init__()
176        def _mlp():
177            return nn.Sequential(
178                nn.Linear(state_dim + action_dim, hidden_dim), nn.ReLU(),
179                nn.Linear(hidden_dim, hidden_dim),             nn.ReLU(),
180                nn.Linear(hidden_dim, 1),
181            )
182        self.q1 = _mlp()
183        self.q2 = _mlp()
184
185    def forward(
186        self, state: torch.Tensor, action: torch.Tensor
187    ) -> Tuple[torch.Tensor, torch.Tensor]:
188        sa = torch.cat([state, action], dim=-1)
189        return self.q1(sa), self.q2(sa)
190
191
192# ---------------------------------------------------------------------------
193# SAC Agent
194# ---------------------------------------------------------------------------
195class SACAgent:
196    """
197    Soft Actor-Critic (SAC) — Haarnoja et al., 2018/2019.
198
199    SAC maximises a temperature-weighted entropy-augmented objective:
200        J(π) = Σ_t  E[ r(s,a) + α · H(π(·|s)) ]
201
202    Key advantages over on-policy algorithms (PPO) for robotics:
203      • Off-policy: far more sample efficient (critical for physical robots).
204      • Entropy regularisation: automatic exploration without ε-greedy schedules.
205      • Automatic temperature α: no manual tuning of the exploration weight.
206    """
207
208    def __init__(self, state_dim: int, action_dim: int, action_scale: float = 1.0):
209        self.action_scale = action_scale
210        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
211
212        # Actor
213        self.actor = GaussianActor(state_dim, action_dim).to(self.device)
214        self.actor_optim = optim.Adam(self.actor.parameters(), lr=ACTOR_LR)
215
216        # Twin critics + target critics
217        self.critic        = TwinQNetwork(state_dim, action_dim).to(self.device)
218        self.critic_target = TwinQNetwork(state_dim, action_dim).to(self.device)
219        self.critic_target.load_state_dict(self.critic.state_dict())
220        self.critic_optim  = optim.Adam(self.critic.parameters(), lr=CRITIC_LR)
221
222        # Entropy temperature α (learned automatically)
223        self.target_entropy = -float(action_dim)          # heuristic: -|A|
224        self.log_alpha      = torch.zeros(1, requires_grad=True, device=self.device)
225        self.alpha_optim    = optim.Adam([self.log_alpha], lr=ALPHA_LR)
226
227        self.replay_buffer  = ReplayBuffer()
228        self.total_steps    = 0
229
230    @property
231    def alpha(self) -> torch.Tensor:
232        return self.log_alpha.exp()
233
234    # ------------------------------------------------------------------
235    def select_action(self, state: np.ndarray, deterministic: bool = False) -> np.ndarray:
236        """Sample action from actor; use mean for deterministic evaluation."""
237        state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
238        with torch.no_grad():
239            action, _ = self.actor(state_t)
240            if deterministic:
241                # Use mean (no rsample noise)
242                features = self.actor.net(state_t)
243                action   = torch.tanh(self.actor.mean_layer(features))
244        return (action.cpu().numpy()[0] * self.action_scale)
245
246    # ------------------------------------------------------------------
247    def _soft_update(self) -> None:
248        """Polyak averaging: θ_target ← τ·θ + (1−τ)·θ_target."""
249        for p, pt in zip(self.critic.parameters(), self.critic_target.parameters()):
250            pt.data.copy_(TAU * p.data + (1 - TAU) * pt.data)
251
252    # ------------------------------------------------------------------
253    def update(self) -> Optional[dict]:
254        """One gradient step for critic, actor, and temperature."""
255        if len(self.replay_buffer) < BATCH_SIZE:
256            return None
257
258        states, actions, rewards, next_states, dones = self.replay_buffer.sample(BATCH_SIZE)
259        states, actions, rewards, next_states, dones = (
260            x.to(self.device) for x in (states, actions, rewards, next_states, dones)
261        )
262
263        # ---- Critic update -------------------------------------------
264        with torch.no_grad():
265            next_actions, next_log_probs = self.actor(next_states)
266            q1_next, q2_next = self.critic_target(next_states, next_actions)
267            q_next   = torch.min(q1_next, q2_next) - self.alpha * next_log_probs
268            q_target = rewards + GAMMA * (1 - dones) * q_next
269
270        q1, q2 = self.critic(states, actions)
271        critic_loss = F.mse_loss(q1, q_target) + F.mse_loss(q2, q_target)
272
273        self.critic_optim.zero_grad()
274        critic_loss.backward()
275        self.critic_optim.step()
276
277        # ---- Actor update --------------------------------------------
278        new_actions, log_probs = self.actor(states)
279        q1_new, q2_new = self.critic(states, new_actions)
280        q_new    = torch.min(q1_new, q2_new)
281        actor_loss = (self.alpha.detach() * log_probs - q_new).mean()
282
283        self.actor_optim.zero_grad()
284        actor_loss.backward()
285        self.actor_optim.step()
286
287        # ---- Temperature (α) update ----------------------------------
288        alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
289
290        self.alpha_optim.zero_grad()
291        alpha_loss.backward()
292        self.alpha_optim.step()
293
294        self._soft_update()
295
296        return {
297            "critic_loss": critic_loss.item(),
298            "actor_loss":  actor_loss.item(),
299            "alpha":       self.alpha.item(),
300        }
301
302    # ------------------------------------------------------------------
303    def step_and_maybe_update(
304        self,
305        state: np.ndarray,
306        action: np.ndarray,
307        reward: float,
308        next_state: np.ndarray,
309        done: bool,
310    ) -> Optional[dict]:
311        """Store transition; trigger gradient updates on schedule."""
312        self.replay_buffer.push(state, action, reward, next_state, done)
313        self.total_steps += 1
314
315        if self.total_steps < UPDATE_AFTER:
316            return None
317
318        if self.total_steps % UPDATE_EVERY == 0:
319            info = None
320            for _ in range(UPDATE_EVERY):
321                info = self.update()
322            return info
323        return None
324
325
326# ---------------------------------------------------------------------------
327# Fallback environment (lightweight continuous pendulum)
328# ---------------------------------------------------------------------------
329class ContinuousPendulumEnv:
330    """
331    Simple undamped pendulum with continuous torque input.
332
333    State:  [cos θ, sin θ, θ̇]   (3-dim, matching Gymnasium Pendulum-v1)
334    Action: [τ]  ∈ [-1, 1]
335    Reward: -(θ² + 0.1·θ̇² + 0.001·τ²)  — penalise distance from upright
336    """
337
338    def __init__(self):
339        self.max_torque  = 2.0
340        self.max_speed   = 8.0
341        self.g, self.m, self.l = 10.0, 1.0, 1.0
342        self.dt          = 0.05
343        self.observation_space_shape = (3,)
344        self.action_space_shape      = (1,)
345        self.theta = self.theta_dot = 0.0
346
347    def reset(self) -> np.ndarray:
348        self.theta     = np.random.uniform(-np.pi, np.pi)
349        self.theta_dot = np.random.uniform(-1.0, 1.0)
350        return self._obs()
351
352    def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool]:
353        torque = float(np.clip(action[0], -1, 1)) * self.max_torque
354        theta_dot_new = (
355            self.theta_dot
356            + (-3 * self.g / (2 * self.l) * np.sin(self.theta + np.pi)
357               + 3.0 / (self.m * self.l ** 2) * torque)
358            * self.dt
359        )
360        self.theta_dot = np.clip(theta_dot_new, -self.max_speed, self.max_speed)
361        self.theta    += self.theta_dot * self.dt
362
363        reward = -(
364            self.angle_normalize(self.theta) ** 2
365            + 0.1 * self.theta_dot ** 2
366            + 0.001 * torque ** 2
367        )
368        return self._obs(), reward, False   # pendulum never terminates
369
370    def _obs(self) -> np.ndarray:
371        return np.array(
372            [np.cos(self.theta), np.sin(self.theta), self.theta_dot],
373            dtype=np.float32,
374        )
375
376    @staticmethod
377    def angle_normalize(x: float) -> float:
378        return ((x + np.pi) % (2 * np.pi)) - np.pi
379
380
381# ---------------------------------------------------------------------------
382# Training loop
383# ---------------------------------------------------------------------------
384def train(episodes: int = TRAINING_EPISODES) -> None:
385    random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
386
387    if MUJOCO_AVAILABLE:
388        import gymnasium as gym
389        env = gym.make(TARGET_ENV)
390        state_dim    = env.observation_space.shape[0]
391        action_dim   = env.action_space.shape[0]
392        action_scale = float(env.action_space.high[0])
393        print(f"Environment : {TARGET_ENV}")
394        print(f"  obs dim   : {state_dim}  |  action dim : {action_dim}")
395        print(f"  action scale : ±{action_scale}")
396    else:
397        env          = ContinuousPendulumEnv()
398        state_dim    = env.observation_space_shape[0]
399        action_dim   = env.action_space_shape[0]
400        action_scale = 1.0
401        print("Environment : ContinuousPendulum (fallback)")
402
403    agent = SACAgent(state_dim, action_dim, action_scale)
404    print(f"Device : {agent.device}\n")
405    print(f"{'Episode':>8}  {'Steps':>7}  {'Reward':>10}  {'Alpha':>7}")
406    print("-" * 45)
407
408    for ep in range(1, episodes + 1):
409        if MUJOCO_AVAILABLE:
410            state, _ = env.reset(seed=SEED + ep)
411        else:
412            state = env.reset()
413
414        ep_reward = 0.0
415        last_info: Optional[dict] = None
416
417        for _ in range(MAX_STEPS):
418            # Random exploration during warm-up; policy afterwards
419            if agent.total_steps < UPDATE_AFTER:
420                if MUJOCO_AVAILABLE:
421                    action = env.action_space.sample()
422                else:
423                    action = np.random.uniform(-1, 1, size=(action_dim,)).astype(np.float32)
424            else:
425                action = agent.select_action(state)
426
427            if MUJOCO_AVAILABLE:
428                next_state, reward, terminated, truncated, _ = env.step(action)
429                done = terminated or truncated
430            else:
431                next_state, reward, done = env.step(action)
432
433            info = agent.step_and_maybe_update(state, action, reward, next_state, done)
434            if info:
435                last_info = info
436
437            ep_reward += reward
438            state      = next_state
439            if done:
440                break
441
442        alpha_val = last_info["alpha"] if last_info else float(agent.alpha)
443        print(
444            f"{ep:>8}  {agent.total_steps:>7}  "
445            f"{ep_reward:>10.2f}  {alpha_val:>7.4f}"
446        )
447
448    if MUJOCO_AVAILABLE:
449        env.close()
450    print("\nDone. Increase TRAINING_EPISODES / MAX_STEPS for meaningful learning.")
451
452
453# ---------------------------------------------------------------------------
454# Entry point
455# ---------------------------------------------------------------------------
456if __name__ == "__main__":
457    import argparse
458
459    parser = argparse.ArgumentParser(description="SAC robot control with MuJoCo")
460    parser.add_argument(
461        "--episodes", type=int, default=TRAINING_EPISODES,
462        help="Number of training episodes"
463    )
464    args = parser.parse_args()
465
466    train(episodes=args.episodes)