|
39 | 39 | from gymnasium import spaces
|
40 | 40 | from scipy.spatial.transform import Rotation as R
|
41 | 41 |
|
42 |
| -from lsy_drone_racing.envs.randomize import randomize_sim_fn |
| 42 | +from lsy_drone_racing.envs.randomize import ( |
| 43 | + randomize_drone_inertia_fn, |
| 44 | + randomize_drone_mass_fn, |
| 45 | + randomize_drone_pos_fn, |
| 46 | + randomize_drone_quat_fn, |
| 47 | + randomize_gate_pos_fn, |
| 48 | + randomize_gate_rpy_fn, |
| 49 | + randomize_obstacle_pos_fn, |
| 50 | +) |
43 | 51 | from lsy_drone_racing.utils import check_gate_pass
|
44 | 52 |
|
45 | 53 | if TYPE_CHECKING:
|
@@ -147,7 +155,7 @@ def __init__(self, config: dict):
|
147 | 155 | self.gates, self.obstacles, self.drone = self.load_track(config.env.track)
|
148 | 156 | self.n_gates = len(config.env.track.gates)
|
149 | 157 | self.disturbances = self.load_disturbances(config.env.get("disturbances", None))
|
150 |
| - self.randomization = self.load_randomizations(config.env.get("randomization", None)) |
| 158 | + self.randomizations = self.load_randomizations(config.env.get("randomization", None)) |
151 | 159 | self.contact_mask = np.ones((self.sim.n_worlds, 25), dtype=bool)
|
152 | 160 | self.contact_mask[..., 0] = 0 # Ignore contacts with the floor
|
153 | 161 |
|
@@ -250,7 +258,7 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
|
250 | 258 | obstacles_pos[self.obstacles_visited] = self.obstacles["pos"][self.obstacles_visited]
|
251 | 259 | obs["obstacles_pos"] = obstacles_pos.astype(np.float32)
|
252 | 260 | obs["obstacles_visited"] = self.obstacles_visited
|
253 |
| - # TODO: Observation disturbances? |
| 261 | + # TODO: Decide on observation disturbances |
254 | 262 | return obs
|
255 | 263 |
|
256 | 264 | def reward(self) -> float:
|
@@ -347,7 +355,9 @@ def setup_sim(self):
|
347 | 355 | rpy_rates = self.drone["rpy_rates"].reshape(self.sim.data.states.rpy_rates.shape)
|
348 | 356 | states = self.sim.data.states.replace(pos=pos, quat=quat, vel=vel, rpy_rates=rpy_rates)
|
349 | 357 | self.sim.data = self.sim.data.replace(states=states)
|
350 |
| - self.sim.reset_hook = build_reset_hook(self.randomization) |
| 358 | + self.sim.reset_hook = build_reset_hook( |
| 359 | + self.randomizations, self.gates["mocap_ids"], self.obstacles["mocap_ids"] |
| 360 | + ) |
351 | 361 | if "dynamics" in self.disturbances:
|
352 | 362 | self.sim.disturbance_fn = build_dynamics_disturbance_fn(self.disturbances["dynamics"])
|
353 | 363 | self.sim.build(mjx=False, data=False) # Save the reset state and rebuild the reset function
|
@@ -400,13 +410,33 @@ def close(self):
|
400 | 410 | self.sim.close()
|
401 | 411 |
|
402 | 412 |
|
403 |
| -def build_reset_hook(randomizations: dict) -> Callable[[SimData, Array], SimData]: |
| 413 | +def build_reset_hook( |
| 414 | + randomizations: dict, gate_mocap_ids: list[int], obstacle_mocap_ids: list[int] |
| 415 | +) -> Callable[[SimData, Array], SimData]: |
404 | 416 | """Build the reset hook for the simulation."""
|
405 |
| - randomizations = [randomize_sim_fn(target, rng) for target, rng in randomizations.items()] |
| 417 | + randomization_fns = [] |
| 418 | + for target, rng in randomizations.items(): |
| 419 | + match target: |
| 420 | + case "drone_pos": |
| 421 | + randomization_fns.append(randomize_drone_pos_fn(rng)) |
| 422 | + case "drone_rpy": |
| 423 | + randomization_fns.append(randomize_drone_quat_fn(rng)) |
| 424 | + case "drone_mass": |
| 425 | + randomization_fns.append(randomize_drone_mass_fn(rng)) |
| 426 | + case "drone_inertia": |
| 427 | + randomization_fns.append(randomize_drone_inertia_fn(rng)) |
| 428 | + case "gate_pos": |
| 429 | + randomization_fns.append(randomize_gate_pos_fn(rng, gate_mocap_ids)) |
| 430 | + case "gate_rpy": |
| 431 | + randomization_fns.append(randomize_gate_rpy_fn(rng, gate_mocap_ids)) |
| 432 | + case "obstacle_pos": |
| 433 | + randomization_fns.append(randomize_obstacle_pos_fn(rng, obstacle_mocap_ids)) |
| 434 | + case _: |
| 435 | + raise ValueError(f"Invalid target: {target}") |
406 | 436 |
|
407 | 437 | def reset_hook(data: SimData, mask: Array) -> SimData:
|
408 |
| - for randomize in randomizations: |
409 |
| - data = randomize(data, mask) |
| 438 | + for randomize_fn in randomization_fns: |
| 439 | + data = randomize_fn(data, mask) |
410 | 440 | return data
|
411 | 441 |
|
412 | 442 | return reset_hook
|
|
0 commit comments