27
27
28
28
import copy as copy
29
29
import logging
30
+ from pathlib import Path
30
31
from typing import TYPE_CHECKING
31
32
32
33
import gymnasium
34
+ import mujoco
33
35
import numpy as np
34
36
from crazyflow import Sim
35
37
from gymnasium import spaces
36
38
from scipy .spatial .transform import Rotation as R
37
39
38
40
from lsy_drone_racing .sim .noise import NoiseList
39
- from lsy_drone_racing .sim .physics import PhysicsMode
40
41
from lsy_drone_racing .utils import check_gate_pass
41
42
42
43
if TYPE_CHECKING :
@@ -70,16 +71,21 @@ class DroneRacingEnv(gymnasium.Env):
70
71
- "ang_vel": Drone angular velocity
71
72
- "gates.pos": Positions of the gates
72
73
- "gates.rpy": Orientations of the gates
73
- - "gates.visited": Flags indicating if the drone already was/ is in the sensor range of the gates and the true position is known
74
+ - "gates.visited": Flags indicating if the drone already was/ is in the sensor range of the
75
+ gates and the true position is known
74
76
- "obstacles.pos": Positions of the obstacles
75
- - "obstacles.visited": Flags indicating if the drone already was/ is in the sensor range of the obstacles and the true position is known
77
+ - "obstacles.visited": Flags indicating if the drone already was/ is in the sensor range of the
78
+ obstacles and the true position is known
76
79
- "target_gate": The current target gate index
77
80
78
81
The action space consists of a desired full-state command
79
82
[x, y, z, vx, vy, vz, ax, ay, az, yaw, rrate, prate, yrate] that is tracked by the drone's
80
83
low-level controller.
81
84
"""
82
85
86
+ gate_spec_path = Path (__file__ ).parents [1 ] / "sim/assets/gate.urdf"
87
+ obstacle_spec_path = Path (__file__ ).parents [1 ] / "sim/assets/obstacle.urdf"
88
+
83
89
def __init__ (self , config : dict ):
84
90
"""Initialize the DroneRacingEnv.
85
91
@@ -92,13 +98,12 @@ def __init__(self, config: dict):
92
98
n_worlds = 1 ,
93
99
n_drones = 1 ,
94
100
physics = config .sim .physics ,
95
- control = " state" ,
101
+ control = config . sim . get ( "control" , " state") ,
96
102
freq = config .sim .sim_freq ,
97
103
state_freq = config .env .freq ,
98
104
attitude_freq = config .sim .attitude_freq ,
99
105
rng_key = config .env .seed ,
100
106
)
101
- self .contact_mask = np .array ([0 ], dtype = bool )
102
107
if config .sim .sim_freq % config .env .freq != 0 :
103
108
raise ValueError (f"({ config .sim .sim_freq = } ) is no multiple of ({ config .env .freq = } )" )
104
109
self .action_space = spaces .Box (low = - 1 , high = 1 , shape = (13 ,))
@@ -134,6 +139,7 @@ def __init__(self, config: dict):
134
139
self ._steps = 0
135
140
self ._last_drone_pos = np .zeros (3 )
136
141
self .gates , self .obstacles , self .drone = self .load_track (config .env .track )
142
+ self .n_gates = len (config .env .track .gates )
137
143
self .disturbances = self .load_disturbances (config .env .get ("disturbances" , None ))
138
144
139
145
self .gates_visited = np .array ([False ] * len (config .env .track .gates ))
@@ -151,13 +157,12 @@ def reset(
151
157
Returns:
152
158
Observation and info.
153
159
"""
154
- # The system identification model is based on the attitude control interface. We cannot
155
- # support its use with the full state control interface
156
160
if self .config .env .reseed :
157
161
self .sim .seed (self .config .env .seed )
158
162
if seed is not None :
159
163
self .sim .seed (seed )
160
164
self .sim .reset ()
165
+ # TODO: Add randomization of gates, obstacles, drone, and disturbances
161
166
states = self .sim .data .states .replace (
162
167
pos = self .drone ["pos" ].reshape ((1 , 1 , 3 )),
163
168
quat = self .drone ["quat" ].reshape ((1 , 1 , 4 )),
@@ -168,12 +173,12 @@ def reset(
168
173
self .target_gate = 0
169
174
self ._steps = 0
170
175
self ._last_drone_pos [:] = self .sim .data .states .pos [0 , 0 ]
171
- info = self .info
176
+ info = self .info ()
172
177
info ["sim_freq" ] = self .sim .data .core .freq
173
178
info ["low_level_ctrl_freq" ] = self .sim .data .controls .attitude_freq
174
- info ["drone_mass" ] = self .sim .default_data .params .mass [0 , 0 ]
179
+ info ["drone_mass" ] = self .sim .default_data .params .mass [0 , 0 , 0 ]
175
180
info ["env_freq" ] = self .config .env .freq
176
- return self .obs , info
181
+ return self .obs () , info
177
182
178
183
def step (
179
184
self , action : NDArray [np .floating ]
@@ -187,20 +192,21 @@ def step(
187
192
action: Full-state command [x, y, z, vx, vy, vz, ax, ay, az, yaw, rrate, prate, yrate]
188
193
to follow.
189
194
"""
190
- assert (
191
- self .config .sim .physics != PhysicsMode .SYS_ID
192
- ), "sys_id model not supported for full state control interface"
193
- action = action .astype (np .float64 ) # Drone firmware expects float64
194
195
assert action .shape == self .action_space .shape , f"Invalid action shape: { action .shape } "
195
- self .sim .state_control (action .reshape ((1 , 1 , 13 )))
196
- self .sim .step (self .sim .freq // self .sim .control_freq )
197
- return self .obs , self .reward , self .terminated , False , self .info
196
+ # TODO: Add action noise
197
+ # TODO: Check why sim is being compiled twice
198
+ self .sim .state_control (action .reshape ((1 , 1 , 13 )).astype (np .float32 ))
199
+ self .sim .step (self .sim .freq // self .config .env .freq )
200
+ self .target_gate += self .gate_passed ()
201
+ if self .target_gate == self .n_gates :
202
+ self .target_gate = - 1
203
+ self ._last_drone_pos [:] = self .sim .data .states .pos [0 , 0 ]
204
+ return self .obs (), self .reward (), self .terminated (), False , self .info ()
198
205
199
206
def render (self ):
200
207
"""Render the environment."""
201
208
self .sim .render ()
202
209
203
- @property
204
210
def obs (self ) -> dict [str , NDArray [np .floating ]]:
205
211
"""Return the observation of the environment."""
206
212
obs = {
@@ -240,7 +246,6 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
240
246
obs = self .disturbances ["observation" ].apply (obs )
241
247
return obs
242
248
243
- @property
244
249
def reward (self ) -> float :
245
250
"""Compute the reward for the current state.
246
251
@@ -254,7 +259,6 @@ def reward(self) -> float:
254
259
"""
255
260
return - 1.0 if self .target_gate != - 1 else 0.0
256
261
257
- @property
258
262
def terminated (self ) -> bool :
259
263
"""Check if the episode is terminated.
260
264
@@ -274,18 +278,17 @@ def terminated(self) -> bool:
274
278
}
275
279
if state not in self .state_space :
276
280
return True # Drone is out of bounds
277
- if np . logical_and ( self .sim .contacts ("drone:0" ), self . contact_mask ).any ():
281
+ if self .sim .contacts ("drone:0" ).any ():
278
282
return True
279
283
if self .sim .data .states .pos [0 , 0 , 2 ] < 0.0 :
280
284
return True
281
285
if self .target_gate == - 1 : # Drone has passed all gates
282
286
return True
283
287
return False
284
288
285
- @property
286
289
def info (self ) -> dict :
287
290
"""Return an info dictionary containing additional information about the environment."""
288
- return {"collisions" : self .sim .contacts ("drone:0" ), "symbolic_model" : self .symbolic }
291
+ return {"collisions" : self .sim .contacts ("drone:0" ). any () , "symbolic_model" : self .symbolic }
289
292
290
293
def load_track (self , track : dict ) -> tuple [dict , dict , dict ]:
291
294
"""Load the track from the config file."""
@@ -299,8 +302,32 @@ def load_track(self, track: dict) -> tuple[dict, dict, dict]:
299
302
for k in ("pos" , "rpy" , "vel" , "rpy_rates" )
300
303
}
301
304
drone ["quat" ] = R .from_euler ("xyz" , drone ["rpy" ]).as_quat ()
305
+ # Load the models into the simulation and set their positions
306
+ self ._load_track_into_sim (gates , obstacles )
302
307
return gates , obstacles , drone
303
308
309
+ def _load_track_into_sim (self , gates : dict , obstacles : dict ):
310
+ """Load the track into the simulation."""
311
+ gate_spec = mujoco .MjSpec .from_file (str (self .gate_spec_path ))
312
+ obstacle_spec = mujoco .MjSpec .from_file (str (self .obstacle_spec_path ))
313
+ spec = self .sim .spec
314
+ frame = spec .worldbody .add_frame ()
315
+ for i in range (len (gates ["pos" ])):
316
+ gate = frame .attach_body (gate_spec .find_body ("world" ), "" , f":g{ i } " )
317
+ gate .pos = gates ["pos" ][i ]
318
+ quat = R .from_euler ("xyz" , gates ["rpy" ][i ]).as_quat ()
319
+ gate .quat = quat [[3 , 0 , 1 , 2 ]] # MuJoCo uses wxyz order instead of xyzw
320
+ for i in range (len (obstacles ["pos" ])):
321
+ obstacle = frame .attach_body (obstacle_spec .find_body ("world" ), "" , f":o{ i } " )
322
+ obstacle .pos = obstacles ["pos" ][i ]
323
+ # TODO: Simplify rebuilding the simulation after changing the mujoco model
324
+ self .sim .mj_model , self .sim .mj_data , self .sim .mjx_model , mjx_data = self .sim .compile_mj (
325
+ spec
326
+ )
327
+ self .sim .data = self .sim .data .replace (mjx_data = mjx_data )
328
+ self .sim .default_data = self .sim .data .replace ()
329
+ self .sim .build ()
330
+
304
331
def load_disturbances (self , disturbances : dict | None = None ) -> dict :
305
332
"""Load the disturbances from the config."""
306
333
dist = {}
@@ -316,10 +343,10 @@ def gate_passed(self) -> bool:
316
343
Returns:
317
344
True if the drone has passed a gate, else False.
318
345
"""
319
- if self .sim . n_gates > 0 and self .target_gate < self . sim .n_gates and self .target_gate != - 1 :
320
- gate_pos = self .sim . gates [self . target_gate ][ "pos" ]
321
- gate_rot = R .from_euler ("xyz" , self .sim . gates [self . target_gate ][ "rpy" ])
322
- drone_pos = self .sim .drone . pos
346
+ if self .n_gates > 0 and self .target_gate < self .n_gates and self .target_gate != - 1 :
347
+ gate_pos = self .gates ["pos" ][ self . target_gate ]
348
+ gate_rot = R .from_euler ("xyz" , self .gates ["rpy" ][ self . target_gate ])
349
+ drone_pos = self .sim .data . states . pos [ 0 , 0 ]
323
350
last_drone_pos = self ._last_drone_pos
324
351
gate_size = (0.45 , 0.45 )
325
352
return check_gate_pass (gate_pos , gate_rot , gate_size , drone_pos , last_drone_pos )
@@ -343,6 +370,7 @@ def __init__(self, config: dict):
343
370
Args:
344
371
config: Configuration dictionary for the environment.
345
372
"""
373
+ config .sim .control = "attitude"
346
374
super ().__init__ (config )
347
375
bounds = np .array ([1 , np .pi , np .pi , np .pi ], dtype = np .float32 )
348
376
self .action_space = spaces .Box (low = - bounds , high = bounds )
@@ -356,24 +384,12 @@ def step(
356
384
action: Thrust command [thrust, roll, pitch, yaw].
357
385
"""
358
386
assert action .shape == self .action_space .shape , f"Invalid action shape: { action .shape } "
359
- action = action .astype (np .float64 )
360
- collision = False
361
- # We currently need to differentiate between the sys_id backend and all others because the
362
- # simulation step size is different for the sys_id backend (we do not substep in the
363
- # identified model). In future iterations, the sim API should be flexible to handle both
364
- # cases without an explicit step_sys_id function.
365
- if self .config .sim .physics == "sys_id" :
366
- cmd_thrust , cmd_rpy = action [0 ], action [1 :]
367
- self .sim .step_sys_id (cmd_thrust , cmd_rpy , 1 / self .config .env .freq )
368
- self .target_gate += self .gate_passed ()
369
- if self .target_gate == self .sim .n_gates :
370
- self .target_gate = - 1
371
- self ._last_drone_pos [:] = self .sim .drone .pos
372
- else :
373
- # Crazyflie firmware expects negated pitch command. TODO: Check why this is the case and
374
- # fix this on the firmware side if possible.
375
- cmd_thrust , cmd_rpy = action [0 ], action [1 :] * np .array ([1 , - 1 , 1 ])
376
- self .sim .drone .collective_thrust_cmd (cmd_thrust , cmd_rpy )
377
- collision = self ._inner_step_loop ()
378
- terminated = self .terminated or collision
379
- return self .obs , self .reward , terminated , False , self .info
387
+ # TODO: Add action noise
388
+ # TODO: Check why sim is being compiled twice
389
+ self .sim .attitude_control (action .reshape ((1 , 1 , 4 )).astype (np .float32 ))
390
+ self .sim .step (self .sim .freq // self .config .env .freq )
391
+ self .target_gate += self .gate_passed ()
392
+ if self .target_gate == self .n_gates :
393
+ self .target_gate = - 1
394
+ self ._last_drone_pos [:] = self .sim .data .states .pos [0 , 0 ]
395
+ return self .obs (), self .reward (), self .terminated (), False , self .info ()
0 commit comments