@@ -141,6 +141,8 @@ def __init__(self, config: dict):
141
141
self .gates , self .obstacles , self .drone = self .load_track (config .env .track )
142
142
self .n_gates = len (config .env .track .gates )
143
143
self .disturbances = self .load_disturbances (config .env .get ("disturbances" , None ))
144
+ self .contact_mask = np .ones ((self .sim .n_worlds , 29 ), dtype = bool )
145
+ self .contact_mask [..., 0 ] = 0 # Ignore contacts with the floor
144
146
145
147
self .gates_visited = np .array ([False ] * len (config .env .track .gates ))
146
148
self .obstacles_visited = np .array ([False ] * len (config .env .track .obstacles ))
@@ -161,6 +163,8 @@ def reset(
161
163
self .sim .seed (self .config .env .seed )
162
164
if seed is not None :
163
165
self .sim .seed (seed )
166
+ # Randomization of gates, obstacles and drones is compiled into the sim reset function with
167
+ # the sim.reset_hook function, so we don't need to explicitly do it here
164
168
self .sim .reset ()
165
169
# TODO: Add randomization of gates, obstacles, drone, and disturbances
166
170
states = self .sim .data .states .replace (
@@ -278,7 +282,7 @@ def terminated(self) -> bool:
278
282
}
279
283
if state not in self .state_space :
280
284
return True # Drone is out of bounds
281
- if self .sim .contacts ("drone:0" ).any ():
285
+ if np . logical_and ( self .sim .contacts ("drone:0" ), self . contact_mask ).any ():
282
286
return True
283
287
if self .sim .data .states .pos [0 , 0 , 2 ] < 0.0 :
284
288
return True
@@ -320,17 +324,11 @@ def _load_track_into_sim(self, gates: dict, obstacles: dict):
320
324
for i in range (len (obstacles ["pos" ])):
321
325
obstacle = frame .attach_body (obstacle_spec .find_body ("world" ), "" , f":o{ i } " )
322
326
obstacle .pos = obstacles ["pos" ][i ]
323
- # TODO: Simplify rebuilding the simulation after changing the mujoco model
324
- self .sim .mj_model , self .sim .mj_data , self .sim .mjx_model , mjx_data = self .sim .compile_mj (
325
- spec
326
- )
327
- self .sim .data = self .sim .data .replace (mjx_data = mjx_data )
328
- self .sim .default_data = self .sim .data .replace ()
329
327
self .sim .build ()
330
328
331
329
def load_disturbances (self , disturbances : dict | None = None ) -> dict :
332
330
"""Load the disturbances from the config."""
333
- dist = {}
331
+ dist = {} # TODO: Add jax disturbances for the simulator dynamics
334
332
if disturbances is None : # Default: no passive disturbances.
335
333
return dist
336
334
for mode , spec in disturbances .items ():
0 commit comments