Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions .github/workflows/pyright.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
name: Tests

on:
push:
branches: [main]
paths-ignore:
- '**.md'
- 'Makefile'
- 'LICENSE'
pull_request:
branches: [main]
paths-ignore:
- '**.md'
- 'Makefile'
- 'LICENSE'

env:
UV_FROZEN: "1"

jobs:
pyright:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Setup uv
uses: astral-sh/setup-uv@v6
with:
python-version: ${{ matrix.python-version }}
enable-cache: true
version: "0.8.13"
- name: Install dependencies
run: uv sync --all-extras --all-packages --group dev
- name: Test with python ${{ matrix.python-version }}
run: uv run pyright
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ videos/
__pycache__/
MUJOCO_LOG.TXT
debug.py
typings/
.vscode/
*.ipynb_checkpoints/
*.ipynb
motions/
*_rerun*
*.mp4
artifacts/
.venv/
.venv/
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ format:

.PHONY: test
test:
uv run pytest
uv run pytest
uv run pyright
18 changes: 10 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,18 @@ dependencies = [
"mujoco",
"trimesh",
"viser",
"wandb",
"moviepy",
"tensordict",
"rsl-rl-lib",
]

[dependency-groups]
dev = [
{include-group = "lint"},
{include-group = "test"},
"pre-commit",
"ty",
"pyright>=1.1.386",
]
lint = [
"ruff"
Expand All @@ -68,7 +72,6 @@ rl = [
"moviepy >=1.0.0",
"rsl-rl-lib",
"tensorboard",
"wandb",
]

[[tool.uv.index]]
Expand Down Expand Up @@ -99,14 +102,13 @@ torch = [{ index = "pytorch-cu128", extra = "cu12" }]
indent-width = 2
exclude = [
"src/mjlab/third_party",
"typings",
]

[tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "I", "B"]

[tool.ty.rules]
unresolved-attribute = "ignore"

[tool.ty.src]
include = ["src", "tests"]
exclude = ["src/mjlab/third_party"]
[tool.pyright]
pythonVersion = "3.10"
ignore = ["./typings", "./src/mjlab/third_party"]
stubPath = "typings"
2 changes: 1 addition & 1 deletion scripts/tracking/data/compare_npz.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from mjlab import Robot
from mjlab.asset_zoo.robots.unitree_g1.g1_constants import G1_ROBOT_CFG
from mjlab.entities import Robot

robot = Robot(G1_ROBOT_CFG)

Expand Down
9 changes: 6 additions & 3 deletions scripts/tracking/data/csv_to_npz.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import replace
from typing import Any

import numpy as np
import torch
Expand Down Expand Up @@ -111,10 +112,12 @@ def _slerp(
"""Spherical linear interpolation between two quaternions."""
slerped_quats = torch.zeros_like(a)
for i in range(a.shape[0]):
slerped_quats[i] = quat_slerp(a[i], b[i], blend[i])
slerped_quats[i] = quat_slerp(a[i], b[i], float(blend[i]))
return slerped_quats

def _compute_frame_blend(self, times: torch.Tensor) -> torch.Tensor:
def _compute_frame_blend(
self, times: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Computes the frame blend for the motion."""
phase = times / self.duration
index_0 = (phase * (self.input_frames - 1)).floor().long()
Expand Down Expand Up @@ -206,7 +209,7 @@ def run_sim(
robot_joint_indexes = robot.find_joints(joint_names, preserve_order=True)[0]

# ------- data logger -------------------------------------------------------
log = {
log: dict[str, Any] = {
"fps": [output_fps],
"joint_pos": [],
"joint_vel": [],
Expand Down
16 changes: 9 additions & 7 deletions scripts/tracking/rl/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tyro
import wandb

from mjlab.envs.manager_based_rl_env import ManagerBasedRlEnv
from mjlab.rl import RslRlOnPolicyRunnerCfg, RslRlVecEnvWrapper
from mjlab.tasks.tracking.rl import MotionTrackingOnPolicyRunner
from mjlab.tasks.tracking.tracking_env_cfg import TrackingEnvCfg
Expand Down Expand Up @@ -66,15 +67,16 @@ def main(
log_dir = resume_path.parent

env = gym.make(task, cfg=env_cfg, render_mode="rgb_array" if video else None)
assert isinstance(env, ManagerBasedRlEnv)
if video:
video_kwargs = {
"video_folder": log_dir / "videos" / "play",
"step_trigger": lambda step: step == 0,
"video_length": video_length,
"disable_logger": True,
}
print("[INFO] Recording videos during training.")
env = gym.wrappers.RecordVideo(env, **video_kwargs)
env = gym.wrappers.RecordVideo(
env,
video_folder=str(log_dir / "videos" / "play"),
step_trigger=lambda step: step == 0,
video_length=video_length,
disable_logger=True,
)

env = RslRlVecEnvWrapper(env, clip_actions=agent_cfg.clip_actions)

Expand Down
15 changes: 7 additions & 8 deletions scripts/tracking/rl/train.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Script to train RL agent with RSL-RL."""

import os
import pathlib
from dataclasses import asdict
from datetime import datetime
from pathlib import Path
from typing import cast

import gymnasium as gym
import tyro
import wandb

from mjlab.rl import RslRlOnPolicyRunnerCfg, RslRlVecEnvWrapper
from mjlab.tasks.tracking.rl import MotionTrackingOnPolicyRunner
Expand All @@ -34,17 +35,14 @@ def main(
):
configure_torch_backends()

env_cfg = cast(TrackingEnvCfg, load_cfg_from_registry(task, "env_cfg_entry_point"))
agent_cfg = cast(
RslRlOnPolicyRunnerCfg, load_cfg_from_registry(task, "rl_cfg_entry_point")
)
env_cfg = load_cfg_from_registry(task, "env_cfg_entry_point")
agent_cfg = load_cfg_from_registry(task, "rl_cfg_entry_point")
assert isinstance(env_cfg, TrackingEnvCfg)
assert isinstance(agent_cfg, RslRlOnPolicyRunnerCfg)

# Check if the registry name includes alias, if not, append ":latest".
if ":" not in registry_name:
registry_name += ":latest"
import pathlib

import wandb

api = wandb.Api()
artifact = api.artifact(registry_name)
Expand Down Expand Up @@ -73,6 +71,7 @@ def main(
env = gym.make(task, cfg=env_cfg)

# Save resume path before creating a new log_dir.
resume_path = None
if agent_cfg.resume:
resume_path = get_checkpoint_path(
log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint
Expand Down
10 changes: 7 additions & 3 deletions scripts/velocity/rl/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import tyro
from rsl_rl.runners import OnPolicyRunner

from mjlab.envs.manager_based_env_config import ManagerBasedEnvCfg
from mjlab.rl import RslRlVecEnvWrapper
from mjlab.rl.config import RslRlOnPolicyRunnerCfg
from mjlab.third_party.isaaclab.isaaclab_tasks.utils.parse_cfg import (
load_cfg_from_registry,
)
Expand All @@ -29,7 +31,7 @@

def main(
task: str,
wandb_run_path: str,
wandb_run_path: Path,
motion_file: str | None = None,
num_envs: int | None = None,
device: str | None = None,
Expand All @@ -41,6 +43,8 @@ def main(
):
env_cfg = load_cfg_from_registry(task, "env_cfg_entry_point")
agent_cfg = load_cfg_from_registry(task, "rl_cfg_entry_point")
assert isinstance(env_cfg, ManagerBasedEnvCfg)
assert isinstance(agent_cfg, RslRlOnPolicyRunnerCfg)

env_cfg.sim.num_envs = num_envs or env_cfg.sim.num_envs
env_cfg.sim.device = device or env_cfg.sim.device
Expand Down Expand Up @@ -71,9 +75,9 @@ def main(
env = RslRlVecEnvWrapper(env, clip_actions=agent_cfg.clip_actions)

runner = OnPolicyRunner(
env, asdict(agent_cfg), log_dir=log_dir, device=agent_cfg.device
env, asdict(agent_cfg), log_dir=str(log_dir), device=agent_cfg.device
)
runner.load(resume_path, map_location=agent_cfg.device)
runner.load(str(resume_path), map_location=agent_cfg.device)

policy = runner.get_inference_policy(device=env.device)

Expand Down
19 changes: 12 additions & 7 deletions scripts/velocity/rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import tyro
from rsl_rl.runners import OnPolicyRunner

from mjlab.envs.manager_based_rl_env_config import ManagerBasedRlEnvCfg
from mjlab.rl import RslRlVecEnvWrapper
from mjlab.rl.config import RslRlOnPolicyRunnerCfg
from mjlab.third_party.isaaclab.isaaclab_tasks.utils.parse_cfg import (
load_cfg_from_registry,
)
Expand All @@ -37,6 +39,8 @@ def main(
):
env_cfg = load_cfg_from_registry(task, "env_cfg_entry_point")
agent_cfg = load_cfg_from_registry(task, "rl_cfg_entry_point")
assert isinstance(env_cfg, ManagerBasedRlEnvCfg)
assert isinstance(agent_cfg, RslRlOnPolicyRunnerCfg)

env_cfg.sim.num_envs = num_envs or env_cfg.sim.num_envs
agent_cfg.max_iterations = max_iterations or agent_cfg.max_iterations
Expand All @@ -59,10 +63,11 @@ def main(
env = gym.make(task, cfg=env_cfg)

# Save resume path before creating a new log_dir.
if agent_cfg.resume:
resume_path = get_checkpoint_path(
log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint
)
resume_path = (
get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint)
if agent_cfg.resume
else None
)

# Wrap for video recording.
if video:
Expand All @@ -80,14 +85,14 @@ def main(
runner = OnPolicyRunner(
env,
asdict(agent_cfg),
log_dir=log_dir,
log_dir=str(log_dir),
device=agent_cfg.device,
)
runner.add_git_repo_to_log(__file__)

if agent_cfg.resume:
if resume_path is not None:
print(f"[INFO]: Loading model checkpoint from: {resume_path}")
runner.load(resume_path)
runner.load(str(resume_path))

dump_yaml(log_dir / "params" / "env.yaml", asdict(env_cfg))
dump_yaml(log_dir / "params" / "agent.yaml", asdict(agent_cfg))
Expand Down
4 changes: 2 additions & 2 deletions src/mjlab/entities/robots/robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, robot_cfg: RobotCfg):
if actuator.trntype != mujoco.mjtTrn.mjTRN_JOINT:
continue
self.actuator_to_joint[actuator.name] = actuator.target
self.jnt_actuators = list(self.actuator_to_joint.values())
self.joint_actuators = list(self.actuator_to_joint.values())

# Sensors.
self._sensor_names = [s.name for s in self._spec.sensors]
Expand Down Expand Up @@ -166,7 +166,7 @@ def initialize(
)

local_act_ids = resolve_matching_names(
self._actuator_names, self.jnt_actuators, True
self._actuator_names, self.joint_actuators, True
)[0]
self._actuator_ids_global = torch.tensor(
[indexing.actuator_local2global[lid] for lid in local_act_ids],
Expand Down
4 changes: 2 additions & 2 deletions src/mjlab/envs/manager_based_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def reset(
self,
*,
seed: int | None = None,
env_ids: torch.Tensor | slice | None = None,
env_ids: torch.Tensor | None = None,
options: dict[str, Any] | None = None,
) -> tuple[types.VecEnvObs, dict]:
del options # Unused.
Expand Down Expand Up @@ -140,7 +140,7 @@ def update_visualizers(self, scn) -> None:

# Private methods.

def _reset_idx(self, env_ids: torch.Tensor | slice | None = None) -> None:
def _reset_idx(self, env_ids: torch.Tensor | None = None) -> None:
self.scene.reset(env_ids)
if "reset" in self.event_manager.available_modes:
env_step_count = self._sim_step_counter // self.cfg.decimation
Expand Down
6 changes: 4 additions & 2 deletions src/mjlab/envs/manager_based_rl_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def load_managers(self) -> None:
if "startup" in self.event_manager.available_modes:
self.event_manager.apply(mode="startup")

def step(self, action: torch.Tensor) -> types.VecEnvStepReturn:
def step(self, action: torch.Tensor) -> types.VecEnvStepReturn: # pyright: ignore[reportIncompatibleMethodOverride]
self.action_manager.process_action(action.to(self.device))

for _ in range(self.cfg.decimation):
Expand Down Expand Up @@ -138,10 +138,12 @@ def _configure_gym_env_spaces(self) -> None:
has_concatenated_obs = self.observation_manager.group_obs_concatenate[group_name]
group_dim = self.observation_manager.group_obs_dim[group_name]
if has_concatenated_obs:
assert isinstance(group_dim, tuple)
self.single_observation_space[group_name] = gym.spaces.Box(
low=-math.inf, high=math.inf, shape=group_dim
)
else:
assert not isinstance(group_dim, tuple)
group_term_cfgs = self.observation_manager._group_obs_term_cfgs[group_name]
for term_name, term_dim, _term_cfg in zip(
group_term_names, group_dim, group_term_cfgs, strict=False
Expand All @@ -162,7 +164,7 @@ def _configure_gym_env_spaces(self) -> None:
self.single_action_space, self.num_envs
)

def _reset_idx(self, env_ids: torch.Tensor | slice | None = None) -> None:
def _reset_idx(self, env_ids: torch.Tensor | None = None) -> None:
self.curriculum_manager.compute(env_ids=env_ids)
# Reset the internal buffers of the scene elements.
self.scene.reset(env_ids)
Expand Down
Loading