Skip to content

Commit 9ef2f70

Browse files
committed
Adapt to mujoco 3.3 spec API changes. Fix typing. Fix tests.
1 parent 37a7b96 commit 9ef2f70

File tree

10 files changed

+47
-30
lines changed

10 files changed

+47
-30
lines changed

lsy_drone_racing/control/attitude_controller.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def compute_control(
129129
R_desired = np.vstack([x_axis_desired, y_axis_desired, z_axis_desired]).T
130130
euler_desired = R.from_matrix(R_desired).as_euler("xyz", degrees=False)
131131
thrust_desired, euler_desired
132-
return np.concatenate([[thrust_desired], euler_desired])
132+
return np.concatenate([[thrust_desired], euler_desired], dtype=np.float32)
133133

134134
def step_callback(
135135
self,

lsy_drone_racing/control/trajectory_controller.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def compute_control(
7171
array.
7272
"""
7373
target_pos = self.trajectory(min(self._tick / self._freq, self.t_total))
74-
return np.concatenate((target_pos, np.zeros(10)))
74+
return np.concatenate((target_pos, np.zeros(10)), dtype=np.float32)
7575

7676
def step_callback(
7777
self,

lsy_drone_racing/envs/drone_race.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111
from lsy_drone_racing.envs.race_core import RaceCoreEnv, build_action_space, build_observation_space
1212

1313
if TYPE_CHECKING:
14-
import numpy as np
14+
from jax import Array
1515
from ml_collections import ConfigDict
16-
from numpy.typing import NDArray
1716

1817

1918
class DroneRaceEnv(RaceCoreEnv, Env):
@@ -83,7 +82,7 @@ def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[d
8382
info = {k: v[0, 0] for k, v in info.items()}
8483
return obs, info
8584

86-
def step(self, action: NDArray[np.floating]) -> tuple[dict, float, bool, bool, dict]:
85+
def step(self, action: Array) -> tuple[dict, float, bool, bool, dict]:
8786
"""Step the environment.
8887
8988
Args:
@@ -168,9 +167,7 @@ def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[d
168167
info = {k: v[:, 0] for k, v in info.items()}
169168
return obs, info
170169

171-
def step(
172-
self, action: NDArray[np.floating]
173-
) -> tuple[dict, NDArray[np.floating], NDArray[np.bool_], NDArray[np.bool_], dict]:
170+
def step(self, action: Array) -> tuple[dict, Array, Array, Array, dict]:
174171
"""Step the environment in all worlds.
175172
176173
Args:

lsy_drone_racing/envs/multi_drone_race.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
"""Multi-agent drone racing environments."""
22

3-
from typing import Literal
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING, Literal
46

5-
import numpy as np
67
from gymnasium import Env
78
from gymnasium.vector import VectorEnv
89
from gymnasium.vector.utils import batch_space
9-
from ml_collections import ConfigDict
10-
from numpy.typing import NDArray
1110

1211
from lsy_drone_racing.envs.race_core import RaceCoreEnv, build_action_space, build_observation_space
1312

13+
if TYPE_CHECKING:
14+
from jax import Array
15+
from ml_collections import ConfigDict
16+
1417

1518
class MultiDroneRaceEnv(RaceCoreEnv, Env):
1619
"""Multi-agent drone racing environment.
@@ -87,9 +90,7 @@ def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[d
8790
info = {k: v[0] for k, v in info.items()}
8891
return obs, info
8992

90-
def step(
91-
self, action: NDArray[np.floating]
92-
) -> tuple[dict, NDArray[np.floating], NDArray[np.bool_], NDArray[np.bool_], dict]:
93+
def step(self, action: Array) -> tuple[dict, Array, Array, Array, dict]:
9394
"""Step the environment for all drones.
9495
9596
Args:
@@ -180,9 +181,7 @@ def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[d
180181
"""
181182
return self._reset(seed=seed, options=options)
182183

183-
def step(
184-
self, action: NDArray[np.floating]
185-
) -> tuple[dict, NDArray[np.floating], NDArray[np.bool_], NDArray[np.bool_], dict]:
184+
def step(self, action: Array) -> tuple[dict, Array, Array, Array, dict]:
186185
"""Step the environment for all drones.
187186
188187
Args:

lsy_drone_racing/envs/race_core.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from crazyflow.sim.symbolic import symbolic_attitude
2020
from flax.struct import dataclass
2121
from gymnasium import spaces
22-
from jax.scipy.spatial.transform import Rotation as JaxR
2322
from scipy.spatial.transform import Rotation as R
2423

2524
from lsy_drone_racing.envs.randomize import (
@@ -264,7 +263,7 @@ def __init__(
264263

265264
def _reset(
266265
self, *, seed: int | None = None, options: dict | None = None, mask: Array | None = None
267-
) -> tuple[dict[str, NDArray[np.floating]], dict]:
266+
) -> tuple[dict[str, Array], dict]:
268267
"""Reset the environment.
269268
270269
Args:
@@ -287,9 +286,7 @@ def _reset(
287286
self.data = self._reset_env_data(self.data, self.sim.data.states.pos, mask)
288287
return self.obs(), self.info()
289288

290-
def _step(
291-
self, action: NDArray[np.floating]
292-
) -> tuple[dict[str, NDArray[np.floating]], float, bool, bool, dict]:
289+
def _step(self, action: Array) -> tuple[dict[str, Array], float, bool, bool, dict]:
293290
"""Step the firmware_wrapper class and its environment.
294291
295292
This function should be called once at the rate of ctrl_freq. Step processes and high level
@@ -319,7 +316,7 @@ def _step(
319316
self._reset(mask=marked_for_reset)
320317
return self.obs(), self.reward(), self.terminated(), self.truncated(), self.info()
321318

322-
def apply_action(self, action: NDArray[np.floating]):
319+
def apply_action(self, action: Array):
323320
"""Apply the commanded state action to the simulation."""
324321
action = action.reshape((self.sim.n_worlds, self.sim.n_drones, -1))
325322
if "action" in self.disturbances:
@@ -342,7 +339,7 @@ def close(self):
342339
"""Close the environment by stopping the drone and landing back at the starting position."""
343340
self.sim.close()
344341

345-
def obs(self) -> dict[str, NDArray[np.floating]]:
342+
def obs(self) -> dict[str, Array]:
346343
"""Return the observation of the environment."""
347344
# Add the gate and obstacle poses to the info. If gates or obstacles are in sensor range,
348345
# use the actual pose, otherwise use the nominal pose.
@@ -564,13 +561,19 @@ def _load_track_into_sim(self, gate_spec: MjSpec, obstacle_spec: MjSpec):
564561
frame = self.sim.spec.worldbody.add_frame()
565562
n_gates, n_obstacles = len(self.gates["pos"]), len(self.obstacles["pos"])
566563
for i in range(n_gates):
567-
gate = frame.attach_body(gate_spec.find_body("gate"), "", f":{i}")
564+
gate_body = gate_spec.body("gate")
565+
if gate_body is None:
566+
raise ValueError("Gate body not found in gate spec")
567+
gate = frame.attach_body(gate_body, "", f":{i}")
568568
gate.pos = self.gates["pos"][i]
569569
# Convert from scipy order to MuJoCo order
570570
gate.quat = self.gates["quat"][i][[3, 0, 1, 2]]
571571
gate.mocap = True # Make mocap to modify the position of static bodies during sim
572572
for i in range(n_obstacles):
573-
obstacle = frame.attach_body(obstacle_spec.find_body("obstacle"), "", f":{i}")
573+
obstacle_body = obstacle_spec.body("obstacle")
574+
if obstacle_body is None:
575+
raise ValueError("Obstacle body not found in obstacle spec")
576+
obstacle = frame.attach_body(obstacle_body, "", f":{i}")
574577
obstacle.pos = self.obstacles["pos"][i]
575578
obstacle.mocap = True
576579
self.sim.build(data=False, default_data=False)

pyproject.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@ classifiers = [
1717
"Intended Audience :: Science/Research",
1818
]
1919

20+
# TODO: Include crazyflow once it's installable via pip
2021
dependencies = [
2122
"fire >= 0.6.0",
2223
"numpy >= 1.24.1, < 2.0.0",
23-
"PyYAML >= 6.0.1", # TODO: Remove after removing crazyswarm dependency
24-
"rospkg >= 1.5.1", # TODO: Remove after moving to cflib
24+
"PyYAML >= 6.0.1", # TODO: Remove after removing crazyswarm dependency
25+
"rospkg >= 1.5.1", # TODO: Remove after moving to cflib
2526
"scipy >= 1.10.1",
2627
"gymnasium >= 1.0.0",
2728
"toml >= 0.10.2",

scripts/sim.py

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import fire
1717
import gymnasium
18+
from gymnasium.wrappers.jax_to_numpy import JaxToNumpy
1819

1920
from lsy_drone_racing.utils import load_config, load_controller
2021

@@ -68,6 +69,7 @@ def simulate(
6869
random_resets=config.env.random_resets,
6970
seed=config.env.seed,
7071
)
72+
env = JaxToNumpy(env)
7173

7274
ep_times = []
7375
for _ in range(n_runs): # Run n_runs episodes with the controller

tests/conftest.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import jax
2+
3+
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
4+
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
5+
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
6+
# Do not enable XLA caches, crashes PyTest
7+
# jax.config.update("jax_persistent_cache_enable_xla_caches", "all")

tests/integration/test_controllers.py

+7
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import gymnasium
44
import numpy as np
55
import pytest
6+
from gymnasium.wrappers.jax_to_numpy import JaxToNumpy
67

78
from lsy_drone_racing.utils import load_config, load_controller
89

@@ -27,6 +28,8 @@ def test_controllers(controller_file: str):
2728
random_resets=config.env.random_resets,
2829
seed=config.env.seed,
2930
)
31+
env = JaxToNumpy(env)
32+
3033
obs, info = env.reset()
3134
ctrl = ctrl_cls(obs, info, config)
3235
while True:
@@ -59,6 +62,8 @@ def test_attitude_controller(physics: str):
5962
random_resets=config.env.random_resets,
6063
seed=config.env.seed,
6164
)
65+
env = JaxToNumpy(env)
66+
6267
obs, info = env.reset()
6368
ctrl = ctrl_cls(obs, info, config)
6469
while True:
@@ -98,6 +103,8 @@ def test_trajectory_controller_finish(yaw: float, physics: str):
98103
random_resets=config.env.random_resets,
99104
seed=config.env.seed,
100105
)
106+
env = JaxToNumpy(env)
107+
101108
obs, info = env.reset()
102109
ctrl = ctrl_cls(obs, info, config)
103110
while True:

tests/unit/envs/test_envs.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import gymnasium
55
import pytest
66
from gymnasium.utils.env_checker import check_env
7+
from gymnasium.wrappers.jax_to_numpy import JaxToNumpy
78

89
from lsy_drone_racing.utils import load_config
910

@@ -31,4 +32,4 @@ def test_passive_checker_wrapper_warnings(action_space: str):
3132
seed=config.env.seed,
3233
disable_env_checker=False,
3334
)
34-
check_env(env.unwrapped)
35+
check_env(JaxToNumpy(env.unwrapped))

0 commit comments

Comments
 (0)