Skip to content

Commit 0df29e2

Browse files
committed
Remove old code. Improve tests. Improve gate_passed function
1 parent 6eb3cf5 commit 0df29e2

File tree

5 files changed

+100
-165
lines changed

5 files changed

+100
-165
lines changed

lsy_drone_racing/envs/race_core.py

+27-48
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
randomize_gate_rpy_fn,
3232
randomize_obstacle_pos_fn,
3333
)
34+
from lsy_drone_racing.utils.utils import gate_passed
3435

3536
if TYPE_CHECKING:
3637
from crazyflow.sim.structs import SimData
@@ -457,15 +458,15 @@ def _step_env(
457458
n_gates = len(data.gate_mj_ids)
458459
disabled_drones = RaceCoreEnv._disabled_drones(drone_pos, drone_quat, contacts, data)
459460
gates_pos = mocap_pos[:, data.gate_mj_ids]
461+
obstacles_pos = mocap_pos[:, data.obstacle_mj_ids]
460462
# We need to convert the mocap quat from MuJoCo order to scipy order
461463
gates_quat = mocap_quat[:, data.gate_mj_ids][..., [3, 0, 1, 2]]
462-
obstacles_pos = mocap_pos[:, data.obstacle_mj_ids]
463464
# Extract the gate poses of the current target gates and check if the drones have passed
464465
# them between the last and current position
465466
gate_ids = data.gate_mj_ids[data.target_gate % n_gates]
466467
gate_pos = gates_pos[jp.arange(gates_pos.shape[0])[:, None], gate_ids]
467468
gate_quat = gates_quat[jp.arange(gates_quat.shape[0])[:, None], gate_ids]
468-
passed = RaceCoreEnv._gate_passed(drone_pos, data.last_drone_pos, gate_pos, gate_quat)
469+
passed = gate_passed(drone_pos, data.last_drone_pos, gate_pos, gate_quat, (0.45, 0.45))
469470
# Update the target gate index. Increment by one if drones have passed a gate
470471
target_gate = data.target_gate + passed * ~disabled_drones
471472
target_gate = jp.where(target_gate >= n_gates, -1, target_gate)
@@ -511,6 +512,30 @@ def _obs(
511512
obstacles_pos = jp.where(mask, real_pos[:, None], nominal_obstacle_pos[None, None])
512513
return gates_pos, gates_rpy, obstacles_pos
513514

515+
@staticmethod
516+
def _disabled_drones(pos: Array, quat: Array, contacts: Array, data: EnvData) -> Array:
517+
rpy = JaxR.from_quat(quat).as_euler("xyz")
518+
disabled = jp.logical_or(data.disabled_drones, jp.all(pos < data.pos_limit_low, axis=-1))
519+
disabled = jp.logical_or(disabled, jp.all(pos > data.pos_limit_high, axis=-1))
520+
disabled = jp.logical_or(disabled, jp.all(rpy < data.rpy_limit_low, axis=-1))
521+
disabled = jp.logical_or(disabled, jp.all(rpy > data.rpy_limit_high, axis=-1))
522+
disabled = jp.logical_or(disabled, data.target_gate == -1)
523+
contacts = jp.any(jp.logical_and(contacts[:, None, :], data.contact_masks), axis=-1)
524+
disabled = jp.logical_or(disabled, contacts)
525+
return disabled
526+
527+
@staticmethod
528+
def _visited(drone_pos: Array, target_pos: Array, sensor_range: float, visited: Array) -> Array:
529+
dpos = drone_pos[..., None, :2] - target_pos[:, None, :, :2]
530+
return jp.logical_or(visited, jp.linalg.norm(dpos, axis=-1) < sensor_range)
531+
532+
@staticmethod
533+
@jax.jit
534+
def _warp_disabled_drones(data: SimData, mask: Array) -> SimData:
535+
"""Warp the disabled drones below the ground."""
536+
pos = jax.numpy.where(mask[..., None], -1, data.states.pos)
537+
return data.replace(states=data.states.replace(pos=pos))
538+
514539
def _load_track(self, track: dict) -> tuple[dict, dict, dict]:
515540
"""Load the track from the config file."""
516541
gate_pos = np.array([g["pos"] for g in track.gates])
@@ -593,52 +618,6 @@ def _load_contact_masks(self, sim: Sim) -> Array:
593618
masks = np.tile(masks[None, ...], (sim.n_worlds, 1, 1))
594619
return masks
595620

596-
@staticmethod
597-
def _disabled_drones(pos: Array, quat: Array, contacts: Array, data: EnvData) -> Array:
598-
rpy = JaxR.from_quat(quat).as_euler("xyz")
599-
disabled = jp.logical_or(data.disabled_drones, jp.all(pos < data.pos_limit_low, axis=-1))
600-
disabled = jp.logical_or(disabled, jp.all(pos > data.pos_limit_high, axis=-1))
601-
disabled = jp.logical_or(disabled, jp.all(rpy < data.rpy_limit_low, axis=-1))
602-
disabled = jp.logical_or(disabled, jp.all(rpy > data.rpy_limit_high, axis=-1))
603-
disabled = jp.logical_or(disabled, data.target_gate == -1)
604-
contacts = jp.any(jp.logical_and(contacts[:, None, :], data.contact_masks), axis=-1)
605-
disabled = jp.logical_or(disabled, contacts)
606-
return disabled
607-
608-
@staticmethod
609-
def _gate_passed(
610-
drone_pos: Array, last_drone_pos: Array, gate_pos: Array, gate_quat: Array
611-
) -> bool:
612-
"""Check if the drone has passed a gate.
613-
614-
Returns:
615-
True if the drone has passed a gate, else False.
616-
"""
617-
gate_rot = JaxR.from_quat(gate_quat)
618-
gate_size = (0.45, 0.45)
619-
last_pos_local = gate_rot.apply(last_drone_pos - gate_pos, inverse=True)
620-
pos_local = gate_rot.apply(drone_pos - gate_pos, inverse=True)
621-
# Check if the line between the last position and the current position intersects the plane.
622-
# If so, calculate the point of the intersection and check if it is within the gate box.
623-
passed_plane = (last_pos_local[..., 1] < 0) & (pos_local[..., 1] > 0)
624-
alpha = -last_pos_local[..., 1] / (pos_local[..., 1] - last_pos_local[..., 1])
625-
x_intersect = alpha * (pos_local[..., 0]) + (1 - alpha) * last_pos_local[..., 0]
626-
z_intersect = alpha * (pos_local[..., 2]) + (1 - alpha) * last_pos_local[..., 2]
627-
in_box = (abs(x_intersect) < gate_size[0] / 2) & (abs(z_intersect) < gate_size[1] / 2)
628-
return passed_plane & in_box
629-
630-
@staticmethod
631-
def _visited(drone_pos: Array, target_pos: Array, sensor_range: float, visited: Array) -> Array:
632-
dpos = drone_pos[..., None, :2] - target_pos[:, None, :, :2]
633-
return jp.logical_or(visited, jp.linalg.norm(dpos, axis=-1) < sensor_range)
634-
635-
@staticmethod
636-
@jax.jit
637-
def _warp_disabled_drones(data: SimData, mask: Array) -> SimData:
638-
"""Warp the disabled drones below the ground."""
639-
pos = jax.numpy.where(mask[..., None], -1, data.states.pos)
640-
return data.replace(states=data.states.replace(pos=pos))
641-
642621

643622
# region Factories
644623

lsy_drone_racing/utils/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
dependency for sim-only scripts.
55
"""
66

7-
from lsy_drone_racing.utils.utils import check_gate_pass, load_config, load_controller, map2pi
7+
from lsy_drone_racing.utils.utils import gate_passed, load_config, load_controller
88

9-
__all__ = ["load_config", "load_controller", "check_gate_pass", "map2pi"]
9+
__all__ = ["load_config", "load_controller", "gate_passed"]

lsy_drone_racing/utils/utils.py

+24-33
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,26 @@
66
import inspect
77
import logging
88
import sys
9+
from functools import partial
910
from typing import TYPE_CHECKING, Type
1011

11-
import numpy as np
12+
import jax
1213
import toml
14+
from jax.numpy import vectorize
15+
from jax.scipy.spatial.transform import Rotation as R
1316
from ml_collections import ConfigDict
14-
from scipy.spatial.transform import Rotation as R
1517

1618
from lsy_drone_racing.control.controller import BaseController
1719

1820
if TYPE_CHECKING:
1921
from pathlib import Path
2022
from typing import Any
2123

22-
from numpy.typing import NDArray
24+
from jax import Array
2325

2426
logger = logging.getLogger(__name__)
2527

2628

27-
def map2pi(angle: NDArray[np.floating]) -> NDArray[np.floating]:
28-
"""Map an angle or array of angles to the interval of [-pi, pi].
29-
30-
Args:
31-
angle: Number or array of numbers.
32-
33-
Returns:
34-
The remapped angles.
35-
"""
36-
return ((angle + np.pi) % (2 * np.pi)) - np.pi
37-
38-
3929
def load_controller(path: Path) -> Type[BaseController]:
4030
"""Load the controller module from the given path and return the Controller class.
4131
@@ -89,12 +79,14 @@ def load_config(path: Path) -> ConfigDict:
8979
return ConfigDict(toml.load(f))
9080

9181

92-
def check_gate_pass(
93-
gate_pos: np.ndarray,
94-
gate_rot: R,
95-
gate_size: np.ndarray,
96-
drone_pos: np.ndarray,
97-
last_drone_pos: np.ndarray,
82+
@jax.jit
83+
@partial(vectorize, signature="(3),(3),(3),(4)->()", excluded=[4])
84+
def gate_passed(
85+
drone_pos: Array,
86+
last_drone_pos: Array,
87+
gate_pos: Array,
88+
gate_quat: Array,
89+
gate_size: tuple[float, float],
9890
) -> bool:
9991
"""Check if the drone has passed the current gate.
10092
@@ -110,23 +102,22 @@ def check_gate_pass(
110102
goal changes.
111103
112104
Args:
113-
gate_pos: The position of the gate in the world frame.
114-
gate_rot: The rotation of the gate.
115-
gate_size: The size of the gate box in meters.
116105
drone_pos: The position of the drone in the world frame.
117106
last_drone_pos: The position of the drone in the world frame at the last time step.
107+
gate_pos: The position of the gate in the world frame.
108+
gate_quat: The rotation of the gate as a wxyz quaternion.
109+
gate_size: The size of the gate box in meters.
118110
"""
119111
# Transform last and current drone position into current gate frame.
120-
assert isinstance(gate_rot, R), "gate_rot has to be a Rotation object."
112+
gate_rot = R.from_quat(gate_quat)
121113
last_pos_local = gate_rot.apply(last_drone_pos - gate_pos, inverse=True)
122114
pos_local = gate_rot.apply(drone_pos - gate_pos, inverse=True)
123115
# Check the plane intersection. If passed, calculate the point of the intersection and check if
124116
# it is within the gate box.
125-
if last_pos_local[1] < 0 and pos_local[1] > 0: # Drone has passed the goal plane
126-
alpha = -last_pos_local[1] / (pos_local[1] - last_pos_local[1])
127-
x_intersect = alpha * (pos_local[0]) + (1 - alpha) * last_pos_local[0]
128-
z_intersect = alpha * (pos_local[2]) + (1 - alpha) * last_pos_local[2]
129-
# Divide gate size by 2 to get the distance from the center to the edges
130-
if abs(x_intersect) < gate_size[0] / 2 and abs(z_intersect) < gate_size[1] / 2:
131-
return True
132-
return False
117+
passed_plane = (last_pos_local[1] < 0) & (pos_local[1] > 0)
118+
alpha = -last_pos_local[1] / (pos_local[1] - last_pos_local[1])
119+
x_intersect = alpha * (pos_local[0]) + (1 - alpha) * last_pos_local[0]
120+
z_intersect = alpha * (pos_local[2]) + (1 - alpha) * last_pos_local[2]
121+
# Divide gate size by 2 to get the distance from the center to the edges
122+
in_box = (abs(x_intersect) < gate_size[0] / 2) & (abs(z_intersect) < gate_size[1] / 2)
123+
return passed_plane & in_box

tests/integration/test_envs.py

+30-47
Original file line numberDiff line numberDiff line change
@@ -6,65 +6,48 @@
66
import lsy_drone_racing # noqa: F401, environment registrations
77
from lsy_drone_racing.utils import load_config
88

9-
CONFIG_FILES = ["level0.toml", "level1.toml", "level2.toml", "level3.toml"]
10-
MULTI_CONFIG_FILES = ["multi_level0.toml", "multi_level3.toml"]
9+
CONFIG_FILES = {
10+
"DroneRacing-v0": ["level0.toml", "level1.toml", "level2.toml", "level3.toml"],
11+
"MultiDroneRacing-v0": ["multi_level0.toml", "multi_level3.toml"],
12+
}
13+
ENV_IDS = ["DroneRacing-v0", "MultiDroneRacing-v0"]
1114

1215

1316
@pytest.mark.parametrize("physics", ["analytical", "sys_id"])
14-
@pytest.mark.parametrize("config_file", CONFIG_FILES)
17+
@pytest.mark.parametrize(
18+
("env_id", "config_file"),
19+
[(env_id, config_file) for env_id in ENV_IDS for config_file in CONFIG_FILES[env_id]],
20+
)
1521
@pytest.mark.integration
16-
def test_envs(physics: str, config_file: str):
22+
def test_single_drone_envs(env_id: str, config_file: str, physics: str):
1723
"""Test the simulation environments with different physics modes and config files."""
1824
config = load_config(Path(__file__).parents[2] / "config" / config_file)
1925
assert hasattr(config.sim, "physics"), "Physics mode is not set"
2026
config.sim.physics = physics # override physics mode
2127
assert hasattr(config.env, "id"), "Environment ID is not set"
22-
config.env.id = "DroneRacing-v0" # override environment ID
2328

24-
env = gymnasium.make(
25-
"DroneRacing-v0",
26-
freq=config.env.freq,
27-
sim_config=config.sim,
28-
sensor_range=config.env.sensor_range,
29-
track=config.env.track,
30-
disturbances=config.env.get("disturbances"),
31-
randomizations=config.env.get("randomizations"),
32-
random_resets=config.env.random_resets,
33-
seed=config.env.seed,
34-
)
29+
kwargs = {
30+
"freq": config.env.freq,
31+
"sim_config": config.sim,
32+
"sensor_range": config.env.sensor_range,
33+
"track": config.env.track,
34+
"disturbances": config.env.get("disturbances"),
35+
"randomizations": config.env.get("randomizations"),
36+
"random_resets": config.env.random_resets,
37+
"seed": config.env.seed,
38+
}
39+
if "n_drones" in config.env:
40+
kwargs["n_drones"] = config.env.n_drones
41+
42+
env = gymnasium.make(env_id, **kwargs)
3543
env.reset()
36-
for _ in range(10):
37-
_, _, terminated, truncated, _ = env.step(env.action_space.sample())
38-
if terminated or truncated:
39-
break
44+
for _ in range(100):
45+
_, _, _, _, _ = env.step(env.action_space.sample())
4046
env.close()
4147

42-
43-
@pytest.mark.parametrize("physics", ["analytical", "sys_id"])
44-
@pytest.mark.parametrize("config_file", MULTI_CONFIG_FILES)
45-
@pytest.mark.integration
46-
def test_vec_envs(physics: str, config_file: str):
47-
"""Test the simulation environments with different physics modes and config files."""
48-
config = load_config(Path(__file__).parents[2] / "config" / config_file)
49-
assert hasattr(config.sim, "physics"), "Physics mode is not set"
50-
config.sim.physics = physics # override physics mode
51-
assert hasattr(config.env, "id"), "Environment ID is not set"
52-
config.env.id = "MultiDroneRacing-v0" # override environment ID
53-
54-
env = gymnasium.make_vec(
55-
"MultiDroneRacing-v0",
56-
num_envs=2,
57-
n_drones=config.env.n_drones,
58-
freq=config.env.freq,
59-
sim_config=config.sim,
60-
sensor_range=config.env.sensor_range,
61-
track=config.env.track,
62-
disturbances=config.env.get("disturbances"),
63-
randomizations=config.env.get("randomizations"),
64-
random_resets=config.env.random_resets,
65-
seed=config.env.seed,
66-
)
48+
kwargs["num_envs"] = 2
49+
env = gymnasium.make_vec(env_id, **kwargs)
6750
env.reset()
68-
for _ in range(10):
69-
env.step(env.action_space.sample())
51+
for _ in range(100):
52+
_, _, _, _, _ = env.step(env.action_space.sample())
7053
env.close()

tests/unit/utils/test_utils.py

+17-35
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from scipy.spatial.transform import Rotation as R
77

88
from lsy_drone_racing.control.controller import BaseController
9-
from lsy_drone_racing.utils import check_gate_pass, load_config, load_controller, map2pi
9+
from lsy_drone_racing.utils import gate_passed, load_config, load_controller
1010

1111

1212
@pytest.mark.unit
@@ -24,48 +24,30 @@ def test_load_controller():
2424

2525

2626
@pytest.mark.unit
27-
def test_map2pi():
28-
assert map2pi(0) == 0
29-
assert map2pi(np.pi) == -np.pi
30-
assert map2pi(-np.pi) == -np.pi
31-
assert map2pi(2 * np.pi) == 0
32-
assert map2pi(-2 * np.pi) == 0
33-
assert np.allclose(map2pi(np.arange(10) * 2 * np.pi), np.zeros(10))
34-
assert np.max(map2pi(np.linspace(-100, 100, num=1000))) <= np.pi
35-
assert np.min(map2pi(np.linspace(-100, 100, num=1000))) >= -np.pi
36-
37-
38-
@pytest.mark.unit
39-
def test_check_gate_pass():
27+
def test_gate_pass():
4028
# TODO: Check accelerated function in RaceCore instead
4129
gate_pos = np.array([0, 0, 0])
42-
gate_rot = R.from_euler("xyz", [0, 0, 0])
30+
gate_quat = R.identity().as_quat()
4331
gate_size = np.array([1, 1])
4432
# Test passing through the gate
45-
assert check_gate_pass(gate_pos, gate_rot, gate_size, np.array([0, 1, 0]), np.array([0, -1, 0]))
33+
drone_pos, last_drone_pos = np.array([0, 1, 0]), np.array([0, -1, 0])
34+
assert gate_passed(drone_pos, last_drone_pos, gate_pos, gate_quat, gate_size)
4635
# Test passing outside the gate boundaries
47-
assert not check_gate_pass(
48-
gate_pos, gate_rot, gate_size, np.array([2, 1, 0]), np.array([2, -1, 0])
49-
)
36+
drone_pos, last_drone_pos = np.array([2, 1, 0]), np.array([2, -1, 0])
37+
assert not gate_passed(drone_pos, last_drone_pos, gate_pos, gate_quat, gate_size)
5038
# Test passing close to the gate
51-
assert not check_gate_pass(
52-
gate_pos, gate_rot, gate_size, np.array([0.51, 1, 0]), np.array([0.51, -1, 0])
53-
)
39+
drone_pos, last_drone_pos = np.array([0.51, 1, 0]), np.array([0.51, -1, 0])
40+
assert not gate_passed(drone_pos, last_drone_pos, gate_pos, gate_quat, gate_size)
5441
# Test passing opposite direction
55-
assert not check_gate_pass(
56-
gate_pos, gate_rot, gate_size, np.array([0, -1, 0]), np.array([0, 1, 0])
57-
)
42+
assert not gate_passed(drone_pos, last_drone_pos, gate_pos, gate_quat, gate_size)
5843
# Test with rotated gate
59-
rotated_gate = R.from_euler("xyz", [0, np.pi / 4, 0])
60-
assert check_gate_pass(
61-
gate_pos, rotated_gate, gate_size, np.array([0.5, 0.5, 0]), np.array([-0.5, -0.5, 0])
62-
)
44+
rotated_gate_quat = R.from_euler("xyz", [0, np.pi / 4, 0]).as_quat()
45+
drone_pos, last_drone_pos = np.array([0.5, 0.5, 0]), np.array([-0.5, -0.5, 0])
46+
assert gate_passed(drone_pos, last_drone_pos, gate_pos, rotated_gate_quat, gate_size)
6347
# Test with moved gate
6448
moved_gate_pos = np.array([1, 1, 1])
65-
assert check_gate_pass(
66-
moved_gate_pos, gate_rot, gate_size, np.array([1, 2, 1]), np.array([1, 0, 1])
67-
)
49+
drone_pos, last_drone_pos = np.array([1, 2, 1]), np.array([1, 0, 1])
50+
assert gate_passed(drone_pos, last_drone_pos, moved_gate_pos, gate_quat, gate_size)
6851
# Test not crossing the plane
69-
assert not check_gate_pass(
70-
gate_pos, gate_rot, gate_size, np.array([0, -0.5, 0]), np.array([0, -1, 0])
71-
)
52+
drone_pos, last_drone_pos = np.array([0, -0.5, 0]), np.array([0, -1, 0])
53+
assert not gate_passed(drone_pos, last_drone_pos, gate_pos, gate_quat, gate_size)

0 commit comments

Comments
 (0)