Skip to content

Commit 4c25828

Browse files
committed
[wip, broken] Add race randomization
1 parent e877fcb commit 4c25828

File tree

3 files changed

+139
-8
lines changed

3 files changed

+139
-8
lines changed

lsy_drone_racing/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
"""LSY drone racing package for the Autonomous Drone Racing class @ TUM."""
2+
from crazyflow.utils import enable_cache
23

34
import lsy_drone_racing.envs # noqa: F401, register environments with gymnasium
5+
6+
enable_cache() # Enable persistent caching of jax functions

lsy_drone_racing/envs/drone_racing_env.py

+70-8
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,28 @@
2828
import copy as copy
2929
import logging
3030
from pathlib import Path
31-
from typing import TYPE_CHECKING
31+
from typing import TYPE_CHECKING, Callable
3232

3333
import gymnasium
3434
import mujoco
3535
import numpy as np
3636
from crazyflow import Sim
37+
from crazyflow.sim.sim import identity
3738
from gymnasium import spaces
3839
from scipy.spatial.transform import Rotation as R
3940

41+
from lsy_drone_racing.envs.utils import (
42+
randomize_drone_inertia_fn,
43+
randomize_drone_mass_fn,
44+
randomize_drone_pos_fn,
45+
randomize_drone_quat_fn,
46+
)
4047
from lsy_drone_racing.sim.noise import NoiseList
4148
from lsy_drone_racing.utils import check_gate_pass
4249

4350
if TYPE_CHECKING:
51+
from crazyflow.sim.structs import SimData
52+
from jax import Array
4453
from numpy.typing import NDArray
4554

4655
logger = logging.getLogger(__name__)
@@ -106,6 +115,7 @@ def __init__(self, config: dict):
106115
)
107116
if config.sim.sim_freq % config.env.freq != 0:
108117
raise ValueError(f"({config.sim.sim_freq=}) is no multiple of ({config.env.freq=})")
118+
109119
self.action_space = spaces.Box(low=-1, high=1, shape=(13,))
110120
n_gates, n_obstacles = len(config.env.track.gates), len(config.env.track.obstacles)
111121
self.observation_space = spaces.Dict(
@@ -134,16 +144,20 @@ def __init__(self, config: dict):
134144
"ang_vel": spaces.Box(low=-np.inf, high=np.inf, shape=(3,), dtype=np.float64),
135145
}
136146
)
147+
137148
self.target_gate = 0
138149
self.symbolic = self.sim.symbolic() if config.env.symbolic else None
139150
self._steps = 0
140151
self._last_drone_pos = np.zeros(3)
141152
self.gates, self.obstacles, self.drone = self.load_track(config.env.track)
142153
self.n_gates = len(config.env.track.gates)
143154
self.disturbances = self.load_disturbances(config.env.get("disturbances", None))
155+
self.randomization = self.load_randomizations(config.env.get("randomization", None))
144156
self.contact_mask = np.ones((self.sim.n_worlds, 29), dtype=bool)
145157
self.contact_mask[..., 0] = 0 # Ignore contacts with the floor
146158

159+
self.setup_sim()
160+
147161
self.gates_visited = np.array([False] * len(config.env.track.gates))
148162
self.obstacles_visited = np.array([False] * len(config.env.track.obstacles))
149163

