13_robot_pybullet.py

Download
python 419 lines 12.7 KB
  1#!/usr/bin/env python3
  2"""
  3Robot Control using PyBullet Simulation with PPO
  4
  5This script demonstrates reinforcement learning for robot control using PyBullet.
  6We'll train a Kuka robotic arm to reach a target position using Proximal Policy
  7Optimization (PPO).
  8
  9If PyBullet is not available, we fall back to a custom cartpole environment.
 10"""
 11
 12import numpy as np
 13import torch
 14import torch.nn as nn
 15import torch.optim as optim
 16from typing import Tuple, List, Dict, Optional
 17from dataclasses import dataclass
 18import sys
 19
 20try:
 21    import pybullet as p
 22    import pybullet_data
 23    PYBULLET_AVAILABLE = True
 24except ImportError:
 25    PYBULLET_AVAILABLE = False
 26    print("PyBullet not available. Using custom fallback environment.")
 27
 28
 29@dataclass
 30class Transition:
 31    """Store a single transition in the environment."""
 32    state: np.ndarray
 33    action: np.ndarray
 34    reward: float
 35    next_state: np.ndarray
 36    done: bool
 37    log_prob: float
 38    value: float
 39
 40
 41class ActorCritic(nn.Module):
 42    """
 43    Actor-Critic network for continuous control.
 44
 45    Actor outputs mean and log_std for a Gaussian policy.
 46    Critic outputs state value.
 47    """
 48
 49    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 128):
 50        super().__init__()
 51
 52        # Shared feature extractor
 53        self.shared = nn.Sequential(
 54            nn.Linear(state_dim, hidden_dim),
 55            nn.ReLU(),
 56            nn.Linear(hidden_dim, hidden_dim),
 57            nn.ReLU()
 58        )
 59
 60        # Actor head
 61        self.actor_mean = nn.Linear(hidden_dim, action_dim)
 62        self.actor_log_std = nn.Parameter(torch.zeros(action_dim))
 63
 64        # Critic head
 65        self.critic = nn.Linear(hidden_dim, 1)
 66
 67    def forward(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 68        """
 69        Forward pass through network.
 70
 71        Returns:
 72            action_mean: Mean of action distribution
 73            action_std: Standard deviation of action distribution
 74            value: State value estimate
 75        """
 76        features = self.shared(state)
 77
 78        action_mean = self.actor_mean(features)
 79        action_std = torch.exp(self.actor_log_std)
 80        value = self.critic(features)
 81
 82        return action_mean, action_std, value
 83
 84    def get_action(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 85        """Sample action from policy."""
 86        action_mean, action_std, value = self(state)
 87
 88        # Create normal distribution
 89        dist = torch.distributions.Normal(action_mean, action_std)
 90        action = dist.sample()
 91        log_prob = dist.log_prob(action).sum(dim=-1)
 92
 93        return action, log_prob, value
 94
 95
 96class KukaReachEnv:
 97    """Kuka robot arm reaching task using PyBullet."""
 98
 99    def __init__(self, gui: bool = False):
100        if gui:
101            self.client = p.connect(p.GUI)
102        else:
103            self.client = p.connect(p.DIRECT)
104
105        p.setAdditionalSearchPath(pybullet_data.getDataPath())
106        p.setGravity(0, 0, -9.8)
107
108        # Load environment
109        self.plane = p.loadURDF("plane.urdf")
110        self.robot = p.loadURDF("kuka_iiwa/model.urdf", [0, 0, 0], useFixedBase=True)
111
112        # Robot parameters
113        self.num_joints = 7  # Kuka IIWA has 7 DoF
114        self.end_effector_index = 6
115
116        # Target position
117        self.target_pos = None
118        self.target_visual = None
119
120        self._setup_joints()
121
122    def _setup_joints(self):
123        """Setup joint parameters."""
124        self.joint_indices = list(range(self.num_joints))
125
126        # Get joint limits
127        self.joint_lower = []
128        self.joint_upper = []
129        for i in self.joint_indices:
130            info = p.getJointInfo(self.robot, i)
131            self.joint_lower.append(info[8])
132            self.joint_upper.append(info[9])
133
134    def reset(self) -> np.ndarray:
135        """Reset environment to initial state."""
136        # Reset joint positions to random configuration
137        for i in self.joint_indices:
138            pos = np.random.uniform(self.joint_lower[i], self.joint_upper[i])
139            p.resetJointState(self.robot, i, pos)
140
141        # Random target position
142        self.target_pos = np.array([
143            np.random.uniform(0.3, 0.6),
144            np.random.uniform(-0.3, 0.3),
145            np.random.uniform(0.3, 0.6)
146        ])
147
148        # Create visual marker for target
149        if self.target_visual is not None:
150            p.removeBody(self.target_visual)
151
152        self.target_visual = p.createVisualShape(
153            p.GEOM_SPHERE, radius=0.02, rgbaColor=[1, 0, 0, 1]
154        )
155        self.target_visual = p.createMultiBody(
156            baseVisualShapeIndex=self.target_visual,
157            basePosition=self.target_pos
158        )
159
160        return self._get_obs()
161
162    def _get_obs(self) -> np.ndarray:
163        """Get current observation."""
164        # Joint positions and velocities
165        joint_states = p.getJointStates(self.robot, self.joint_indices)
166        joint_pos = np.array([state[0] for state in joint_states])
167        joint_vel = np.array([state[1] for state in joint_states])
168
169        # End effector position
170        ee_state = p.getLinkState(self.robot, self.end_effector_index)
171        ee_pos = np.array(ee_state[0])
172
173        # Distance to target
174        distance = ee_pos - self.target_pos
175
176        obs = np.concatenate([joint_pos, joint_vel, distance])
177        return obs.astype(np.float32)
178
179    def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool]:
180        """Execute action and return next state, reward, done."""
181        # Apply action as joint velocities
182        action = np.clip(action, -1, 1)
183
184        for i, idx in enumerate(self.joint_indices):
185            p.setJointMotorControl2(
186                self.robot, idx,
187                p.VELOCITY_CONTROL,
188                targetVelocity=action[i] * 2.0,  # Scale velocity
189                force=100
190            )
191
192        p.stepSimulation()
193
194        # Get new observation
195        obs = self._get_obs()
196
197        # Compute reward
198        ee_state = p.getLinkState(self.robot, self.end_effector_index)
199        ee_pos = np.array(ee_state[0])
200        distance = np.linalg.norm(ee_pos - self.target_pos)
201
202        reward = -distance  # Negative distance as reward
203
204        # Success bonus
205        if distance < 0.05:
206            reward += 10.0
207            done = True
208        else:
209            done = False
210
211        return obs, reward, done
212
213    def close(self):
214        """Cleanup environment."""
215        p.disconnect()
216
217
218class FallbackCartPoleEnv:
219    """Simple CartPole environment as fallback when PyBullet is not available."""
220
221    def __init__(self):
222        self.gravity = 9.8
223        self.mass_cart = 1.0
224        self.mass_pole = 0.1
225        self.length = 0.5
226        self.dt = 0.02
227
228        self.state = None
229
230    def reset(self) -> np.ndarray:
231        """Reset to initial state."""
232        self.state = np.random.randn(4) * 0.1
233        return self.state.astype(np.float32)
234
235    def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool]:
236        """Execute action."""
237        force = action[0] * 10.0  # Scale action
238
239        x, x_dot, theta, theta_dot = self.state
240
241        cos_theta = np.cos(theta)
242        sin_theta = np.sin(theta)
243
244        # Physics equations
245        temp = (force + self.mass_pole * self.length * theta_dot**2 * sin_theta) / (
246            self.mass_cart + self.mass_pole
247        )
248        theta_acc = (self.gravity * sin_theta - cos_theta * temp) / (
249            self.length * (4.0/3.0 - self.mass_pole * cos_theta**2 /
250                          (self.mass_cart + self.mass_pole))
251        )
252        x_acc = temp - self.mass_pole * self.length * theta_acc * cos_theta / (
253            self.mass_cart + self.mass_pole
254        )
255
256        # Update state
257        x += self.dt * x_dot
258        x_dot += self.dt * x_acc
259        theta += self.dt * theta_dot
260        theta_dot += self.dt * theta_acc
261
262        self.state = np.array([x, x_dot, theta, theta_dot])
263
264        # Reward and done
265        done = abs(x) > 2.4 or abs(theta) > 0.2
266        reward = 1.0 if not done else 0.0
267
268        return self.state.astype(np.float32), reward, done
269
270
271class PPOAgent:
272    """Proximal Policy Optimization agent."""
273
274    def __init__(
275        self,
276        state_dim: int,
277        action_dim: int,
278        lr: float = 3e-4,
279        gamma: float = 0.99,
280        eps_clip: float = 0.2,
281        k_epochs: int = 4
282    ):
283        self.gamma = gamma
284        self.eps_clip = eps_clip
285        self.k_epochs = k_epochs
286
287        self.policy = ActorCritic(state_dim, action_dim)
288        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
289
290        self.memory: List[Transition] = []
291
292    def select_action(self, state: np.ndarray) -> Tuple[np.ndarray, float, float]:
293        """Select action using current policy."""
294        state_tensor = torch.FloatTensor(state).unsqueeze(0)
295
296        with torch.no_grad():
297            action, log_prob, value = self.policy.get_action(state_tensor)
298
299        return action.numpy()[0], log_prob.item(), value.item()
300
301    def store_transition(self, transition: Transition):
302        """Store transition in memory."""
303        self.memory.append(transition)
304
305    def update(self):
306        """Update policy using PPO."""
307        # Convert memory to tensors
308        states = torch.FloatTensor([t.state for t in self.memory])
309        actions = torch.FloatTensor([t.action for t in self.memory])
310        old_log_probs = torch.FloatTensor([t.log_prob for t in self.memory])
311
312        # Compute returns and advantages
313        returns = []
314        advantages = []
315        running_return = 0
316        running_advantage = 0
317
318        for t in reversed(self.memory):
319            running_return = t.reward + self.gamma * running_return * (1 - t.done)
320            returns.insert(0, running_return)
321
322            td_error = t.reward + self.gamma * running_return * (1 - t.done) - t.value
323            running_advantage = td_error + self.gamma * 0.95 * running_advantage * (1 - t.done)
324            advantages.insert(0, running_advantage)
325
326        returns = torch.FloatTensor(returns)
327        advantages = torch.FloatTensor(advantages)
328        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
329
330        # PPO update
331        for _ in range(self.k_epochs):
332            # Evaluate actions
333            action_mean, action_std, values = self.policy(states)
334            dist = torch.distributions.Normal(action_mean, action_std)
335            log_probs = dist.log_prob(actions).sum(dim=-1)
336            entropy = dist.entropy().sum(dim=-1).mean()
337
338            # Compute ratio and surrogate loss
339            ratio = torch.exp(log_probs - old_log_probs)
340            surr1 = ratio * advantages
341            surr2 = torch.clamp(ratio, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
342
343            actor_loss = -torch.min(surr1, surr2).mean()
344            critic_loss = nn.MSELoss()(values.squeeze(), returns)
345
346            loss = actor_loss + 0.5 * critic_loss - 0.01 * entropy
347
348            self.optimizer.zero_grad()
349            loss.backward()
350            self.optimizer.step()
351
352        self.memory.clear()
353
354
355def train_robot(episodes: int = 500, gui: bool = False):
356    """Train robot using PPO."""
357    # Create environment
358    if PYBULLET_AVAILABLE:
359        env = KukaReachEnv(gui=gui)
360        state_dim = 17  # 7 joint pos + 7 joint vel + 3 distance
361        action_dim = 7
362        print("Using PyBullet Kuka environment")
363    else:
364        env = FallbackCartPoleEnv()
365        state_dim = 4
366        action_dim = 1
367        print("Using fallback CartPole environment")
368
369    agent = PPOAgent(state_dim, action_dim)
370
371    episode_rewards = []
372    update_freq = 10  # Update every N episodes
373
374    for episode in range(episodes):
375        state = env.reset()
376        episode_reward = 0
377
378        for step in range(200):
379            action, log_prob, value = agent.select_action(state)
380            next_state, reward, done = env.step(action)
381
382            transition = Transition(
383                state, action, reward, next_state, done, log_prob, value
384            )
385            agent.store_transition(transition)
386
387            episode_reward += reward
388            state = next_state
389
390            if done:
391                break
392
393        episode_rewards.append(episode_reward)
394
395        # Update policy
396        if (episode + 1) % update_freq == 0:
397            agent.update()
398
399        if (episode + 1) % 10 == 0:
400            avg_reward = np.mean(episode_rewards[-10:])
401            print(f"Episode {episode + 1}, Avg Reward: {avg_reward:.2f}")
402
403    if PYBULLET_AVAILABLE:
404        env.close()
405
406    print(f"\nTraining completed! Final average reward: {np.mean(episode_rewards[-100:]):.2f}")
407
408
409if __name__ == "__main__":
410    import argparse
411
412    parser = argparse.ArgumentParser(description="Robot control with PPO")
413    parser.add_argument("--episodes", type=int, default=500, help="Number of episodes")
414    parser.add_argument("--gui", action="store_true", help="Show PyBullet GUI")
415
416    args = parser.parse_args()
417
418    train_robot(episodes=args.episodes, gui=args.gui)