Skip to content

Commit 41acff1

Browse files
committed
[wip] Implement track randomization
1 parent 4c25828 commit 41acff1

File tree

9 files changed

+219
-2351
lines changed

9 files changed

+219
-2351
lines changed

config/level3.toml

+43-42
Original file line numberDiff line numberDiff line change
@@ -17,32 +17,17 @@ practice_without_track_objects = false
1717

1818
[sim]
1919
# Physics options:
20-
# "pyb": PyBullet
21-
# "dyn": Mathematical dynamics model
22-
# "pyb_gnd" PyBullet with ground effect
23-
# "pyb_drag": PyBullet with drag
24-
# "pyb_dw": PyBullet with downwash
25-
# "pyb_gnd_drag_dw": PyBullet with ground effect, drag, and downwash.
26-
physics = "pyb"
20+
physics = "analytical"
2721
camera_view = [5.0, -40.0, -40.0, 0.5, -1.0, 0.5]
2822
sim_freq = 500 # Simulation frequency, in Hz
29-
ctrl_freq = 500 # Controller frequency, in Hz. This frequency is used to simulate the onboard controller, NOT for the environment's step function
23+
attitude_freq = 500 # Controller frequency, in Hz. This frequency is used to simulate the onboard controller, NOT for the environment's step function
3024
gui = false # Enable/disable PyBullet's GUI
3125

32-
[sim.disturbances.action]
33-
type = "GaussianNoise"
34-
std = 0.001
35-
36-
[sim.disturbances.dynamics]
37-
type = "UniformNoise"
38-
low = [-0.1, -0.1, -0.1]
39-
high = [0.1, 0.1, 0.1]
40-
4126
[env]
4227
id = "DroneRacing-v0" # Either "DroneRacing-v0" or "DroneRacingThrust-v0". If using "DroneRacingThrust-v0", the drone will use the thrust controller instead of the position controller.
4328
reseed = false # Whether to re-seed the random number generator between episodes
4429
seed = 1337 # Random seed
45-
freq = 60 # Frequency of the environment's step function, in Hz
30+
freq = 50 # Frequency of the environment's step function, in Hz
4631
symbolic = false # Whether to include symbolic expressions in the info dict. Note: This can interfere with multiprocessing! If you want to parallelize your training, set this to false.
4732
sensor_range = 0.45 # Range at which the exact location of gates and obstacles become visible to the drone. Objects that are not in the drone's sensor range report their nominal position.
4833

@@ -73,43 +58,59 @@ pos = [0.0, 1.0, 1.4]
7358
pos = [-0.5, 0.0, 1.4]
7459

7560
[env.track.drone]
76-
pos = [1.0, 1.0, 0.05]
61+
pos = [1.0, 1.0, 0.07]
7762
rpy = [0, 0, 0]
7863
vel = [0, 0, 0]
79-
ang_vel = [0, 0, 0]
64+
rpy_rates = [0, 0, 0]
65+
66+
[env.disturbances.action]
67+
fn = "normal"
68+
scale = 0.001
69+
70+
[env.disturbances.dynamics]
71+
fn = "uniform"
72+
[env.disturbances.dynamics.kwargs]
73+
minval = [-0.1, -0.1, -0.1]
74+
maxval = [0.1, 0.1, 0.1]
8075

8176
[env.randomization.drone_pos]
82-
type = "uniform" # Everything that can be used as a distribution in numpy.random
83-
# Kwargs that are permissable in the np random function
84-
low = [-0.1, -0.1, 0.0]
85-
high = [0.1, 0.1, 0.02]
77+
fn = "uniform"
78+
[env.randomization.drone_pos.kwargs]
79+
minval = [-0.1, -0.1, 0.0]
80+
maxval = [0.1, 0.1, 0.02]
8681

8782
[env.randomization.drone_rpy]
88-
type = "uniform"
89-
low = [-0.1, -0.1, -0.1]
90-
high = [0.1, 0.1, 0.1]
83+
fn = "uniform"
84+
[env.randomization.drone_rpy.kwargs]
85+
minval = [-0.1, -0.1, -0.1]
86+
maxval = [0.1, 0.1, 0.1]
9187

9288
[env.randomization.drone_mass]
93-
type = "uniform"
94-
low = -0.01
95-
high = 0.01
89+
fn = "uniform"
90+
[env.randomization.drone_mass.kwargs]
91+
minval = -0.01
92+
maxval = 0.01
9693

