diff --git a/src/mjlab/tasks/manipulation/config/yam/env_cfgs.py b/src/mjlab/tasks/manipulation/config/yam/env_cfgs.py index 5a962e95e..41849456a 100644 --- a/src/mjlab/tasks/manipulation/config/yam/env_cfgs.py +++ b/src/mjlab/tasks/manipulation/config/yam/env_cfgs.py @@ -9,7 +9,7 @@ from mjlab.envs.mdp.actions import JointPositionActionCfg from mjlab.sensor import ContactSensorCfg from mjlab.tasks.manipulation.lift_cube_env_cfg import make_lift_cube_env_cfg -from mjlab.tasks.manipulation.mdp import LiftingCommandCfg +from mjlab.tasks.manipulation.mdp import PositionCommandCfg def get_cube_spec(cube_size: float = 0.02, mass: float = 0.05) -> mujoco.MjSpec: @@ -33,7 +33,10 @@ def yam_lift_cube_env_cfg( cfg.scene.entities = { "robot": get_yam_robot_cfg(), - "cube": EntityCfg(spec_fn=get_cube_spec), + "cube": EntityCfg( + spec_fn=get_cube_spec, + init_state=EntityCfg.InitialStateCfg(pos=(0.3, 0.0, 0.035)), + ), } joint_pos_action = cfg.actions["joint_pos"] @@ -42,7 +45,7 @@ def yam_lift_cube_env_cfg( assert cfg.commands is not None lift_command = cfg.commands["lift_height"] - assert isinstance(lift_command, LiftingCommandCfg) + assert isinstance(lift_command, PositionCommandCfg) cfg.observations["policy"].terms["ee_to_cube"].params["asset_cfg"].site_names = ( "grasp_site", diff --git a/src/mjlab/tasks/manipulation/lift_cube_env_cfg.py b/src/mjlab/tasks/manipulation/lift_cube_env_cfg.py index 0480d2518..2de768ffe 100644 --- a/src/mjlab/tasks/manipulation/lift_cube_env_cfg.py +++ b/src/mjlab/tasks/manipulation/lift_cube_env_cfg.py @@ -15,7 +15,7 @@ from mjlab.sensor import ContactMatch, ContactSensorCfg from mjlab.sim import MujocoCfg, SimulationCfg from mjlab.tasks.manipulation import mdp as manipulation_mdp -from mjlab.tasks.manipulation.mdp import LiftingCommandCfg +from mjlab.tasks.manipulation.mdp import PositionCommandCfg from mjlab.tasks.velocity import mdp from mjlab.terrains import TerrainImporterCfg from mjlab.utils.noise import UniformNoiseCfg as Unoise @@ -70,28 +70,32 @@ def make_lift_cube_env_cfg() -> ManagerBasedRlEnvCfg: } commands: dict[str, CommandTermCfg] = { - "lift_height": LiftingCommandCfg( - asset_name="cube", + "lift_height": PositionCommandCfg( resampling_time_range=(8.0, 12.0), debug_vis=True, difficulty="dynamic", - object_pose_range=LiftingCommandCfg.ObjectPoseRangeCfg( - x=(0.2, 0.4), - y=(-0.2, 0.2), - z=(0.02, 0.05), - yaw=(-3.14, 3.14), - ), ) } events = { - # For positioning the base of the robot at env_origins. - "reset_base": EventTermCfg( - func=mdp.reset_root_state_uniform, + # For positioning per envs. + "reset_scene": EventTermCfg( + func=mdp.reset_scene_to_default, mode="startup", + ), + "reset_object": EventTermCfg( + func=mdp.reset_root_state_uniform, + mode="interval", + interval_range_s=(8.0, 12.0), params={ - "pose_range": {}, + "pose_range": { + "x": (-0.1, 0.1), + "y": (-0.2, 0.2), + "z": (-0.015, 0.015), + "yaw": (-3.14, 3.14), + }, "velocity_range": {}, + "asset_cfg": SceneEntityCfg("cube"), }, ), "reset_robot_joints": EventTermCfg( diff --git a/src/mjlab/tasks/manipulation/mdp/commands.py b/src/mjlab/tasks/manipulation/mdp/commands.py index 4d7d5eb3b..141e0c8b6 100644 --- a/src/mjlab/tasks/manipulation/mdp/commands.py +++ b/src/mjlab/tasks/manipulation/mdp/commands.py @@ -1,16 +1,13 @@ from __future__ import annotations -import math from dataclasses import dataclass, field from typing import TYPE_CHECKING, Literal import torch -from mjlab.entity import Entity from mjlab.managers.command_manager import CommandTerm from mjlab.managers.manager_term_config import CommandTermCfg from mjlab.utils.lab_api.math import ( - quat_from_euler_xyz, sample_uniform, ) @@ -19,49 +16,24 @@ from mjlab.viewer.debug_visualizer import DebugVisualizer -class LiftingCommand(CommandTerm): - cfg: LiftingCommandCfg +class PositionCommand(CommandTerm): + cfg: PositionCommandCfg - def __init__(self, cfg: LiftingCommandCfg, env: ManagerBasedRlEnv): + def __init__(self, cfg: PositionCommandCfg, env: ManagerBasedRlEnv): super().__init__(cfg, env) - self.object: Entity = env.scene[cfg.asset_name] self.target_pos = torch.zeros(self.num_envs, 3, device=self.device) - self.episode_success = torch.zeros(self.num_envs, device=self.device) - - self.metrics["object_height"] = torch.zeros(self.num_envs, device=self.device) - self.metrics["position_error"] = torch.zeros(self.num_envs, device=self.device) - self.metrics["at_goal"] = torch.zeros(self.num_envs, device=self.device) - self.metrics["episode_success"] = torch.zeros(self.num_envs, device=self.device) @property def command(self) -> torch.Tensor: return self.target_pos def _update_metrics(self) -> None: - object_pos_w = self.object.data.root_link_pos_w - object_height = object_pos_w[:, 2] - position_error = torch.norm(self.target_pos - object_pos_w, dim=-1) - at_goal = (position_error < self.cfg.success_threshold).float() - - # Latch episode_success to 1 once goal is reached. - self.episode_success = torch.maximum(self.episode_success, at_goal) - - self.metrics["object_height"] = object_height - self.metrics["position_error"] = position_error - self.metrics["at_goal"] = at_goal - self.metrics["episode_success"] = self.episode_success - - def compute_success(self) -> torch.Tensor: - position_error = self.metrics["position_error"] - return position_error < self.cfg.success_threshold + pass def _resample_command(self, env_ids: torch.Tensor) -> None: n = len(env_ids) - # Reset episode success for resampled envs. - self.episode_success[env_ids] = 0.0 - # Set target position based on difficulty mode. if self.cfg.difficulty == "fixed": target_pos = torch.tensor( @@ -76,28 +48,6 @@ def _resample_command(self, env_ids: torch.Tensor) -> None: target_pos = sample_uniform(lower, upper, (n, 3), device=self.device) self.target_pos[env_ids] = target_pos + self._env.scene.env_origins[env_ids] - # Reset object to new position. - if self.cfg.object_pose_range is not None: - r = self.cfg.object_pose_range - lower = torch.tensor([r.x[0], r.y[0], r.z[0]], device=self.device) - upper = torch.tensor([r.x[1], r.y[1], r.z[1]], device=self.device) - pos = sample_uniform(lower, upper, (n, 3), device=self.device) - pos = pos + self._env.scene.env_origins[env_ids] - - # Sample orientation (yaw only, keep upright). - yaw = sample_uniform(r.yaw[0], r.yaw[1], (n,), device=self.device) - quat = quat_from_euler_xyz( - torch.zeros(n, device=self.device), # roll - torch.zeros(n, device=self.device), # pitch - yaw, - ) - pose = torch.cat([pos, quat], dim=-1) - - velocity = torch.zeros(n, 6, device=self.device) - - self.object.write_root_link_pose_to_sim(pose, env_ids=env_ids) - self.object.write_root_link_velocity_to_sim(velocity, env_ids=env_ids) - def _update_command(self) -> None: pass @@ -116,10 +66,8 @@ def _debug_vis_impl(self, visualizer: DebugVisualizer) -> None: @dataclass(kw_only=True) -class LiftingCommandCfg(CommandTermCfg): - asset_name: str - class_type: type[CommandTerm] = LiftingCommand - success_threshold: float = 0.05 +class PositionCommandCfg(CommandTermCfg): + class_type: type[CommandTerm] = PositionCommand difficulty: Literal["fixed", "dynamic"] = "fixed" @dataclass @@ -135,19 +83,6 @@ class TargetPositionRangeCfg: default_factory=TargetPositionRangeCfg ) - @dataclass - class ObjectPoseRangeCfg: - """Configuration for object pose sampling when resampling commands.""" - - x: tuple[float, float] = (0.3, 0.35) - y: tuple[float, float] = (-0.1, 0.1) - z: tuple[float, float] = (0.02, 0.05) - yaw: tuple[float, float] = (-math.pi, math.pi) - - object_pose_range: ObjectPoseRangeCfg | None = field( - default_factory=ObjectPoseRangeCfg - ) - @dataclass class VizCfg: target_color: tuple[float, float, float, float] = (1.0, 0.5, 0.0, 0.3) diff --git a/src/mjlab/tasks/manipulation/mdp/observations.py b/src/mjlab/tasks/manipulation/mdp/observations.py index 1d1bbd496..9d2914035 100644 --- a/src/mjlab/tasks/manipulation/mdp/observations.py +++ b/src/mjlab/tasks/manipulation/mdp/observations.py @@ -6,7 +6,7 @@ from mjlab.entity import Entity from mjlab.managers.scene_entity_config import SceneEntityCfg -from mjlab.tasks.manipulation.mdp.commands import LiftingCommand +from mjlab.tasks.manipulation.mdp.commands import PositionCommand from mjlab.utils.lab_api.math import quat_apply, quat_inv if TYPE_CHECKING: @@ -38,9 +38,9 @@ def object_position_error( ) -> torch.Tensor: """3D position error between object and target position (target - current).""" command = env.command_manager.get_term(command_name) - if not isinstance(command, LiftingCommand): + if not isinstance(command, PositionCommand): raise TypeError( - f"Command '{command_name}' must be a LiftingCommand, got {type(command)}" + f"Command '{command_name}' must be a PositionCommand, got {type(command)}" ) obj: Entity = env.scene[object_name] obj_pos_w = obj.data.root_link_pos_w diff --git a/src/mjlab/tasks/manipulation/mdp/rewards.py b/src/mjlab/tasks/manipulation/mdp/rewards.py index b5792cf94..1a2b6c9d3 100644 --- a/src/mjlab/tasks/manipulation/mdp/rewards.py +++ b/src/mjlab/tasks/manipulation/mdp/rewards.py @@ -5,8 +5,9 @@ import torch from mjlab.entity import Entity +from mjlab.managers.manager_term_config import RewardTermCfg from mjlab.managers.scene_entity_config import SceneEntityCfg -from mjlab.tasks.manipulation.mdp.commands import LiftingCommand +from mjlab.tasks.manipulation.mdp.commands import PositionCommand if TYPE_CHECKING: from mjlab.envs import ManagerBasedRlEnv @@ -29,7 +30,7 @@ def staged_position_reward( """ robot: Entity = env.scene[asset_cfg.name] obj: Entity = env.scene[object_name] - command = cast(LiftingCommand, env.command_manager.get_term(command_name)) + command = cast(PositionCommand, env.command_manager.get_term(command_name)) ee_pos_w = robot.data.site_pos_w[:, asset_cfg.site_ids].squeeze(1) obj_pos_w = obj.data.root_link_pos_w reach_error = torch.sum(torch.square(ee_pos_w - obj_pos_w), dim=-1) @@ -39,18 +40,39 @@ def staged_position_reward( return reaching * (1.0 + bringing) -def bring_object_reward( - env: ManagerBasedRlEnv, - command_name: str, - object_name: str, - std: float, -) -> torch.Tensor: - obj: Entity = env.scene[object_name] - command = cast(LiftingCommand, env.command_manager.get_term(command_name)) - position_error = torch.sum( - torch.square(command.target_pos - obj.data.root_link_pos_w), dim=-1 - ) - return torch.exp(-position_error / std**2) +class bring_object_reward: + """Penalize deviation from target swing height, evaluated at landing.""" + + def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv): + self.episode_success = torch.zeros(env.num_envs, device=env.device) + + def __call__( + self, + env: ManagerBasedRlEnv, + command_name: str, + object_name: str, + std: float, + success_threshold: float = 0.05, + ) -> torch.Tensor: + obj: Entity = env.scene[object_name] + command = cast(PositionCommand, env.command_manager.get_term(command_name)) + + obj_pos_w = obj.data.root_link_pos_w + position_error = command.target_pos - obj_pos_w + position_error_sq = torch.sum(torch.square(position_error), dim=-1) + + position_error_norm = torch.norm(position_error, dim=-1) + at_goal = (position_error_norm < success_threshold).float() + + self.episode_success[env.reset_buf] = 0.0 + self.episode_success = torch.maximum(self.episode_success, at_goal) + + env.extras["log"]["Metrics/object_height"] = obj_pos_w[:, 2] + env.extras["log"]["Metrics/position_error"] = position_error_norm + env.extras["log"]["Metrics/at_goal"] = at_goal + env.extras["log"]["Metrics/episode_success"] = self.episode_success + + return torch.exp(-position_error_sq / std**2) def joint_velocity_hinge_penalty(