diff --git a/src/mjlab/envs/manager_based_env.py b/src/mjlab/envs/manager_based_env.py index 5eed68ed4..96b6ec2b8 100644 --- a/src/mjlab/envs/manager_based_env.py +++ b/src/mjlab/envs/manager_based_env.py @@ -74,15 +74,7 @@ def load_managers(self) -> None: self.event_manager = EventManager(self.cfg.events, self) print("[INFO] Event manager: ", self.event_manager) - expanded_model_fields: list[str] = [] - if "startup" in self.event_manager.available_modes: - for event_cfg in self.event_manager._mode_term_cfgs["startup"]: - if "field" in event_cfg.params: - expanded_model_fields.append(event_cfg.params["field"]) - # Special handling for actuator gain randomization. - if event_cfg.func.__name__ == "randomize_actuator_gains": - expanded_model_fields.extend(["actuator_gainprm", "actuator_biasprm"]) - self.sim.expand_model_fields(expanded_model_fields) + self.sim.expand_model_fields(self.event_manager.domain_randomization_fields) self.action_manager = ActionManager(self.cfg.actions, self) print("[INFO] Action Manager:", self.action_manager) diff --git a/src/mjlab/managers/event_manager.py b/src/mjlab/managers/event_manager.py index 709c38153..da259fde0 100644 --- a/src/mjlab/managers/event_manager.py +++ b/src/mjlab/managers/event_manager.py @@ -20,6 +20,7 @@ def __init__(self, cfg: object, env: ManagerBasedEnv): self._mode_term_names: dict[EventMode, list[str]] = dict() self._mode_term_cfgs: dict[EventMode, list[EventTermCfg]] = dict() self._mode_class_term_cfgs: dict[EventMode, list[EventTermCfg]] = dict() + self._domain_randomization_fields: list[str] = list() super().__init__(cfg=cfg, env=env) @@ -42,6 +43,15 @@ def __str__(self) -> str: table.add_row([index, name]) msg += table.get_string() msg += "\n" + if self._domain_randomization_fields: + table = PrettyTable() + table.title = "Domain Randomization Fields" + table.field_names = ["Index", "Field Name"] + table.align["Field Name"] = "l" + for index, field in enumerate(self._domain_randomization_fields): + table.add_row([index, field]) + msg += table.get_string() + msg += "\n" return msg # Properties. @@ -54,6 +64,10 @@ def active_terms(self) -> dict[EventMode, list[str]]: def available_modes(self) -> list[EventMode]: return list(self._mode_term_names.keys()) + @property + def domain_randomization_fields(self) -> list[str]: + return self._domain_randomization_fields + # Methods. def reset(self, env_ids: torch.Tensor | slice | None = None): @@ -182,3 +196,8 @@ def _prepare_terms(self) -> None: self._reset_term_last_triggered_step_id.append(step_count) no_trigger = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) self._reset_term_last_triggered_once.append(no_trigger) + + if term_cfg.func.__name__ == "randomize_field": + field_name = term_cfg.params["field"] + if field_name not in self._domain_randomization_fields: + self._domain_randomization_fields.append(field_name)