9794
[env.randomization.drone_inertia]
98-
type = "uniform"
99-
low = [-0.000001, -0.000001, -0.000001]
100-
high = [0.000001, 0.000001, 0.000001]
95+
fn = "uniform"
96+
[env.randomization.drone_inertia.kwargs]
97+
minval = [-0.000001, -0.000001, -0.000001]
98+
maxval = [0.000001, 0.000001, 0.000001]
10199

102100
[env.randomization.gate_pos]
103-
type = "uniform"
104-
low = [-0.15, -0.15, -0.1]
105-
high = [0.15, 0.15, 0.1]
101+
fn = "uniform"
102+
[env.randomization.gate_pos.kwargs]
103+
minval = [-0.15, -0.15, -0.1]
104+
maxval = [0.15, 0.15, 0.1]
106105

107106
[env.randomization.gate_rpy]
108-
type = "uniform"
109-
low = [0.0, 0.0, -0.1]
110-
high = [0.0, 0.0, 0.1]
107+
fn = "uniform"
108+
[env.randomization.gate_rpy.kwargs]
109+
minval = [0.0, 0.0, -0.1]
110+
maxval = [0.0, 0.0, 0.1]
111111

112112
[env.randomization.obstacle_pos]
113-
type = "uniform"
114-
low = [-0.15, -0.15, -0.05]
115-
high = [0.15, 0.15, 0.05]
113+
fn = "uniform"
114+
[env.randomization.obstacle_pos.kwargs]
115+
minval = [-0.15, -0.15, -0.05]
116+
maxval = [0.15, 0.15, 0.05]

lsy_drone_racing/envs/drone_racing_env.py

+67-77
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,19 @@
2727

2828
import copy as copy
2929
import logging
30+
from functools import partial
3031
from pathlib import Path
31-
from typing import TYPE_CHECKING, Callable
32+
from typing import TYPE_CHECKING, Any, Callable
3233

3334
import gymnasium
35+
import jax
3436
import mujoco
3537
import numpy as np
3638
from crazyflow import Sim
37-
from crazyflow.sim.sim import identity
3839
from gymnasium import spaces
3940
from scipy.spatial.transform import Rotation as R
4041

41-
from lsy_drone_racing.envs.utils import (
42-
randomize_drone_inertia_fn,
43-
randomize_drone_mass_fn,
44-
randomize_drone_pos_fn,
45-
randomize_drone_quat_fn,
46-
)
47-
from lsy_drone_racing.sim.noise import NoiseList
42+
from lsy_drone_racing.envs.utils import randomize_sim_fn
4843
from lsy_drone_racing.utils import check_gate_pass
4944

5045
if TYPE_CHECKING:
@@ -92,8 +87,8 @@ class DroneRacingEnv(gymnasium.Env):
9287
low-level controller.
9388
"""
9489

95-
gate_spec_path = Path(__file__).parents[1] / "sim/assets/gate.urdf"
96-
obstacle_spec_path = Path(__file__).parents[1] / "sim/assets/obstacle.urdf"
90+
gate_spec_path = Path(__file__).parents[1] / "sim/assets/gate.xml"
91+
obstacle_spec_path = Path(__file__).parents[1] / "sim/assets/obstacle.xml"
9792

9893
def __init__(self, config: dict):
9994
"""Initialize the DroneRacingEnv.
@@ -153,7 +148,7 @@ def __init__(self, config: dict):
153148
self.n_gates = len(config.env.track.gates)
154149
self.disturbances = self.load_disturbances(config.env.get("disturbances", None))
155150
self.randomization = self.load_randomizations(config.env.get("randomization", None))
156-
self.contact_mask = np.ones((self.sim.n_worlds, 29), dtype=bool)
151+
self.contact_mask = np.ones((self.sim.n_worlds, 25), dtype=bool)
157152
self.contact_mask[..., 0] = 0 # Ignore contacts with the floor
158153

159154
self.setup_sim()
@@ -180,6 +175,7 @@ def reset(
180175
# Randomization of gates, obstacles and drones is compiled into the sim reset function with
181176
# the sim.reset_hook function, so we don't need to explicitly do it here
182177
self.sim.reset()
178+
183179
# TODO: Add randomization of gates, obstacles, drone, and disturbances
184180
self.target_gate = 0
185181
self._steps = 0
@@ -313,105 +309,99 @@ def load_track(self, track: dict) -> tuple[dict, dict, dict]:
313309
for k in ("pos", "rpy", "vel", "rpy_rates")
314310
}
315311
drone["quat"] = R.from_euler("xyz", drone["rpy"]).as_quat()
316-
# Load the models into the simulation and set their positions
317-
self._load_track_into_sim(gates, obstacles)
318312
return gates, obstacles, drone
319313

