27
27
28
28
import copy as copy
29
29
import logging
30
+ from functools import partial
30
31
from pathlib import Path
31
- from typing import TYPE_CHECKING , Callable
32
+ from typing import TYPE_CHECKING , Any , Callable
32
33
33
34
import gymnasium
35
+ import jax
34
36
import mujoco
35
37
import numpy as np
36
38
from crazyflow import Sim
37
- from crazyflow .sim .sim import identity
38
39
from gymnasium import spaces
39
40
from scipy .spatial .transform import Rotation as R
40
41
41
- from lsy_drone_racing .envs .utils import (
42
- randomize_drone_inertia_fn ,
43
- randomize_drone_mass_fn ,
44
- randomize_drone_pos_fn ,
45
- randomize_drone_quat_fn ,
46
- )
47
- from lsy_drone_racing .sim .noise import NoiseList
42
+ from lsy_drone_racing .envs .utils import randomize_sim_fn
48
43
from lsy_drone_racing .utils import check_gate_pass
49
44
50
45
if TYPE_CHECKING :
@@ -92,8 +87,8 @@ class DroneRacingEnv(gymnasium.Env):
92
87
low-level controller.
93
88
"""
94
89
95
- gate_spec_path = Path (__file__ ).parents [1 ] / "sim/assets/gate.urdf "
96
- obstacle_spec_path = Path (__file__ ).parents [1 ] / "sim/assets/obstacle.urdf "
90
+ gate_spec_path = Path (__file__ ).parents [1 ] / "sim/assets/gate.xml "
91
+ obstacle_spec_path = Path (__file__ ).parents [1 ] / "sim/assets/obstacle.xml "
97
92
98
93
def __init__ (self , config : dict ):
99
94
"""Initialize the DroneRacingEnv.
@@ -153,7 +148,7 @@ def __init__(self, config: dict):
153
148
self .n_gates = len (config .env .track .gates )
154
149
self .disturbances = self .load_disturbances (config .env .get ("disturbances" , None ))
155
150
self .randomization = self .load_randomizations (config .env .get ("randomization" , None ))
156
- self .contact_mask = np .ones ((self .sim .n_worlds , 29 ), dtype = bool )
151
+ self .contact_mask = np .ones ((self .sim .n_worlds , 25 ), dtype = bool )
157
152
self .contact_mask [..., 0 ] = 0 # Ignore contacts with the floor
158
153
159
154
self .setup_sim ()
@@ -180,6 +175,7 @@ def reset(
180
175
# Randomization of gates, obstacles and drones is compiled into the sim reset function with
181
176
# the sim.reset_hook function, so we don't need to explicitly do it here
182
177
self .sim .reset ()
178
+
183
179
# TODO: Add randomization of gates, obstacles, drone, and disturbances
184
180
self .target_gate = 0
185
181
self ._steps = 0
@@ -313,105 +309,99 @@ def load_track(self, track: dict) -> tuple[dict, dict, dict]:
313
309
for k in ("pos" , "rpy" , "vel" , "rpy_rates" )
314
310
}
315
311
drone ["quat" ] = R .from_euler ("xyz" , drone ["rpy" ]).as_quat ()
316
- # Load the models into the simulation and set their positions
317
- self ._load_track_into_sim (gates , obstacles )
318
312
return gates , obstacles , drone
319
313
320
- def _load_track_into_sim (self , gates : dict , obstacles : dict ):
321
- """Load the track into the simulation."""
322
- gate_spec = mujoco .MjSpec .from_file (str (self .gate_spec_path ))
323
- obstacle_spec = mujoco .MjSpec .from_file (str (self .obstacle_spec_path ))
324
- spec = self .sim .spec
325
- frame = spec .worldbody .add_frame ()
326
- for i in range (len (gates ["pos" ])):
327
- gate = frame .attach_body (gate_spec .find_body ("world" ), "" , f":g{ i } " )
328
- gate .pos = gates ["pos" ][i ]
329
- quat = R .from_euler ("xyz" , gates ["rpy" ][i ]).as_quat ()
330
- gate .quat = quat [[3 , 0 , 1 , 2 ]] # MuJoCo uses wxyz order instead of xyzw
331
- for i in range (len (obstacles ["pos" ])):
332
- obstacle = frame .attach_body (obstacle_spec .find_body ("world" ), "" , f":o{ i } " )
333
- obstacle .pos = obstacles ["pos" ][i ]
334
- self .sim .build ()
335
-
336
314
def load_disturbances (self , disturbances : dict | None = None ) -> dict :
337
315
"""Load the disturbances from the config."""
338
- dist = {} # TODO: Add jax disturbances for the simulator dynamics
316
+ # TODO: Add jax disturbances for the simulator dynamics
339
317
if disturbances is None : # Default: no passive disturbances.
340
- return dist
341
- for mode , spec in disturbances .items ():
342
- dist [mode ] = NoiseList .from_specs ([spec ])
343
- return dist
318
+ return {}
319
+ return {mode : self .load_random_fn (spec ) for mode , spec in disturbances .items ()}
344
320
345
321
def load_randomizations (self , randomizations : dict | None = None ) -> dict :
346
322
"""Load the randomization from the config."""
347
323
if randomizations is None :
348
324
return {}
349
- return {}
325
+ return {mode : self .load_random_fn (spec ) for mode , spec in randomizations .items ()}
326
+
327
+ @staticmethod
328
+ def load_random_fn (fn_spec : dict ) -> Callable :
329
+ """Convert a function spec to a function from jax.random."""
330
+ offset , scale = np .array (fn_spec .get ("offset" , 0 )), np .array (fn_spec .get ("scale" , 1 ))
331
+ kwargs = fn_spec .get ("kwargs" , {})
332
+ if "shape" in kwargs :
333
+ raise KeyError ("Shape must not be specified for randomization functions." )
334
+ kwargs = {k : np .array (v ) if isinstance (v , list ) else v for k , v in kwargs .items ()}
335
+ jax_fn = partial (getattr (jax .random , fn_spec ["fn" ]), ** kwargs )
336
+
337
+ def random_fn (* args : Any , ** kwargs : Any ) -> Array :
338
+ return jax_fn (* args , ** kwargs ) * scale + offset
339
+
340
+ return random_fn
350
341
351
342
def setup_sim (self ):
352
343
"""Setup the simulation data and build the reset and step functions with custom hooks."""
344
+ self ._load_track_into_sim (self .gates , self .obstacles )
353
345
pos = self .drone ["pos" ].reshape (self .sim .data .states .pos .shape )
354
346
quat = self .drone ["quat" ].reshape (self .sim .data .states .quat .shape )
355
347
vel = self .drone ["vel" ].reshape (self .sim .data .states .vel .shape )
356
348
rpy_rates = self .drone ["rpy_rates" ].reshape (self .sim .data .states .rpy_rates .shape )
357
349
states = self .sim .data .states .replace (pos = pos , quat = quat , vel = vel , rpy_rates = rpy_rates )
358
350
self .sim .data = self .sim .data .replace (states = states )
359
- reset_hook = build_reset_hook (self .randomization )
360
- self .sim .reset_hook = reset_hook
351
+ self .sim .reset_hook = build_reset_hook (self .randomization )
361
352
self .sim .build (mjx = False , data = False ) # Save the reset state and rebuild the reset function
362
353
354
+ def _load_track_into_sim (self , gates : dict , obstacles : dict ):
355
+ """Load the track into the simulation."""
356
+ gate_spec = mujoco .MjSpec .from_file (str (self .gate_spec_path ))
357
+ obstacle_spec = mujoco .MjSpec .from_file (str (self .obstacle_spec_path ))
358
+ spec = self .sim .spec
359
+ frame = spec .worldbody .add_frame ()
360
+ n_gates , n_obstacles = len (gates ["pos" ]), len (obstacles ["pos" ])
361
+ for i in range (n_gates ):
362
+ gate = frame .attach_body (gate_spec .find_body ("gate" ), "" , f":{ i } " )
363
+ gate .pos = gates ["pos" ][i ]
364
+ gate .quat = R .from_euler ("xyz" , gates ["rpy" ][i ]).as_quat ()[[3 , 0 , 1 , 2 ]] # MuJoCo order
365
+ gate .mocap = True # Make mocap to modify the position of static bodies during sim
366
+ for i in range (n_obstacles ):
367
+ obstacle = frame .attach_body (obstacle_spec .find_body ("obstacle" ), "" , f":{ i } " )
368
+ obstacle .pos = obstacles ["pos" ][i ]
369
+ obstacle .mocap = True
370
+ self .sim .build (data = False , default_data = False )
371
+ assert not hasattr (self .sim .data , "gate_pos" )
372
+ assert not hasattr (self .sim .data , "obstacle_pos" )
373
+
374
+ gate_ids = [self .sim .mj_model .body (f"gate:{ i } " ).id for i in range (n_gates )]
375
+ gates ["ids" ] = gate_ids
376
+ obstacle_ids = [self .sim .mj_model .body (f"obstacle:{ i } " ).id for i in range (n_obstacles )]
377
+ obstacles ["ids" ] = obstacle_ids
378
+
363
379
def gate_passed (self ) -> bool :
364
380
"""Check if the drone has passed a gate.
365
381
366
382
Returns:
367
383
True if the drone has passed a gate, else False.
368
384
"""
369
- if self .n_gates > 0 and self .target_gate < self .n_gates and self .target_gate != - 1 :
370
- gate_pos = self .gates ["pos" ][self .target_gate ]
371
- gate_rot = R .from_euler ("xyz" , self .gates ["rpy" ][self .target_gate ])
372
- drone_pos = self .sim .data .states .pos [0 , 0 ]
373
- last_drone_pos = self ._last_drone_pos
374
- gate_size = (0.45 , 0.45 )
375
- return check_gate_pass (gate_pos , gate_rot , gate_size , drone_pos , last_drone_pos )
376
- return False
385
+ if self .n_gates <= 0 or self .target_gate >= self .n_gates or self .target_gate == - 1 :
386
+ return False
387
+ gate_pos = self .gates ["pos" ][self .target_gate ]
388
+ gate_rot = R .from_euler ("xyz" , self .gates ["rpy" ][self .target_gate ])
389
+ drone_pos = self .sim .data .states .pos [0 , 0 ]
390
+ gate_size = (0.45 , 0.45 )
391
+ return check_gate_pass (gate_pos , gate_rot , gate_size , drone_pos , self ._last_drone_pos )
377
392
378
393
def close (self ):
379
394
"""Close the environment by stopping the drone and landing back at the starting position."""
380
395
self .sim .close ()
381
396
382
397
383
- def build_reset_hook (randomizations : dict ) -> Callable [[SimData , Array [ bool ] ], SimData ]:
398
+ def build_reset_hook (randomizations : dict ) -> Callable [[SimData , Array ], SimData ]:
384
399
"""Build the reset hook for the simulation."""
385
- modify_drone_pos = identity
386
- if "drone_pos" in randomizations :
387
- modify_drone_pos = randomize_drone_pos_fn (randomizations ["drone_pos" ])
388
- modify_drone_quat = identity
389
- if "drone_rpy" in randomizations :
390
- modify_drone_quat = randomize_drone_quat_fn (randomizations ["drone_rpy" ])
391
- modify_drone_mass = identity
392
- if "drone_mass" in randomizations :
393
- modify_drone_mass = randomize_drone_mass_fn (randomizations ["drone_mass" ])
394
- modify_drone_inertia = identity
395
- if "drone_inertia" in randomizations :
396
- modify_drone_inertia = randomize_drone_inertia_fn (randomizations ["drone_inertia" ])
397
- modify_gate_pos = identity
398
- if "gate_pos" in randomizations :
399
- modify_gate_pos = randomize_gate_pos_fn (randomizations ["gate_pos" ])
400
- modify_gate_rpy = identity
401
- if "gate_rpy" in randomizations :
402
- modify_gate_rpy = randomize_gate_rpy_fn (randomizations ["gate_rpy" ])
403
- modify_obstacle_pos = identity
404
- if "obstacle_pos" in randomizations :
405
- modify_obstacle_pos = randomize_obstacle_pos_fn (randomizations ["obstacle_pos" ])
406
-
407
- def reset_hook (data : SimData , mask : Array [bool ]) -> SimData :
408
- data = modify_drone_pos (data , mask )
409
- data = modify_drone_quat (data , mask )
410
- data = modify_drone_mass (data , mask )
411
- data = modify_drone_inertia (data , mask )
412
- data = modify_gate_pos (data , mask )
413
- data = modify_gate_rpy (data , mask )
414
- data = modify_obstacle_pos (data , mask )
400
+ randomizations = [randomize_sim_fn (target , rng ) for target , rng in randomizations .items ()]
401
+
402
+ def reset_hook (data : SimData , mask : Array ) -> SimData :
403
+ for randomize in randomizations :
404
+ data = randomize (data , mask )
415
405
return data
416
406
417
407
return reset_hook
0 commit comments