Skip to content

Commit e877fcb

Browse files
committed
[wip,broken] Add contact masking. Prepare reset logic
1 parent 28fd1cb commit e877fcb

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

lsy_drone_racing/envs/drone_racing_env.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ def __init__(self, config: dict):
141141
self.gates, self.obstacles, self.drone = self.load_track(config.env.track)
142142
self.n_gates = len(config.env.track.gates)
143143
self.disturbances = self.load_disturbances(config.env.get("disturbances", None))
144+
self.contact_mask = np.ones((self.sim.n_worlds, 29), dtype=bool)
145+
self.contact_mask[..., 0] = 0 # Ignore contacts with the floor
144146

145147
self.gates_visited = np.array([False] * len(config.env.track.gates))
146148
self.obstacles_visited = np.array([False] * len(config.env.track.obstacles))
@@ -161,6 +163,8 @@ def reset(
161163
self.sim.seed(self.config.env.seed)
162164
if seed is not None:
163165
self.sim.seed(seed)
166+
# Randomization of gates, obstacles and drones is compiled into the sim reset function with
167+
# the sim.reset_hook function, so we don't need to explicitly do it here
164168
self.sim.reset()
165169
# TODO: Add randomization of gates, obstacles, drone, and disturbances
166170
states = self.sim.data.states.replace(
@@ -278,7 +282,7 @@ def terminated(self) -> bool:
278282
}
279283
if state not in self.state_space:
280284
return True # Drone is out of bounds
281-
if self.sim.contacts("drone:0").any():
285+
if np.logical_and(self.sim.contacts("drone:0"), self.contact_mask).any():
282286
return True
283287
if self.sim.data.states.pos[0, 0, 2] < 0.0:
284288
return True
@@ -320,17 +324,11 @@ def _load_track_into_sim(self, gates: dict, obstacles: dict):
320324
for i in range(len(obstacles["pos"])):
321325
obstacle = frame.attach_body(obstacle_spec.find_body("world"), "", f":o{i}")
322326
obstacle.pos = obstacles["pos"][i]
323-
# TODO: Simplify rebuilding the simulation after changing the mujoco model
324-
self.sim.mj_model, self.sim.mj_data, self.sim.mjx_model, mjx_data = self.sim.compile_mj(
325-
spec
326-
)
327-
self.sim.data = self.sim.data.replace(mjx_data=mjx_data)
328-
self.sim.default_data = self.sim.data.replace()
329327
self.sim.build()
330328

331329
def load_disturbances(self, disturbances: dict | None = None) -> dict:
332330
"""Load the disturbances from the config."""
333-
dist = {}
331+
dist = {} # TODO: Add jax disturbances for the simulator dynamics
334332
if disturbances is None: # Default: no passive disturbances.
335333
return dist
336334
for mode, spec in disturbances.items():

scripts/sim.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -117,5 +117,7 @@ def log_episode_stats(obs: dict, info: dict, config: Munch, curr_time: float):
117117

118118

119119
if __name__ == "__main__":
120-
logging.basicConfig(level=logging.INFO)
120+
logging.basicConfig()
121+
logging.getLogger("lsy_drone_racing").setLevel(logging.INFO)
122+
logger.setLevel(logging.INFO)
121123
fire.Fire(simulate)

0 commit comments

Comments
 (0)