Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/mjlab/tasks/manipulation/config/yam/env_cfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from mjlab.envs.mdp.actions import JointPositionActionCfg
from mjlab.sensor import ContactSensorCfg
from mjlab.tasks.manipulation.lift_cube_env_cfg import make_lift_cube_env_cfg
from mjlab.tasks.manipulation.mdp import LiftingCommandCfg
from mjlab.tasks.manipulation.mdp import PositionCommandCfg


def get_cube_spec(cube_size: float = 0.02, mass: float = 0.05) -> mujoco.MjSpec:
Expand All @@ -33,7 +33,10 @@ def yam_lift_cube_env_cfg(

cfg.scene.entities = {
"robot": get_yam_robot_cfg(),
"cube": EntityCfg(spec_fn=get_cube_spec),
"cube": EntityCfg(
spec_fn=get_cube_spec,
init_state=EntityCfg.InitialStateCfg(pos=(0.3, 0.0, 0.035)),
),
}

joint_pos_action = cfg.actions["joint_pos"]
Expand All @@ -42,7 +45,7 @@ def yam_lift_cube_env_cfg(

assert cfg.commands is not None
lift_command = cfg.commands["lift_height"]
assert isinstance(lift_command, LiftingCommandCfg)
assert isinstance(lift_command, PositionCommandCfg)

cfg.observations["policy"].terms["ee_to_cube"].params["asset_cfg"].site_names = (
"grasp_site",
Expand Down
30 changes: 17 additions & 13 deletions src/mjlab/tasks/manipulation/lift_cube_env_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from mjlab.sensor import ContactMatch, ContactSensorCfg
from mjlab.sim import MujocoCfg, SimulationCfg
from mjlab.tasks.manipulation import mdp as manipulation_mdp
from mjlab.tasks.manipulation.mdp import LiftingCommandCfg
from mjlab.tasks.manipulation.mdp import PositionCommandCfg
from mjlab.tasks.velocity import mdp
from mjlab.terrains import TerrainImporterCfg
from mjlab.utils.noise import UniformNoiseCfg as Unoise
Expand Down Expand Up @@ -70,28 +70,32 @@ def make_lift_cube_env_cfg() -> ManagerBasedRlEnvCfg:
}

commands: dict[str, CommandTermCfg] = {
"lift_height": LiftingCommandCfg(
asset_name="cube",
"lift_height": PositionCommandCfg(
resampling_time_range=(8.0, 12.0),
debug_vis=True,
difficulty="dynamic",
object_pose_range=LiftingCommandCfg.ObjectPoseRangeCfg(
x=(0.2, 0.4),
y=(-0.2, 0.2),
z=(0.02, 0.05),
yaw=(-3.14, 3.14),
),
)
}

events = {
# For positioning the base of the robot at env_origins.
"reset_base": EventTermCfg(
func=mdp.reset_root_state_uniform,
# For positioning per envs.
"reset_scene": EventTermCfg(
func=mdp.reset_scene_to_default,
mode="startup",
),
"reset_object": EventTermCfg(
func=mdp.reset_root_state_uniform,
mode="interval",
interval_range_s=(8.0, 12.0),
params={
"pose_range": {},
"pose_range": {
"x": (-0.1, 0.1),
"y": (-0.2, 0.2),
"z": (-0.015, 0.015),
"yaw": (-3.14, 3.14),
},
"velocity_range": {},
"asset_cfg": SceneEntityCfg("cube"),
},
),
"reset_robot_joints": EventTermCfg(
Expand Down
77 changes: 6 additions & 71 deletions src/mjlab/tasks/manipulation/mdp/commands.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from __future__ import annotations

import math
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Literal

import torch

from mjlab.entity import Entity
from mjlab.managers.command_manager import CommandTerm
from mjlab.managers.manager_term_config import CommandTermCfg
from mjlab.utils.lab_api.math import (
quat_from_euler_xyz,
sample_uniform,
)

Expand All @@ -19,49 +16,24 @@
from mjlab.viewer.debug_visualizer import DebugVisualizer


class LiftingCommand(CommandTerm):
cfg: LiftingCommandCfg
class PositionCommand(CommandTerm):
cfg: PositionCommandCfg

def __init__(self, cfg: LiftingCommandCfg, env: ManagerBasedRlEnv):
def __init__(self, cfg: PositionCommandCfg, env: ManagerBasedRlEnv):
super().__init__(cfg, env)

self.object: Entity = env.scene[cfg.asset_name]
self.target_pos = torch.zeros(self.num_envs, 3, device=self.device)
self.episode_success = torch.zeros(self.num_envs, device=self.device)

self.metrics["object_height"] = torch.zeros(self.num_envs, device=self.device)
self.metrics["position_error"] = torch.zeros(self.num_envs, device=self.device)
self.metrics["at_goal"] = torch.zeros(self.num_envs, device=self.device)
self.metrics["episode_success"] = torch.zeros(self.num_envs, device=self.device)

@property
def command(self) -> torch.Tensor:
return self.target_pos

def _update_metrics(self) -> None:
object_pos_w = self.object.data.root_link_pos_w
object_height = object_pos_w[:, 2]
position_error = torch.norm(self.target_pos - object_pos_w, dim=-1)
at_goal = (position_error < self.cfg.success_threshold).float()

# Latch episode_success to 1 once goal is reached.
self.episode_success = torch.maximum(self.episode_success, at_goal)

self.metrics["object_height"] = object_height
self.metrics["position_error"] = position_error
self.metrics["at_goal"] = at_goal
self.metrics["episode_success"] = self.episode_success

def compute_success(self) -> torch.Tensor:
position_error = self.metrics["position_error"]
return position_error < self.cfg.success_threshold
pass

def _resample_command(self, env_ids: torch.Tensor) -> None:
n = len(env_ids)

# Reset episode success for resampled envs.
self.episode_success[env_ids] = 0.0

# Set target position based on difficulty mode.
if self.cfg.difficulty == "fixed":
target_pos = torch.tensor(
Expand All @@ -76,28 +48,6 @@ def _resample_command(self, env_ids: torch.Tensor) -> None:
target_pos = sample_uniform(lower, upper, (n, 3), device=self.device)
self.target_pos[env_ids] = target_pos + self._env.scene.env_origins[env_ids]

# Reset object to new position.
if self.cfg.object_pose_range is not None:
r = self.cfg.object_pose_range
lower = torch.tensor([r.x[0], r.y[0], r.z[0]], device=self.device)
upper = torch.tensor([r.x[1], r.y[1], r.z[1]], device=self.device)
pos = sample_uniform(lower, upper, (n, 3), device=self.device)
pos = pos + self._env.scene.env_origins[env_ids]

# Sample orientation (yaw only, keep upright).
yaw = sample_uniform(r.yaw[0], r.yaw[1], (n,), device=self.device)
quat = quat_from_euler_xyz(
torch.zeros(n, device=self.device), # roll
torch.zeros(n, device=self.device), # pitch
yaw,
)
pose = torch.cat([pos, quat], dim=-1)

velocity = torch.zeros(n, 6, device=self.device)

self.object.write_root_link_pose_to_sim(pose, env_ids=env_ids)
self.object.write_root_link_velocity_to_sim(velocity, env_ids=env_ids)

def _update_command(self) -> None:
pass

Expand All @@ -116,10 +66,8 @@ def _debug_vis_impl(self, visualizer: DebugVisualizer) -> None:


@dataclass(kw_only=True)
class LiftingCommandCfg(CommandTermCfg):
asset_name: str
class_type: type[CommandTerm] = LiftingCommand
success_threshold: float = 0.05
class PositionCommandCfg(CommandTermCfg):
class_type: type[CommandTerm] = PositionCommand
difficulty: Literal["fixed", "dynamic"] = "fixed"

@dataclass
Expand All @@ -135,19 +83,6 @@ class TargetPositionRangeCfg:
default_factory=TargetPositionRangeCfg
)

@dataclass
class ObjectPoseRangeCfg:
"""Configuration for object pose sampling when resampling commands."""

x: tuple[float, float] = (0.3, 0.35)
y: tuple[float, float] = (-0.1, 0.1)
z: tuple[float, float] = (0.02, 0.05)
yaw: tuple[float, float] = (-math.pi, math.pi)

object_pose_range: ObjectPoseRangeCfg | None = field(
default_factory=ObjectPoseRangeCfg
)

@dataclass
class VizCfg:
target_color: tuple[float, float, float, float] = (1.0, 0.5, 0.0, 0.3)
Expand Down
6 changes: 3 additions & 3 deletions src/mjlab/tasks/manipulation/mdp/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from mjlab.entity import Entity
from mjlab.managers.scene_entity_config import SceneEntityCfg
from mjlab.tasks.manipulation.mdp.commands import LiftingCommand
from mjlab.tasks.manipulation.mdp.commands import PositionCommand
from mjlab.utils.lab_api.math import quat_apply, quat_inv

if TYPE_CHECKING:
Expand Down Expand Up @@ -38,9 +38,9 @@ def object_position_error(
) -> torch.Tensor:
"""3D position error between object and target position (target - current)."""
command = env.command_manager.get_term(command_name)
if not isinstance(command, LiftingCommand):
if not isinstance(command, PositionCommand):
raise TypeError(
f"Command '{command_name}' must be a LiftingCommand, got {type(command)}"
f"Command '{command_name}' must be a PositionCommand, got {type(command)}"
)
obj: Entity = env.scene[object_name]
obj_pos_w = obj.data.root_link_pos_w
Expand Down
50 changes: 36 additions & 14 deletions src/mjlab/tasks/manipulation/mdp/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import torch

from mjlab.entity import Entity
from mjlab.managers.manager_term_config import RewardTermCfg
from mjlab.managers.scene_entity_config import SceneEntityCfg
from mjlab.tasks.manipulation.mdp.commands import LiftingCommand
from mjlab.tasks.manipulation.mdp.commands import PositionCommand

if TYPE_CHECKING:
from mjlab.envs import ManagerBasedRlEnv
Expand All @@ -29,7 +30,7 @@ def staged_position_reward(
"""
robot: Entity = env.scene[asset_cfg.name]
obj: Entity = env.scene[object_name]
command = cast(LiftingCommand, env.command_manager.get_term(command_name))
command = cast(PositionCommand, env.command_manager.get_term(command_name))
ee_pos_w = robot.data.site_pos_w[:, asset_cfg.site_ids].squeeze(1)
obj_pos_w = obj.data.root_link_pos_w
reach_error = torch.sum(torch.square(ee_pos_w - obj_pos_w), dim=-1)
Expand All @@ -39,18 +40,39 @@ def staged_position_reward(
return reaching * (1.0 + bringing)


def bring_object_reward(
env: ManagerBasedRlEnv,
command_name: str,
object_name: str,
std: float,
) -> torch.Tensor:
obj: Entity = env.scene[object_name]
command = cast(LiftingCommand, env.command_manager.get_term(command_name))
position_error = torch.sum(
torch.square(command.target_pos - obj.data.root_link_pos_w), dim=-1
)
return torch.exp(-position_error / std**2)
class bring_object_reward:
"""Penalize deviation from target swing height, evaluated at landing."""

def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRlEnv):
self.episode_success = torch.zeros(env.num_envs, device=env.device)

def __call__(
self,
env: ManagerBasedRlEnv,
command_name: str,
object_name: str,
std: float,
success_threshold: float = 0.05,
) -> torch.Tensor:
obj: Entity = env.scene[object_name]
command = cast(PositionCommand, env.command_manager.get_term(command_name))

obj_pos_w = obj.data.root_link_pos_w
position_error = command.target_pos - obj_pos_w
position_error_sq = torch.sum(torch.square(position_error), dim=-1)

position_error_norm = torch.norm(position_error, dim=-1)
at_goal = (position_error_norm < success_threshold).float()

self.episode_success[env.reset_buf] = 0.0
self.episode_success = torch.maximum(self.episode_success, at_goal)

env.extras["log"]["Metrics/object_height"] = obj_pos_w[:, 2]
env.extras["log"]["Metrics/position_error"] = position_error_norm
env.extras["log"]["Metrics/at_goal"] = at_goal
env.extras["log"]["Metrics/episode_success"] = self.episode_success

return torch.exp(-position_error_sq / std**2)


def joint_velocity_hinge_penalty(
Expand Down