19
19
from crazyflow .sim .symbolic import symbolic_attitude
20
20
from flax .struct import dataclass
21
21
from gymnasium import spaces
22
- from jax .scipy .spatial .transform import Rotation as JaxR
23
22
from scipy .spatial .transform import Rotation as R
24
23
25
24
from lsy_drone_racing .envs .randomize import (
@@ -264,7 +263,7 @@ def __init__(
264
263
265
264
def _reset (
266
265
self , * , seed : int | None = None , options : dict | None = None , mask : Array | None = None
267
- ) -> tuple [dict [str , NDArray [ np . floating ] ], dict ]:
266
+ ) -> tuple [dict [str , Array ], dict ]:
268
267
"""Reset the environment.
269
268
270
269
Args:
@@ -287,9 +286,7 @@ def _reset(
287
286
self .data = self ._reset_env_data (self .data , self .sim .data .states .pos , mask )
288
287
return self .obs (), self .info ()
289
288
290
- def _step (
291
- self , action : NDArray [np .floating ]
292
- ) -> tuple [dict [str , NDArray [np .floating ]], float , bool , bool , dict ]:
289
+ def _step (self , action : Array ) -> tuple [dict [str , Array ], float , bool , bool , dict ]:
293
290
"""Step the firmware_wrapper class and its environment.
294
291
295
292
This function should be called once at the rate of ctrl_freq. Step processes and high level
@@ -319,7 +316,7 @@ def _step(
319
316
self ._reset (mask = marked_for_reset )
320
317
return self .obs (), self .reward (), self .terminated (), self .truncated (), self .info ()
321
318
322
- def apply_action (self , action : NDArray [ np . floating ] ):
319
+ def apply_action (self , action : Array ):
323
320
"""Apply the commanded state action to the simulation."""
324
321
action = action .reshape ((self .sim .n_worlds , self .sim .n_drones , - 1 ))
325
322
if "action" in self .disturbances :
@@ -342,7 +339,7 @@ def close(self):
342
339
"""Close the environment by stopping the drone and landing back at the starting position."""
343
340
self .sim .close ()
344
341
345
- def obs (self ) -> dict [str , NDArray [ np . floating ] ]:
342
+ def obs (self ) -> dict [str , Array ]:
346
343
"""Return the observation of the environment."""
347
344
# Add the gate and obstacle poses to the info. If gates or obstacles are in sensor range,
348
345
# use the actual pose, otherwise use the nominal pose.
@@ -564,13 +561,19 @@ def _load_track_into_sim(self, gate_spec: MjSpec, obstacle_spec: MjSpec):
564
561
frame = self .sim .spec .worldbody .add_frame ()
565
562
n_gates , n_obstacles = len (self .gates ["pos" ]), len (self .obstacles ["pos" ])
566
563
for i in range (n_gates ):
567
- gate = frame .attach_body (gate_spec .find_body ("gate" ), "" , f":{ i } " )
564
+ gate_body = gate_spec .body ("gate" )
565
+ if gate_body is None :
566
+ raise ValueError ("Gate body not found in gate spec" )
567
+ gate = frame .attach_body (gate_body , "" , f":{ i } " )
568
568
gate .pos = self .gates ["pos" ][i ]
569
569
# Convert from scipy order to MuJoCo order
570
570
gate .quat = self .gates ["quat" ][i ][[3 , 0 , 1 , 2 ]]
571
571
gate .mocap = True # Make mocap to modify the position of static bodies during sim
572
572
for i in range (n_obstacles ):
573
- obstacle = frame .attach_body (obstacle_spec .find_body ("obstacle" ), "" , f":{ i } " )
573
+ obstacle_body = obstacle_spec .body ("obstacle" )
574
+ if obstacle_body is None :
575
+ raise ValueError ("Obstacle body not found in obstacle spec" )
576
+ obstacle = frame .attach_body (obstacle_body , "" , f":{ i } " )
574
577
obstacle .pos = self .obstacles ["pos" ][i ]
575
578
obstacle .mocap = True
576
579
self .sim .build (data = False , default_data = False )
0 commit comments