diff --git a/notebooks/create_new_task.ipynb b/notebooks/create_new_task.ipynb
new file mode 100644
index 00000000..90ce0112
--- /dev/null
+++ b/notebooks/create_new_task.ipynb
@@ -0,0 +1,865 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "PO76KS1i-MwA"
+ },
+ "source": [
+ "# **š¤ CartPole Tutorial with MJLab**\n",
+ "\n",
+ "This notebook demonstrates how to create a custom reinforcement learning task using MJLab. We'll build a CartPole environment from scratch, including:\n",
+ "\n",
+ "1. **Robot Definition** - Define the CartPole model in MuJoCo XML\n",
+ "2. **Task Configuration** - Set up observations, actions, rewards, and terminations\n",
+ "3. **Training** - Train a policy using PPO\n",
+ "4. **Evaluation** - Visualize the simulation with the trained policy\n",
+ "\n",
+ "> **Note**: This tutorial is created based on the official MJLab documentation [\"Create a New Task\"](https://github.com/mujocolab/mjlab/blob/main/docs/create_new_task.md)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "3ywZTgfR3C_w"
+ },
+ "source": [
+ "## **š¦ Setup and Installation**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": true,
+ "id": "dtLMJHzy3Nee"
+ },
+ "outputs": [],
+ "source": [
+ "# Install mujoco-warp\n",
+ "!pip install git+https://github.com/google-deepmind/mujoco_warp@9fc294d86955a303619a254cefae809a41adb274 -q\n",
+ "\n",
+ "# Clone the mjlab repository\n",
+ "!if [ ! -d \"mjlab\" ]; then git clone -q https://github.com/mujocolab/mjlab.git; fi\n",
+ "%cd /content/mjlab\n",
+ "\n",
+ "# Install mjlab in editable mode\n",
+ "!pip install -e . -q\n",
+ "\n",
+ "print(\"ā Installation complete!\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "SSf2943z3b0s"
+ },
+ "source": [
+ "### **š WandB Setup**\n",
+ "\n",
+ "Configure Weights & Biases for experiment tracking. Add your WandB API key to Colab Secrets:\n",
+ "- `WANDB_API_KEY`: from [wandb.ai/authorize](https://wandb.ai/authorize)\n",
+ "- `WANDB_ENTITY`: your wandb entity name"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "KC9ywCnm3dGg"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "from google.colab import userdata\n",
+ "\n",
+ "try:\n",
+ " # Set this to use wandb logger\n",
+ " os.environ[\"WANDB_API_KEY\"] = userdata.get(\"WANDB_API_KEY\")\n",
+ " os.environ[\"WANDB_ENTITY\"] = userdata.get(\"WANDB_ENTITY\")\n",
+ "\n",
+ " print(\"ā WandB configured successfully!\")\n",
+ "except (AttributeError, KeyError):\n",
+ " # Set this to disable wandb logger\n",
+ " os.environ['WANDB_MODE'] = 'disabled'\n",
+ "\n",
+ " print(\"ā WandB secrets not found. Training will proceed without WandB logging.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "mispfmy73lmq"
+ },
+ "source": [
+ "---\n",
+ "\n",
+ "## **š¤ Step 1: Define the Robot**\n",
+ "\n",
+ "We'll create a simple CartPole robot with:\n",
+ "- A sliding cart (1 DOF)\n",
+ "- A hinged pole (1 DOF)\n",
+ "- A velocity actuator to control the cart"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-FvJYPWD3scd"
+ },
+ "source": [
+ "### **š Structure Directories**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "OP-yET-R3ofN"
+ },
+ "outputs": [],
+ "source": [
+ "# Create the cartpole robot directory structure\n",
+ "!mkdir -p /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/\n",
+ "!mkdir -p /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/xmls\n",
+ "\n",
+ "print(\"ā Directory structure created\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "MRyN1Pok3u25"
+ },
+ "source": [
+ "### **š Create MuJoCo XML Model**\n",
+ "\n",
+ "This XML defines the CartPole physics:\n",
+ "- **Ground plane** for visualization\n",
+ "- **Cart body** with a sliding joint (±2m range)\n",
+ "- **Pole body** with a hinge joint (±90° range)\n",
+ "- **Velocity actuator** for cart control"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "gWGyFX5V3yWc"
+ },
+ "outputs": [],
+ "source": [
+ "%%writefile /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/xmls/cartpole.xml\n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "MpYCG9jI31dZ"
+ },
+ "source": [
+ "### **āļø Create Robot Configuration**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "HDhiyDTn4AVa"
+ },
+ "outputs": [],
+ "source": [
+ "%%writefile /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/cartpole_constants.py\n",
+ "from pathlib import Path\n",
+ "import mujoco\n",
+ "\n",
+ "from mjlab import MJLAB_SRC_PATH\n",
+ "from mjlab.entity import Entity, EntityCfg, EntityArticulationInfoCfg\n",
+ "from mjlab.actuator import XmlVelocityActuatorCfg\n",
+ "\n",
+ "CARTPOLE_XML: Path = (\n",
+ " MJLAB_SRC_PATH / \"asset_zoo\" / \"robots\" / \"cartpole\" / \"xmls\" / \"cartpole.xml\"\n",
+ ")\n",
+ "assert CARTPOLE_XML.exists(), f\"XML not found: {CARTPOLE_XML}\"\n",
+ "\n",
+ "def get_spec() -> mujoco.MjSpec:\n",
+ " return mujoco.MjSpec.from_file(str(CARTPOLE_XML))\n",
+ "\n",
+ "def get_cartpole_robot_cfg() -> EntityCfg:\n",
+ " \"\"\"Get a fresh CartPole robot configuration instance.\"\"\"\n",
+ " actuators = (\n",
+ " XmlVelocityActuatorCfg(\n",
+ " joint_names_expr=(\"slide\",),\n",
+ " ),\n",
+ " )\n",
+ " articulation = EntityArticulationInfoCfg(actuators=actuators)\n",
+ " return EntityCfg(\n",
+ " spec_fn=get_spec,\n",
+ " articulation=articulation\n",
+ " )\n",
+ "\n",
+ "# if __name__ == \"__main__\":\n",
+ "# import mujoco.viewer as viewer\n",
+ "# robot = Entity(get_cartpole_robot_cfg())\n",
+ "# viewer.launch(robot.spec.compile())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-WSaDod04FwN"
+ },
+ "outputs": [],
+ "source": [
+ "# Create __init__.py for the cartpole robot package\n",
+ "%%writefile /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/__init__.py\n",
+ "# Empty __init__.py to mark the directory as a Python package"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "W1tiBPfp_oVP"
+ },
+ "outputs": [],
+ "source": [
+ "import sys\n",
+ "\n",
+ "# Append src dir to python path\n",
+ "mjlab_src = \"/content/mjlab/src\"\n",
+ "if mjlab_src not in sys.path:\n",
+ " sys.path.insert(0, mjlab_src)\n",
+ " print(f\"ā Added {mjlab_src} to Python path\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ToWF84qC4Hfg"
+ },
+ "source": [
+ "### **ā
Verify Robot Setup**\n",
+ "\n",
+ "Let's test that the robot can be loaded correctly."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "5tVsvqzQ4J9h"
+ },
+ "outputs": [],
+ "source": [
+ "from mjlab.asset_zoo.robots.cartpole.cartpole_constants import get_cartpole_robot_cfg\n",
+ "\n",
+ "from mjlab.entity import Entity\n",
+ "\n",
+ "# Load the robot\n",
+ "robot = Entity(get_cartpole_robot_cfg())\n",
+ "model = robot.spec.compile()\n",
+ "\n",
+ "# Display robot information\n",
+ "print(\"ā CartPole robot loaded successfully!\")\n",
+ "print(f\" ⢠Degrees of Freedom (DOF): {model.nv}\")\n",
+ "print(f\" ⢠Number of Actuators: {model.nu}\")\n",
+ "print(f\" ⢠Bodies: {model.nbody}\")\n",
+ "print(f\" ⢠Joints: {model.njnt}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "e2_9dixlHON1"
+ },
+ "source": [
+ "### **š Register the Robot**\n",
+ "\n",
+ "Add the CartPole robot to the asset zoo registry."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "8qDIF__lHPcb"
+ },
+ "outputs": [],
+ "source": [
+ "# Add CartPole import to robots __init__.py\n",
+ "with open(\"/content/mjlab/src/mjlab/asset_zoo/robots/__init__.py\", \"a\") as f:\n",
+ " f.write(\"\\n# CartPole robot\\n\")\n",
+ " f.write(\"from mjlab.asset_zoo.robots.cartpole.cartpole_constants import \")\n",
+ " f.write(\"get_cartpole_robot_cfg as get_cartpole_robot_cfg\\n\")\n",
+ "\n",
+ "print(\"ā CartPole robot registered in asset zoo\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6lVD_L6PHWNm"
+ },
+ "source": [
+ "---\n",
+ "\n",
+ "## **šÆ Step 2: Define the Task (MDP)**\n",
+ "\n",
+ "Now we'll define the Markov Decision Process:\n",
+ "- **Observations**: pole angle, angular velocity, cart position, cart velocity\n",
+ "- **Actions**: cart velocity commands\n",
+ "- **Rewards**: upright reward + effort penalty\n",
+ "- **Terminations**: pole tips over or timeout\n",
+ "- **Events**: random pushes for robustness"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "RQxe4TBrHb-I"
+ },
+ "source": [
+ "### **š Create Task Directory**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "nWBqdkziHc2G"
+ },
+ "outputs": [],
+ "source": [
+ "!mkdir -p /content/mjlab/src/mjlab/tasks/cartpole\n",
+ "\n",
+ "print(\"ā Task directory created\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "GJfjPpm0Hhj1"
+ },
+ "source": [
+ "### **š Create Environment Configuration**\n",
+ "\n",
+ "This file contains the MDP (Markov Decision Process) components:\n",
+ "1. **Scene Config**: 64 parallel environments\n",
+ "2. **Actions**: Joint velocity control with 20.0 scale\n",
+ "3. **Observations**: Normalized state variables\n",
+ "4. **Rewards**: Upright reward (5.0) + effort penalty (-0.01)\n",
+ "5. **Events**: Joint resets + random pushes\n",
+ "6. **Terminations**: Pole tipped (>30°) or timeout (10s)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "javx9XDIHkFI"
+ },
+ "outputs": [],
+ "source": [
+ "%%writefile /content/mjlab/src/mjlab/tasks/cartpole/env_cfg.py\n",
+ "\"\"\"CartPole task environment configuration.\"\"\"\n",
+ "\n",
+ "import math\n",
+ "import torch\n",
+ "\n",
+ "from mjlab.envs import ManagerBasedRlEnvCfg\n",
+ "from mjlab.envs.mdp.actions import JointVelocityActionCfg\n",
+ "from mjlab.managers.manager_term_config import (\n",
+ " ObservationGroupCfg,\n",
+ " ObservationTermCfg,\n",
+ " RewardTermCfg,\n",
+ " TerminationTermCfg,\n",
+ " EventTermCfg,\n",
+ ")\n",
+ "from mjlab.managers.scene_entity_config import SceneEntityCfg\n",
+ "from mjlab.scene import SceneCfg\n",
+ "from mjlab.sim import MujocoCfg, SimulationCfg\n",
+ "from mjlab.viewer import ViewerConfig\n",
+ "from mjlab.asset_zoo.robots.cartpole.cartpole_constants import get_cartpole_robot_cfg\n",
+ "from mjlab.envs import mdp\n",
+ "\n",
+ "\n",
+ "def cartpole_env_cfg(play: bool = False) -> ManagerBasedRlEnvCfg:\n",
+ " \"\"\"Create CartPole environment configuration.\n",
+ "\n",
+ " Args:\n",
+ " play: If True, disables corruption and extends episode length for evaluation.\n",
+ " \"\"\"\n",
+ "\n",
+ " # ==============================================================================\n",
+ " # Scene Configuration\n",
+ " # ==============================================================================\n",
+ "\n",
+ " scene_cfg = SceneCfg(\n",
+ " num_envs=64 if not play else 16, # Fewer envs for play mode\n",
+ " extent=1.0, # Spacing between environments\n",
+ " entities={\"robot\": get_cartpole_robot_cfg()},\n",
+ " )\n",
+ "\n",
+ " viewer_cfg = ViewerConfig(\n",
+ " origin_type=ViewerConfig.OriginType.ASSET_BODY,\n",
+ " asset_name=\"robot\",\n",
+ " body_name=\"pole\",\n",
+ " distance=3.0,\n",
+ " elevation=10.0,\n",
+ " azimuth=90.0,\n",
+ " )\n",
+ "\n",
+ " sim_cfg = SimulationCfg(\n",
+ " mujoco=MujocoCfg(\n",
+ " timestep=0.02, # 50 Hz control\n",
+ " iterations=1,\n",
+ " ),\n",
+ " )\n",
+ "\n",
+ " # ==============================================================================\n",
+ " # Actions\n",
+ " # ==============================================================================\n",
+ "\n",
+ " actions = {\n",
+ " \"joint_pos\": JointVelocityActionCfg(\n",
+ " asset_name=\"robot\",\n",
+ " actuator_names=(\".*\",),\n",
+ " scale=20.0,\n",
+ " use_default_offset=False,\n",
+ " ),\n",
+ " }\n",
+ "\n",
+ " # ==============================================================================\n",
+ " # Observations\n",
+ " # ==============================================================================\n",
+ "\n",
+ " policy_terms = {\n",
+ " \"angle\": ObservationTermCfg(\n",
+ " func=lambda env: env.sim.data.qpos[:, 1:2] / math.pi\n",
+ " ),\n",
+ " \"ang_vel\": ObservationTermCfg(\n",
+ " func=lambda env: env.sim.data.qvel[:, 1:2] / 5.0\n",
+ " ),\n",
+ " \"cart_pos\": ObservationTermCfg(\n",
+ " func=lambda env: env.sim.data.qpos[:, 0:1] / 2.0\n",
+ " ),\n",
+ " \"cart_vel\": ObservationTermCfg(\n",
+ " func=lambda env: env.sim.data.qvel[:, 0:1] / 20.0\n",
+ " ),\n",
+ " }\n",
+ "\n",
+ " observations = {\n",
+ " \"policy\": ObservationGroupCfg(\n",
+ " terms=policy_terms,\n",
+ " concatenate_terms=True,\n",
+ " enable_corruption=not play, # Disable corruption in play mode\n",
+ " ),\n",
+ " \"critic\": ObservationGroupCfg(\n",
+ " terms=policy_terms, # Critic uses same observations\n",
+ " concatenate_terms=True,\n",
+ " enable_corruption=False,\n",
+ " ),\n",
+ " }\n",
+ "\n",
+ " # ==============================================================================\n",
+ " # Rewards\n",
+ " # ==============================================================================\n",
+ "\n",
+ " def compute_upright_reward(env):\n",
+ " \"\"\"Reward for keeping pole upright (cosine of angle).\"\"\"\n",
+ " return env.sim.data.qpos[:, 1].cos()\n",
+ "\n",
+ " def compute_effort_penalty(env):\n",
+ " \"\"\"Penalty for control effort.\"\"\"\n",
+ " return -0.01 * (env.sim.data.ctrl[:, 0] ** 2)\n",
+ "\n",
+ " rewards = {\n",
+ " \"upright\": RewardTermCfg(func=compute_upright_reward, weight=5.0),\n",
+ " \"effort\": RewardTermCfg(func=compute_effort_penalty, weight=1.0),\n",
+ " }\n",
+ "\n",
+ " # ==============================================================================\n",
+ " # Events\n",
+ " # ==============================================================================\n",
+ "\n",
+ " def random_push_cart(env, env_ids, force_range=(-5, 5)):\n",
+ " \"\"\"Apply random force to cart for robustness training.\"\"\"\n",
+ " n = len(env_ids)\n",
+ " random_forces = (\n",
+ " torch.rand(n, device=env.device) *\n",
+ " (force_range[1] - force_range[0]) +\n",
+ " force_range[0]\n",
+ " )\n",
+ " env.sim.data.qfrc_applied[env_ids, 0] = random_forces\n",
+ "\n",
+ " events = {\n",
+ " \"reset_robot_joints\": EventTermCfg(\n",
+ " func=mdp.reset_joints_by_offset,\n",
+ " mode=\"reset\",\n",
+ " params={\n",
+ " \"asset_cfg\": SceneEntityCfg(\"robot\"),\n",
+ " \"position_range\": (-0.1, 0.1),\n",
+ " \"velocity_range\": (-0.1, 0.1),\n",
+ " },\n",
+ " ),\n",
+ " }\n",
+ "\n",
+ " # Add random pushes only in training mode\n",
+ " if not play:\n",
+ " events[\"random_push\"] = EventTermCfg(\n",
+ " func=random_push_cart,\n",
+ " mode=\"interval\",\n",
+ " interval_range_s=(1.0, 2.0),\n",
+ " params={\"force_range\": (-20.0, 20.0)},\n",
+ " )\n",
+ "\n",
+ " # ==============================================================================\n",
+ " # Terminations\n",
+ " # ==============================================================================\n",
+ "\n",
+ " def check_pole_tipped(env):\n",
+ " \"\"\"Check if pole has tipped beyond 30 degrees.\"\"\"\n",
+ " return env.sim.data.qpos[:, 1].abs() > math.radians(30)\n",
+ "\n",
+ " terminations = {\n",
+ " \"timeout\": TerminationTermCfg(func=mdp.time_out, time_out=True),\n",
+ " \"tipped\": TerminationTermCfg(func=check_pole_tipped, time_out=False),\n",
+ " }\n",
+ "\n",
+ " # ==============================================================================\n",
+ " # Environment Configuration\n",
+ " # ==============================================================================\n",
+ "\n",
+ " return ManagerBasedRlEnvCfg(\n",
+ " scene=scene_cfg,\n",
+ " observations=observations,\n",
+ " actions=actions,\n",
+ " rewards=rewards,\n",
+ " events=events,\n",
+ " terminations=terminations,\n",
+ " sim=sim_cfg,\n",
+ " viewer=viewer_cfg,\n",
+ " decimation=1, # No action repeat\n",
+ " episode_length_s=int(1e9) if play else 10.0, # Infinite for play, 10s for training\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "fC5maMjzSj_X"
+ },
+ "source": [
+ "### **āļø Create RL Configuration**\n",
+ "\n",
+ "This file defines the PPO (Proximal Policy Optimization) training parameters."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "C81zZm6mSj_X"
+ },
+ "outputs": [],
+ "source": [
+ "%%writefile /content/mjlab/src/mjlab/tasks/cartpole/rl_cfg.py\n",
+ "\"\"\"RL configuration for CartPole task.\"\"\"\n",
+ "\n",
+ "from mjlab.rl.config import (\n",
+ " RslRlOnPolicyRunnerCfg,\n",
+ " RslRlPpoActorCriticCfg,\n",
+ " RslRlPpoAlgorithmCfg,\n",
+ ")\n",
+ "\n",
+ "\n",
+ "def cartpole_ppo_runner_cfg() -> RslRlOnPolicyRunnerCfg:\n",
+ " \"\"\"Create RL runner configuration for CartPole task.\"\"\"\n",
+ " return RslRlOnPolicyRunnerCfg(\n",
+ " policy=RslRlPpoActorCriticCfg(\n",
+ " init_noise_std=1.0,\n",
+ " actor_obs_normalization=True,\n",
+ " critic_obs_normalization=True,\n",
+ " actor_hidden_dims=(256, 128, 64), # Smaller network for simpler task\n",
+ " critic_hidden_dims=(256, 128, 64),\n",
+ " activation=\"elu\",\n",
+ " ),\n",
+ " algorithm=RslRlPpoAlgorithmCfg(\n",
+ " value_loss_coef=1.0,\n",
+ " use_clipped_value_loss=True,\n",
+ " clip_param=0.2,\n",
+ " entropy_coef=0.01,\n",
+ " num_learning_epochs=5,\n",
+ " num_mini_batches=4,\n",
+ " learning_rate=1.0e-3,\n",
+ " schedule=\"adaptive\",\n",
+ " gamma=0.99,\n",
+ " lam=0.95,\n",
+ " desired_kl=0.01,\n",
+ " max_grad_norm=1.0,\n",
+ " ),\n",
+ " experiment_name=\"cartpole\",\n",
+ " save_interval=50,\n",
+ " num_steps_per_env=24,\n",
+ " max_iterations=5_000, # Fewer iterations for simpler task\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Oc8-AHGcHt78"
+ },
+ "source": [
+ "### **š Register the Task Environment**\n",
+ "\n",
+ "Register the CartPole task with mjlab registry."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "YitUGUBRHxD4"
+ },
+ "outputs": [],
+ "source": [
+ "%%writefile /content/mjlab/src/mjlab/tasks/cartpole/__init__.py\n",
+ "\"\"\"CartPole task registration.\"\"\"\n",
+ "\n",
+ "from mjlab.tasks.registry import register_mjlab_task\n",
+ "from mjlab.tasks.velocity.rl import VelocityOnPolicyRunner\n",
+ "\n",
+ "from .env_cfg import cartpole_env_cfg\n",
+ "from .rl_cfg import cartpole_ppo_runner_cfg\n",
+ "\n",
+ "register_mjlab_task(\n",
+ " task_id=\"Mjlab-Cartpole\",\n",
+ " env_cfg=cartpole_env_cfg(),\n",
+ " play_env_cfg=cartpole_env_cfg(play=True),\n",
+ " rl_cfg=cartpole_ppo_runner_cfg(),\n",
+ " runner_cls=VelocityOnPolicyRunner,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "K7wqLZR1rnGn"
+ },
+ "source": [
+ "---\n",
+ "\n",
+ "## **š Step 3: Train the Agent**\n",
+ "\n",
+ "Now let's train a PPO policy to balance the CartPole!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Hht_hF4trqP2"
+ },
+ "outputs": [],
+ "source": [
+ "!python -m mjlab.scripts.train Mjlab-Cartpole --agent.max-iterations 201 --agent.save-interval 20"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "xCaqPznGrx8H"
+ },
+ "source": [
+ "### **š Locate Training Checkpoints**\n",
+ "\n",
+ "After training, checkpoints are saved locally."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "uPnmHYu8r0uY"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import re\n",
+ "from pathlib import Path\n",
+ "\n",
+ "# Find the most recent training run\n",
+ "log_dir = Path(\"/content/mjlab/logs/rsl_rl/cartpole\")\n",
+ "if log_dir.exists():\n",
+ " runs = sorted(log_dir.glob(\"*\"), key=os.path.getmtime, reverse=True)\n",
+ " if runs:\n",
+ " latest_run = runs[0]\n",
+ " print(f\"ā Latest training run: {latest_run.name}\\n\")\n",
+ "\n",
+ " # List checkpoints - sorted by iteration number\n",
+ " checkpoints = list(latest_run.glob(\"model_*.pt\"))\n",
+ " if checkpoints:\n",
+ " # Extract iteration number and sort numerically\n",
+ " def get_iteration(ckpt):\n",
+ " match = re.search(r\"model_(\\d+)\\.pt\", ckpt.name)\n",
+ " return int(match.group(1)) if match else 0\n",
+ "\n",
+ " checkpoints = sorted(checkpoints, key=get_iteration)\n",
+ "\n",
+ " print(f\"Found {len(checkpoints)} checkpoints:\")\n",
+ " for ckpt in checkpoints[-5:]: # Show last 5\n",
+ " size_mb = ckpt.stat().st_size / (1024 * 1024)\n",
+ " iteration = get_iteration(ckpt)\n",
+ " print(f\" ⢠{ckpt.name} (iteration {iteration}, {size_mb:.2f} MB)\")\n",
+ "\n",
+ " # Store the last checkpoint path\n",
+ " last_checkpoint = str(checkpoints[-1])\n",
+ " print(f\"\\nš¾ Last checkpoint: {last_checkpoint}\")\n",
+ " else:\n",
+ " print(\"ā No checkpoints found yet\")\n",
+ " else:\n",
+ " print(\"ā No training runs found\")\n",
+ "else:\n",
+ " print(\"ā Log directory not found. Have you run training yet?\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "eWFS9Pw7r2uH"
+ },
+ "source": [
+ "---\n",
+ "\n",
+ "## **š® Step 4: Visualize the Trained Policy**\n",
+ "\n",
+ "Let's see the trained policy in action!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "78PgHtpfr5sb"
+ },
+ "source": [
+ "### **š Launch Viser API**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "_9tGiFyBr2bW"
+ },
+ "outputs": [],
+ "source": [
+ "import subprocess\n",
+ "import sys\n",
+ "\n",
+ "process = subprocess.Popen(\n",
+ " [\n",
+ " \"python\",\n",
+ " \"-m\",\n",
+ " \"mjlab.scripts.play\",\n",
+ " \"Mjlab-Cartpole\",\n",
+ " \"--checkpoint_file\",\n",
+ " last_checkpoint,\n",
+ " \"--num_envs\",\n",
+ " \"4\",\n",
+ " ],\n",
+ " stdout=subprocess.PIPE,\n",
+ " stderr=subprocess.STDOUT,\n",
+ " universal_newlines=True,\n",
+ " bufsize=1,\n",
+ ")\n",
+ "\n",
+ "detected_port = None\n",
+ "\n",
+ "for line in process.stdout:\n",
+ " print(line, end=\"\")\n",
+ " sys.stdout.flush()\n",
+ "\n",
+ " # Extract port number from viser output\n",
+ " port_match = re.search(r\":(\\d{4})\", line)\n",
+ " if port_match and \"viser\" in line.lower():\n",
+ " detected_port = int(port_match.group(1))\n",
+ " print(\"\\n\" + \"=\" * 34)\n",
+ " print(f\"ā
Server is running on port {detected_port}!\")\n",
+ " print(\"=\" * 34)\n",
+ " break"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "XgzJXyBXssZS"
+ },
+ "source": [
+ "### **š„ļø Embed Client as iframe**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ll89QnuSuUxx"
+ },
+ "outputs": [],
+ "source": [
+ "from google.colab import output\n",
+ "\n",
+ "port = detected_port if detected_port else 8081\n",
+ "output.serve_kernel_port_as_iframe(port=port, height=700)"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "provenance": [],
+ "toc_visible": true,
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file