@@ -63,8 +63,6 @@ class EnvData:
63
63
contact_masks : Array
64
64
pos_limit_low : Array
65
65
pos_limit_high : Array
66
- rpy_limit_low : Array
67
- rpy_limit_high : Array
68
66
gate_mj_ids : Array
69
67
obstacle_mj_ids : Array
70
68
max_episode_steps : Array
@@ -84,8 +82,6 @@ def create(
84
82
sensor_range : float ,
85
83
pos_limit_low : Array ,
86
84
pos_limit_high : Array ,
87
- rpy_limit_low : Array ,
88
- rpy_limit_high : Array ,
89
85
device : Device ,
90
86
) -> EnvData :
91
87
"""Create a new environment data struct with default values."""
@@ -100,8 +96,6 @@ def create(
100
96
steps = jp .zeros (n_envs , dtype = int , device = device ),
101
97
pos_limit_low = jp .array (pos_limit_low , dtype = np .float32 , device = device ),
102
98
pos_limit_high = jp .array (pos_limit_high , dtype = np .float32 , device = device ),
103
- rpy_limit_low = jp .array (rpy_limit_low , dtype = np .float32 , device = device ),
104
- rpy_limit_high = jp .array (rpy_limit_high , dtype = np .float32 , device = device ),
105
99
gate_mj_ids = jp .array (gate_mj_ids , dtype = int , device = device ),
106
100
obstacle_mj_ids = jp .array (obstacle_mj_ids , dtype = int , device = device ),
107
101
max_episode_steps = jp .array ([max_episode_steps ], dtype = int , device = device ),
@@ -124,12 +118,12 @@ def build_observation_space(n_gates: int, n_obstacles: int) -> spaces.Dict:
124
118
"""Create the observation space for the environment."""
125
119
obs_spec = {
126
120
"pos" : spaces .Box (low = - np .inf , high = np .inf , shape = (3 ,)),
127
- "rpy " : spaces .Box (low = - np . pi , high = np . pi , shape = (3 ,)),
121
+ "quat " : spaces .Box (low = - 1 , high = 1 , shape = (4 ,)),
128
122
"vel" : spaces .Box (low = - np .inf , high = np .inf , shape = (3 ,)),
129
123
"ang_vel" : spaces .Box (low = - np .inf , high = np .inf , shape = (3 ,)),
130
124
"target_gate" : spaces .Discrete (n_gates , start = - 1 ),
131
125
"gates_pos" : spaces .Box (low = - np .inf , high = np .inf , shape = (n_gates , 3 )),
132
- "gates_rpy " : spaces .Box (low = - np . pi , high = np . pi , shape = (n_gates , 3 )),
126
+ "gates_quat " : spaces .Box (low = - 1 , high = 1 , shape = (n_gates , 4 )),
133
127
"gates_visited" : spaces .Box (low = 0 , high = 1 , shape = (n_gates ,), dtype = bool ),
134
128
"obstacles_pos" : spaces .Box (low = - np .inf , high = np .inf , shape = (n_obstacles , 3 )),
135
129
"obstacles_visited" : spaces .Box (low = 0 , high = 1 , shape = (n_obstacles ,), dtype = bool ),
@@ -160,15 +154,15 @@ class RaceCoreEnv:
160
154
161
155
The observation space is a dictionary with the following keys:
162
156
- "pos": Drone position
163
- - "rpy ": Drone orientation (roll, pitch, yaw )
157
+ - "quat ": Drone orientation as a quaternion (x, y, z, w )
164
158
- "vel": Drone linear velocity
165
159
- "ang_vel": Drone angular velocity
166
- - "gates.pos ": Positions of the gates
167
- - "gates.rpy ": Orientations of the gates
168
- - "gates.visited ": Flags indicating if the drone already was/ is in the sensor range of the
160
+ - "gates_pos ": Positions of the gates
161
+ - "gates_quat ": Orientations of the gates
162
+ - "gates_visited ": Flags indicating if the drone already was/ is in the sensor range of the
169
163
gates and the true position is known
170
- - "obstacles.pos ": Positions of the obstacles
171
- - "obstacles.visited ": Flags indicating if the drone already was/ is in the sensor range of the
164
+ - "obstacles_pos ": Positions of the obstacles
165
+ - "obstacles_visited ": Flags indicating if the drone already was/ is in the sensor range of the
172
166
obstacles and the true position is known
173
167
- "target_gate": The current target gate index
174
168
@@ -253,7 +247,6 @@ def __init__(
253
247
gate_mj_ids , obstacle_mj_ids = self .gates ["mj_ids" ], self .obstacles ["mj_ids" ]
254
248
pos_limit_low = jp .array ([- 3 , - 3 , 0 ], dtype = np .float32 , device = self .device )
255
249
pos_limit_high = jp .array ([3 , 3 , 2.5 ], dtype = np .float32 , device = self .device )
256
- rpy_limit = jp .array ([jp .pi / 2 , jp .pi / 2 , jp .pi ], dtype = jp .float32 , device = self .device )
257
250
self .data = EnvData .create (
258
251
n_envs ,
259
252
n_drones ,
@@ -266,8 +259,6 @@ def __init__(
266
259
sensor_range ,
267
260
pos_limit_low ,
268
261
pos_limit_high ,
269
- - rpy_limit ,
270
- rpy_limit ,
271
262
self .device ,
272
263
)
273
264
@@ -315,16 +306,14 @@ def _step(
315
306
self .sim .data = self ._warp_disabled_drones (self .sim .data , self .data .disabled_drones )
316
307
# Apply the environment logic. Check which drones are now disabled, check which gates have
317
308
# been passed, and update the target gate.
318
- drone_pos , drone_quat = self .sim .data .states .pos , self . sim . data . states . quat
309
+ drone_pos = self .sim .data .states .pos
319
310
mocap_pos , mocap_quat = self .sim .data .mjx_data .mocap_pos , self .sim .data .mjx_data .mocap_quat
320
311
contacts = self .sim .contacts ()
321
312
# Get marked_for_reset before it is updated, because the autoreset needs to be based on the
322
313
# previous flags, not the ones from the current step
323
314
marked_for_reset = self .data .marked_for_reset
324
315
# Apply the environment logic with updated simulation data.
325
- self .data = self ._step_env (
326
- self .data , drone_pos , drone_quat , mocap_pos , mocap_quat , contacts
327
- )
316
+ self .data = self ._step_env (self .data , drone_pos , mocap_pos , mocap_quat , contacts )
328
317
# Auto-reset envs. Add configuration option to disable for single-world envs
329
318
if self .autoreset and marked_for_reset .any ():
330
319
self ._reset (mask = marked_for_reset )
@@ -357,27 +346,25 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
357
346
"""Return the observation of the environment."""
358
347
# Add the gate and obstacle poses to the info. If gates or obstacles are in sensor range,
359
348
# use the actual pose, otherwise use the nominal pose.
360
- gates_pos , gates_rpy , obstacles_pos = self ._obs (
349
+ gates_pos , gates_quat , obstacles_pos = self ._obs (
361
350
self .sim .data .mjx_data .mocap_pos ,
362
351
self .sim .data .mjx_data .mocap_quat ,
363
352
self .data .gates_visited ,
364
353
self .gates ["mj_ids" ],
365
354
self .gates ["nominal_pos" ],
366
- self .gates ["nominal_rpy " ],
355
+ self .gates ["nominal_quat " ],
367
356
self .data .obstacles_visited ,
368
357
self .obstacles ["mj_ids" ],
369
358
self .obstacles ["nominal_pos" ],
370
359
)
371
- quat = self .sim .data .states .quat
372
- rpy = R .from_quat (quat .reshape (- 1 , 4 )).as_euler ("xyz" ).reshape ((* quat .shape [:- 1 ], 3 ))
373
360
obs = {
374
361
"pos" : np .array (self .sim .data .states .pos , dtype = np .float32 ),
375
- "rpy " : rpy . astype ( np .float32 ),
362
+ "quat " : np . array ( self . sim . data . states . quat , dtype = np .float32 ),
376
363
"vel" : np .array (self .sim .data .states .vel , dtype = np .float32 ),
377
364
"ang_vel" : np .array (self .sim .data .states .ang_vel , dtype = np .float32 ),
378
365
"target_gate" : np .array (self .data .target_gate , dtype = int ),
379
366
"gates_pos" : np .asarray (gates_pos , dtype = np .float32 ),
380
- "gates_rpy " : np .asarray (gates_rpy , dtype = np .float32 ),
367
+ "gates_quat " : np .asarray (gates_quat , dtype = np .float32 ),
381
368
"gates_visited" : np .asarray (self .data .gates_visited , dtype = bool ),
382
369
"obstacles_pos" : np .asarray (obstacles_pos , dtype = np .float32 ),
383
370
"obstacles_visited" : np .asarray (self .data .obstacles_visited , dtype = bool ),
@@ -447,16 +434,11 @@ def _reset_env_data(data: EnvData, drone_pos: Array, mask: Array | None = None)
447
434
@staticmethod
448
435
@jax .jit
449
436
def _step_env (
450
- data : EnvData ,
451
- drone_pos : Array ,
452
- drone_quat : Array ,
453
- mocap_pos : Array ,
454
- mocap_quat : Array ,
455
- contacts : Array ,
437
+ data : EnvData , drone_pos : Array , mocap_pos : Array , mocap_quat : Array , contacts : Array
456
438
) -> EnvData :
457
439
"""Step the environment data."""
458
440
n_gates = len (data .gate_mj_ids )
459
- disabled_drones = RaceCoreEnv ._disabled_drones (drone_pos , drone_quat , contacts , data )
441
+ disabled_drones = RaceCoreEnv ._disabled_drones (drone_pos , contacts , data )
460
442
gates_pos = mocap_pos [:, data .gate_mj_ids ]
461
443
obstacles_pos = mocap_pos [:, data .obstacle_mj_ids ]
462
444
# We need to convert the mocap quat from MuJoCo order to scipy order
@@ -498,27 +480,24 @@ def _obs(
498
480
gates_visited : Array ,
499
481
gate_mocap_ids : Array ,
500
482
nominal_gate_pos : NDArray ,
501
- nominal_gate_rpy : NDArray ,
483
+ nominal_gate_quat : NDArray ,
502
484
obstacles_visited : Array ,
503
485
obstacle_mocap_ids : Array ,
504
486
nominal_obstacle_pos : NDArray ,
505
487
) -> tuple [Array , Array ]:
506
488
"""Get the nominal or real gate positions and orientations depending on the sensor range."""
507
489
mask , real_pos = gates_visited [..., None ], mocap_pos [:, gate_mocap_ids ]
508
- real_rpy = JaxR . from_quat ( mocap_quat [:, gate_mocap_ids ][..., [1 , 2 , 3 , 0 ]]). as_euler ( "xyz" )
490
+ real_quat = mocap_quat [:, gate_mocap_ids ][..., [1 , 2 , 3 , 0 ]]
509
491
gates_pos = jp .where (mask , real_pos [:, None ], nominal_gate_pos [None , None ])
510
- gates_rpy = jp .where (mask , real_rpy [:, None ], nominal_gate_rpy [None , None ])
492
+ gates_quat = jp .where (mask , real_quat [:, None ], nominal_gate_quat [None , None ])
511
493
mask , real_pos = obstacles_visited [..., None ], mocap_pos [:, obstacle_mocap_ids ]
512
494
obstacles_pos = jp .where (mask , real_pos [:, None ], nominal_obstacle_pos [None , None ])
513
- return gates_pos , gates_rpy , obstacles_pos
495
+ return gates_pos , gates_quat , obstacles_pos
514
496
515
497
@staticmethod
516
- def _disabled_drones (pos : Array , quat : Array , contacts : Array , data : EnvData ) -> Array :
517
- rpy = JaxR .from_quat (quat ).as_euler ("xyz" )
518
- disabled = jp .logical_or (data .disabled_drones , jp .all (pos < data .pos_limit_low , axis = - 1 ))
519
- disabled = jp .logical_or (disabled , jp .all (pos > data .pos_limit_high , axis = - 1 ))
520
- disabled = jp .logical_or (disabled , jp .all (rpy < data .rpy_limit_low , axis = - 1 ))
521
- disabled = jp .logical_or (disabled , jp .all (rpy > data .rpy_limit_high , axis = - 1 ))
498
+ def _disabled_drones (pos : Array , contacts : Array , data : EnvData ) -> Array :
499
+ disabled = jp .logical_or (data .disabled_drones , jp .any (pos < data .pos_limit_low , axis = - 1 ))
500
+ disabled = jp .logical_or (disabled , jp .any (pos > data .pos_limit_high , axis = - 1 ))
522
501
disabled = jp .logical_or (disabled , data .target_gate == - 1 )
523
502
contacts = jp .any (jp .logical_and (contacts [:, None , :], data .contact_masks ), axis = - 1 )
524
503
disabled = jp .logical_or (disabled , contacts )
@@ -539,8 +518,13 @@ def _warp_disabled_drones(data: SimData, mask: Array) -> SimData:
539
518
def _load_track (self , track : dict ) -> tuple [dict , dict , dict ]:
540
519
"""Load the track from the config file."""
541
520
gate_pos = np .array ([g ["pos" ] for g in track .gates ])
542
- gate_rpy = np .array ([g ["rpy" ] for g in track .gates ])
543
- gates = {"pos" : gate_pos , "rpy" : gate_rpy , "nominal_pos" : gate_pos , "nominal_rpy" : gate_rpy }
521
+ gate_quat = R .from_euler ("xyz" , np .array ([g ["rpy" ] for g in track .gates ])).as_quat ()
522
+ gates = {
523
+ "pos" : gate_pos ,
524
+ "quat" : gate_quat ,
525
+ "nominal_pos" : gate_pos ,
526
+ "nominal_quat" : gate_quat ,
527
+ }
544
528
obstacle_pos = np .array ([o ["pos" ] for o in track .obstacles ])
545
529
obstacles = {"pos" : obstacle_pos , "nominal_pos" : obstacle_pos }
546
530
drone_keys = ("pos" , "rpy" , "vel" , "ang_vel" )
@@ -578,7 +562,7 @@ def _load_track_into_sim(self, gate_spec: MjSpec, obstacle_spec: MjSpec):
578
562
gate = frame .attach_body (gate_spec .find_body ("gate" ), "" , f":{ i } " )
579
563
gate .pos = self .gates ["pos" ][i ]
580
564
# Convert from scipy order to MuJoCo order
581
- gate .quat = R . from_euler ( "xyz" , self .gates ["rpy " ][i ]). as_quat () [[3 , 0 , 1 , 2 ]]
565
+ gate .quat = self .gates ["quat " ][i ][[3 , 0 , 1 , 2 ]]
582
566
gate .mocap = True # Make mocap to modify the position of static bodies during sim
583
567
for i in range (n_obstacles ):
584
568
obstacle = frame .attach_body (obstacle_spec .find_body ("obstacle" ), "" , f":{ i } " )
0 commit comments