28
28
import copy as copy
29
29
import logging
30
30
from pathlib import Path
31
- from typing import TYPE_CHECKING
31
+ from typing import TYPE_CHECKING , Callable
32
32
33
33
import gymnasium
34
34
import mujoco
35
35
import numpy as np
36
36
from crazyflow import Sim
37
+ from crazyflow .sim .sim import identity
37
38
from gymnasium import spaces
38
39
from scipy .spatial .transform import Rotation as R
39
40
41
+ from lsy_drone_racing .envs .utils import (
42
+ randomize_drone_inertia_fn ,
43
+ randomize_drone_mass_fn ,
44
+ randomize_drone_pos_fn ,
45
+ randomize_drone_quat_fn ,
46
+ )
40
47
from lsy_drone_racing .sim .noise import NoiseList
41
48
from lsy_drone_racing .utils import check_gate_pass
42
49
43
50
if TYPE_CHECKING :
51
+ from crazyflow .sim .structs import SimData
52
+ from jax import Array
44
53
from numpy .typing import NDArray
45
54
46
55
logger = logging .getLogger (__name__ )
@@ -106,6 +115,7 @@ def __init__(self, config: dict):
106
115
)
107
116
if config .sim .sim_freq % config .env .freq != 0 :
108
117
raise ValueError (f"({ config .sim .sim_freq = } ) is no multiple of ({ config .env .freq = } )" )
118
+
109
119
self .action_space = spaces .Box (low = - 1 , high = 1 , shape = (13 ,))
110
120
n_gates , n_obstacles = len (config .env .track .gates ), len (config .env .track .obstacles )
111
121
self .observation_space = spaces .Dict (
@@ -134,16 +144,20 @@ def __init__(self, config: dict):
134
144
"ang_vel" : spaces .Box (low = - np .inf , high = np .inf , shape = (3 ,), dtype = np .float64 ),
135
145
}
136
146
)
147
+
137
148
self .target_gate = 0
138
149
self .symbolic = self .sim .symbolic () if config .env .symbolic else None
139
150
self ._steps = 0
140
151
self ._last_drone_pos = np .zeros (3 )
141
152
self .gates , self .obstacles , self .drone = self .load_track (config .env .track )
142
153
self .n_gates = len (config .env .track .gates )
143
154
self .disturbances = self .load_disturbances (config .env .get ("disturbances" , None ))
155
+ self .randomization = self .load_randomizations (config .env .get ("randomization" , None ))
144
156
self .contact_mask = np .ones ((self .sim .n_worlds , 29 ), dtype = bool )
145
157
self .contact_mask [..., 0 ] = 0 # Ignore contacts with the floor
146
158
159
+ self .setup_sim ()
160
+
147
161
self .gates_visited = np .array ([False ] * len (config .env .track .gates ))
148
162
self .obstacles_visited = np .array ([False ] * len (config .env .track .obstacles ))
149
163
@@ -167,13 +181,6 @@ def reset(
167
181
# the sim.reset_hook function, so we don't need to explicitly do it here
168
182
self .sim .reset ()
169
183
# TODO: Add randomization of gates, obstacles, drone, and disturbances
170
- states = self .sim .data .states .replace (
171
- pos = self .drone ["pos" ].reshape ((1 , 1 , 3 )),
172
- quat = self .drone ["quat" ].reshape ((1 , 1 , 4 )),
173
- vel = self .drone ["vel" ].reshape ((1 , 1 , 3 )),
174
- rpy_rates = self .drone ["rpy_rates" ].reshape ((1 , 1 , 3 )),
175
- )
176
- self .sim .data = self .sim .data .replace (states = states )
177
184
self .target_gate = 0
178
185
self ._steps = 0
179
186
self ._last_drone_pos [:] = self .sim .data .states .pos [0 , 0 ]
@@ -335,6 +342,24 @@ def load_disturbances(self, disturbances: dict | None = None) -> dict:
335
342
dist [mode ] = NoiseList .from_specs ([spec ])
336
343
return dist
337
344
345
+ def load_randomizations (self , randomizations : dict | None = None ) -> dict :
346
+ """Load the randomization from the config."""
347
+ if randomizations is None :
348
+ return {}
349
+ return {}
350
+
351
+ def setup_sim (self ):
352
+ """Setup the simulation data and build the reset and step functions with custom hooks."""
353
+ pos = self .drone ["pos" ].reshape (self .sim .data .states .pos .shape )
354
+ quat = self .drone ["quat" ].reshape (self .sim .data .states .quat .shape )
355
+ vel = self .drone ["vel" ].reshape (self .sim .data .states .vel .shape )
356
+ rpy_rates = self .drone ["rpy_rates" ].reshape (self .sim .data .states .rpy_rates .shape )
357
+ states = self .sim .data .states .replace (pos = pos , quat = quat , vel = vel , rpy_rates = rpy_rates )
358
+ self .sim .data = self .sim .data .replace (states = states )
359
+ reset_hook = build_reset_hook (self .randomization )
360
+ self .sim .reset_hook = reset_hook
361
+ self .sim .build (mjx = False , data = False ) # Save the reset state and rebuild the reset function
362
+
338
363
def gate_passed (self ) -> bool :
339
364
"""Check if the drone has passed a gate.
340
365
@@ -355,6 +380,43 @@ def close(self):
355
380
self .sim .close ()
356
381
357
382
383
+ def build_reset_hook (randomizations : dict ) -> Callable [[SimData , Array [bool ]], SimData ]:
384
+ """Build the reset hook for the simulation."""
385
+ modify_drone_pos = identity
386
+ if "drone_pos" in randomizations :
387
+ modify_drone_pos = randomize_drone_pos_fn (randomizations ["drone_pos" ])
388
+ modify_drone_quat = identity
389
+ if "drone_rpy" in randomizations :
390
+ modify_drone_quat = randomize_drone_quat_fn (randomizations ["drone_rpy" ])
391
+ modify_drone_mass = identity
392
+ if "drone_mass" in randomizations :
393
+ modify_drone_mass = randomize_drone_mass_fn (randomizations ["drone_mass" ])
394
+ modify_drone_inertia = identity
395
+ if "drone_inertia" in randomizations :
396
+ modify_drone_inertia = randomize_drone_inertia_fn (randomizations ["drone_inertia" ])
397
+ modify_gate_pos = identity
398
+ if "gate_pos" in randomizations :
399
+ modify_gate_pos = randomize_gate_pos_fn (randomizations ["gate_pos" ])
400
+ modify_gate_rpy = identity
401
+ if "gate_rpy" in randomizations :
402
+ modify_gate_rpy = randomize_gate_rpy_fn (randomizations ["gate_rpy" ])
403
+ modify_obstacle_pos = identity
404
+ if "obstacle_pos" in randomizations :
405
+ modify_obstacle_pos = randomize_obstacle_pos_fn (randomizations ["obstacle_pos" ])
406
+
407
+ def reset_hook (data : SimData , mask : Array [bool ]) -> SimData :
408
+ data = modify_drone_pos (data , mask )
409
+ data = modify_drone_quat (data , mask )
410
+ data = modify_drone_mass (data , mask )
411
+ data = modify_drone_inertia (data , mask )
412
+ data = modify_gate_pos (data , mask )
413
+ data = modify_gate_rpy (data , mask )
414
+ data = modify_obstacle_pos (data , mask )
415
+ return data
416
+
417
+ return reset_hook
418
+
419
+
358
420
class DroneRacingThrustEnv (DroneRacingEnv ):
359
421
"""Drone racing environment with a collective thrust attitude command interface.
360
422
0 commit comments