diff --git a/docs/source/actuators.rst b/docs/source/actuators.rst index 400e17079..eebcb2371 100644 --- a/docs/source/actuators.rst +++ b/docs/source/actuators.rst @@ -541,7 +541,7 @@ Domain Randomization .. code-block:: python from mjlab.envs.mdp import events - from mjlab.managers.manager_term_config import EventTermCfg + from mjlab.managers.event_manager import EventTermCfg from mjlab.managers.scene_entity_config import SceneEntityCfg EventTermCfg( diff --git a/docs/source/faq.rst b/docs/source/faq.rst index dc448e1a9..483d8a36e 100644 --- a/docs/source/faq.rst +++ b/docs/source/faq.rst @@ -81,7 +81,7 @@ My training crashes with NaN errors from dataclasses import dataclass, field from mjlab.envs.mdp.terminations import nan_detection - from mjlab.managers.manager_term_config import TerminationTermCfg + from mjlab.managers.termination_manager import TerminationTermCfg @dataclass class TerminationCfg: diff --git a/docs/source/nan_guard.rst b/docs/source/nan_guard.rst index cbc7cb465..2cacee24c 100644 --- a/docs/source/nan_guard.rst +++ b/docs/source/nan_guard.rst @@ -119,7 +119,7 @@ While ``nan_guard`` helps **debug** NaN issues by capturing states, you can also .. code-block:: python from mjlab.envs.mdp.terminations import nan_detection - from mjlab.managers.manager_term_config import TerminationTermCfg + from mjlab.managers.termination_manager import TerminationTermCfg # In your termination config: nan_term: TerminationTermCfg = field( diff --git a/docs/source/observation.rst b/docs/source/observation.rst index 759e10cec..238d3fefc 100644 --- a/docs/source/observation.rst +++ b/docs/source/observation.rst @@ -13,7 +13,7 @@ TL;DR .. code-block:: python - from mjlab.managers.manager_term_config import ObservationTermCfg + from mjlab.managers.observation_manager import ObservationTermCfg joint_vel: ObservationTermCfg = ObservationTermCfg( func=joint_vel, diff --git a/docs/source/randomization.rst b/docs/source/randomization.rst index d1983dd0a..25ccb987e 100644 --- a/docs/source/randomization.rst +++ b/docs/source/randomization.rst @@ -15,7 +15,7 @@ apply the draw. .. code-block:: python - from mjlab.managers.manager_term_config import EventTermCfg + from mjlab.managers.event_manager import EventTermCfg from mjlab.managers.scene_entity_config import SceneEntityCfg from mjlab.envs import mdp diff --git a/scripts/benchmarks/measure_throughput.py b/scripts/benchmarks/measure_throughput.py index 17504a05d..73f7ace8b 100644 --- a/scripts/benchmarks/measure_throughput.py +++ b/scripts/benchmarks/measure_throughput.py @@ -126,7 +126,7 @@ def benchmark_task(task: str, cfg: ThroughputConfig) -> BenchmarkResult: env_cfg.scene.num_envs = cfg.num_envs # Handle tracking task motion file. - if env_cfg.commands is not None: + if len(env_cfg.commands) > 0: motion_cmd = env_cfg.commands.get("motion") if isinstance(motion_cmd, MotionCommandCfg): api = wandb.Api() diff --git a/src/mjlab/envs/manager_based_rl_env.py b/src/mjlab/envs/manager_based_rl_env.py index 67af64aa4..94de3807b 100644 --- a/src/mjlab/envs/manager_based_rl_env.py +++ b/src/mjlab/envs/manager_based_rl_env.py @@ -10,22 +10,21 @@ from mjlab.envs import types from mjlab.envs.mdp.events import reset_scene_to_default -from mjlab.managers.action_manager import ActionManager -from mjlab.managers.command_manager import CommandManager, NullCommandManager -from mjlab.managers.curriculum_manager import CurriculumManager, NullCurriculumManager -from mjlab.managers.event_manager import EventManager -from mjlab.managers.manager_term_config import ( - ActionTermCfg, +from mjlab.managers.action_manager import ActionManager, ActionTermCfg +from mjlab.managers.command_manager import ( + CommandManager, CommandTermCfg, + NullCommandManager, +) +from mjlab.managers.curriculum_manager import ( + CurriculumManager, CurriculumTermCfg, - EventTermCfg, - ObservationGroupCfg, - RewardTermCfg, - TerminationTermCfg, + NullCurriculumManager, ) -from mjlab.managers.observation_manager import ObservationManager -from mjlab.managers.reward_manager import RewardManager -from mjlab.managers.termination_manager import TerminationManager +from mjlab.managers.event_manager import EventManager, EventTermCfg +from mjlab.managers.observation_manager import ObservationGroupCfg, ObservationManager +from mjlab.managers.reward_manager import RewardManager, RewardTermCfg +from mjlab.managers.termination_manager import TerminationManager, TerminationTermCfg from mjlab.scene import Scene from mjlab.scene.scene import SceneCfg from mjlab.sim import SimulationCfg @@ -41,14 +40,36 @@ @dataclass(kw_only=True) class ManagerBasedRlEnvCfg: - """Configuration for a manager-based RL environment.""" + """Configuration for a manager-based RL environment. + + This config defines all aspects of an RL environment: the physical scene, + observations, actions, rewards, terminations, and optional features like + commands and curriculum learning. + + The environment step size is ``sim.mujoco.timestep * decimation``. For example, + with a 2ms physics timestep and decimation=10, the environment runs at 50Hz. + """ # Base environment configuration. + decimation: int - """Number of simulation steps per environment step.""" + """Number of physics simulation steps per environment step. Higher values mean + coarser control frequency. Environment step duration = physics_dt * decimation.""" + scene: SceneCfg + """Scene configuration defining terrain, entities, and sensors. The scene + specifies ``num_envs``, the number of parallel environments.""" + observations: dict[str, ObservationGroupCfg] + """Observation groups configuration. Each group (e.g., "policy", "critic") contains + observation terms that are concatenated. Groups can have different settings for + noise, history, and delay. Can be empty for environments without observations.""" + actions: dict[str, ActionTermCfg] + """Action terms configuration. Each term controls a specific entity/aspect + (e.g., joint positions). Action dimensions are concatenated across terms. + Can be empty for observation-only environments.""" + events: dict[str, EventTermCfg] = field( default_factory=lambda: { "reset_scene_to_default": EventTermCfg( @@ -57,25 +78,44 @@ class ManagerBasedRlEnvCfg: ) } ) + """Event terms for domain randomization and state resets. Default includes + ``reset_scene_to_default`` which resets entities to their initial state. + Can be set to empty to disable all events including default reset.""" + seed: int | None = None + """Random seed for reproducibility. If None, a random seed is used. The actual + seed used is stored back into this field after initialization.""" + sim: SimulationCfg = field(default_factory=SimulationCfg) + """Simulation configuration including physics timestep, solver iterations, + contact parameters, and NaN guarding.""" + viewer: ViewerConfig = field(default_factory=ViewerConfig) + """Viewer configuration for rendering (camera position, resolution, etc.).""" # RL-specific configuration. + episode_length_s: float = 0.0 """Duration of an episode (in seconds). Episode length in steps is computed as: ceil(episode_length_s / (sim.mujoco.timestep * decimation)) """ + rewards: dict[str, RewardTermCfg] = field(default_factory=dict) - """Reward terms configuration.""" + """Reward terms configuration. Can be empty for unsupervised environments.""" + terminations: dict[str, TerminationTermCfg] = field(default_factory=dict) - """Termination terms configuration.""" - commands: dict[str, CommandTermCfg] | None = None - """Command terms configuration. If None, no commands are used.""" - curriculum: dict[str, CurriculumTermCfg] | None = None - """Curriculum terms configuration. If None, no curriculum is used.""" + """Termination terms configuration. Can be empty for infinite episodes (no + terminations). Use ``mdp.time_out`` with ``time_out=True`` for episode time limits.""" + + commands: dict[str, CommandTermCfg] = field(default_factory=dict) + """Command generator terms (e.g., velocity targets). Can be empty if the + task has no goal commands.""" + + curriculum: dict[str, CurriculumTermCfg] = field(default_factory=dict) + """Curriculum terms for adaptive difficulty. Can be empty to disable.""" + is_finite_horizon: bool = False """Whether the task has a finite or infinite horizon. Defaults to False (infinite). @@ -85,6 +125,7 @@ class ManagerBasedRlEnvCfg: receives a truncated done signal to bootstrap the value of continuing beyond the limit. """ + scale_rewards_by_dt: bool = True """Whether to multiply rewards by the environment step duration (dt). @@ -227,7 +268,7 @@ def load_managers(self) -> None: # Command manager (must be before observation manager since observations # may reference commands). - if self.cfg.commands is not None: + if len(self.cfg.commands) > 0: self.command_manager = CommandManager(self.cfg.commands, self) else: self.command_manager = NullCommandManager() @@ -247,7 +288,7 @@ def load_managers(self) -> None: self.cfg.rewards, self, scale_by_dt=self.cfg.scale_rewards_by_dt ) print_info(f"[INFO] {self.reward_manager}") - if self.cfg.curriculum is not None: + if len(self.cfg.curriculum) > 0: self.curriculum_manager = CurriculumManager(self.cfg.curriculum, self) else: self.curriculum_manager = NullCurriculumManager() diff --git a/src/mjlab/envs/mdp/actions/actions.py b/src/mjlab/envs/mdp/actions/actions.py index b06626a22..72354fc06 100644 --- a/src/mjlab/envs/mdp/actions/actions.py +++ b/src/mjlab/envs/mdp/actions/actions.py @@ -8,8 +8,7 @@ import torch from mjlab.actuator.actuator import TransmissionType -from mjlab.managers.action_manager import ActionTerm -from mjlab.managers.manager_term_config import ActionTermCfg +from mjlab.managers.action_manager import ActionTerm, ActionTermCfg from mjlab.utils.lab_api.string import resolve_matching_names_values if TYPE_CHECKING: diff --git a/src/mjlab/envs/mdp/rewards.py b/src/mjlab/envs/mdp/rewards.py index af2882a66..edfe5a4f1 100644 --- a/src/mjlab/envs/mdp/rewards.py +++ b/src/mjlab/envs/mdp/rewards.py @@ -7,7 +7,7 @@ import torch from mjlab.entity import Entity -from mjlab.managers.manager_term_config import RewardTermCfg +from mjlab.managers.reward_manager import RewardTermCfg from mjlab.managers.scene_entity_config import SceneEntityCfg from mjlab.utils.lab_api.string import ( resolve_matching_names_values, diff --git a/src/mjlab/managers/__init__.py b/src/mjlab/managers/__init__.py index 621d63f1f..114f198c1 100644 --- a/src/mjlab/managers/__init__.py +++ b/src/mjlab/managers/__init__.py @@ -1,10 +1,30 @@ """Environment managers.""" +from mjlab.managers.action_manager import ActionManager as ActionManager +from mjlab.managers.action_manager import ActionTerm as ActionTerm +from mjlab.managers.action_manager import ActionTermCfg as ActionTermCfg from mjlab.managers.command_manager import CommandManager as CommandManager from mjlab.managers.command_manager import CommandTerm as CommandTerm +from mjlab.managers.command_manager import CommandTermCfg as CommandTermCfg from mjlab.managers.command_manager import NullCommandManager as NullCommandManager from mjlab.managers.curriculum_manager import CurriculumManager as CurriculumManager +from mjlab.managers.curriculum_manager import CurriculumTermCfg as CurriculumTermCfg from mjlab.managers.curriculum_manager import ( NullCurriculumManager as NullCurriculumManager, ) -from mjlab.managers.manager_term_config import CommandTermCfg as CommandTermCfg +from mjlab.managers.event_manager import EventManager as EventManager +from mjlab.managers.event_manager import EventMode as EventMode +from mjlab.managers.event_manager import EventTermCfg as EventTermCfg +from mjlab.managers.manager_base import ManagerBase as ManagerBase +from mjlab.managers.manager_base import ManagerTermBase as ManagerTermBase +from mjlab.managers.manager_base import ManagerTermBaseCfg as ManagerTermBaseCfg +from mjlab.managers.observation_manager import ( + ObservationGroupCfg as ObservationGroupCfg, +) +from mjlab.managers.observation_manager import ObservationManager as ObservationManager +from mjlab.managers.observation_manager import ObservationTermCfg as ObservationTermCfg +from mjlab.managers.reward_manager import RewardManager as RewardManager +from mjlab.managers.reward_manager import RewardTermCfg as RewardTermCfg +from mjlab.managers.scene_entity_config import SceneEntityCfg as SceneEntityCfg +from mjlab.managers.termination_manager import TerminationManager as TerminationManager +from mjlab.managers.termination_manager import TerminationTermCfg as TerminationTermCfg diff --git a/src/mjlab/managers/action_manager.py b/src/mjlab/managers/action_manager.py index e5daa3086..4bd936f01 100644 --- a/src/mjlab/managers/action_manager.py +++ b/src/mjlab/managers/action_manager.py @@ -3,6 +3,7 @@ from __future__ import annotations import abc +from dataclasses import dataclass from typing import TYPE_CHECKING, Sequence import torch @@ -12,7 +13,27 @@ if TYPE_CHECKING: from mjlab.envs import ManagerBasedRlEnv - from mjlab.managers.manager_term_config import ActionTermCfg + + +@dataclass(kw_only=True) +class ActionTermCfg(abc.ABC): + """Configuration for an action term. + + Action terms process raw actions from the policy and apply them to entities + in the scene (e.g., setting joint positions, velocities, or efforts). + """ + + entity_name: str + """Name of the entity in the scene that this action term controls.""" + + clip: dict[str, tuple] | None = None + """Optional clipping bounds per transmission type. Maps transmission name + (e.g., 'position', 'velocity') to (min, max) tuple.""" + + @abc.abstractmethod + def build(self, env: ManagerBasedRlEnv) -> ActionTerm: + """Build the action term from this config.""" + raise NotImplementedError class ActionTerm(ManagerTermBase): @@ -47,6 +68,13 @@ def raw_action(self) -> torch.Tensor: class ActionManager(ManagerBase): + """Manages action processing for the environment. + + The action manager aggregates multiple action terms, each controlling a different + entity or aspect of the simulation. It splits the policy's action tensor and + routes each slice to the appropriate action term. + """ + def __init__(self, cfg: dict[str, ActionTermCfg], env: ManagerBasedRlEnv): self.cfg = cfg super().__init__(env=env) diff --git a/src/mjlab/managers/command_manager.py b/src/mjlab/managers/command_manager.py index bc6881929..b044a08e8 100644 --- a/src/mjlab/managers/command_manager.py +++ b/src/mjlab/managers/command_manager.py @@ -3,6 +3,7 @@ from __future__ import annotations import abc +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Sequence import torch @@ -12,10 +13,34 @@ if TYPE_CHECKING: from mjlab.envs.manager_based_rl_env import ManagerBasedRlEnv - from mjlab.managers.manager_term_config import CommandTermCfg from mjlab.viewer.debug_visualizer import DebugVisualizer +@dataclass(kw_only=True) +class CommandTermCfg(abc.ABC): + """Configuration for a command generator term. + + Command terms generate goal commands for the agent (e.g., target velocity, + target position). Commands are automatically resampled at configurable + intervals and can track metrics for logging. + """ + + resampling_time_range: tuple[float, float] + """Time range in seconds for command resampling. When the timer expires, a new + command is sampled and the timer is reset to a value uniformly drawn from + ``[min, max]``. Set both values equal for fixed-interval resampling.""" + + debug_vis: bool = False + """Whether to enable debug visualization for this command term. When True, + the command term's ``_debug_vis_impl`` method is called each frame to render + visual aids (e.g., velocity arrows, target markers).""" + + @abc.abstractmethod + def build(self, env: ManagerBasedRlEnv) -> CommandTerm: + """Build the command term from this config.""" + raise NotImplementedError + + class CommandTerm(ManagerTermBase): """Base class for command terms.""" @@ -83,6 +108,13 @@ def _update_command(self) -> None: class CommandManager(ManagerBase): + """Manages command generation for the environment. + + The command manager generates and updates goal commands for the agent (e.g., + target velocity, target position). Commands are resampled at configurable + intervals and can track metrics for logging. + """ + _env: ManagerBasedRlEnv def __init__(self, cfg: dict[str, CommandTermCfg], env: ManagerBasedRlEnv): diff --git a/src/mjlab/managers/curriculum_manager.py b/src/mjlab/managers/curriculum_manager.py index cd180c063..d34b25aa6 100644 --- a/src/mjlab/managers/curriculum_manager.py +++ b/src/mjlab/managers/curriculum_manager.py @@ -3,19 +3,37 @@ from __future__ import annotations from copy import deepcopy +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Sequence import torch from prettytable import PrettyTable -from mjlab.managers.manager_base import ManagerBase -from mjlab.managers.manager_term_config import CurriculumTermCfg +from mjlab.managers.manager_base import ManagerBase, ManagerTermBaseCfg if TYPE_CHECKING: from mjlab.envs.manager_based_rl_env import ManagerBasedRlEnv +@dataclass(kw_only=True) +class CurriculumTermCfg(ManagerTermBaseCfg): + """Configuration for a curriculum term. + + Curriculum terms modify environment parameters during training to implement + curriculum learning strategies (e.g., gradually increasing task difficulty). + """ + + pass + + class CurriculumManager(ManagerBase): + """Manages curriculum learning for the environment. + + The curriculum manager updates environment parameters during training based + on agent performance. Each term can modify different aspects of the task + difficulty (e.g., terrain complexity, command ranges). + """ + _env: ManagerBasedRlEnv def __init__(self, cfg: dict[str, CurriculumTermCfg], env: ManagerBasedRlEnv): diff --git a/src/mjlab/managers/event_manager.py b/src/mjlab/managers/event_manager.py index 03ab162db..98c256866 100644 --- a/src/mjlab/managers/event_manager.py +++ b/src/mjlab/managers/event_manager.py @@ -3,19 +3,74 @@ from __future__ import annotations from copy import deepcopy -from typing import TYPE_CHECKING +from dataclasses import dataclass +from typing import TYPE_CHECKING, Literal import torch from prettytable import PrettyTable -from mjlab.managers.manager_base import ManagerBase -from mjlab.managers.manager_term_config import EventMode, EventTermCfg +from mjlab.managers.manager_base import ManagerBase, ManagerTermBaseCfg if TYPE_CHECKING: from mjlab.envs import ManagerBasedRlEnv +EventMode = Literal["startup", "reset", "interval"] + + +@dataclass(kw_only=True) +class EventTermCfg(ManagerTermBaseCfg): + """Configuration for an event term. + + Event terms trigger operations at specific simulation events. They're commonly + used for domain randomization, state resets, and periodic perturbations. + + The three modes determine when the event fires: + + - ``"startup"``: Once when the environment initializes. Use for parameters that + should be randomized per-environment but stay constant within an episode ( + e.g., domain randomization). + + - ``"reset"``: On every episode reset. Use for parameters that should vary + between episodes (e.g., initial robot pose, domain randomization). + + - ``"interval"``: Periodically during simulation, controlled by + ``interval_range_s``. Use for perturbations that should happen during + episodes (e.g., pushing the robot, external disturbances). + """ + + mode: EventMode + """When the event triggers: ``"startup"`` (once at init), ``"reset"`` (every + episode), or ``"interval"`` (periodically during simulation).""" + + interval_range_s: tuple[float, float] | None = None + """Time range in seconds for interval mode. The next trigger time is uniformly + sampled from ``[min, max]``. Required when ``mode="interval"``.""" + + is_global_time: bool = False + """Whether all environments share the same timer. If True, all envs trigger + simultaneously. If False (default), each env has an independent timer that + resets on episode reset. Only applies to ``mode="interval"``.""" + + min_step_count_between_reset: int = 0 + """Minimum environment steps between triggers. Prevents the event from firing + too frequently when episodes reset rapidly. Only applies to ``mode="reset"``. + Set to 0 (default) to trigger on every reset.""" + + domain_randomization: bool = False + """Whether this event performs domain randomization. If True, the field name + from ``params["field"]`` is tracked and exposed via + ``EventManager.domain_randomization_fields`` for logging/debugging.""" + + class EventManager(ManagerBase): + """Manages event-based operations for the environment. + + The event manager triggers operations at different simulation events: startup + (once at initialization), reset (on episode reset), or interval (periodically + during simulation). Common uses include domain randomization and state resets. + """ + _env: ManagerBasedRlEnv def __init__(self, cfg: dict[str, EventTermCfg], env: ManagerBasedRlEnv): diff --git a/src/mjlab/managers/manager_base.py b/src/mjlab/managers/manager_base.py index ff4743636..049df2d4c 100644 --- a/src/mjlab/managers/manager_base.py +++ b/src/mjlab/managers/manager_base.py @@ -3,6 +3,7 @@ import abc import inspect from collections.abc import Sequence +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any import torch @@ -11,7 +12,53 @@ if TYPE_CHECKING: from mjlab.envs import ManagerBasedRlEnv - from mjlab.managers.manager_term_config import ManagerTermBaseCfg + + +@dataclass +class ManagerTermBaseCfg: + """Base configuration for manager terms. + + This is the base config for terms in observation, reward, termination, curriculum, + and event managers. It provides a common interface for specifying a callable + and its parameters. + + The ``func`` field accepts either a function or a class: + + **Function-based terms** are simpler and suitable for stateless computations: + + .. code-block:: python + + RewardTermCfg(func=mdp.joint_torques_l2, weight=-0.01) + + **Class-based terms** are instantiated with ``(cfg, env)`` and useful when you need + to: + + - Cache computed values at initialization (e.g., resolve regex patterns to indices) + - Maintain state across calls + - Perform expensive setup once rather than every call + + .. code-block:: python + + class posture: + def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): + # Resolve std dict to tensor once at init + self.std = resolve_std_to_tensor(cfg.params["std"], env) + + def __call__(self, env, **kwargs) -> torch.Tensor: + # Use cached self.std + return compute_posture_reward(env, self.std) + + RewardTermCfg(func=posture, params={"std": {".*knee.*": 0.3}}, weight=1.0) + + Class-based terms can optionally implement ``reset(env_ids)`` for per-episode state. + """ + + func: Any + """The callable that computes this term's value. Can be a function or a class. + Classes are auto-instantiated with ``(cfg=term_cfg, env=env)``.""" + + params: dict[str, Any] = field(default_factory=lambda: {}) + """Additional keyword arguments passed to func when called.""" class ManagerTermBase: diff --git a/src/mjlab/managers/manager_term_config.py b/src/mjlab/managers/manager_term_config.py deleted file mode 100644 index 6a5aac57c..000000000 --- a/src/mjlab/managers/manager_term_config.py +++ /dev/null @@ -1,174 +0,0 @@ -from __future__ import annotations - -import abc -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Literal - -import torch - -from mjlab.managers.command_manager import CommandTerm -from mjlab.utils.noise.noise_cfg import NoiseCfg, NoiseModelCfg - -if TYPE_CHECKING: - from mjlab.envs import ManagerBasedRlEnv - from mjlab.managers.action_manager import ActionTerm - - -@dataclass -class ManagerTermBaseCfg: - func: Any - params: dict[str, Any] = field(default_factory=lambda: {}) - - -## -# Action manager. -## - - -@dataclass(kw_only=True) -class ActionTermCfg(abc.ABC): - """Configuration for an action term.""" - - entity_name: str - clip: dict[str, tuple] | None = None - - @abc.abstractmethod - def build(self, env: ManagerBasedRlEnv) -> ActionTerm: - """Build the action term from this config.""" - raise NotImplementedError - - -## -# Command manager. -## - - -@dataclass(kw_only=True) -class CommandTermCfg(abc.ABC): - """Configuration for a command generator term.""" - - resampling_time_range: tuple[float, float] - debug_vis: bool = False - - @abc.abstractmethod - def build(self, env: ManagerBasedRlEnv) -> CommandTerm: - """Build the command term from this config.""" - raise NotImplementedError - - -## -# Curriculum manager. -## - - -@dataclass(kw_only=True) -class CurriculumTermCfg(ManagerTermBaseCfg): - pass - - -## -# Event manager. -## - - -EventMode = Literal["startup", "reset", "interval"] - - -@dataclass(kw_only=True) -class EventTermCfg(ManagerTermBaseCfg): - """Configuration for an event term.""" - - mode: EventMode - interval_range_s: tuple[float, float] | None = None - is_global_time: bool = False - min_step_count_between_reset: int = 0 - domain_randomization: bool = False - """Whether this event term performs domain randomization. If True, the field - name (from params["field"]) will be tracked for domain randomization purposes.""" - - -## -# Observation manager. -## - - -@dataclass -class ObservationTermCfg(ManagerTermBaseCfg): - """Configuration for an observation term. - - Processing pipeline: compute → noise → clip → scale → delay → history. - Delay models sensor latency. History provides temporal context. Both are optional - and can be combined. - """ - - noise: NoiseCfg | NoiseModelCfg | None = None - """Noise model to apply to the observation.""" - clip: tuple[float, float] | None = None - """Range (min, max) to clip the observation values.""" - scale: tuple[float, ...] | float | torch.Tensor | None = None - """Scaling factor(s) to multiply the observation by.""" - delay_min_lag: int = 0 - """Minimum lag (in steps) for delayed observations. Lag sampled uniformly from - [min_lag, max_lag]. Convert to ms: lag * (1000 / control_hz).""" - delay_max_lag: int = 0 - """Maximum lag (in steps) for delayed observations. Use min=max for constant delay.""" - delay_per_env: bool = True - """If True, each environment samples its own lag. If False, all environments share - the same lag at each step.""" - delay_hold_prob: float = 0.0 - """Probability of reusing the previous lag instead of resampling. Useful for - temporally correlated latency patterns.""" - delay_update_period: int = 0 - """Resample lag every N steps (models multi-rate sensors). If 0, update every step.""" - delay_per_env_phase: bool = True - """If True and update_period > 0, stagger update timing across envs to avoid - synchronized resampling.""" - history_length: int = 0 - """Number of past observations to keep in history. 0 = no history.""" - flatten_history_dim: bool = True - """Whether to flatten the history dimension into observation. - - When True and concatenate_terms=True, uses term-major ordering: - [A_t0, A_t1, ..., A_tH-1, B_t0, B_t1, ..., B_tH-1, ...] - See docs/api/observation_history_delay.md for details on ordering.""" - - -@dataclass -class ObservationGroupCfg: - """Configuration for an observation group. - - The `terms` field contains a dictionary mapping term names to their configurations. - """ - - terms: dict[str, ObservationTermCfg] - concatenate_terms: bool = True - concatenate_dim: int = -1 - enable_corruption: bool = False - history_length: int | None = None - flatten_history_dim: bool = True - - -## -# Reward manager. -## - - -@dataclass(kw_only=True) -class RewardTermCfg(ManagerTermBaseCfg): - """Configuration for a reward term.""" - - func: Any - weight: float - - -## -# Termination manager. -## - - -@dataclass -class TerminationTermCfg(ManagerTermBaseCfg): - """Configuration for a termination term.""" - - time_out: bool = False - """Whether the term contributes towards episodic timeouts.""" diff --git a/src/mjlab/managers/observation_manager.py b/src/mjlab/managers/observation_manager.py index 026b48d00..311d30309 100644 --- a/src/mjlab/managers/observation_manager.py +++ b/src/mjlab/managers/observation_manager.py @@ -1,19 +1,111 @@ """Observation manager for computing observations.""" from copy import deepcopy +from dataclasses import dataclass from typing import Sequence import numpy as np import torch from prettytable import PrettyTable -from mjlab.managers.manager_base import ManagerBase -from mjlab.managers.manager_term_config import ObservationGroupCfg, ObservationTermCfg +from mjlab.managers.manager_base import ManagerBase, ManagerTermBaseCfg from mjlab.utils.buffers import CircularBuffer, DelayBuffer from mjlab.utils.noise import noise_cfg, noise_model +from mjlab.utils.noise.noise_cfg import NoiseCfg, NoiseModelCfg + + +@dataclass +class ObservationTermCfg(ManagerTermBaseCfg): + """Configuration for an observation term. + + Processing pipeline: compute → noise → clip → scale → delay → history. + Delay models sensor latency. History provides temporal context. Both are optional + and can be combined. + """ + + noise: NoiseCfg | NoiseModelCfg | None = None + """Noise model to apply to the observation.""" + + clip: tuple[float, float] | None = None + """Range (min, max) to clip the observation values.""" + + scale: tuple[float, ...] | float | torch.Tensor | None = None + """Scaling factor(s) to multiply the observation by.""" + + delay_min_lag: int = 0 + """Minimum lag (in steps) for delayed observations. Lag sampled uniformly from + [min_lag, max_lag]. Convert to ms: lag * (1000 / control_hz).""" + + delay_max_lag: int = 0 + """Maximum lag (in steps) for delayed observations. Use min=max for constant delay.""" + + delay_per_env: bool = True + """If True, each environment samples its own lag. If False, all environments share + the same lag at each step.""" + + delay_hold_prob: float = 0.0 + """Probability of reusing the previous lag instead of resampling. Useful for + temporally correlated latency patterns.""" + + delay_update_period: int = 0 + """Resample lag every N steps (models multi-rate sensors). If 0, update every step.""" + + delay_per_env_phase: bool = True + """If True and update_period > 0, stagger update timing across envs to avoid + synchronized resampling.""" + + history_length: int = 0 + """Number of past observations to keep in history. 0 = no history.""" + + flatten_history_dim: bool = True + """Whether to flatten the history dimension into observation. + + When True and concatenate_terms=True, uses term-major ordering: + [A_t0, A_t1, ..., A_tH-1, B_t0, B_t1, ..., B_tH-1, ...] + See docs/api/observation_history_delay.md for details on ordering.""" + + +@dataclass +class ObservationGroupCfg: + """Configuration for an observation group. + + An observation group bundles multiple observation terms together. Groups are + typically used to separate observations for different purposes (e.g., "policy" + for the actor, "critic" for the value function). + """ + + terms: dict[str, ObservationTermCfg] + """Dictionary mapping term names to their configurations.""" + + concatenate_terms: bool = True + """Whether to concatenate all terms into a single tensor. If False, returns + a dict mapping term names to their individual tensors.""" + + concatenate_dim: int = -1 + """Dimension along which to concatenate terms. Default -1 (last dimension).""" + + enable_corruption: bool = False + """Whether to apply noise corruption to observations. Set to True during + training for domain randomization, False during evaluation.""" + + history_length: int | None = None + """Group-level history length override. If set, applies to all terms in + this group. If None, each term uses its own ``history_length`` setting.""" + + flatten_history_dim: bool = True + """Whether to flatten history into the observation dimension. If True, + observations have shape ``(num_envs, obs_dim * history_length)``. If False, + shape is ``(num_envs, history_length, obs_dim)``.""" class ObservationManager(ManagerBase): + """Manages observation computation for the environment. + + The observation manager computes observations from multiple terms organized + into groups. Each term can have noise, clipping, scaling, delay, and history + applied. Groups can optionally concatenate their terms into a single tensor. + """ + def __init__(self, cfg: dict[str, ObservationGroupCfg], env): self.cfg = deepcopy(cfg) super().__init__(env=env) diff --git a/src/mjlab/managers/reward_manager.py b/src/mjlab/managers/reward_manager.py index 51cf8dfd0..6afb5c7b6 100644 --- a/src/mjlab/managers/reward_manager.py +++ b/src/mjlab/managers/reward_manager.py @@ -3,18 +3,29 @@ from __future__ import annotations from copy import deepcopy -from typing import TYPE_CHECKING +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any import torch from prettytable import PrettyTable -from mjlab.managers.manager_base import ManagerBase -from mjlab.managers.manager_term_config import RewardTermCfg +from mjlab.managers.manager_base import ManagerBase, ManagerTermBaseCfg if TYPE_CHECKING: from mjlab.envs.manager_based_rl_env import ManagerBasedRlEnv +@dataclass(kw_only=True) +class RewardTermCfg(ManagerTermBaseCfg): + """Configuration for a reward term.""" + + func: Any + """The callable that computes this reward term's value.""" + + weight: float + """Weight multiplier for this reward term.""" + + class RewardManager(ManagerBase): """Manages reward computation by aggregating weighted reward terms. diff --git a/src/mjlab/managers/termination_manager.py b/src/mjlab/managers/termination_manager.py index baac956fc..d31808609 100644 --- a/src/mjlab/managers/termination_manager.py +++ b/src/mjlab/managers/termination_manager.py @@ -3,19 +3,34 @@ from __future__ import annotations from copy import deepcopy +from dataclasses import dataclass from typing import TYPE_CHECKING, Sequence import torch from prettytable import PrettyTable -from mjlab.managers.manager_base import ManagerBase -from mjlab.managers.manager_term_config import TerminationTermCfg +from mjlab.managers.manager_base import ManagerBase, ManagerTermBaseCfg if TYPE_CHECKING: from mjlab.envs.manager_based_rl_env import ManagerBasedRlEnv +@dataclass +class TerminationTermCfg(ManagerTermBaseCfg): + """Configuration for a termination term.""" + + time_out: bool = False + """Whether the term contributes towards episodic timeouts.""" + + class TerminationManager(ManagerBase): + """Manages termination conditions for the environment. + + The termination manager aggregates multiple termination terms to compute + episode done signals. Terms can be either truncations (time-based) or + terminations (failure conditions). + """ + _env: ManagerBasedRlEnv def __init__(self, cfg: dict[str, TerminationTermCfg], env: ManagerBasedRlEnv): diff --git a/src/mjlab/scripts/play.py b/src/mjlab/scripts/play.py index 4d385b5e3..b905edb8d 100644 --- a/src/mjlab/scripts/play.py +++ b/src/mjlab/scripts/play.py @@ -52,21 +52,17 @@ def run_play(task_id: str, cfg: PlayConfig): TRAINED_MODE = not DUMMY_MODE # Check if this is a tracking task by checking for motion command. - is_tracking_task = ( - env_cfg.commands is not None - and "motion" in env_cfg.commands - and isinstance(env_cfg.commands["motion"], MotionCommandCfg) + is_tracking_task = "motion" in env_cfg.commands and isinstance( + env_cfg.commands["motion"], MotionCommandCfg ) if is_tracking_task and cfg._demo_mode: # Demo mode: use uniform sampling to see more diversity with num_envs > 1. - assert env_cfg.commands is not None motion_cmd = env_cfg.commands["motion"] assert isinstance(motion_cmd, MotionCommandCfg) motion_cmd.sampling_mode = "uniform" if is_tracking_task: - assert env_cfg.commands is not None motion_cmd = env_cfg.commands["motion"] assert isinstance(motion_cmd, MotionCommandCfg) diff --git a/src/mjlab/scripts/train.py b/src/mjlab/scripts/train.py index 1a845dd46..69f1503b3 100644 --- a/src/mjlab/scripts/train.py +++ b/src/mjlab/scripts/train.py @@ -68,10 +68,8 @@ def run_train(task_id: str, cfg: TrainConfig, log_dir: Path) -> None: registry_name: str | None = None # Check if this is a tracking task by checking for motion command. - is_tracking_task = ( - cfg.env.commands is not None - and "motion" in cfg.env.commands - and isinstance(cfg.env.commands["motion"], MotionCommandCfg) + is_tracking_task = "motion" in cfg.env.commands and isinstance( + cfg.env.commands["motion"], MotionCommandCfg ) if is_tracking_task: @@ -87,7 +85,6 @@ def run_train(task_id: str, cfg: TrainConfig, log_dir: Path) -> None: api = wandb.Api() artifact = api.artifact(registry_name) - assert cfg.env.commands is not None motion_cmd = cfg.env.commands["motion"] assert isinstance(motion_cmd, MotionCommandCfg) motion_cmd.motion_file = str(Path(artifact.download()) / "motion.npz") diff --git a/src/mjlab/tasks/manipulation/config/yam/env_cfgs.py b/src/mjlab/tasks/manipulation/config/yam/env_cfgs.py index 5a962e95e..58684c1c6 100644 --- a/src/mjlab/tasks/manipulation/config/yam/env_cfgs.py +++ b/src/mjlab/tasks/manipulation/config/yam/env_cfgs.py @@ -40,7 +40,6 @@ def yam_lift_cube_env_cfg( assert isinstance(joint_pos_action, JointPositionActionCfg) joint_pos_action.scale = YAM_ACTION_SCALE - assert cfg.commands is not None lift_command = cfg.commands["lift_height"] assert isinstance(lift_command, LiftingCommandCfg) diff --git a/src/mjlab/tasks/manipulation/lift_cube_env_cfg.py b/src/mjlab/tasks/manipulation/lift_cube_env_cfg.py index 2e4ddcece..0554fcb01 100644 --- a/src/mjlab/tasks/manipulation/lift_cube_env_cfg.py +++ b/src/mjlab/tasks/manipulation/lift_cube_env_cfg.py @@ -1,16 +1,13 @@ from mjlab.envs import ManagerBasedRlEnvCfg from mjlab.envs.mdp.actions import JointPositionActionCfg -from mjlab.managers.manager_term_config import ( - ActionTermCfg, - CommandTermCfg, - CurriculumTermCfg, - EventTermCfg, - ObservationGroupCfg, - ObservationTermCfg, - RewardTermCfg, - TerminationTermCfg, -) +from mjlab.managers.action_manager import ActionTermCfg +from mjlab.managers.command_manager import CommandTermCfg +from mjlab.managers.curriculum_manager import CurriculumTermCfg +from mjlab.managers.event_manager import EventTermCfg +from mjlab.managers.observation_manager import ObservationGroupCfg, ObservationTermCfg +from mjlab.managers.reward_manager import RewardTermCfg from mjlab.managers.scene_entity_config import SceneEntityCfg +from mjlab.managers.termination_manager import TerminationTermCfg from mjlab.scene import SceneCfg from mjlab.sensor import ContactMatch, ContactSensorCfg from mjlab.sim import MujocoCfg, SimulationCfg diff --git a/src/mjlab/tasks/manipulation/mdp/commands.py b/src/mjlab/tasks/manipulation/mdp/commands.py index 32764595a..b0877563e 100644 --- a/src/mjlab/tasks/manipulation/mdp/commands.py +++ b/src/mjlab/tasks/manipulation/mdp/commands.py @@ -7,8 +7,7 @@ import torch from mjlab.entity import Entity -from mjlab.managers.command_manager import CommandTerm -from mjlab.managers.manager_term_config import CommandTermCfg +from mjlab.managers.command_manager import CommandTerm, CommandTermCfg from mjlab.utils.lab_api.math import ( quat_from_euler_xyz, sample_uniform, diff --git a/src/mjlab/tasks/tracking/config/g1/env_cfgs.py b/src/mjlab/tasks/tracking/config/g1/env_cfgs.py index 2f63e02b3..563eae90d 100644 --- a/src/mjlab/tasks/tracking/config/g1/env_cfgs.py +++ b/src/mjlab/tasks/tracking/config/g1/env_cfgs.py @@ -6,7 +6,7 @@ ) from mjlab.envs import ManagerBasedRlEnvCfg from mjlab.envs.mdp.actions import JointPositionActionCfg -from mjlab.managers.manager_term_config import ObservationGroupCfg +from mjlab.managers.observation_manager import ObservationGroupCfg from mjlab.sensor import ContactMatch, ContactSensorCfg from mjlab.tasks.tracking.mdp import MotionCommandCfg from mjlab.tasks.tracking.tracking_env_cfg import make_tracking_env_cfg @@ -35,7 +35,6 @@ def unitree_g1_flat_tracking_env_cfg( assert isinstance(joint_pos_action, JointPositionActionCfg) joint_pos_action.scale = G1_ACTION_SCALE - assert cfg.commands is not None motion_cmd = cfg.commands["motion"] assert isinstance(motion_cmd, MotionCommandCfg) motion_cmd.anchor_body_name = "torso_link" diff --git a/src/mjlab/tasks/tracking/scripts/evaluate.py b/src/mjlab/tasks/tracking/scripts/evaluate.py index e7a8d94d0..1983802f2 100644 --- a/src/mjlab/tasks/tracking/scripts/evaluate.py +++ b/src/mjlab/tasks/tracking/scripts/evaluate.py @@ -52,7 +52,6 @@ def run_evaluate(task_id: str, cfg: EvaluateConfig) -> dict[str, float]: env_cfg = load_env_cfg(task_id, play=False) agent_cfg = load_rl_cfg(task_id) - assert env_cfg.commands is not None motion_cmd = env_cfg.commands.get("motion") if not isinstance(motion_cmd, MotionCommandCfg): raise ValueError(f"Task {task_id} is not a tracking task.") diff --git a/src/mjlab/tasks/tracking/tracking_env_cfg.py b/src/mjlab/tasks/tracking/tracking_env_cfg.py index adadabf08..ed978861f 100644 --- a/src/mjlab/tasks/tracking/tracking_env_cfg.py +++ b/src/mjlab/tasks/tracking/tracking_env_cfg.py @@ -11,16 +11,13 @@ from mjlab.envs import ManagerBasedRlEnvCfg from mjlab.envs.mdp.actions import JointPositionActionCfg -from mjlab.managers.manager_term_config import ( - ActionTermCfg, - CommandTermCfg, - EventTermCfg, - ObservationGroupCfg, - ObservationTermCfg, - RewardTermCfg, - TerminationTermCfg, -) +from mjlab.managers.action_manager import ActionTermCfg +from mjlab.managers.command_manager import CommandTermCfg +from mjlab.managers.event_manager import EventTermCfg +from mjlab.managers.observation_manager import ObservationGroupCfg, ObservationTermCfg +from mjlab.managers.reward_manager import RewardTermCfg from mjlab.managers.scene_entity_config import SceneEntityCfg +from mjlab.managers.termination_manager import TerminationTermCfg from mjlab.scene import SceneCfg from mjlab.sim import MujocoCfg, SimulationCfg from mjlab.tasks.tracking import mdp diff --git a/src/mjlab/tasks/velocity/config/g1/env_cfgs.py b/src/mjlab/tasks/velocity/config/g1/env_cfgs.py index 8df322225..2b093afa5 100644 --- a/src/mjlab/tasks/velocity/config/g1/env_cfgs.py +++ b/src/mjlab/tasks/velocity/config/g1/env_cfgs.py @@ -7,10 +7,8 @@ from mjlab.envs import ManagerBasedRlEnvCfg from mjlab.envs import mdp as envs_mdp from mjlab.envs.mdp.actions import JointPositionActionCfg -from mjlab.managers.manager_term_config import ( - EventTermCfg, - RewardTermCfg, -) +from mjlab.managers.event_manager import EventTermCfg +from mjlab.managers.reward_manager import RewardTermCfg from mjlab.sensor import ContactMatch, ContactSensorCfg from mjlab.tasks.velocity import mdp from mjlab.tasks.velocity.mdp import UniformVelocityCommandCfg @@ -60,7 +58,6 @@ def unitree_g1_rough_env_cfg(play: bool = False) -> ManagerBasedRlEnvCfg: cfg.viewer.body_name = "torso_link" - assert cfg.commands is not None twist_cmd = cfg.commands["twist"] assert isinstance(twist_cmd, UniformVelocityCommandCfg) twist_cmd.viz.z_offset = 1.15 @@ -169,14 +166,11 @@ def unitree_g1_flat_env_cfg(play: bool = False) -> ManagerBasedRlEnvCfg: cfg.scene.terrain.terrain_generator = None # Disable terrain curriculum. - assert cfg.curriculum is not None assert "terrain_levels" in cfg.curriculum del cfg.curriculum["terrain_levels"] if play: - commands = cfg.commands - assert commands is not None - twist_cmd = commands["twist"] + twist_cmd = cfg.commands["twist"] assert isinstance(twist_cmd, UniformVelocityCommandCfg) twist_cmd.ranges.lin_vel_x = (-1.5, 2.0) twist_cmd.ranges.ang_vel_z = (-0.7, 0.7) diff --git a/src/mjlab/tasks/velocity/config/go1/env_cfgs.py b/src/mjlab/tasks/velocity/config/go1/env_cfgs.py index 17e3d039a..fc7d377d5 100644 --- a/src/mjlab/tasks/velocity/config/go1/env_cfgs.py +++ b/src/mjlab/tasks/velocity/config/go1/env_cfgs.py @@ -6,7 +6,7 @@ ) from mjlab.envs import ManagerBasedRlEnvCfg from mjlab.envs.mdp.actions import JointPositionActionCfg -from mjlab.managers.manager_term_config import TerminationTermCfg +from mjlab.managers.termination_manager import TerminationTermCfg from mjlab.sensor import ContactMatch, ContactSensorCfg from mjlab.tasks.velocity import mdp from mjlab.tasks.velocity.velocity_env_cfg import make_velocity_env_cfg @@ -122,7 +122,6 @@ def unitree_go1_flat_env_cfg(play: bool = False) -> ManagerBasedRlEnvCfg: cfg.scene.terrain.terrain_generator = None # Disable terrain curriculum. - assert cfg.curriculum is not None del cfg.curriculum["terrain_levels"] return cfg diff --git a/src/mjlab/tasks/velocity/mdp/rewards.py b/src/mjlab/tasks/velocity/mdp/rewards.py index 265044af0..ff0199e53 100644 --- a/src/mjlab/tasks/velocity/mdp/rewards.py +++ b/src/mjlab/tasks/velocity/mdp/rewards.py @@ -5,7 +5,7 @@ import torch from mjlab.entity import Entity -from mjlab.managers.manager_term_config import RewardTermCfg +from mjlab.managers.reward_manager import RewardTermCfg from mjlab.managers.scene_entity_config import SceneEntityCfg from mjlab.sensor import BuiltinSensor, ContactSensor from mjlab.utils.lab_api.math import quat_apply_inverse diff --git a/src/mjlab/tasks/velocity/mdp/velocity_command.py b/src/mjlab/tasks/velocity/mdp/velocity_command.py index e420a390d..346751b20 100644 --- a/src/mjlab/tasks/velocity/mdp/velocity_command.py +++ b/src/mjlab/tasks/velocity/mdp/velocity_command.py @@ -7,8 +7,7 @@ import torch from mjlab.entity import Entity -from mjlab.managers.command_manager import CommandTerm -from mjlab.managers.manager_term_config import CommandTermCfg +from mjlab.managers.command_manager import CommandTerm, CommandTermCfg from mjlab.utils.lab_api.math import ( matrix_from_quat, quat_apply, diff --git a/src/mjlab/tasks/velocity/velocity_env_cfg.py b/src/mjlab/tasks/velocity/velocity_env_cfg.py index 34f2f6113..2be11cf48 100644 --- a/src/mjlab/tasks/velocity/velocity_env_cfg.py +++ b/src/mjlab/tasks/velocity/velocity_env_cfg.py @@ -9,17 +9,14 @@ from mjlab.envs import ManagerBasedRlEnvCfg from mjlab.envs.mdp.actions import JointPositionActionCfg -from mjlab.managers.manager_term_config import ( - ActionTermCfg, - CommandTermCfg, - CurriculumTermCfg, - EventTermCfg, - ObservationGroupCfg, - ObservationTermCfg, - RewardTermCfg, - TerminationTermCfg, -) +from mjlab.managers.action_manager import ActionTermCfg +from mjlab.managers.command_manager import CommandTermCfg +from mjlab.managers.curriculum_manager import CurriculumTermCfg +from mjlab.managers.event_manager import EventTermCfg +from mjlab.managers.observation_manager import ObservationGroupCfg, ObservationTermCfg +from mjlab.managers.reward_manager import RewardTermCfg from mjlab.managers.scene_entity_config import SceneEntityCfg +from mjlab.managers.termination_manager import TerminationTermCfg from mjlab.scene import SceneCfg from mjlab.sim import MujocoCfg, SimulationCfg from mjlab.tasks.velocity import mdp diff --git a/tests/test_encoder_bias.py b/tests/test_encoder_bias.py index 2c5e4762c..5d7eb8f64 100644 --- a/tests/test_encoder_bias.py +++ b/tests/test_encoder_bias.py @@ -18,11 +18,8 @@ from mjlab.actuator import BuiltinPositionActuatorCfg from mjlab.entity import EntityArticulationInfoCfg, EntityCfg from mjlab.envs import ManagerBasedRlEnv, ManagerBasedRlEnvCfg, mdp -from mjlab.managers.manager_term_config import ( - EventTermCfg, - ObservationGroupCfg, - ObservationTermCfg, -) +from mjlab.managers.event_manager import EventTermCfg +from mjlab.managers.observation_manager import ObservationGroupCfg, ObservationTermCfg from mjlab.managers.scene_entity_config import SceneEntityCfg from mjlab.scene import SceneCfg from mjlab.sim import MujocoCfg, SimulationCfg diff --git a/tests/test_events.py b/tests/test_events.py index 550a9022b..f2dc7ee47 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -8,8 +8,7 @@ from mjlab import actuator from mjlab.envs.mdp import events -from mjlab.managers.event_manager import EventManager -from mjlab.managers.manager_term_config import EventTermCfg +from mjlab.managers.event_manager import EventManager, EventTermCfg from mjlab.managers.scene_entity_config import SceneEntityCfg diff --git a/tests/test_manager_config_immutability.py b/tests/test_manager_config_immutability.py index cfdf980f8..3bc17edc4 100644 --- a/tests/test_manager_config_immutability.py +++ b/tests/test_manager_config_immutability.py @@ -13,13 +13,12 @@ from conftest import get_test_device from mjlab.entity import Entity, EntityCfg -from mjlab.managers.manager_term_config import ( +from mjlab.managers.observation_manager import ( ObservationGroupCfg, + ObservationManager, ObservationTermCfg, - RewardTermCfg, ) -from mjlab.managers.observation_manager import ObservationManager -from mjlab.managers.reward_manager import RewardManager +from mjlab.managers.reward_manager import RewardManager, RewardTermCfg from mjlab.managers.scene_entity_config import SceneEntityCfg from mjlab.sim.sim import Simulation, SimulationCfg diff --git a/tests/test_observation_delay.py b/tests/test_observation_delay.py index f13e3b3c2..87447064c 100644 --- a/tests/test_observation_delay.py +++ b/tests/test_observation_delay.py @@ -6,8 +6,11 @@ import torch from conftest import get_test_device -from mjlab.managers.manager_term_config import ObservationGroupCfg, ObservationTermCfg -from mjlab.managers.observation_manager import ObservationManager +from mjlab.managers.observation_manager import ( + ObservationGroupCfg, + ObservationManager, + ObservationTermCfg, +) @pytest.fixture diff --git a/tests/test_observation_history.py b/tests/test_observation_history.py index d49ba40d3..21b94b9bd 100644 --- a/tests/test_observation_history.py +++ b/tests/test_observation_history.py @@ -6,8 +6,11 @@ import torch from conftest import get_test_device -from mjlab.managers.manager_term_config import ObservationGroupCfg, ObservationTermCfg -from mjlab.managers.observation_manager import ObservationManager +from mjlab.managers.observation_manager import ( + ObservationGroupCfg, + ObservationManager, + ObservationTermCfg, +) @pytest.fixture diff --git a/tests/test_rewards.py b/tests/test_rewards.py index b98ca5c02..f2ad35782 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -10,8 +10,7 @@ from mjlab.actuator import BuiltinPositionActuatorCfg from mjlab.entity import Entity, EntityArticulationInfoCfg, EntityCfg from mjlab.envs.mdp.rewards import electrical_power_cost -from mjlab.managers.manager_term_config import RewardTermCfg -from mjlab.managers.reward_manager import RewardManager +from mjlab.managers.reward_manager import RewardManager, RewardTermCfg from mjlab.managers.scene_entity_config import SceneEntityCfg from mjlab.sim.sim import Simulation, SimulationCfg diff --git a/tests/test_rl_exporter.py b/tests/test_rl_exporter.py index bb8b2355c..3d6c5d142 100644 --- a/tests/test_rl_exporter.py +++ b/tests/test_rl_exporter.py @@ -11,10 +11,7 @@ from mjlab.actuator import XmlMotorActuatorCfg from mjlab.entity import EntityArticulationInfoCfg, EntityCfg from mjlab.envs import ManagerBasedRlEnv, ManagerBasedRlEnvCfg, mdp -from mjlab.managers.manager_term_config import ( - ObservationGroupCfg, - ObservationTermCfg, -) +from mjlab.managers.observation_manager import ObservationGroupCfg, ObservationTermCfg from mjlab.rl.exporter_utils import ( attach_metadata_to_onnx, get_base_metadata, diff --git a/tests/test_task_configs.py b/tests/test_task_configs.py index 7d69e49a2..4582b651f 100644 --- a/tests/test_task_configs.py +++ b/tests/test_task_configs.py @@ -3,7 +3,7 @@ import pytest from mjlab.envs import ManagerBasedRlEnvCfg -from mjlab.managers.manager_term_config import ObservationGroupCfg +from mjlab.managers.observation_manager import ObservationGroupCfg from mjlab.tasks.registry import list_tasks, load_env_cfg diff --git a/tests/test_terminations.py b/tests/test_terminations.py index c09bf1ccf..964337eb8 100644 --- a/tests/test_terminations.py +++ b/tests/test_terminations.py @@ -8,8 +8,7 @@ from conftest import get_test_device from mjlab.envs.mdp.terminations import nan_detection -from mjlab.managers.manager_term_config import TerminationTermCfg -from mjlab.managers.termination_manager import TerminationManager +from mjlab.managers.termination_manager import TerminationManager, TerminationTermCfg from mjlab.sim.sim import Simulation, SimulationCfg diff --git a/tests/test_tracking_task.py b/tests/test_tracking_task.py index e81ebda62..656483db9 100644 --- a/tests/test_tracking_task.py +++ b/tests/test_tracking_task.py @@ -25,7 +25,6 @@ def test_tracking_tasks_have_motion_command(tracking_task_ids: list[str]) -> Non for task_id in tracking_task_ids: cfg = load_env_cfg(task_id) - assert cfg.commands is not None, f"Task {task_id} has no commands" assert "motion" in cfg.commands, f"Task {task_id} missing 'motion' command" motion_cmd = cfg.commands["motion"] @@ -83,7 +82,6 @@ def test_tracking_play_disables_rsi_randomization() -> None: for task_id in tracking_tasks: cfg = load_env_cfg(task_id, play=True) - assert cfg.commands is not None, f"Task {task_id} (play mode) has no commands" motion_cmd = cfg.commands["motion"] assert isinstance(motion_cmd, MotionCommandCfg), ( f"Task {task_id} (play mode) motion command is not MotionCommandCfg" @@ -109,7 +107,6 @@ def test_tracking_play_uses_start_sampling_mode() -> None: for task_id in tracking_tasks: cfg = load_env_cfg(task_id, play=True) - assert cfg.commands is not None, f"Task {task_id} (play mode) has no commands" motion_cmd = cfg.commands["motion"] assert isinstance(motion_cmd, MotionCommandCfg), ( f"Task {task_id} (play mode) motion command is not MotionCommandCfg" diff --git a/tests/test_velocity_task.py b/tests/test_velocity_task.py index da55e2646..9e449e919 100644 --- a/tests/test_velocity_task.py +++ b/tests/test_velocity_task.py @@ -43,7 +43,6 @@ def test_velocity_tasks_have_twist_command(velocity_task_ids: list[str]) -> None for task_id in velocity_task_ids: cfg = load_env_cfg(task_id) - assert cfg.commands is not None, f"Task {task_id} has no commands" assert "twist" in cfg.commands, f"Task {task_id} missing 'twist' command" twist_cmd = cfg.commands["twist"] diff --git a/tests/test_xml_actuator.py b/tests/test_xml_actuator.py index e78888671..e8b90fc0b 100644 --- a/tests/test_xml_actuator.py +++ b/tests/test_xml_actuator.py @@ -7,10 +7,7 @@ from mjlab.actuator import XmlMotorActuatorCfg from mjlab.entity import Entity, EntityArticulationInfoCfg, EntityCfg from mjlab.envs import ManagerBasedRlEnv, ManagerBasedRlEnvCfg, mdp -from mjlab.managers.manager_term_config import ( - ObservationGroupCfg, - ObservationTermCfg, -) +from mjlab.managers.observation_manager import ObservationGroupCfg, ObservationTermCfg from mjlab.scene import SceneCfg from mjlab.sim import MujocoCfg, SimulationCfg from mjlab.terrains import TerrainImporterCfg