Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/actuators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ Domain Randomization
.. code-block:: python

from mjlab.envs.mdp import events
from mjlab.managers.manager_term_config import EventTermCfg
from mjlab.managers.event_manager import EventTermCfg
from mjlab.managers.scene_entity_config import SceneEntityCfg

EventTermCfg(
Expand Down
2 changes: 1 addition & 1 deletion docs/source/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ My training crashes with NaN errors
from dataclasses import dataclass, field

from mjlab.envs.mdp.terminations import nan_detection
from mjlab.managers.manager_term_config import TerminationTermCfg
from mjlab.managers.termination_manager import TerminationTermCfg

@dataclass
class TerminationCfg:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/nan_guard.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ While ``nan_guard`` helps **debug** NaN issues by capturing states, you can also
.. code-block:: python

from mjlab.envs.mdp.terminations import nan_detection
from mjlab.managers.manager_term_config import TerminationTermCfg
from mjlab.managers.termination_manager import TerminationTermCfg

# In your termination config:
nan_term: TerminationTermCfg = field(
Expand Down
2 changes: 1 addition & 1 deletion docs/source/observation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ TL;DR

.. code-block:: python

from mjlab.managers.manager_term_config import ObservationTermCfg
from mjlab.managers.observation_manager import ObservationTermCfg

joint_vel: ObservationTermCfg = ObservationTermCfg(
func=joint_vel,
Expand Down
2 changes: 1 addition & 1 deletion docs/source/randomization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ apply the draw.

.. code-block:: python

from mjlab.managers.manager_term_config import EventTermCfg
from mjlab.managers.event_manager import EventTermCfg
from mjlab.managers.scene_entity_config import SceneEntityCfg
from mjlab.envs import mdp

Expand Down
2 changes: 1 addition & 1 deletion scripts/benchmarks/measure_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def benchmark_task(task: str, cfg: ThroughputConfig) -> BenchmarkResult:
env_cfg.scene.num_envs = cfg.num_envs

# Handle tracking task motion file.
if env_cfg.commands is not None:
if len(env_cfg.commands) > 0:
motion_cmd = env_cfg.commands.get("motion")
if isinstance(motion_cmd, MotionCommandCfg):
api = wandb.Api()
Expand Down
87 changes: 64 additions & 23 deletions src/mjlab/envs/manager_based_rl_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,21 @@

from mjlab.envs import types
from mjlab.envs.mdp.events import reset_scene_to_default
from mjlab.managers.action_manager import ActionManager
from mjlab.managers.command_manager import CommandManager, NullCommandManager
from mjlab.managers.curriculum_manager import CurriculumManager, NullCurriculumManager
from mjlab.managers.event_manager import EventManager
from mjlab.managers.manager_term_config import (
ActionTermCfg,
from mjlab.managers.action_manager import ActionManager, ActionTermCfg
from mjlab.managers.command_manager import (
CommandManager,
CommandTermCfg,
NullCommandManager,
)
from mjlab.managers.curriculum_manager import (
CurriculumManager,
CurriculumTermCfg,
EventTermCfg,
ObservationGroupCfg,
RewardTermCfg,
TerminationTermCfg,
NullCurriculumManager,
)
from mjlab.managers.observation_manager import ObservationManager
from mjlab.managers.reward_manager import RewardManager
from mjlab.managers.termination_manager import TerminationManager
from mjlab.managers.event_manager import EventManager, EventTermCfg
from mjlab.managers.observation_manager import ObservationGroupCfg, ObservationManager
from mjlab.managers.reward_manager import RewardManager, RewardTermCfg
from mjlab.managers.termination_manager import TerminationManager, TerminationTermCfg
from mjlab.scene import Scene
from mjlab.scene.scene import SceneCfg
from mjlab.sim import SimulationCfg
Expand All @@ -41,14 +40,36 @@

@dataclass(kw_only=True)
class ManagerBasedRlEnvCfg:
"""Configuration for a manager-based RL environment."""
"""Configuration for a manager-based RL environment.

This config defines all aspects of an RL environment: the physical scene,
observations, actions, rewards, terminations, and optional features like
commands and curriculum learning.

The environment step size is ``sim.mujoco.timestep * decimation``. For example,
with a 2ms physics timestep and decimation=10, the environment runs at 50Hz.
"""

# Base environment configuration.

decimation: int
"""Number of simulation steps per environment step."""
"""Number of physics simulation steps per environment step. Higher values mean
coarser control frequency. Environment step duration = physics_dt * decimation."""

scene: SceneCfg
"""Scene configuration defining terrain, entities, and sensors. The scene
specifies ``num_envs``, the number of parallel environments."""

observations: dict[str, ObservationGroupCfg]
"""Observation groups configuration. Each group (e.g., "policy", "critic") contains
observation terms that are concatenated. Groups can have different settings for
noise, history, and delay. Can be empty for environments without observations."""

actions: dict[str, ActionTermCfg]
"""Action terms configuration. Each term controls a specific entity/aspect
(e.g., joint positions). Action dimensions are concatenated across terms.
Can be empty for observation-only environments."""

events: dict[str, EventTermCfg] = field(
default_factory=lambda: {
"reset_scene_to_default": EventTermCfg(
Expand All @@ -57,25 +78,44 @@ class ManagerBasedRlEnvCfg:
)
}
)
"""Event terms for domain randomization and state resets. Default includes
``reset_scene_to_default`` which resets entities to their initial state.
Can be set to empty to disable all events including default reset."""

seed: int | None = None
"""Random seed for reproducibility. If None, a random seed is used. The actual
seed used is stored back into this field after initialization."""

sim: SimulationCfg = field(default_factory=SimulationCfg)
"""Simulation configuration including physics timestep, solver iterations,
contact parameters, and NaN guarding."""

viewer: ViewerConfig = field(default_factory=ViewerConfig)
"""Viewer configuration for rendering (camera position, resolution, etc.)."""

# RL-specific configuration.

episode_length_s: float = 0.0
"""Duration of an episode (in seconds).

Episode length in steps is computed as:
ceil(episode_length_s / (sim.mujoco.timestep * decimation))
"""

rewards: dict[str, RewardTermCfg] = field(default_factory=dict)
"""Reward terms configuration."""
"""Reward terms configuration. Can be empty for unsupervised environments."""

terminations: dict[str, TerminationTermCfg] = field(default_factory=dict)
"""Termination terms configuration."""
commands: dict[str, CommandTermCfg] | None = None
"""Command terms configuration. If None, no commands are used."""
curriculum: dict[str, CurriculumTermCfg] | None = None
"""Curriculum terms configuration. If None, no curriculum is used."""
"""Termination terms configuration. Can be empty for infinite episodes (no
terminations). Use ``mdp.time_out`` with ``time_out=True`` for episode time limits."""

commands: dict[str, CommandTermCfg] = field(default_factory=dict)
"""Command generator terms (e.g., velocity targets). Can be empty if the
task has no goal commands."""

curriculum: dict[str, CurriculumTermCfg] = field(default_factory=dict)
"""Curriculum terms for adaptive difficulty. Can be empty to disable."""

is_finite_horizon: bool = False
"""Whether the task has a finite or infinite horizon. Defaults to False (infinite).

Expand All @@ -85,6 +125,7 @@ class ManagerBasedRlEnvCfg:
receives a truncated done signal to bootstrap the value of continuing beyond the
limit.
"""

scale_rewards_by_dt: bool = True
"""Whether to multiply rewards by the environment step duration (dt).

Expand Down Expand Up @@ -227,7 +268,7 @@ def load_managers(self) -> None:

# Command manager (must be before observation manager since observations
# may reference commands).
if self.cfg.commands is not None:
if len(self.cfg.commands) > 0:
self.command_manager = CommandManager(self.cfg.commands, self)
else:
self.command_manager = NullCommandManager()
Expand All @@ -247,7 +288,7 @@ def load_managers(self) -> None:
self.cfg.rewards, self, scale_by_dt=self.cfg.scale_rewards_by_dt
)
print_info(f"[INFO] {self.reward_manager}")
if self.cfg.curriculum is not None:
if len(self.cfg.curriculum) > 0:
self.curriculum_manager = CurriculumManager(self.cfg.curriculum, self)
else:
self.curriculum_manager = NullCurriculumManager()
Expand Down
3 changes: 1 addition & 2 deletions src/mjlab/envs/mdp/actions/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import torch

from mjlab.actuator.actuator import TransmissionType
from mjlab.managers.action_manager import ActionTerm
from mjlab.managers.manager_term_config import ActionTermCfg
from mjlab.managers.action_manager import ActionTerm, ActionTermCfg
from mjlab.utils.lab_api.string import resolve_matching_names_values

if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion src/mjlab/envs/mdp/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch

from mjlab.entity import Entity
from mjlab.managers.manager_term_config import RewardTermCfg
from mjlab.managers.reward_manager import RewardTermCfg
from mjlab.managers.scene_entity_config import SceneEntityCfg
from mjlab.utils.lab_api.string import (
resolve_matching_names_values,
Expand Down
22 changes: 21 additions & 1 deletion src/mjlab/managers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,30 @@
"""Environment managers."""

from mjlab.managers.action_manager import ActionManager as ActionManager
from mjlab.managers.action_manager import ActionTerm as ActionTerm
from mjlab.managers.action_manager import ActionTermCfg as ActionTermCfg
from mjlab.managers.command_manager import CommandManager as CommandManager
from mjlab.managers.command_manager import CommandTerm as CommandTerm
from mjlab.managers.command_manager import CommandTermCfg as CommandTermCfg
from mjlab.managers.command_manager import NullCommandManager as NullCommandManager
from mjlab.managers.curriculum_manager import CurriculumManager as CurriculumManager
from mjlab.managers.curriculum_manager import CurriculumTermCfg as CurriculumTermCfg
from mjlab.managers.curriculum_manager import (
NullCurriculumManager as NullCurriculumManager,
)
from mjlab.managers.manager_term_config import CommandTermCfg as CommandTermCfg
from mjlab.managers.event_manager import EventManager as EventManager
from mjlab.managers.event_manager import EventMode as EventMode
from mjlab.managers.event_manager import EventTermCfg as EventTermCfg
from mjlab.managers.manager_base import ManagerBase as ManagerBase
from mjlab.managers.manager_base import ManagerTermBase as ManagerTermBase
from mjlab.managers.manager_base import ManagerTermBaseCfg as ManagerTermBaseCfg
from mjlab.managers.observation_manager import (
ObservationGroupCfg as ObservationGroupCfg,
)
from mjlab.managers.observation_manager import ObservationManager as ObservationManager
from mjlab.managers.observation_manager import ObservationTermCfg as ObservationTermCfg
from mjlab.managers.reward_manager import RewardManager as RewardManager
from mjlab.managers.reward_manager import RewardTermCfg as RewardTermCfg
from mjlab.managers.scene_entity_config import SceneEntityCfg as SceneEntityCfg
from mjlab.managers.termination_manager import TerminationManager as TerminationManager
from mjlab.managers.termination_manager import TerminationTermCfg as TerminationTermCfg
30 changes: 29 additions & 1 deletion src/mjlab/managers/action_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import abc
from dataclasses import dataclass
from typing import TYPE_CHECKING, Sequence

import torch
Expand All @@ -12,7 +13,27 @@

if TYPE_CHECKING:
from mjlab.envs import ManagerBasedRlEnv
from mjlab.managers.manager_term_config import ActionTermCfg


@dataclass(kw_only=True)
class ActionTermCfg(abc.ABC):
"""Configuration for an action term.

Action terms process raw actions from the policy and apply them to entities
in the scene (e.g., setting joint positions, velocities, or efforts).
"""

entity_name: str
"""Name of the entity in the scene that this action term controls."""

clip: dict[str, tuple] | None = None
"""Optional clipping bounds per transmission type. Maps transmission name
(e.g., 'position', 'velocity') to (min, max) tuple."""

@abc.abstractmethod
def build(self, env: ManagerBasedRlEnv) -> ActionTerm:
"""Build the action term from this config."""
raise NotImplementedError


class ActionTerm(ManagerTermBase):
Expand Down Expand Up @@ -47,6 +68,13 @@ def raw_action(self) -> torch.Tensor:


class ActionManager(ManagerBase):
"""Manages action processing for the environment.

The action manager aggregates multiple action terms, each controlling a different
entity or aspect of the simulation. It splits the policy's action tensor and
routes each slice to the appropriate action term.
"""

def __init__(self, cfg: dict[str, ActionTermCfg], env: ManagerBasedRlEnv):
self.cfg = cfg
super().__init__(env=env)
Expand Down
34 changes: 33 additions & 1 deletion src/mjlab/managers/command_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import abc
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Sequence

import torch
Expand All @@ -12,10 +13,34 @@

if TYPE_CHECKING:
from mjlab.envs.manager_based_rl_env import ManagerBasedRlEnv
from mjlab.managers.manager_term_config import CommandTermCfg
from mjlab.viewer.debug_visualizer import DebugVisualizer


@dataclass(kw_only=True)
class CommandTermCfg(abc.ABC):
"""Configuration for a command generator term.

Command terms generate goal commands for the agent (e.g., target velocity,
target position). Commands are automatically resampled at configurable
intervals and can track metrics for logging.
"""

resampling_time_range: tuple[float, float]
"""Time range in seconds for command resampling. When the timer expires, a new
command is sampled and the timer is reset to a value uniformly drawn from
``[min, max]``. Set both values equal for fixed-interval resampling."""

debug_vis: bool = False
"""Whether to enable debug visualization for this command term. When True,
the command term's ``_debug_vis_impl`` method is called each frame to render
visual aids (e.g., velocity arrows, target markers)."""

@abc.abstractmethod
def build(self, env: ManagerBasedRlEnv) -> CommandTerm:
"""Build the command term from this config."""
raise NotImplementedError


class CommandTerm(ManagerTermBase):
"""Base class for command terms."""

Expand Down Expand Up @@ -83,6 +108,13 @@ def _update_command(self) -> None:


class CommandManager(ManagerBase):
"""Manages command generation for the environment.

The command manager generates and updates goal commands for the agent (e.g.,
target velocity, target position). Commands are resampled at configurable
intervals and can track metrics for logging.
"""

_env: ManagerBasedRlEnv

def __init__(self, cfg: dict[str, CommandTermCfg], env: ManagerBasedRlEnv):
Expand Down
Loading