Skip to content

Commit cac75f7

Browse files
committed
Add dynamics disturbances
1 parent 9348740 commit cac75f7

File tree

1 file changed

+28
-5
lines changed

1 file changed

+28
-5
lines changed

lsy_drone_racing/envs/drone_racing_env.py

+28-5
Original file line numberDiff line numberDiff line change
@@ -169,14 +169,15 @@ def reset(
169169
Observation and info.
170170
"""
171171
if not self.config.env.random_resets:
172+
self.np_random = np.random.default_rng(seed=self.config.env.seed)
172173
self.sim.seed(self.config.env.seed)
173174
if seed is not None:
175+
self.np_random = np.random.default_rng(seed=self.config.env.seed)
174176
self.sim.seed(seed)
175177
# Randomization of gates, obstacles and drones is compiled into the sim reset function with
176178
# the sim.reset_hook function, so we don't need to explicitly do it here
177179
self.sim.reset()
178180

179-
# TODO: Add disturbances
180181
self.target_gate = 0
181182
self._steps = 0
182183
self._last_drone_pos = self.sim.data.states.pos[0, 0]
@@ -199,9 +200,13 @@ def step(
199200
action: Full-state command [x, y, z, vx, vy, vz, ax, ay, az, yaw, rrate, prate, yrate]
200201
to follow.
201202
"""
202-
# TODO: Add action noise
203203
assert action.shape == self.action_space.shape, f"Invalid action shape: {action.shape}"
204-
self.sim.state_control(action.reshape((1, 1, 13)))
204+
action = action.reshape((1, 1, 13))
205+
if "action" in self.disturbances:
206+
key, subkey = jax.random.split(self.sim.data.core.rng_key)
207+
action += self.disturbances["action"](subkey, (1, 1, 13))
208+
self.sim.data = self.sim.data.replace(core=self.sim.data.core.replace(rng_key=key))
209+
self.sim.state_control(action)
205210
self.sim.step(self.sim.freq // self.config.env.freq)
206211
self.target_gate += self.gate_passed()
207212
if self.target_gate == self.n_gates:
@@ -308,7 +313,6 @@ def load_track(self, track: dict) -> tuple[dict, dict, dict]:
308313

309314
def load_disturbances(self, disturbances: dict | None = None) -> dict:
310315
"""Load the disturbances from the config."""
311-
# TODO: Add jax disturbances for the simulator dynamics
312316
if disturbances is None: # Default: no passive disturbances.
313317
return {}
314318
return {mode: self.load_random_fn(spec) for mode, spec in disturbances.items()}
@@ -344,6 +348,8 @@ def setup_sim(self):
344348
states = self.sim.data.states.replace(pos=pos, quat=quat, vel=vel, rpy_rates=rpy_rates)
345349
self.sim.data = self.sim.data.replace(states=states)
346350
self.sim.reset_hook = build_reset_hook(self.randomization)
351+
if "dynamics" in self.disturbances:
352+
self.sim.disturbance_fn = build_dynamics_disturbance_fn(self.disturbances["dynamics"])
347353
self.sim.build(mjx=False, data=False) # Save the reset state and rebuild the reset function
348354

349355
def _load_track_into_sim(self, gates: dict, obstacles: dict):
@@ -406,6 +412,20 @@ def reset_hook(data: SimData, mask: Array) -> SimData:
406412
return reset_hook
407413

408414

415+
def build_dynamics_disturbance_fn(
416+
fn: Callable[[jax.random.PRNGKey, tuple[int]], jax.Array],
417+
) -> Callable[[SimData], SimData]:
418+
"""Build the dynamics disturbance function for the simulation."""
419+
420+
def dynamics_disturbance(data: SimData) -> SimData:
421+
key, subkey = jax.random.split(data.core.rng_key)
422+
states = data.states
423+
states = states.replace(force=states.force + fn(subkey, states.force.shape)) # World frame
424+
return data.replace(states=states, core=data.core.replace(rng_key=key))
425+
426+
return dynamics_disturbance
427+
428+
409429
class DroneRacingThrustEnv(DroneRacingEnv):
410430
"""Drone racing environment with a collective thrust attitude command interface.
411431
@@ -433,7 +453,10 @@ def step(
433453
action: Thrust command [thrust, roll, pitch, yaw].
434454
"""
435455
assert action.shape == self.action_space.shape, f"Invalid action shape: {action.shape}"
436-
# TODO: Add action noise
456+
if "action" in self.disturbances:
457+
key, subkey = jax.random.split(self.sim.data.core.rng_key)
458+
action += self.disturbances["action"](subkey, (1, 1, 4))
459+
self.sim.data = self.sim.data.replace(core=self.sim.data.core.replace(rng_key=key))
437460
self.sim.attitude_control(action.reshape((1, 1, 4)).astype(np.float32))
438461
self.sim.step(self.sim.freq // self.config.env.freq)
439462
self.target_gate += self.gate_passed()

0 commit comments

Comments
 (0)