Skip to content

Commit 28fd1cb

Browse files
committed
[wip,broken] Load gates and obstacles into the simulation. Fix environments for all control types
1 parent 73b979d commit 28fd1cb

File tree

7 files changed

+91
-102
lines changed

7 files changed

+91
-102
lines changed

config/level0.toml

+3-7
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,9 @@ practice_without_track_objects = false
1717

1818
[sim]
1919
# Physics options:
20-
# "pyb": PyBullet
21-
# "dyn": Mathematical dynamics model
22-
# "pyb_gnd" PyBullet with ground effect
23-
# "pyb_drag": PyBullet with drag
24-
# "pyb_dw": PyBullet with downwash
25-
# "pyb_gnd_drag_dw": PyBullet with ground effect, drag, and downwash.
26-
# "sys_id": System identification model. Only supported for attitude control interface (DroneRacingThrust-v0)
20+
# "analytical": Analytical, simplified dynamics model
21+
# "mujoco": Mujoco dynamics. May take longer to compile at startup.
22+
# "sys_id": System identification model.
2723
physics = "analytical"
2824

2925
camera_view = [5.0, -40.0, -40.0, 0.5, -1.0, 0.5]

lsy_drone_racing/control/thrust_controller.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,11 @@ def __init__(self, initial_obs: dict[str, NDArray[np.floating]], initial_info: d
5252
# Same waypoints as in the trajectory controller. Determined by trial and error.
5353
waypoints = np.array(
5454
[
55-
[1.0, 1.0, 0.0],
55+
[1.0, 1.0, 0.05],
5656
[0.8, 0.5, 0.2],
57-
[0.55, -0.8, 0.4],
57+
[0.55, -0.8, 0.5],
5858
[0.2, -1.8, 0.65],
59-
[1.1, -1.35, 1.0],
59+
[1.1, -1.35, 1.1],
6060
[0.2, 0.0, 0.65],
6161
[0.0, 0.75, 0.525],
6262
[0.0, 0.75, 1.1],
@@ -124,7 +124,6 @@ def compute_control(
124124
target_thrust += self.kp * pos_error
125125
target_thrust += self.ki * self.i_error
126126
target_thrust += self.kd * vel_error
127-
# target_thrust += params.quad.m * desired_acc
128127
target_thrust[2] += self.drone_mass * self.g
129128

130129
# Update z_axis to the current orientation of the drone

lsy_drone_racing/control/trajectory_controller.py

+3-21
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from typing import TYPE_CHECKING
1515

1616
import numpy as np
17-
import pybullet as p
1817
from scipy.interpolate import CubicSpline
1918

2019
from lsy_drone_racing.control import BaseController
@@ -37,11 +36,11 @@ def __init__(self, initial_obs: dict[str, NDArray[np.floating]], initial_info: d
3736
super().__init__(initial_obs, initial_info)
3837
waypoints = np.array(
3938
[
40-
[1.0, 1.0, 0.0],
39+
[1.0, 1.0, 0.05],
4140
[0.8, 0.5, 0.2],
42-
[0.55, -0.8, 0.4],
41+
[0.55, -0.8, 0.5],
4342
[0.2, -1.8, 0.65],
44-
[1.1, -1.35, 1.0],
43+
[1.1, -1.35, 1.1],
4544
[0.2, 0.0, 0.65],
4645
[0.0, 0.75, 0.525],
4746
[0.0, 0.75, 1.1],
@@ -55,23 +54,6 @@ def __init__(self, initial_obs: dict[str, NDArray[np.floating]], initial_info: d
5554
self._tick = 0
5655
self._freq = initial_info["env_freq"]
5756

58-
# Generate points along the spline for visualization
59-
t_vis = np.linspace(0, self.t_total - 1, 100)
60-
spline_points = self.trajectory(t_vis)
61-
try:
62-
# Plot the spline as a line in PyBullet
63-
for i in range(len(spline_points) - 1):
64-
p.addUserDebugLine(
65-
spline_points[i],
66-
spline_points[i + 1],
67-
lineColorRGB=[1, 0, 0], # Red color
68-
lineWidth=2,
69-
lifeTime=0, # 0 means the line persists indefinitely
70-
physicsClientId=0,
71-
)
72-
except p.error:
73-
... # Ignore errors if PyBullet is not available
74-
7557
def compute_control(
7658
self, obs: dict[str, NDArray[np.floating]], info: dict | None = None
7759
) -> NDArray[np.floating]:

lsy_drone_racing/envs/drone_racing_env.py

+64-48
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,17 @@
2727

2828
import copy as copy
2929
import logging
30+
from pathlib import Path
3031
from typing import TYPE_CHECKING
3132

3233
import gymnasium
34+
import mujoco
3335
import numpy as np
3436
from crazyflow import Sim
3537
from gymnasium import spaces
3638
from scipy.spatial.transform import Rotation as R
3739

3840
from lsy_drone_racing.sim.noise import NoiseList
39-
from lsy_drone_racing.sim.physics import PhysicsMode
4041
from lsy_drone_racing.utils import check_gate_pass
4142

4243
if TYPE_CHECKING:
@@ -70,16 +71,21 @@ class DroneRacingEnv(gymnasium.Env):
7071
- "ang_vel": Drone angular velocity
7172
- "gates.pos": Positions of the gates
7273
- "gates.rpy": Orientations of the gates
73-
- "gates.visited": Flags indicating if the drone already was/ is in the sensor range of the gates and the true position is known
74+
- "gates.visited": Flags indicating if the drone already was/ is in the sensor range of the
75+
gates and the true position is known
7476
- "obstacles.pos": Positions of the obstacles
75-
- "obstacles.visited": Flags indicating if the drone already was/ is in the sensor range of the obstacles and the true position is known
77+
- "obstacles.visited": Flags indicating if the drone already was/ is in the sensor range of the
78+
obstacles and the true position is known
7679
- "target_gate": The current target gate index
7780
7881
The action space consists of a desired full-state command
7982
[x, y, z, vx, vy, vz, ax, ay, az, yaw, rrate, prate, yrate] that is tracked by the drone's
8083
low-level controller.
8184
"""
8285

86+
gate_spec_path = Path(__file__).parents[1] / "sim/assets/gate.urdf"
87+
obstacle_spec_path = Path(__file__).parents[1] / "sim/assets/obstacle.urdf"
88+
8389
def __init__(self, config: dict):
8490
"""Initialize the DroneRacingEnv.
8591
@@ -92,13 +98,12 @@ def __init__(self, config: dict):
9298
n_worlds=1,
9399
n_drones=1,
94100
physics=config.sim.physics,
95-
control="state",
101+
control=config.sim.get("control", "state"),
96102
freq=config.sim.sim_freq,
97103
state_freq=config.env.freq,
98104
attitude_freq=config.sim.attitude_freq,
99105
rng_key=config.env.seed,
100106
)
101-
self.contact_mask = np.array([0], dtype=bool)
102107
if config.sim.sim_freq % config.env.freq != 0:
103108
raise ValueError(f"({config.sim.sim_freq=}) is no multiple of ({config.env.freq=})")
104109
self.action_space = spaces.Box(low=-1, high=1, shape=(13,))
@@ -134,6 +139,7 @@ def __init__(self, config: dict):
134139
self._steps = 0
135140
self._last_drone_pos = np.zeros(3)
136141
self.gates, self.obstacles, self.drone = self.load_track(config.env.track)
142+
self.n_gates = len(config.env.track.gates)
137143
self.disturbances = self.load_disturbances(config.env.get("disturbances", None))
138144

139145
self.gates_visited = np.array([False] * len(config.env.track.gates))
@@ -151,13 +157,12 @@ def reset(
151157
Returns:
152158
Observation and info.
153159
"""
154-
# The system identification model is based on the attitude control interface. We cannot
155-
# support its use with the full state control interface
156160
if self.config.env.reseed:
157161
self.sim.seed(self.config.env.seed)
158162
if seed is not None:
159163
self.sim.seed(seed)
160164
self.sim.reset()
165+
# TODO: Add randomization of gates, obstacles, drone, and disturbances
161166
states = self.sim.data.states.replace(
162167
pos=self.drone["pos"].reshape((1, 1, 3)),
163168
quat=self.drone["quat"].reshape((1, 1, 4)),
@@ -168,12 +173,12 @@ def reset(
168173
self.target_gate = 0
169174
self._steps = 0
170175
self._last_drone_pos[:] = self.sim.data.states.pos[0, 0]
171-
info = self.info
176+
info = self.info()
172177
info["sim_freq"] = self.sim.data.core.freq
173178
info["low_level_ctrl_freq"] = self.sim.data.controls.attitude_freq
174-
info["drone_mass"] = self.sim.default_data.params.mass[0, 0]
179+
info["drone_mass"] = self.sim.default_data.params.mass[0, 0, 0]
175180
info["env_freq"] = self.config.env.freq
176-
return self.obs, info
181+
return self.obs(), info
177182

178183
def step(
179184
self, action: NDArray[np.floating]
@@ -187,20 +192,21 @@ def step(
187192
action: Full-state command [x, y, z, vx, vy, vz, ax, ay, az, yaw, rrate, prate, yrate]
188193
to follow.
189194
"""
190-
assert (
191-
self.config.sim.physics != PhysicsMode.SYS_ID
192-
), "sys_id model not supported for full state control interface"
193-
action = action.astype(np.float64) # Drone firmware expects float64
194195
assert action.shape == self.action_space.shape, f"Invalid action shape: {action.shape}"
195-
self.sim.state_control(action.reshape((1, 1, 13)))
196-
self.sim.step(self.sim.freq // self.sim.control_freq)
197-
return self.obs, self.reward, self.terminated, False, self.info
196+
# TODO: Add action noise
197+
# TODO: Check why sim is being compiled twice
198+
self.sim.state_control(action.reshape((1, 1, 13)).astype(np.float32))
199+
self.sim.step(self.sim.freq // self.config.env.freq)
200+
self.target_gate += self.gate_passed()
201+
if self.target_gate == self.n_gates:
202+
self.target_gate = -1
203+
self._last_drone_pos[:] = self.sim.data.states.pos[0, 0]
204+
return self.obs(), self.reward(), self.terminated(), False, self.info()
198205

199206
def render(self):
200207
"""Render the environment."""
201208
self.sim.render()
202209

203-
@property
204210
def obs(self) -> dict[str, NDArray[np.floating]]:
205211
"""Return the observation of the environment."""
206212
obs = {
@@ -240,7 +246,6 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
240246
obs = self.disturbances["observation"].apply(obs)
241247
return obs
242248

243-
@property
244249
def reward(self) -> float:
245250
"""Compute the reward for the current state.
246251
@@ -254,7 +259,6 @@ def reward(self) -> float:
254259
"""
255260
return -1.0 if self.target_gate != -1 else 0.0
256261

257-
@property
258262
def terminated(self) -> bool:
259263
"""Check if the episode is terminated.
260264
@@ -274,18 +278,17 @@ def terminated(self) -> bool:
274278
}
275279
if state not in self.state_space:
276280
return True # Drone is out of bounds
277-
if np.logical_and(self.sim.contacts("drone:0"), self.contact_mask).any():
281+
if self.sim.contacts("drone:0").any():
278282
return True
279283
if self.sim.data.states.pos[0, 0, 2] < 0.0:
280284
return True
281285
if self.target_gate == -1: # Drone has passed all gates
282286
return True
283287
return False
284288

285-
@property
286289
def info(self) -> dict:
287290
"""Return an info dictionary containing additional information about the environment."""
288-
return {"collisions": self.sim.contacts("drone:0"), "symbolic_model": self.symbolic}
291+
return {"collisions": self.sim.contacts("drone:0").any(), "symbolic_model": self.symbolic}
289292

290293
def load_track(self, track: dict) -> tuple[dict, dict, dict]:
291294
"""Load the track from the config file."""
@@ -299,8 +302,32 @@ def load_track(self, track: dict) -> tuple[dict, dict, dict]:
299302
for k in ("pos", "rpy", "vel", "rpy_rates")
300303
}
301304
drone["quat"] = R.from_euler("xyz", drone["rpy"]).as_quat()
305+
# Load the models into the simulation and set their positions
306+
self._load_track_into_sim(gates, obstacles)
302307
return gates, obstacles, drone
303308

309+
def _load_track_into_sim(self, gates: dict, obstacles: dict):
310+
"""Load the track into the simulation."""
311+
gate_spec = mujoco.MjSpec.from_file(str(self.gate_spec_path))
312+
obstacle_spec = mujoco.MjSpec.from_file(str(self.obstacle_spec_path))
313+
spec = self.sim.spec
314+
frame = spec.worldbody.add_frame()
315+
for i in range(len(gates["pos"])):
316+
gate = frame.attach_body(gate_spec.find_body("world"), "", f":g{i}")
317+
gate.pos = gates["pos"][i]
318+
quat = R.from_euler("xyz", gates["rpy"][i]).as_quat()
319+
gate.quat = quat[[3, 0, 1, 2]] # MuJoCo uses wxyz order instead of xyzw
320+
for i in range(len(obstacles["pos"])):
321+
obstacle = frame.attach_body(obstacle_spec.find_body("world"), "", f":o{i}")
322+
obstacle.pos = obstacles["pos"][i]
323+
# TODO: Simplify rebuilding the simulation after changing the mujoco model
324+
self.sim.mj_model, self.sim.mj_data, self.sim.mjx_model, mjx_data = self.sim.compile_mj(
325+
spec
326+
)
327+
self.sim.data = self.sim.data.replace(mjx_data=mjx_data)
328+
self.sim.default_data = self.sim.data.replace()
329+
self.sim.build()
330+
304331
def load_disturbances(self, disturbances: dict | None = None) -> dict:
305332
"""Load the disturbances from the config."""
306333
dist = {}
@@ -316,10 +343,10 @@ def gate_passed(self) -> bool:
316343
Returns:
317344
True if the drone has passed a gate, else False.
318345
"""
319-
if self.sim.n_gates > 0 and self.target_gate < self.sim.n_gates and self.target_gate != -1:
320-
gate_pos = self.sim.gates[self.target_gate]["pos"]
321-
gate_rot = R.from_euler("xyz", self.sim.gates[self.target_gate]["rpy"])
322-
drone_pos = self.sim.drone.pos
346+
if self.n_gates > 0 and self.target_gate < self.n_gates and self.target_gate != -1:
347+
gate_pos = self.gates["pos"][self.target_gate]
348+
gate_rot = R.from_euler("xyz", self.gates["rpy"][self.target_gate])
349+
drone_pos = self.sim.data.states.pos[0, 0]
323350
last_drone_pos = self._last_drone_pos
324351
gate_size = (0.45, 0.45)
325352
return check_gate_pass(gate_pos, gate_rot, gate_size, drone_pos, last_drone_pos)
@@ -343,6 +370,7 @@ def __init__(self, config: dict):
343370
Args:
344371
config: Configuration dictionary for the environment.
345372
"""
373+
config.sim.control = "attitude"
346374
super().__init__(config)
347375
bounds = np.array([1, np.pi, np.pi, np.pi], dtype=np.float32)
348376
self.action_space = spaces.Box(low=-bounds, high=bounds)
@@ -356,24 +384,12 @@ def step(
356384
action: Thrust command [thrust, roll, pitch, yaw].
357385
"""
358386
assert action.shape == self.action_space.shape, f"Invalid action shape: {action.shape}"
359-
action = action.astype(np.float64)
360-
collision = False
361-
# We currently need to differentiate between the sys_id backend and all others because the
362-
# simulation step size is different for the sys_id backend (we do not substep in the
363-
# identified model). In future iterations, the sim API should be flexible to handle both
364-
# cases without an explicit step_sys_id function.
365-
if self.config.sim.physics == "sys_id":
366-
cmd_thrust, cmd_rpy = action[0], action[1:]
367-
self.sim.step_sys_id(cmd_thrust, cmd_rpy, 1 / self.config.env.freq)
368-
self.target_gate += self.gate_passed()
369-
if self.target_gate == self.sim.n_gates:
370-
self.target_gate = -1
371-
self._last_drone_pos[:] = self.sim.drone.pos
372-
else:
373-
# Crazyflie firmware expects negated pitch command. TODO: Check why this is the case and
374-
# fix this on the firmware side if possible.
375-
cmd_thrust, cmd_rpy = action[0], action[1:] * np.array([1, -1, 1])
376-
self.sim.drone.collective_thrust_cmd(cmd_thrust, cmd_rpy)
377-
collision = self._inner_step_loop()
378-
terminated = self.terminated or collision
379-
return self.obs, self.reward, terminated, False, self.info
387+
# TODO: Add action noise
388+
# TODO: Check why sim is being compiled twice
389+
self.sim.attitude_control(action.reshape((1, 1, 4)).astype(np.float32))
390+
self.sim.step(self.sim.freq // self.config.env.freq)
391+
self.target_gate += self.gate_passed()
392+
if self.target_gate == self.n_gates:
393+
self.target_gate = -1
394+
self._last_drone_pos[:] = self.sim.data.states.pos[0, 0]
395+
return self.obs(), self.reward(), self.terminated(), False, self.info()

0 commit comments

Comments
 (0)