1#!/usr/bin/env python3
2"""
3Robot Control using PyBullet Simulation with PPO
4
5This script demonstrates reinforcement learning for robot control using PyBullet.
6We'll train a Kuka robotic arm to reach a target position using Proximal Policy
7Optimization (PPO).
8
9If PyBullet is not available, we fall back to a custom cartpole environment.
10"""
11
12import numpy as np
13import torch
14import torch.nn as nn
15import torch.optim as optim
16from typing import Tuple, List, Dict, Optional
17from dataclasses import dataclass
18import sys
19
20try:
21 import pybullet as p
22 import pybullet_data
23 PYBULLET_AVAILABLE = True
24except ImportError:
25 PYBULLET_AVAILABLE = False
26 print("PyBullet not available. Using custom fallback environment.")
27
28
29@dataclass
30class Transition:
31 """Store a single transition in the environment."""
32 state: np.ndarray
33 action: np.ndarray
34 reward: float
35 next_state: np.ndarray
36 done: bool
37 log_prob: float
38 value: float
39
40
41class ActorCritic(nn.Module):
42 """
43 Actor-Critic network for continuous control.
44
45 Actor outputs mean and log_std for a Gaussian policy.
46 Critic outputs state value.
47 """
48
49 def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 128):
50 super().__init__()
51
52 # Shared feature extractor
53 self.shared = nn.Sequential(
54 nn.Linear(state_dim, hidden_dim),
55 nn.ReLU(),
56 nn.Linear(hidden_dim, hidden_dim),
57 nn.ReLU()
58 )
59
60 # Actor head
61 self.actor_mean = nn.Linear(hidden_dim, action_dim)
62 self.actor_log_std = nn.Parameter(torch.zeros(action_dim))
63
64 # Critic head
65 self.critic = nn.Linear(hidden_dim, 1)
66
67 def forward(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
68 """
69 Forward pass through network.
70
71 Returns:
72 action_mean: Mean of action distribution
73 action_std: Standard deviation of action distribution
74 value: State value estimate
75 """
76 features = self.shared(state)
77
78 action_mean = self.actor_mean(features)
79 action_std = torch.exp(self.actor_log_std)
80 value = self.critic(features)
81
82 return action_mean, action_std, value
83
84 def get_action(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
85 """Sample action from policy."""
86 action_mean, action_std, value = self(state)
87
88 # Create normal distribution
89 dist = torch.distributions.Normal(action_mean, action_std)
90 action = dist.sample()
91 log_prob = dist.log_prob(action).sum(dim=-1)
92
93 return action, log_prob, value
94
95
96class KukaReachEnv:
97 """Kuka robot arm reaching task using PyBullet."""
98
99 def __init__(self, gui: bool = False):
100 if gui:
101 self.client = p.connect(p.GUI)
102 else:
103 self.client = p.connect(p.DIRECT)
104
105 p.setAdditionalSearchPath(pybullet_data.getDataPath())
106 p.setGravity(0, 0, -9.8)
107
108 # Load environment
109 self.plane = p.loadURDF("plane.urdf")
110 self.robot = p.loadURDF("kuka_iiwa/model.urdf", [0, 0, 0], useFixedBase=True)
111
112 # Robot parameters
113 self.num_joints = 7 # Kuka IIWA has 7 DoF
114 self.end_effector_index = 6
115
116 # Target position
117 self.target_pos = None
118 self.target_visual = None
119
120 self._setup_joints()
121
122 def _setup_joints(self):
123 """Setup joint parameters."""
124 self.joint_indices = list(range(self.num_joints))
125
126 # Get joint limits
127 self.joint_lower = []
128 self.joint_upper = []
129 for i in self.joint_indices:
130 info = p.getJointInfo(self.robot, i)
131 self.joint_lower.append(info[8])
132 self.joint_upper.append(info[9])
133
134 def reset(self) -> np.ndarray:
135 """Reset environment to initial state."""
136 # Reset joint positions to random configuration
137 for i in self.joint_indices:
138 pos = np.random.uniform(self.joint_lower[i], self.joint_upper[i])
139 p.resetJointState(self.robot, i, pos)
140
141 # Random target position
142 self.target_pos = np.array([
143 np.random.uniform(0.3, 0.6),
144 np.random.uniform(-0.3, 0.3),
145 np.random.uniform(0.3, 0.6)
146 ])
147
148 # Create visual marker for target
149 if self.target_visual is not None:
150 p.removeBody(self.target_visual)
151
152 self.target_visual = p.createVisualShape(
153 p.GEOM_SPHERE, radius=0.02, rgbaColor=[1, 0, 0, 1]
154 )
155 self.target_visual = p.createMultiBody(
156 baseVisualShapeIndex=self.target_visual,
157 basePosition=self.target_pos
158 )
159
160 return self._get_obs()
161
162 def _get_obs(self) -> np.ndarray:
163 """Get current observation."""
164 # Joint positions and velocities
165 joint_states = p.getJointStates(self.robot, self.joint_indices)
166 joint_pos = np.array([state[0] for state in joint_states])
167 joint_vel = np.array([state[1] for state in joint_states])
168
169 # End effector position
170 ee_state = p.getLinkState(self.robot, self.end_effector_index)
171 ee_pos = np.array(ee_state[0])
172
173 # Distance to target
174 distance = ee_pos - self.target_pos
175
176 obs = np.concatenate([joint_pos, joint_vel, distance])
177 return obs.astype(np.float32)
178
179 def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool]:
180 """Execute action and return next state, reward, done."""
181 # Apply action as joint velocities
182 action = np.clip(action, -1, 1)
183
184 for i, idx in enumerate(self.joint_indices):
185 p.setJointMotorControl2(
186 self.robot, idx,
187 p.VELOCITY_CONTROL,
188 targetVelocity=action[i] * 2.0, # Scale velocity
189 force=100
190 )
191
192 p.stepSimulation()
193
194 # Get new observation
195 obs = self._get_obs()
196
197 # Compute reward
198 ee_state = p.getLinkState(self.robot, self.end_effector_index)
199 ee_pos = np.array(ee_state[0])
200 distance = np.linalg.norm(ee_pos - self.target_pos)
201
202 reward = -distance # Negative distance as reward
203
204 # Success bonus
205 if distance < 0.05:
206 reward += 10.0
207 done = True
208 else:
209 done = False
210
211 return obs, reward, done
212
213 def close(self):
214 """Cleanup environment."""
215 p.disconnect()
216
217
218class FallbackCartPoleEnv:
219 """Simple CartPole environment as fallback when PyBullet is not available."""
220
221 def __init__(self):
222 self.gravity = 9.8
223 self.mass_cart = 1.0
224 self.mass_pole = 0.1
225 self.length = 0.5
226 self.dt = 0.02
227
228 self.state = None
229
230 def reset(self) -> np.ndarray:
231 """Reset to initial state."""
232 self.state = np.random.randn(4) * 0.1
233 return self.state.astype(np.float32)
234
235 def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool]:
236 """Execute action."""
237 force = action[0] * 10.0 # Scale action
238
239 x, x_dot, theta, theta_dot = self.state
240
241 cos_theta = np.cos(theta)
242 sin_theta = np.sin(theta)
243
244 # Physics equations
245 temp = (force + self.mass_pole * self.length * theta_dot**2 * sin_theta) / (
246 self.mass_cart + self.mass_pole
247 )
248 theta_acc = (self.gravity * sin_theta - cos_theta * temp) / (
249 self.length * (4.0/3.0 - self.mass_pole * cos_theta**2 /
250 (self.mass_cart + self.mass_pole))
251 )
252 x_acc = temp - self.mass_pole * self.length * theta_acc * cos_theta / (
253 self.mass_cart + self.mass_pole
254 )
255
256 # Update state
257 x += self.dt * x_dot
258 x_dot += self.dt * x_acc
259 theta += self.dt * theta_dot
260 theta_dot += self.dt * theta_acc
261
262 self.state = np.array([x, x_dot, theta, theta_dot])
263
264 # Reward and done
265 done = abs(x) > 2.4 or abs(theta) > 0.2
266 reward = 1.0 if not done else 0.0
267
268 return self.state.astype(np.float32), reward, done
269
270
271class PPOAgent:
272 """Proximal Policy Optimization agent."""
273
274 def __init__(
275 self,
276 state_dim: int,
277 action_dim: int,
278 lr: float = 3e-4,
279 gamma: float = 0.99,
280 eps_clip: float = 0.2,
281 k_epochs: int = 4
282 ):
283 self.gamma = gamma
284 self.eps_clip = eps_clip
285 self.k_epochs = k_epochs
286
287 self.policy = ActorCritic(state_dim, action_dim)
288 self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
289
290 self.memory: List[Transition] = []
291
292 def select_action(self, state: np.ndarray) -> Tuple[np.ndarray, float, float]:
293 """Select action using current policy."""
294 state_tensor = torch.FloatTensor(state).unsqueeze(0)
295
296 with torch.no_grad():
297 action, log_prob, value = self.policy.get_action(state_tensor)
298
299 return action.numpy()[0], log_prob.item(), value.item()
300
301 def store_transition(self, transition: Transition):
302 """Store transition in memory."""
303 self.memory.append(transition)
304
305 def update(self):
306 """Update policy using PPO."""
307 # Convert memory to tensors
308 states = torch.FloatTensor([t.state for t in self.memory])
309 actions = torch.FloatTensor([t.action for t in self.memory])
310 old_log_probs = torch.FloatTensor([t.log_prob for t in self.memory])
311
312 # Compute returns and advantages
313 returns = []
314 advantages = []
315 running_return = 0
316 running_advantage = 0
317
318 for t in reversed(self.memory):
319 running_return = t.reward + self.gamma * running_return * (1 - t.done)
320 returns.insert(0, running_return)
321
322 td_error = t.reward + self.gamma * running_return * (1 - t.done) - t.value
323 running_advantage = td_error + self.gamma * 0.95 * running_advantage * (1 - t.done)
324 advantages.insert(0, running_advantage)
325
326 returns = torch.FloatTensor(returns)
327 advantages = torch.FloatTensor(advantages)
328 advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
329
330 # PPO update
331 for _ in range(self.k_epochs):
332 # Evaluate actions
333 action_mean, action_std, values = self.policy(states)
334 dist = torch.distributions.Normal(action_mean, action_std)
335 log_probs = dist.log_prob(actions).sum(dim=-1)
336 entropy = dist.entropy().sum(dim=-1).mean()
337
338 # Compute ratio and surrogate loss
339 ratio = torch.exp(log_probs - old_log_probs)
340 surr1 = ratio * advantages
341 surr2 = torch.clamp(ratio, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
342
343 actor_loss = -torch.min(surr1, surr2).mean()
344 critic_loss = nn.MSELoss()(values.squeeze(), returns)
345
346 loss = actor_loss + 0.5 * critic_loss - 0.01 * entropy
347
348 self.optimizer.zero_grad()
349 loss.backward()
350 self.optimizer.step()
351
352 self.memory.clear()
353
354
355def train_robot(episodes: int = 500, gui: bool = False):
356 """Train robot using PPO."""
357 # Create environment
358 if PYBULLET_AVAILABLE:
359 env = KukaReachEnv(gui=gui)
360 state_dim = 17 # 7 joint pos + 7 joint vel + 3 distance
361 action_dim = 7
362 print("Using PyBullet Kuka environment")
363 else:
364 env = FallbackCartPoleEnv()
365 state_dim = 4
366 action_dim = 1
367 print("Using fallback CartPole environment")
368
369 agent = PPOAgent(state_dim, action_dim)
370
371 episode_rewards = []
372 update_freq = 10 # Update every N episodes
373
374 for episode in range(episodes):
375 state = env.reset()
376 episode_reward = 0
377
378 for step in range(200):
379 action, log_prob, value = agent.select_action(state)
380 next_state, reward, done = env.step(action)
381
382 transition = Transition(
383 state, action, reward, next_state, done, log_prob, value
384 )
385 agent.store_transition(transition)
386
387 episode_reward += reward
388 state = next_state
389
390 if done:
391 break
392
393 episode_rewards.append(episode_reward)
394
395 # Update policy
396 if (episode + 1) % update_freq == 0:
397 agent.update()
398
399 if (episode + 1) % 10 == 0:
400 avg_reward = np.mean(episode_rewards[-10:])
401 print(f"Episode {episode + 1}, Avg Reward: {avg_reward:.2f}")
402
403 if PYBULLET_AVAILABLE:
404 env.close()
405
406 print(f"\nTraining completed! Final average reward: {np.mean(episode_rewards[-100:]):.2f}")
407
408
409if __name__ == "__main__":
410 import argparse
411
412 parser = argparse.ArgumentParser(description="Robot control with PPO")
413 parser.add_argument("--episodes", type=int, default=500, help="Number of episodes")
414 parser.add_argument("--gui", action="store_true", help="Show PyBullet GUI")
415
416 args = parser.parse_args()
417
418 train_robot(episodes=args.episodes, gui=args.gui)