@@ -358,20 +358,20 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
358
358
self .obstacles ["nominal_pos" ],
359
359
)
360
360
obs = {
361
- "pos" : np . array ( self .sim .data .states .pos , dtype = np . float32 ) ,
362
- "quat" : np . array ( self .sim .data .states .quat , dtype = np . float32 ) ,
363
- "vel" : np . array ( self .sim .data .states .vel , dtype = np . float32 ) ,
364
- "ang_vel" : np . array ( self .sim .data .states .ang_vel , dtype = np . float32 ) ,
365
- "target_gate" : np . array ( self .data .target_gate , dtype = int ) ,
366
- "gates_pos" : np . asarray ( gates_pos , dtype = np . float32 ) ,
367
- "gates_quat" : np . asarray ( gates_quat , dtype = np . float32 ) ,
368
- "gates_visited" : np . asarray ( self .data .gates_visited , dtype = bool ) ,
369
- "obstacles_pos" : np . asarray ( obstacles_pos , dtype = np . float32 ) ,
370
- "obstacles_visited" : np . asarray ( self .data .obstacles_visited , dtype = bool ) ,
361
+ "pos" : self .sim .data .states .pos ,
362
+ "quat" : self .sim .data .states .quat ,
363
+ "vel" : self .sim .data .states .vel ,
364
+ "ang_vel" : self .sim .data .states .ang_vel ,
365
+ "target_gate" : self .data .target_gate ,
366
+ "gates_pos" : gates_pos ,
367
+ "gates_quat" : gates_quat ,
368
+ "gates_visited" : self .data .gates_visited ,
369
+ "obstacles_pos" : obstacles_pos ,
370
+ "obstacles_visited" : self .data .obstacles_visited ,
371
371
}
372
372
return obs
373
373
374
- def reward (self ) -> NDArray [ np . float32 ] :
374
+ def reward (self ) -> Array :
375
375
"""Compute the reward for the current state.
376
376
377
377
Note:
@@ -382,19 +382,19 @@ def reward(self) -> NDArray[np.float32]:
382
382
Returns:
383
383
Reward for the current state.
384
384
"""
385
- return np . array ( - 1.0 * (self .data .target_gate == - 1 ), dtype = np . float32 )
385
+ return - 1.0 * (self .data .target_gate == - 1 ) # Implicit float conversion
386
386
387
- def terminated (self ) -> NDArray [ np . bool_ ] :
387
+ def terminated (self ) -> Array :
388
388
"""Check if the episode is terminated.
389
389
390
390
Returns:
391
391
True if all drones have been disabled, else False.
392
392
"""
393
- return np . array ( self .data .disabled_drones , dtype = bool )
393
+ return self .data .disabled_drones
394
394
395
- def truncated (self ) -> NDArray [ np . bool_ ] :
395
+ def truncated (self ) -> Array :
396
396
"""Array of booleans indicating if the episode is truncated."""
397
- return np . tile (self .data .steps >= self .data .max_episode_steps , ( self .sim .n_drones , 1 ) )
397
+ return self . _truncated (self .data .steps , self .data .max_episode_steps , self .sim .n_drones )
398
398
399
399
def info (self ) -> dict :
400
400
"""Return an info dictionary containing additional information about the environment."""
@@ -494,6 +494,11 @@ def _obs(
494
494
obstacles_pos = jp .where (mask , real_pos [:, None ], nominal_obstacle_pos [None , None ])
495
495
return gates_pos , gates_quat , obstacles_pos
496
496
497
+ @staticmethod
498
+ @partial (jax .jit , static_argnames = "n_drones" )
499
+ def _truncated (steps : Array , max_episode_steps : Array , n_drones : int ) -> Array :
500
+ return jp .tile (steps >= max_episode_steps , (n_drones , 1 ))
501
+
497
502
@staticmethod
498
503
def _disabled_drones (pos : Array , contacts : Array , data : EnvData ) -> Array :
499
504
disabled = jp .logical_or (data .disabled_drones , jp .any (pos < data .pos_limit_low , axis = - 1 ))
0 commit comments