320-
def _load_track_into_sim(self, gates: dict, obstacles: dict):
321-
"""Load the track into the simulation."""
322-
gate_spec = mujoco.MjSpec.from_file(str(self.gate_spec_path))
323-
obstacle_spec = mujoco.MjSpec.from_file(str(self.obstacle_spec_path))
324-
spec = self.sim.spec
325-
frame = spec.worldbody.add_frame()
326-
for i in range(len(gates["pos"])):
327-
gate = frame.attach_body(gate_spec.find_body("world"), "", f":g{i}")
328-
gate.pos = gates["pos"][i]
329-
quat = R.from_euler("xyz", gates["rpy"][i]).as_quat()
330-
gate.quat = quat[[3, 0, 1, 2]] # MuJoCo uses wxyz order instead of xyzw
331-
for i in range(len(obstacles["pos"])):
332-
obstacle = frame.attach_body(obstacle_spec.find_body("world"), "", f":o{i}")
333-
obstacle.pos = obstacles["pos"][i]
334-
self.sim.build()
335-
336314
def load_disturbances(self, disturbances: dict | None = None) -> dict:
337315
"""Load the disturbances from the config."""
338-
dist = {} # TODO: Add jax disturbances for the simulator dynamics
316+
# TODO: Add jax disturbances for the simulator dynamics
339317
if disturbances is None: # Default: no passive disturbances.
340-
return dist
341-
for mode, spec in disturbances.items():
342-
dist[mode] = NoiseList.from_specs([spec])
343-
return dist
318+
return {}
319+
return {mode: self.load_random_fn(spec) for mode, spec in disturbances.items()}
344320

345321
def load_randomizations(self, randomizations: dict | None = None) -> dict:
346322
"""Load the randomization from the config."""
347323
if randomizations is None:
348324
return {}
349-
return {}
325+
return {mode: self.load_random_fn(spec) for mode, spec in randomizations.items()}
326+
327+
@staticmethod
328+
def load_random_fn(fn_spec: dict) -> Callable:
329+
"""Convert a function spec to a function from jax.random."""
330+
offset, scale = np.array(fn_spec.get("offset", 0)), np.array(fn_spec.get("scale", 1))
331+
kwargs = fn_spec.get("kwargs", {})
332+
if "shape" in kwargs:
333+
raise KeyError("Shape must not be specified for randomization functions.")
334+
kwargs = {k: np.array(v) if isinstance(v, list) else v for k, v in kwargs.items()}
335+
jax_fn = partial(getattr(jax.random, fn_spec["fn"]), **kwargs)
336+
337+
def random_fn(*args: Any, **kwargs: Any) -> Array:
338+
return jax_fn(*args, **kwargs) * scale + offset
339+
340+
return random_fn
350341

351342
def setup_sim(self):
352343
"""Setup the simulation data and build the reset and step functions with custom hooks."""
344+
self._load_track_into_sim(self.gates, self.obstacles)
353345
pos = self.drone["pos"].reshape(self.sim.data.states.pos.shape)
354346
quat = self.drone["quat"].reshape(self.sim.data.states.quat.shape)
355347
vel = self.drone["vel"].reshape(self.sim.data.states.vel.shape)
356348
rpy_rates = self.drone["rpy_rates"].reshape(self.sim.data.states.rpy_rates.shape)
357349
states = self.sim.data.states.replace(pos=pos, quat=quat, vel=vel, rpy_rates=rpy_rates)
358350
self.sim.data = self.sim.data.replace(states=states)
359-
reset_hook = build_reset_hook(self.randomization)
360-
self.sim.reset_hook = reset_hook
351+
self.sim.reset_hook = build_reset_hook(self.randomization)
361352
self.sim.build(mjx=False, data=False) # Save the reset state and rebuild the reset function
362353

