@@ -227,7 +227,7 @@ def step(
227
227
self .sim .step (self .sim .freq // self .freq )
228
228
# TODO: Clean up the accelerated functions
229
229
self .disabled_drones = np .array (
230
- self .update_active_drones_acc (
230
+ self ._disabled_drones (
231
231
self .sim .data .states .pos [0 ],
232
232
self .sim .data .states .quat [0 ],
233
233
self .pos_bounds .low ,
@@ -242,7 +242,7 @@ def step(
242
242
)
243
243
self .sim .data = self .warp_disabled_drones (self .sim .data , self .disabled_drones )
244
244
# TODO: Clean up the accelerated functions
245
- passed = self .gate_passed_accelerated (
245
+ passed = self ._gate_passed (
246
246
self .target_gate ,
247
247
self .gates ["mocap_ids" ],
248
248
self .sim .data .mjx_data .mocap_pos [0 ],
@@ -263,21 +263,11 @@ def render(self):
263
263
def obs (self ) -> dict [str , NDArray [np .floating ]]:
264
264
"""Return the observation of the environment."""
265
265
# TODO: Accelerate this function
266
- obs = {
267
- "pos" : np .array (self .sim .data .states .pos [0 ], dtype = np .float32 ),
268
- "rpy" : R .from_quat (self .sim .data .states .quat [0 ]).as_euler ("xyz" ).astype (np .float32 ),
269
- "vel" : np .array (self .sim .data .states .vel [0 ], dtype = np .float32 ),
270
- "ang_vel" : np .array (self .sim .data .states .rpy_rates [0 ], dtype = np .float32 ),
271
- }
272
- obs ["target_gate" ] = self .target_gate
273
266
# Add the gate and obstacle poses to the info. If gates or obstacles are in sensor range,
274
267
# use the actual pose, otherwise use the nominal pose.
275
- drone_pos = self .sim .data .states .pos [0 ]
276
- # Performance optimization: Get a continuous slice instead of using a list of indices which
277
- # copies the data. Assumes that the mocap ids are consecutive.
278
- gates_visited , gates_pos , gates_rpy = self .obs_acc_gates (
268
+ gates_visited , gates_pos , gates_rpy = self ._obs_gates (
279
269
self .gates_visited ,
280
- drone_pos ,
270
+ self . sim . data . states . pos [ 0 ] ,
281
271
self .sim .data .mjx_data .mocap_pos [0 ],
282
272
self .sim .data .mjx_data .mocap_quat [0 ],
283
273
self .gates ["mocap_ids" ],
@@ -286,62 +276,66 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
286
276
self .gates ["nominal_rpy" ],
287
277
)
288
278
self .gates_visited = np .asarray (gates_visited , dtype = bool )
289
- obs ["gates_pos" ] = np .asarray (gates_pos , dtype = np .float32 )
290
- obs ["gates_rpy" ] = np .asarray (gates_rpy , dtype = np .float32 )
291
- obs ["gates_visited" ] = self .gates_visited
292
-
293
- obstacles_visited , obstacles_pos = self .obs_acc_obstacles (
279
+ obstacles_visited , obstacles_pos = self ._obs_obstacles (
294
280
self .obstacles_visited ,
295
- drone_pos ,
281
+ self . sim . data . states . pos [ 0 ] ,
296
282
self .sim .data .mjx_data .mocap_pos [0 ],
297
283
self .obstacles ["mocap_ids" ],
298
284
self .sensor_range ,
299
285
self .obstacles ["nominal_pos" ],
300
286
)
301
287
self .obstacles_visited = np .asarray (obstacles_visited , dtype = bool )
302
- obs ["obstacles_pos" ] = np .asarray (obstacles_pos , dtype = np .float32 )
303
- obs ["obstacles_visited" ] = self .obstacles_visited
304
288
# TODO: Decide on observation disturbances
289
+ obs = {
290
+ "pos" : np .array (self .sim .data .states .pos [0 ], dtype = np .float32 ),
291
+ "rpy" : R .from_quat (self .sim .data .states .quat [0 ]).as_euler ("xyz" ).astype (np .float32 ),
292
+ "vel" : np .array (self .sim .data .states .vel [0 ], dtype = np .float32 ),
293
+ "ang_vel" : np .array (self .sim .data .states .rpy_rates [0 ], dtype = np .float32 ),
294
+ "target_gate" : self .target_gate ,
295
+ "gates_pos" : np .asarray (gates_pos , dtype = np .float32 ),
296
+ "gates_rpy" : np .asarray (gates_rpy , dtype = np .float32 ),
297
+ "gates_visited" : self .gates_visited ,
298
+ "obstacles_pos" : np .asarray (obstacles_pos , dtype = np .float32 ),
299
+ "obstacles_visited" : self .obstacles_visited ,
300
+ }
305
301
return obs
306
302
307
303
@staticmethod
308
304
@jax .jit
309
- def obs_acc_gates (
310
- gates_visited ,
311
- drone_pos ,
312
- mocap_pos ,
313
- mocap_quat ,
314
- mocap_ids ,
315
- sensor_range ,
316
- nominal_pos ,
317
- nominal_rpy ,
318
- ):
319
- # TODO: Clean up the accelerated functions
320
- gates_pos = mocap_pos [mocap_ids ]
321
- gates_quat = mocap_quat [mocap_ids ][..., [1 , 2 , 3 , 0 ]]
322
- gates_rpy = jax .scipy .spatial .transform .Rotation .from_quat (gates_quat ).as_euler ("xyz" )
323
- dpos = drone_pos [..., None , :2 ] - gates_pos [:, :2 ]
305
+ def _obs_gates (
306
+ gates_visited : NDArray ,
307
+ drone_pos : Array ,
308
+ mocap_pos : Array ,
309
+ mocap_quat : Array ,
310
+ mocap_ids : NDArray ,
311
+ sensor_range : float ,
312
+ nominal_pos : NDArray ,
313
+ nominal_rpy : NDArray ,
314
+ ) -> tuple [Array , Array , Array ]:
315
+ """Get the nominal or real gate positions and orientations depending on the sensor range."""
316
+ real_quat = mocap_quat [mocap_ids ][..., [1 , 2 , 3 , 0 ]]
317
+ real_rpy = jax .scipy .spatial .transform .Rotation .from_quat (real_quat ).as_euler ("xyz" )
318
+ dpos = drone_pos [..., None , :2 ] - mocap_pos [mocap_ids , :2 ]
324
319
in_range = jp .linalg .norm (dpos , axis = - 1 ) < sensor_range
325
320
gates_visited = jp .logical_or (gates_visited , in_range )
326
-
327
- mask = gates_visited [..., None ]
328
- gates_pos = jp .where (mask , gates_pos , nominal_pos )
329
- gates_rpy = jp .where (mask , gates_rpy , nominal_rpy )
321
+ gates_pos = jp .where (gates_visited [..., None ], mocap_pos [mocap_ids ], nominal_pos )
322
+ gates_rpy = jp .where (gates_visited [..., None ], real_rpy , nominal_rpy )
330
323
return gates_visited , gates_pos , gates_rpy
331
324
332
325
@staticmethod
333
326
@jax .jit
334
- def obs_acc_obstacles (
335
- obstacles_visited , drone_pos , mocap_pos , mocap_ids , sensor_range , nominal_pos
336
- ):
337
- # TODO: Clean up the accelerated functions
338
- obstacles_pos = mocap_pos [mocap_ids ]
339
- dpos = drone_pos [..., None , :2 ] - obstacles_pos [:, :2 ]
327
+ def _obs_obstacles (
328
+ visited : NDArray ,
329
+ drone_pos : Array ,
330
+ mocap_pos : Array ,
331
+ mocap_ids : NDArray ,
332
+ sensor_range : float ,
333
+ nominal_pos : NDArray ,
334
+ ) -> tuple [Array , Array ]:
335
+ dpos = drone_pos [..., None , :2 ] - mocap_pos [mocap_ids , :2 ]
340
336
in_range = jp .linalg .norm (dpos , axis = - 1 ) < sensor_range
341
- obstacles_visited = jp .logical_or (obstacles_visited , in_range )
342
- mask = obstacles_visited [..., None ]
343
- obstacles_pos = jp .where (mask , obstacles_pos , nominal_pos )
344
- return obstacles_visited , obstacles_pos
337
+ visited = jp .logical_or (visited , in_range )
338
+ return visited , jp .where (visited [..., None ], mocap_pos [mocap_ids ], nominal_pos )
345
339
346
340
def reward (self ) -> float :
347
341
"""Compute the reward for the current state.
@@ -368,33 +362,20 @@ def info(self) -> dict:
368
362
"""Return an info dictionary containing additional information about the environment."""
369
363
return {"collisions" : np .any (self .sim .contacts (), axis = - 1 ), "symbolic_model" : self .symbolic }
370
364
371
- def update_active_drones (self ):
372
- # TODO: Accelerate
373
- pos = self .sim .data .states .pos [0 , ...]
374
- rpy = R .from_quat (self .sim .data .states .quat [0 , ...]).as_euler ("xyz" )
375
- disabled = np .logical_or (self .disabled_drones , np .all (pos < self .pos_bounds .low , axis = - 1 ))
376
- disabled = np .logical_or (disabled , np .all (pos > self .pos_bounds .high , axis = - 1 ))
377
- disabled = np .logical_or (disabled , np .all (rpy < self .rpy_bounds .low , axis = - 1 ))
378
- disabled = np .logical_or (disabled , np .all (rpy > self .rpy_bounds .high , axis = - 1 ))
379
- disabled = np .logical_or (disabled , self .target_gate == - 1 )
380
- contacts = np .any (np .logical_and (self .sim .contacts (), self .contact_masks ), axis = - 1 )
381
- disabled = np .logical_or (disabled , contacts )
382
- self .disabled_drones = disabled
383
-
384
365
@staticmethod
385
366
@jax .jit
386
- def update_active_drones_acc (
387
- pos ,
388
- quat ,
389
- pos_low ,
390
- pos_high ,
391
- rpy_low ,
392
- rpy_high ,
393
- target_gate ,
394
- disabled_drones ,
395
- contacts ,
396
- contact_masks ,
397
- ):
367
+ def _disabled_drones (
368
+ pos : Array ,
369
+ quat : Array ,
370
+ pos_low : NDArray ,
371
+ pos_high : NDArray ,
372
+ rpy_low : NDArray ,
373
+ rpy_high : NDArray ,
374
+ target_gate : NDArray ,
375
+ disabled_drones : NDArray ,
376
+ contacts : Array ,
377
+ contact_masks : NDArray ,
378
+ ) -> Array :
398
379
rpy = jax .scipy .spatial .transform .Rotation .from_quat (quat ).as_euler ("xyz" )
399
380
disabled = jp .logical_or (disabled_drones , jp .all (pos < pos_low , axis = - 1 ))
400
381
disabled = jp .logical_or (disabled , jp .all (pos > pos_high , axis = - 1 ))
@@ -481,30 +462,9 @@ def _load_track_into_sim(self, gates: dict, obstacles: dict):
481
462
mocap_ids = [int (mj_model .body (f"obstacle:{ i } " ).mocapid ) for i in range (n_obstacles )]
482
463
obstacles ["mocap_ids" ] = np .array (mocap_ids , dtype = np .int32 )
483
464
484
- def gate_passed (self ) -> bool :
485
- """Check if the drone has passed a gate.
486
-
487
- Returns:
488
- True if the drone has passed a gate, else False.
489
- """
490
- passed = np .zeros (self .sim .n_drones , dtype = bool )
491
- if self .n_gates <= 0 :
492
- return passed
493
- gate_ids = self .target_gate % self .n_gates
494
- gate_mj_id = self .gates ["mocap_ids" ][gate_ids ]
495
- gate_pos = self .sim .data .mjx_data .mocap_pos [0 , gate_mj_id ].squeeze ()
496
- gate_rot = R .from_quat (self .sim .data .mjx_data .mocap_quat [0 , gate_mj_id ], scalar_first = True )
497
- drone_pos = self .sim .data .states .pos [0 ]
498
- gate_size = (0.45 , 0.45 )
499
- for i in range (self .sim .n_drones ):
500
- passed [i ] = check_gate_pass (
501
- gate_pos [i ], gate_rot [i ], gate_size , drone_pos [i ], self ._last_drone_pos [i ]
502
- )
503
- return passed
504
-
505
465
@staticmethod
506
466
@jax .jit
507
- def gate_passed_accelerated (
467
+ def _gate_passed (
508
468
target_gate : NDArray ,
509
469
mocap_ids : NDArray ,
510
470
mocap_pos : Array ,
@@ -518,18 +478,16 @@ def gate_passed_accelerated(
518
478
Returns:
519
479
True if the drone has passed a gate, else False.
520
480
"""
521
- # TODO: Test, refactor, optimize. Cover cases with no gates.
522
- gate_ids = target_gate % n_gates
523
- gate_mj_id = mocap_ids [gate_ids ]
524
- gate_pos = mocap_pos [gate_mj_id ]
525
- gate_rot = jax .scipy .spatial .transform .Rotation .from_quat (
526
- mocap_quat [gate_mj_id ][..., [1 , 2 , 3 , 0 ]]
527
- )
481
+ # TODO: Test. Cover cases with no gates.
482
+ ids = mocap_ids [target_gate % n_gates ]
483
+ gate_pos = mocap_pos [ids ]
484
+ gate_quat = mocap_quat [ids ][..., [1 , 2 , 3 , 0 ]]
485
+ gate_rot = jax .scipy .spatial .transform .Rotation .from_quat (gate_quat )
528
486
gate_size = (0.45 , 0.45 )
529
487
last_pos_local = gate_rot .apply (last_drone_pos - gate_pos , inverse = True )
530
488
pos_local = gate_rot .apply (drone_pos - gate_pos , inverse = True )
531
- # Check the plane intersection. If passed, calculate the point of the intersection and check if
532
- # it is within the gate box.
489
+ # Check if the line between the last position and the current position intersects the plane.
490
+ # If so, calculate the point of the intersection and check if it is within the gate box.
533
491
passed_plane = (last_pos_local [..., 1 ] < 0 ) & (pos_local [..., 1 ] > 0 )
534
492
alpha = - last_pos_local [..., 1 ] / (pos_local [..., 1 ] - last_pos_local [..., 1 ])
535
493
x_intersect = alpha * (pos_local [..., 0 ]) + (1 - alpha ) * last_pos_local [..., 0 ]
@@ -540,6 +498,7 @@ def gate_passed_accelerated(
540
498
@staticmethod
541
499
@jax .jit
542
500
def warp_disabled_drones (data : SimData , mask : NDArray ) -> SimData :
501
+ """Warp the disabled drones below the ground."""
543
502
mask = mask .reshape ((1 , - 1 , 1 ))
544
503
pos = jax .numpy .where (mask , - 1 , data .states .pos )
545
504
return data .replace (states = data .states .replace (pos = pos ))
0 commit comments