@@ -167,13 +181,6 @@ def reset(
167181
# the sim.reset_hook function, so we don't need to explicitly do it here
168182
self.sim.reset()
169183
# TODO: Add randomization of gates, obstacles, drone, and disturbances
170-
states = self.sim.data.states.replace(
171-
pos=self.drone["pos"].reshape((1, 1, 3)),
172-
quat=self.drone["quat"].reshape((1, 1, 4)),
173-
vel=self.drone["vel"].reshape((1, 1, 3)),
174-
rpy_rates=self.drone["rpy_rates"].reshape((1, 1, 3)),
175-
)
176-
self.sim.data = self.sim.data.replace(states=states)
177184
self.target_gate = 0
178185
self._steps = 0
179186
self._last_drone_pos[:] = self.sim.data.states.pos[0, 0]
@@ -335,6 +342,24 @@ def load_disturbances(self, disturbances: dict | None = None) -> dict:
335342
dist[mode] = NoiseList.from_specs([spec])
336343
return dist
337344

345+
def load_randomizations(self, randomizations: dict | None = None) -> dict:
346+
"""Load the randomization from the config."""
347+
if randomizations is None:
348+
return {}
349+
return {}
350+
351+
def setup_sim(self):
352+
"""Setup the simulation data and build the reset and step functions with custom hooks."""
353+
pos = self.drone["pos"].reshape(self.sim.data.states.pos.shape)
354+
quat = self.drone["quat"].reshape(self.sim.data.states.quat.shape)
355+
vel = self.drone["vel"].reshape(self.sim.data.states.vel.shape)
356+
rpy_rates = self.drone["rpy_rates"].reshape(self.sim.data.states.rpy_rates.shape)
357+
states = self.sim.data.states.replace(pos=pos, quat=quat, vel=vel, rpy_rates=rpy_rates)
358+
self.sim.data = self.sim.data.replace(states=states)
359+
reset_hook = build_reset_hook(self.randomization)
360+
self.sim.reset_hook = reset_hook
361+
self.sim.build(mjx=False, data=False) # Save the reset state and rebuild the reset function
362+
338363
def gate_passed(self) -> bool:
339364
"""Check if the drone has passed a gate.
340365
@@ -355,6 +380,43 @@ def close(self):
355380
self.sim.close()
356381

357382

383+
def build_reset_hook(randomizations: dict) -> Callable[[SimData, Array[bool]], SimData]:
384+
"""Build the reset hook for the simulation."""
385+
modify_drone_pos = identity
386+
if "drone_pos" in randomizations:
387+
modify_drone_pos = randomize_drone_pos_fn(randomizations["drone_pos"])
388+
modify_drone_quat = identity
389+
if "drone_rpy" in randomizations:
390+
modify_drone_quat = randomize_drone_quat_fn(randomizations["drone_rpy"])
391+
modify_drone_mass = identity
392+
if "drone_mass" in randomizations:
393+
modify_drone_mass = randomize_drone_mass_fn(randomizations["drone_mass"])
394+
modify_drone_inertia = identity
395+
if "drone_inertia" in randomizations:
396+
modify_drone_inertia = randomize_drone_inertia_fn(randomizations["drone_inertia"])
397+
modify_gate_pos = identity
398+
if "gate_pos" in randomizations:
399+
modify_gate_pos = randomize_gate_pos_fn(randomizations["gate_pos"])
400+
modify_gate_rpy = identity
401+
if "gate_rpy" in randomizations:
402+
modify_gate_rpy = randomize_gate_rpy_fn(randomizations["gate_rpy"])
403+
modify_obstacle_pos = identity
404+
if "obstacle_pos" in randomizations:
405+
modify_obstacle_pos = randomize_obstacle_pos_fn(randomizations["obstacle_pos"])
406+
407+
def reset_hook(data: SimData, mask: Array[bool]) -> SimData:
408+
data = modify_drone_pos(data, mask)
409+
data = modify_drone_quat(data, mask)
410+
data = modify_drone_mass(data, mask)
411+
data = modify_drone_inertia(data, mask)
412+
data = modify_gate_pos(data, mask)
413+
data = modify_gate_rpy(data, mask)
414+
data = modify_obstacle_pos(data, mask)
415+
return data
416+
417+
return reset_hook
418+
419+
358420
class DroneRacingThrustEnv(DroneRacingEnv):
359421
"""Drone racing environment with a collective thrust attitude command interface.
360422

lsy_drone_racing/envs/utils.py

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from typing import Callable
2+
3+
import jax
4+
import jax.numpy as jp
5+
from crazyflow.sim.structs import SimData
6+
from crazyflow.utils import leaf_replace
7+
from jax import Array
8+
from jax.scipy.spatial.transform import Rotation as R
9+
10+
11+
def randomize_drone_pos_fn(
12+
rng: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array],
13+
) -> Callable[[SimData, Array], SimData]:
14+
"""Create a function that randomizes the drone position."""
15+
16+
def randomize_drone_pos(data: SimData, mask: Array) -> SimData:
17+
key, subkey = jax.random.split(data.core.rng_key)
18+
drone_pos = data.states.pos + rng(subkey, shape=data.states.pos.shape)
19+
states = leaf_replace(data.states, mask, pos=drone_pos)
20+
return data.replace(core=data.core.replace(rng_key=key), states=states)
21+
22+
return randomize_drone_pos
23+
24+
25+
def randomize_drone_quat_fn(
26+
rng: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array],
27+
) -> Callable[[SimData, Array], SimData]:
28+
"""Create a function that randomizes the drone quaternion."""
29+
30+
def randomize_drone_quat(data: SimData, mask: Array) -> SimData:
31+
key, subkey = jax.random.split(data.core.rng_key)
32+
rpy = R.from_quat(data.states.quat).as_euler("xyz")
33+
quat = R.from_euler("xyz", rpy + rng(subkey, shape=rpy.shape)).as_quat()
34+
states = leaf_replace(data.states, mask, quat=quat)
35+
return data.replace(core=data.core.replace(rng_key=key), states=states)
36+
37+
return randomize_drone_quat
38+
39+
40+
def randomize_drone_mass_fn(
41+
rng: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array],
42+
) -> Callable[[SimData, Array], SimData]:
43+
"""Create a function that randomizes the drone mass."""
44+
45+
def randomize_drone_mass(data: SimData, mask: Array) -> SimData:
46+
key, subkey = jax.random.split(data.core.rng_key)
47+
mass = data.states.mass + rng(subkey, shape=data.params.mass.shape)
48+
states = leaf_replace(data.states, mask, mass=mass)
49+
return data.replace(core=data.core.replace(rng_key=key), states=states)
50+
51+
return randomize_drone_mass
52+
53+
54+
def randomize_drone_inertia_fn(
55+
rng: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array],
56+
) -> Callable[[SimData, Array], SimData]:
57+
"""Create a function that randomizes the drone inertia."""
58+
59+
def randomize_drone_inertia(data: SimData, mask: Array) -> SimData:
60+
key, subkey = jax.random.split(data.core.rng_key)
61+
J = data.params.J + rng(subkey, shape=data.params.J.shape)
62+
J_inv = jp.linalg.inv(J)
63+
states = leaf_replace(data.states, mask, J=J, J_inv=J_inv)
64+
return data.replace(core=data.core.replace(rng_key=key), states=states)
65+
66+
return randomize_drone_inertia

0 commit comments

Comments
 (0)