354+
def _load_track_into_sim(self, gates: dict, obstacles: dict):
355+
"""Load the track into the simulation."""
356+
gate_spec = mujoco.MjSpec.from_file(str(self.gate_spec_path))
357+
obstacle_spec = mujoco.MjSpec.from_file(str(self.obstacle_spec_path))
358+
spec = self.sim.spec
359+
frame = spec.worldbody.add_frame()
360+
n_gates, n_obstacles = len(gates["pos"]), len(obstacles["pos"])
361+
for i in range(n_gates):
362+
gate = frame.attach_body(gate_spec.find_body("gate"), "", f":{i}")
363+
gate.pos = gates["pos"][i]
364+
gate.quat = R.from_euler("xyz", gates["rpy"][i]).as_quat()[[3, 0, 1, 2]] # MuJoCo order
365+
gate.mocap = True # Make mocap to modify the position of static bodies during sim
366+
for i in range(n_obstacles):
367+
obstacle = frame.attach_body(obstacle_spec.find_body("obstacle"), "", f":{i}")
368+
obstacle.pos = obstacles["pos"][i]
369+
obstacle.mocap = True
370+
self.sim.build(data=False, default_data=False)
371+
assert not hasattr(self.sim.data, "gate_pos")
372+
assert not hasattr(self.sim.data, "obstacle_pos")
373+
374+
gate_ids = [self.sim.mj_model.body(f"gate:{i}").id for i in range(n_gates)]
375+
gates["ids"] = gate_ids
376+
obstacle_ids = [self.sim.mj_model.body(f"obstacle:{i}").id for i in range(n_obstacles)]
377+
obstacles["ids"] = obstacle_ids
378+
363379
def gate_passed(self) -> bool:
364380
"""Check if the drone has passed a gate.
365381
366382
Returns:
367383
True if the drone has passed a gate, else False.
368384
"""
369-
if self.n_gates > 0 and self.target_gate < self.n_gates and self.target_gate != -1:
370-
gate_pos = self.gates["pos"][self.target_gate]
371-
gate_rot = R.from_euler("xyz", self.gates["rpy"][self.target_gate])
372-
drone_pos = self.sim.data.states.pos[0, 0]
373-
last_drone_pos = self._last_drone_pos
374-
gate_size = (0.45, 0.45)
375-
return check_gate_pass(gate_pos, gate_rot, gate_size, drone_pos, last_drone_pos)
376-
return False
385+
if self.n_gates <= 0 or self.target_gate >= self.n_gates or self.target_gate == -1:
386+
return False
387+
gate_pos = self.gates["pos"][self.target_gate]
388+
gate_rot = R.from_euler("xyz", self.gates["rpy"][self.target_gate])
389+
drone_pos = self.sim.data.states.pos[0, 0]
390+
gate_size = (0.45, 0.45)
391+
return check_gate_pass(gate_pos, gate_rot, gate_size, drone_pos, self._last_drone_pos)
377392

378393
def close(self):
379394
"""Close the environment by stopping the drone and landing back at the starting position."""
380395
self.sim.close()
381396

382397

383-
def build_reset_hook(randomizations: dict) -> Callable[[SimData, Array[bool]], SimData]:
398+
def build_reset_hook(randomizations: dict) -> Callable[[SimData, Array], SimData]:
384399
"""Build the reset hook for the simulation."""
385-
modify_drone_pos = identity
386-
if "drone_pos" in randomizations:
387-
modify_drone_pos = randomize_drone_pos_fn(randomizations["drone_pos"])
388-
modify_drone_quat = identity
389-
if "drone_rpy" in randomizations:
390-
modify_drone_quat = randomize_drone_quat_fn(randomizations["drone_rpy"])
391-
modify_drone_mass = identity
392-
if "drone_mass" in randomizations:
393-
modify_drone_mass = randomize_drone_mass_fn(randomizations["drone_mass"])
394-
modify_drone_inertia = identity
395-
if "drone_inertia" in randomizations:
396-
modify_drone_inertia = randomize_drone_inertia_fn(randomizations["drone_inertia"])
397-
modify_gate_pos = identity
398-
if "gate_pos" in randomizations:
399-
modify_gate_pos = randomize_gate_pos_fn(randomizations["gate_pos"])
400-
modify_gate_rpy = identity
401-
if "gate_rpy" in randomizations:
402-
modify_gate_rpy = randomize_gate_rpy_fn(randomizations["gate_rpy"])
403-
modify_obstacle_pos = identity
404-
if "obstacle_pos" in randomizations:
405-
modify_obstacle_pos = randomize_obstacle_pos_fn(randomizations["obstacle_pos"])
406-
407-
def reset_hook(data: SimData, mask: Array[bool]) -> SimData:
408-
data = modify_drone_pos(data, mask)
409-
data = modify_drone_quat(data, mask)
410-
data = modify_drone_mass(data, mask)
411-
data = modify_drone_inertia(data, mask)
412-
data = modify_gate_pos(data, mask)
413-
data = modify_gate_rpy(data, mask)
414-
data = modify_obstacle_pos(data, mask)
400+
randomizations = [randomize_sim_fn(target, rng) for target, rng in randomizations.items()]
401+
402+
def reset_hook(data: SimData, mask: Array) -> SimData:
403+
for randomize in randomizations:
404+
data = randomize(data, mask)
415405
return data
416406

417407
return reset_hook

0 commit comments

Comments
 (0)