|
31 | 31 | randomize_gate_rpy_fn,
|
32 | 32 | randomize_obstacle_pos_fn,
|
33 | 33 | )
|
| 34 | +from lsy_drone_racing.utils.utils import gate_passed |
34 | 35 |
|
35 | 36 | if TYPE_CHECKING:
|
36 | 37 | from crazyflow.sim.structs import SimData
|
@@ -457,15 +458,15 @@ def _step_env(
|
457 | 458 | n_gates = len(data.gate_mj_ids)
|
458 | 459 | disabled_drones = RaceCoreEnv._disabled_drones(drone_pos, drone_quat, contacts, data)
|
459 | 460 | gates_pos = mocap_pos[:, data.gate_mj_ids]
|
| 461 | + obstacles_pos = mocap_pos[:, data.obstacle_mj_ids] |
460 | 462 | # We need to convert the mocap quat from MuJoCo order to scipy order
|
461 | 463 | gates_quat = mocap_quat[:, data.gate_mj_ids][..., [3, 0, 1, 2]]
|
462 |
| - obstacles_pos = mocap_pos[:, data.obstacle_mj_ids] |
463 | 464 | # Extract the gate poses of the current target gates and check if the drones have passed
|
464 | 465 | # them between the last and current position
|
465 | 466 | gate_ids = data.gate_mj_ids[data.target_gate % n_gates]
|
466 | 467 | gate_pos = gates_pos[jp.arange(gates_pos.shape[0])[:, None], gate_ids]
|
467 | 468 | gate_quat = gates_quat[jp.arange(gates_quat.shape[0])[:, None], gate_ids]
|
468 |
| - passed = RaceCoreEnv._gate_passed(drone_pos, data.last_drone_pos, gate_pos, gate_quat) |
| 469 | + passed = gate_passed(drone_pos, data.last_drone_pos, gate_pos, gate_quat, (0.45, 0.45)) |
469 | 470 | # Update the target gate index. Increment by one if drones have passed a gate
|
470 | 471 | target_gate = data.target_gate + passed * ~disabled_drones
|
471 | 472 | target_gate = jp.where(target_gate >= n_gates, -1, target_gate)
|
@@ -511,6 +512,30 @@ def _obs(
|
511 | 512 | obstacles_pos = jp.where(mask, real_pos[:, None], nominal_obstacle_pos[None, None])
|
512 | 513 | return gates_pos, gates_rpy, obstacles_pos
|
513 | 514 |
|
| 515 | + @staticmethod |
| 516 | + def _disabled_drones(pos: Array, quat: Array, contacts: Array, data: EnvData) -> Array: |
| 517 | + rpy = JaxR.from_quat(quat).as_euler("xyz") |
| 518 | + disabled = jp.logical_or(data.disabled_drones, jp.all(pos < data.pos_limit_low, axis=-1)) |
| 519 | + disabled = jp.logical_or(disabled, jp.all(pos > data.pos_limit_high, axis=-1)) |
| 520 | + disabled = jp.logical_or(disabled, jp.all(rpy < data.rpy_limit_low, axis=-1)) |
| 521 | + disabled = jp.logical_or(disabled, jp.all(rpy > data.rpy_limit_high, axis=-1)) |
| 522 | + disabled = jp.logical_or(disabled, data.target_gate == -1) |
| 523 | + contacts = jp.any(jp.logical_and(contacts[:, None, :], data.contact_masks), axis=-1) |
| 524 | + disabled = jp.logical_or(disabled, contacts) |
| 525 | + return disabled |
| 526 | + |
| 527 | + @staticmethod |
| 528 | + def _visited(drone_pos: Array, target_pos: Array, sensor_range: float, visited: Array) -> Array: |
| 529 | + dpos = drone_pos[..., None, :2] - target_pos[:, None, :, :2] |
| 530 | + return jp.logical_or(visited, jp.linalg.norm(dpos, axis=-1) < sensor_range) |
| 531 | + |
| 532 | + @staticmethod |
| 533 | + @jax.jit |
| 534 | + def _warp_disabled_drones(data: SimData, mask: Array) -> SimData: |
| 535 | + """Warp the disabled drones below the ground.""" |
| 536 | + pos = jax.numpy.where(mask[..., None], -1, data.states.pos) |
| 537 | + return data.replace(states=data.states.replace(pos=pos)) |
| 538 | + |
514 | 539 | def _load_track(self, track: dict) -> tuple[dict, dict, dict]:
|
515 | 540 | """Load the track from the config file."""
|
516 | 541 | gate_pos = np.array([g["pos"] for g in track.gates])
|
@@ -593,52 +618,6 @@ def _load_contact_masks(self, sim: Sim) -> Array:
|
593 | 618 | masks = np.tile(masks[None, ...], (sim.n_worlds, 1, 1))
|
594 | 619 | return masks
|
595 | 620 |
|
596 |
| - @staticmethod |
597 |
| - def _disabled_drones(pos: Array, quat: Array, contacts: Array, data: EnvData) -> Array: |
598 |
| - rpy = JaxR.from_quat(quat).as_euler("xyz") |
599 |
| - disabled = jp.logical_or(data.disabled_drones, jp.all(pos < data.pos_limit_low, axis=-1)) |
600 |
| - disabled = jp.logical_or(disabled, jp.all(pos > data.pos_limit_high, axis=-1)) |
601 |
| - disabled = jp.logical_or(disabled, jp.all(rpy < data.rpy_limit_low, axis=-1)) |
602 |
| - disabled = jp.logical_or(disabled, jp.all(rpy > data.rpy_limit_high, axis=-1)) |
603 |
| - disabled = jp.logical_or(disabled, data.target_gate == -1) |
604 |
| - contacts = jp.any(jp.logical_and(contacts[:, None, :], data.contact_masks), axis=-1) |
605 |
| - disabled = jp.logical_or(disabled, contacts) |
606 |
| - return disabled |
607 |
| - |
608 |
| - @staticmethod |
609 |
| - def _gate_passed( |
610 |
| - drone_pos: Array, last_drone_pos: Array, gate_pos: Array, gate_quat: Array |
611 |
| - ) -> bool: |
612 |
| - """Check if the drone has passed a gate. |
613 |
| -
|
614 |
| - Returns: |
615 |
| - True if the drone has passed a gate, else False. |
616 |
| - """ |
617 |
| - gate_rot = JaxR.from_quat(gate_quat) |
618 |
| - gate_size = (0.45, 0.45) |
619 |
| - last_pos_local = gate_rot.apply(last_drone_pos - gate_pos, inverse=True) |
620 |
| - pos_local = gate_rot.apply(drone_pos - gate_pos, inverse=True) |
621 |
| - # Check if the line between the last position and the current position intersects the plane. |
622 |
| - # If so, calculate the point of the intersection and check if it is within the gate box. |
623 |
| - passed_plane = (last_pos_local[..., 1] < 0) & (pos_local[..., 1] > 0) |
624 |
| - alpha = -last_pos_local[..., 1] / (pos_local[..., 1] - last_pos_local[..., 1]) |
625 |
| - x_intersect = alpha * (pos_local[..., 0]) + (1 - alpha) * last_pos_local[..., 0] |
626 |
| - z_intersect = alpha * (pos_local[..., 2]) + (1 - alpha) * last_pos_local[..., 2] |
627 |
| - in_box = (abs(x_intersect) < gate_size[0] / 2) & (abs(z_intersect) < gate_size[1] / 2) |
628 |
| - return passed_plane & in_box |
629 |
| - |
630 |
| - @staticmethod |
631 |
| - def _visited(drone_pos: Array, target_pos: Array, sensor_range: float, visited: Array) -> Array: |
632 |
| - dpos = drone_pos[..., None, :2] - target_pos[:, None, :, :2] |
633 |
| - return jp.logical_or(visited, jp.linalg.norm(dpos, axis=-1) < sensor_range) |
634 |
| - |
635 |
| - @staticmethod |
636 |
| - @jax.jit |
637 |
| - def _warp_disabled_drones(data: SimData, mask: Array) -> SimData: |
638 |
| - """Warp the disabled drones below the ground.""" |
639 |
| - pos = jax.numpy.where(mask[..., None], -1, data.states.pos) |
640 |
| - return data.replace(states=data.states.replace(pos=pos)) |
641 |
| - |
642 | 621 |
|
643 | 622 | # region Factories
|
644 | 623 |
|
|
0 commit comments