Skip to content

Commit 37a7b96

Browse files
committed
Fix multi-drone envs with JaxToNumpy wrapper
1 parent 2a347cf commit 37a7b96

File tree

3 files changed

+38
-20
lines changed

3 files changed

+38
-20
lines changed

lsy_drone_racing/envs/multi_drone_race.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ def step(
101101
obs, reward, terminated, truncated, info = self._step(action)
102102
obs = {k: v[0] for k, v in obs.items()}
103103
info = {k: v[0] for k, v in info.items()}
104-
return obs, reward[0], terminated[0], truncated[0], info
104+
# TODO: Fix by moving towards pettingzoo API
105+
# https://pettingzoo.farama.org/api/parallel/
106+
return obs, reward[0, 0], terminated[0].all(), truncated[0].all(), info
105107

106108

107109
class VecMultiDroneRaceEnv(RaceCoreEnv, VectorEnv):

lsy_drone_racing/envs/race_core.py

+21-16
Original file line numberDiff line numberDiff line change
@@ -358,20 +358,20 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
358358
self.obstacles["nominal_pos"],
359359
)
360360
obs = {
361-
"pos": np.array(self.sim.data.states.pos, dtype=np.float32),
362-
"quat": np.array(self.sim.data.states.quat, dtype=np.float32),
363-
"vel": np.array(self.sim.data.states.vel, dtype=np.float32),
364-
"ang_vel": np.array(self.sim.data.states.ang_vel, dtype=np.float32),
365-
"target_gate": np.array(self.data.target_gate, dtype=int),
366-
"gates_pos": np.asarray(gates_pos, dtype=np.float32),
367-
"gates_quat": np.asarray(gates_quat, dtype=np.float32),
368-
"gates_visited": np.asarray(self.data.gates_visited, dtype=bool),
369-
"obstacles_pos": np.asarray(obstacles_pos, dtype=np.float32),
370-
"obstacles_visited": np.asarray(self.data.obstacles_visited, dtype=bool),
361+
"pos": self.sim.data.states.pos,
362+
"quat": self.sim.data.states.quat,
363+
"vel": self.sim.data.states.vel,
364+
"ang_vel": self.sim.data.states.ang_vel,
365+
"target_gate": self.data.target_gate,
366+
"gates_pos": gates_pos,
367+
"gates_quat": gates_quat,
368+
"gates_visited": self.data.gates_visited,
369+
"obstacles_pos": obstacles_pos,
370+
"obstacles_visited": self.data.obstacles_visited,
371371
}
372372
return obs
373373

374-
def reward(self) -> NDArray[np.float32]:
374+
def reward(self) -> Array:
375375
"""Compute the reward for the current state.
376376
377377
Note:
@@ -382,19 +382,19 @@ def reward(self) -> NDArray[np.float32]:
382382
Returns:
383383
Reward for the current state.
384384
"""
385-
return np.array(-1.0 * (self.data.target_gate == -1), dtype=np.float32)
385+
return -1.0 * (self.data.target_gate == -1) # Implicit float conversion
386386

387-
def terminated(self) -> NDArray[np.bool_]:
387+
def terminated(self) -> Array:
388388
"""Check if the episode is terminated.
389389
390390
Returns:
391391
True if all drones have been disabled, else False.
392392
"""
393-
return np.array(self.data.disabled_drones, dtype=bool)
393+
return self.data.disabled_drones
394394

395-
def truncated(self) -> NDArray[np.bool_]:
395+
def truncated(self) -> Array:
396396
"""Array of booleans indicating if the episode is truncated."""
397-
return np.tile(self.data.steps >= self.data.max_episode_steps, (self.sim.n_drones, 1))
397+
return self._truncated(self.data.steps, self.data.max_episode_steps, self.sim.n_drones)
398398

399399
def info(self) -> dict:
400400
"""Return an info dictionary containing additional information about the environment."""
@@ -494,6 +494,11 @@ def _obs(
494494
obstacles_pos = jp.where(mask, real_pos[:, None], nominal_obstacle_pos[None, None])
495495
return gates_pos, gates_quat, obstacles_pos
496496

497+
@staticmethod
498+
@partial(jax.jit, static_argnames="n_drones")
499+
def _truncated(steps: Array, max_episode_steps: Array, n_drones: int) -> Array:
500+
return jp.tile(steps >= max_episode_steps, (n_drones, 1))
501+
497502
@staticmethod
498503
def _disabled_drones(pos: Array, contacts: Array, data: EnvData) -> Array:
499504
disabled = jp.logical_or(data.disabled_drones, jp.any(pos < data.pos_limit_low, axis=-1))

scripts/multi_sim.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import fire
1717
import gymnasium
1818
import numpy as np
19+
from gymnasium.wrappers.jax_to_numpy import JaxToNumpy
1920

2021
from lsy_drone_racing.utils import load_config, load_controller
2122

@@ -69,7 +70,9 @@ def simulate(
6970
randomizations=config.env.get("randomizations"),
7071
random_resets=config.env.random_resets,
7172
seed=config.env.seed,
73+
action_space=config.env.action_space,
7274
)
75+
env = JaxToNumpy(env)
7376

7477
for _ in range(n_runs): # Run n_runs episodes with the controller
7578
obs, info = env.reset()
@@ -81,7 +84,9 @@ def simulate(
8184
curr_time = i / config.env.freq
8285

8386
action = controller.compute_control(obs, info)
84-
action = np.array([action] * config.env.n_drones * env.unwrapped.sim.n_worlds)
87+
action = np.array(
88+
[action] * config.env.n_drones * env.unwrapped.sim.n_worlds, dtype=np.float32
89+
)
8590
action[1, 0] += 0.2
8691
obs, reward, terminated, truncated, info = env.step(action)
8792
done = terminated | truncated
@@ -92,9 +97,15 @@ def simulate(
9297
# Synchronize the GUI.
9398
if config.sim.gui:
9499
if ((i * fps) % config.env.freq) < fps:
95-
env.render()
100+
try:
101+
env.render()
102+
# TODO: JaxToNumpy not working with None (returned by env.render()). Open issue
103+
# in gymnasium and fix this.
104+
except Exception as e:
105+
if not e.args[0].startswith("No known conversion for Jax type"):
106+
raise e
96107
i += 1
97-
if done.all():
108+
if done:
98109
break
99110

100111
controller.episode_callback() # Update the controller internal state and models.

0 commit comments

Comments
 (0)