From 9d0ea2cd01500c605638120269f1d871e54594f5 Mon Sep 17 00:00:00 2001 From: ttktjmt Date: Sat, 6 Dec 2025 13:53:51 +0900 Subject: [PATCH 1/2] Add create_new_task tutorial notebook --- notebooks/create_new_task.ipynb | 939 ++++++++++++++++++++++++++++++++ 1 file changed, 939 insertions(+) create mode 100644 notebooks/create_new_task.ipynb diff --git a/notebooks/create_new_task.ipynb b/notebooks/create_new_task.ipynb new file mode 100644 index 00000000..9d18ea99 --- /dev/null +++ b/notebooks/create_new_task.ipynb @@ -0,0 +1,939 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PO76KS1i-MwA" + }, + "source": [ + "# **šŸ¤– CartPole Tutorial with MJLab**\n", + "\n", + "This notebook demonstrates how to create a custom reinforcement learning task using MJLab. We'll build a CartPole environment from scratch, including:\n", + "\n", + "1. **Robot Definition** - Define the CartPole model in MuJoCo XML\n", + "2. **Task Configuration** - Set up observations, actions, rewards, and terminations\n", + "3. **Training** - Train a policy using PPO\n", + "4. **Evaluation** - Visualize the simulation with the trained policy\n", + "\n", + "> **Note**: This tutorial is created based on the official MJLab documentation [\"Create a New Task\"](https://github.com/mujocolab/mjlab/blob/main/docs/create_new_task.md)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3ywZTgfR3C_w" + }, + "source": [ + "## **šŸ“¦ Setup and Installation**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "collapsed": true, + "id": "dtLMJHzy3Nee", + "outputId": "4ff1cea6-db01-4d78-f074-c655bbf4dccd" + }, + "outputs": [], + "source": [ + "# Clone the mjlab repository\n", + "!if [ ! -d \"mjlab\" ]; then git clone -q https://github.com/mujocolab/mjlab.git; fi\n", + "%cd /content/mjlab\n", + "\n", + "# Install mjlab in editable mode\n", + "!uv pip install --system -e . -q\n", + "\n", + "print(\"āœ“ Installation complete!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SSf2943z3b0s" + }, + "source": [ + "### **šŸ”‘ WandB Setup (Optional)**\n", + "\n", + "Configure Weights & Biases for experiment tracking. Add your WandB API key to Colab Secrets:\n", + "- `WANDB_API_KEY`: from [wandb.ai/authorize](https://wandb.ai/authorize)\n", + "- `WANDB_ENTITY`: your wandb entity name" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KC9ywCnm3dGg", + "outputId": "e481f303-f938-49ae-f9c3-aff522c5bdf5" + }, + "outputs": [], + "source": [ + "import os\n", + "from google.colab import userdata\n", + "\n", + "try:\n", + " # Set this to disable wandb logger\n", + " # os.environ['WANDB_MODE'] = 'disabled'\n", + "\n", + " # Set this to use wandb logger\n", + " os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')\n", + " os.environ['WANDB_ENTITY'] = userdata.get('WANDB_ENTITY')\n", + "\n", + " print(\"āœ“ WandB configured successfully!\")\n", + "except (AttributeError, KeyError):\n", + " print(\"⚠ WandB secrets not found. Training will proceed without WandB logging.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mispfmy73lmq" + }, + "source": [ + "---\n", + "\n", + "## **šŸ¤– Step 1: Define the Robot**\n", + "\n", + "We'll create a simple CartPole robot with:\n", + "- A sliding cart (1 DOF)\n", + "- A hinged pole (1 DOF)\n", + "- A velocity actuator to control the cart" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-FvJYPWD3scd" + }, + "source": [ + "### **šŸ“ Structure Directories**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "OP-yET-R3ofN", + "outputId": "660793a3-be51-449f-ba5a-72c9b975a4aa" + }, + "outputs": [], + "source": [ + "# Create the cartpole robot directory structure\n", + "!mkdir -p /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/\n", + "!mkdir -p /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/xmls\n", + "\n", + "print(\"āœ“ Directory structure created\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MRyN1Pok3u25" + }, + "source": [ + "### **šŸ“ Create MuJoCo XML Model**\n", + "\n", + "This XML defines the CartPole physics:\n", + "- **Ground plane** for visualization\n", + "- **Cart body** with a sliding joint (±2m range)\n", + "- **Pole body** with a hinge joint (±90° range)\n", + "- **Velocity actuator** for cart control" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "gWGyFX5V3yWc", + "outputId": "3e4e196e-0cf4-4f4d-8d28-db41f974b3c6" + }, + "outputs": [], + "source": [ + "%%writefile /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/xmls/cartpole.xml\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MpYCG9jI31dZ" + }, + "source": [ + "### **āš™ļø Create Robot Configuration**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "HDhiyDTn4AVa", + "outputId": "7f4a9347-a820-4de3-801e-dc12cd2ccad6" + }, + "outputs": [], + "source": [ + "%%writefile /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/cartpole_constants.py\n", + "from pathlib import Path\n", + "import mujoco\n", + "\n", + "from mjlab import MJLAB_SRC_PATH\n", + "from mjlab.entity import Entity, EntityCfg, EntityArticulationInfoCfg\n", + "from mjlab.actuator import XmlVelocityActuatorCfg\n", + "\n", + "CARTPOLE_XML: Path = (\n", + " MJLAB_SRC_PATH / \"asset_zoo\" / \"robots\" / \"cartpole\" / \"xmls\" / \"cartpole.xml\"\n", + ")\n", + "assert CARTPOLE_XML.exists(), f\"XML not found: {CARTPOLE_XML}\"\n", + "\n", + "def get_spec() -> mujoco.MjSpec:\n", + " return mujoco.MjSpec.from_file(str(CARTPOLE_XML))\n", + "\n", + "def get_cartpole_robot_cfg() -> EntityCfg:\n", + " \"\"\"Get a fresh CartPole robot configuration instance.\"\"\"\n", + " actuators = (\n", + " XmlVelocityActuatorCfg(\n", + " joint_names_expr=(\"slide\",),\n", + " ),\n", + " )\n", + " articulation = EntityArticulationInfoCfg(actuators=actuators)\n", + " return EntityCfg(\n", + " spec_fn=get_spec,\n", + " articulation=articulation\n", + " )\n", + "\n", + "# if __name__ == \"__main__\":\n", + "# import mujoco.viewer as viewer\n", + "# robot = Entity(get_cartpole_robot_cfg())\n", + "# viewer.launch(robot.spec.compile())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "-WSaDod04FwN", + "outputId": "ac085f19-e3c0-4d8c-ec77-a4327305c2c3" + }, + "outputs": [], + "source": [ + "# Create __init__.py for the cartpole robot package\n", + "%%writefile /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/__init__.py\n", + "# Empty __init__.py to mark the directory as a Python package" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "W1tiBPfp_oVP", + "outputId": "4de0534e-81c8-43a6-f3eb-d2e4f6dd1217" + }, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "# Append src dir to python path\n", + "mjlab_src = '/content/mjlab/src'\n", + "if mjlab_src not in sys.path:\n", + " sys.path.insert(0, mjlab_src)\n", + " print(f\"āœ“ Added {mjlab_src} to Python path\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ToWF84qC4Hfg" + }, + "source": [ + "### **āœ… Verify Robot Setup**\n", + "\n", + "Let's test that the robot can be loaded correctly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5tVsvqzQ4J9h", + "outputId": "b27372be-85d9-4150-f69c-6ae0c784f8bd" + }, + "outputs": [], + "source": [ + "from mjlab.entity import Entity\n", + "from mjlab.asset_zoo.robots.cartpole.cartpole_constants import get_cartpole_robot_cfg\n", + "\n", + "# Load the robot\n", + "robot = Entity(get_cartpole_robot_cfg())\n", + "model = robot.spec.compile()\n", + "\n", + "# Display robot information\n", + "print(\"āœ“ CartPole robot loaded successfully!\")\n", + "print(f\" • Degrees of Freedom (DOF): {model.nv}\")\n", + "print(f\" • Number of Actuators: {model.nu}\")\n", + "print(f\" • Bodies: {model.nbody}\")\n", + "print(f\" • Joints: {model.njnt}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "e2_9dixlHON1" + }, + "source": [ + "### **šŸ“‹ Register the Robot**\n", + "\n", + "Add the CartPole robot to the asset zoo registry." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "8qDIF__lHPcb", + "outputId": "b328c3e5-784f-482c-85c2-32c641dfb09a" + }, + "outputs": [], + "source": [ + "# Add CartPole import to robots __init__.py\n", + "with open('/content/mjlab/src/mjlab/asset_zoo/robots/__init__.py', 'a') as f:\n", + " f.write('\\n# CartPole robot\\n')\n", + " f.write('from mjlab.asset_zoo.robots.cartpole.cartpole_constants import ')\n", + " f.write('get_cartpole_robot_cfg as get_cartpole_robot_cfg\\n')\n", + "\n", + "print(\"āœ“ CartPole robot registered in asset zoo\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6lVD_L6PHWNm" + }, + "source": [ + "---\n", + "\n", + "## **šŸŽÆ Step 2: Define the Task (MDP)**\n", + "\n", + "Now we'll define the Markov Decision Process:\n", + "- **Observations**: pole angle, angular velocity, cart position, cart velocity\n", + "- **Actions**: cart velocity commands\n", + "- **Rewards**: upright reward + effort penalty\n", + "- **Terminations**: pole tips over or timeout\n", + "- **Events**: random pushes for robustness" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RQxe4TBrHb-I" + }, + "source": [ + "### **šŸ“ Create Task Directory**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "nWBqdkziHc2G", + "outputId": "77fb7eb2-2b09-4f46-c678-bed897e46ef4" + }, + "outputs": [], + "source": [ + "!mkdir -p /content/mjlab/src/mjlab/tasks/cartpole\n", + "\n", + "print(\"āœ“ Task directory created\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GJfjPpm0Hhj1" + }, + "source": [ + "### **šŸ“ Create Environment Configuration**\n", + "\n", + "This file contains the MDP (Markov Decision Process) components:\n", + "1. **Scene Config**: 64 parallel environments\n", + "2. **Actions**: Joint velocity control with 20.0 scale\n", + "3. **Observations**: Normalized state variables\n", + "4. **Rewards**: Upright reward (5.0) + effort penalty (-0.01)\n", + "5. **Events**: Joint resets + random pushes\n", + "6. **Terminations**: Pole tipped (>30°) or timeout (10s)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "javx9XDIHkFI", + "outputId": "9ecf2cf3-1666-4562-99fd-165753289b8a" + }, + "outputs": [], + "source": [ + "%%writefile /content/mjlab/src/mjlab/tasks/cartpole/env_cfg.py\n", + "\"\"\"CartPole task environment configuration.\"\"\"\n", + "\n", + "import math\n", + "import torch\n", + "\n", + "from mjlab.envs import ManagerBasedRlEnvCfg\n", + "from mjlab.envs.mdp.actions import JointVelocityActionCfg\n", + "from mjlab.managers.manager_term_config import (\n", + " ObservationGroupCfg,\n", + " ObservationTermCfg,\n", + " RewardTermCfg,\n", + " TerminationTermCfg,\n", + " EventTermCfg,\n", + ")\n", + "from mjlab.managers.scene_entity_config import SceneEntityCfg\n", + "from mjlab.scene import SceneCfg\n", + "from mjlab.sim import MujocoCfg, SimulationCfg\n", + "from mjlab.viewer import ViewerConfig\n", + "from mjlab.asset_zoo.robots.cartpole.cartpole_constants import get_cartpole_robot_cfg\n", + "from mjlab.envs import mdp\n", + "\n", + "\n", + "def cartpole_env_cfg(play: bool = False) -> ManagerBasedRlEnvCfg:\n", + " \"\"\"Create CartPole environment configuration.\n", + "\n", + " Args:\n", + " play: If True, disables corruption and extends episode length for evaluation.\n", + " \"\"\"\n", + "\n", + " # ==============================================================================\n", + " # Scene Configuration\n", + " # ==============================================================================\n", + "\n", + " scene_cfg = SceneCfg(\n", + " num_envs=64 if not play else 16, # Fewer envs for play mode\n", + " extent=1.0, # Spacing between environments\n", + " entities={\"robot\": get_cartpole_robot_cfg()},\n", + " )\n", + "\n", + " viewer_cfg = ViewerConfig(\n", + " origin_type=ViewerConfig.OriginType.ASSET_BODY,\n", + " asset_name=\"robot\",\n", + " body_name=\"pole\",\n", + " distance=3.0,\n", + " elevation=10.0,\n", + " azimuth=90.0,\n", + " )\n", + "\n", + " sim_cfg = SimulationCfg(\n", + " mujoco=MujocoCfg(\n", + " timestep=0.02, # 50 Hz control\n", + " iterations=1,\n", + " ),\n", + " )\n", + "\n", + " # ==============================================================================\n", + " # Actions\n", + " # ==============================================================================\n", + "\n", + " actions = {\n", + " \"joint_pos\": JointVelocityActionCfg(\n", + " asset_name=\"robot\",\n", + " actuator_names=(\".*\",),\n", + " scale=20.0,\n", + " use_default_offset=False,\n", + " ),\n", + " }\n", + "\n", + " # ==============================================================================\n", + " # Observations\n", + " # ==============================================================================\n", + "\n", + " policy_terms = {\n", + " \"angle\": ObservationTermCfg(\n", + " func=lambda env: env.sim.data.qpos[:, 1:2] / math.pi\n", + " ),\n", + " \"ang_vel\": ObservationTermCfg(\n", + " func=lambda env: env.sim.data.qvel[:, 1:2] / 5.0\n", + " ),\n", + " \"cart_pos\": ObservationTermCfg(\n", + " func=lambda env: env.sim.data.qpos[:, 0:1] / 2.0\n", + " ),\n", + " \"cart_vel\": ObservationTermCfg(\n", + " func=lambda env: env.sim.data.qvel[:, 0:1] / 20.0\n", + " ),\n", + " }\n", + "\n", + " observations = {\n", + " \"policy\": ObservationGroupCfg(\n", + " terms=policy_terms,\n", + " concatenate_terms=True,\n", + " enable_corruption=not play, # Disable corruption in play mode\n", + " ),\n", + " \"critic\": ObservationGroupCfg(\n", + " terms=policy_terms, # Critic uses same observations\n", + " concatenate_terms=True,\n", + " enable_corruption=False,\n", + " ),\n", + " }\n", + "\n", + " # ==============================================================================\n", + " # Rewards\n", + " # ==============================================================================\n", + "\n", + " def compute_upright_reward(env):\n", + " \"\"\"Reward for keeping pole upright (cosine of angle).\"\"\"\n", + " return env.sim.data.qpos[:, 1].cos()\n", + "\n", + " def compute_effort_penalty(env):\n", + " \"\"\"Penalty for control effort.\"\"\"\n", + " return -0.01 * (env.sim.data.ctrl[:, 0] ** 2)\n", + "\n", + " rewards = {\n", + " \"upright\": RewardTermCfg(func=compute_upright_reward, weight=5.0),\n", + " \"effort\": RewardTermCfg(func=compute_effort_penalty, weight=1.0),\n", + " }\n", + "\n", + " # ==============================================================================\n", + " # Events\n", + " # ==============================================================================\n", + "\n", + " def random_push_cart(env, env_ids, force_range=(-5, 5)):\n", + " \"\"\"Apply random force to cart for robustness training.\"\"\"\n", + " n = len(env_ids)\n", + " random_forces = (\n", + " torch.rand(n, device=env.device) *\n", + " (force_range[1] - force_range[0]) +\n", + " force_range[0]\n", + " )\n", + " env.sim.data.qfrc_applied[env_ids, 0] = random_forces\n", + "\n", + " events = {\n", + " \"reset_robot_joints\": EventTermCfg(\n", + " func=mdp.reset_joints_by_offset,\n", + " mode=\"reset\",\n", + " params={\n", + " \"asset_cfg\": SceneEntityCfg(\"robot\"),\n", + " \"position_range\": (-0.1, 0.1),\n", + " \"velocity_range\": (-0.1, 0.1),\n", + " },\n", + " ),\n", + " }\n", + "\n", + " # Add random pushes only in training mode\n", + " if not play:\n", + " events[\"random_push\"] = EventTermCfg(\n", + " func=random_push_cart,\n", + " mode=\"interval\",\n", + " interval_range_s=(1.0, 2.0),\n", + " params={\"force_range\": (-20.0, 20.0)},\n", + " )\n", + "\n", + " # ==============================================================================\n", + " # Terminations\n", + " # ==============================================================================\n", + "\n", + " def check_pole_tipped(env):\n", + " \"\"\"Check if pole has tipped beyond 30 degrees.\"\"\"\n", + " return env.sim.data.qpos[:, 1].abs() > math.radians(30)\n", + "\n", + " terminations = {\n", + " \"timeout\": TerminationTermCfg(func=mdp.time_out, time_out=True),\n", + " \"tipped\": TerminationTermCfg(func=check_pole_tipped, time_out=False),\n", + " }\n", + "\n", + " # ==============================================================================\n", + " # Environment Configuration\n", + " # ==============================================================================\n", + "\n", + " return ManagerBasedRlEnvCfg(\n", + " scene=scene_cfg,\n", + " observations=observations,\n", + " actions=actions,\n", + " rewards=rewards,\n", + " events=events,\n", + " terminations=terminations,\n", + " sim=sim_cfg,\n", + " viewer=viewer_cfg,\n", + " decimation=1, # No action repeat\n", + " episode_length_s=int(1e9) if play else 10.0, # Infinite for play, 10s for training\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fC5maMjzSj_X" + }, + "source": [ + "### **āš™ļø Create RL Configuration**\n", + "\n", + "This file defines the PPO (Proximal Policy Optimization) training parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "C81zZm6mSj_X", + "outputId": "a248d4d7-5c86-402e-9994-eeceff962a23" + }, + "outputs": [], + "source": [ + "%%writefile /content/mjlab/src/mjlab/tasks/cartpole/rl_cfg.py\n", + "\"\"\"RL configuration for CartPole task.\"\"\"\n", + "\n", + "from mjlab.rl.config import (\n", + " RslRlOnPolicyRunnerCfg,\n", + " RslRlPpoActorCriticCfg,\n", + " RslRlPpoAlgorithmCfg,\n", + ")\n", + "\n", + "\n", + "def cartpole_ppo_runner_cfg() -> RslRlOnPolicyRunnerCfg:\n", + " \"\"\"Create RL runner configuration for CartPole task.\"\"\"\n", + " return RslRlOnPolicyRunnerCfg(\n", + " policy=RslRlPpoActorCriticCfg(\n", + " init_noise_std=1.0,\n", + " actor_obs_normalization=True,\n", + " critic_obs_normalization=True,\n", + " actor_hidden_dims=(256, 128, 64), # Smaller network for simpler task\n", + " critic_hidden_dims=(256, 128, 64),\n", + " activation=\"elu\",\n", + " ),\n", + " algorithm=RslRlPpoAlgorithmCfg(\n", + " value_loss_coef=1.0,\n", + " use_clipped_value_loss=True,\n", + " clip_param=0.2,\n", + " entropy_coef=0.01,\n", + " num_learning_epochs=5,\n", + " num_mini_batches=4,\n", + " learning_rate=1.0e-3,\n", + " schedule=\"adaptive\",\n", + " gamma=0.99,\n", + " lam=0.95,\n", + " desired_kl=0.01,\n", + " max_grad_norm=1.0,\n", + " ),\n", + " experiment_name=\"cartpole\",\n", + " save_interval=50,\n", + " num_steps_per_env=24,\n", + " max_iterations=5_000, # Fewer iterations for simpler task\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Oc8-AHGcHt78" + }, + "source": [ + "### **šŸ“‹ Register the Task Environment**\n", + "\n", + "Register the CartPole task with mjlab registry." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "YitUGUBRHxD4", + "outputId": "cfc2413f-37c4-4043-a574-d38ab1d1776c" + }, + "outputs": [], + "source": [ + "%%writefile /content/mjlab/src/mjlab/tasks/cartpole/__init__.py\n", + "\"\"\"CartPole task registration.\"\"\"\n", + "\n", + "from mjlab.tasks.registry import register_mjlab_task\n", + "from mjlab.tasks.velocity.rl import VelocityOnPolicyRunner\n", + "\n", + "from .env_cfg import cartpole_env_cfg\n", + "from .rl_cfg import cartpole_ppo_runner_cfg\n", + "\n", + "register_mjlab_task(\n", + " task_id=\"Mjlab-Cartpole\",\n", + " env_cfg=cartpole_env_cfg(),\n", + " play_env_cfg=cartpole_env_cfg(play=True),\n", + " rl_cfg=cartpole_ppo_runner_cfg(),\n", + " runner_cls=VelocityOnPolicyRunner,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "K7wqLZR1rnGn" + }, + "source": [ + "---\n", + "\n", + "## **šŸš€ Step 3: Train the Agent**\n", + "\n", + "Now let's train a PPO policy to balance the CartPole!\n", + "\n", + "**Training Configuration:**\n", + "- Algorithm: PPO (Proximal Policy Optimization)\n", + "- Parallel Environments: 64\n", + "- Episode Length: 10 seconds (500 steps @ 50Hz)\n", + "- Total Steps: ~5-10 million (adjust as needed)\n", + "\n", + "**āš ļø You may need to create a project named \"mjlab\" on wandb UI manually when google colab doesn't have permission to create a new project.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Hht_hF4trqP2", + "outputId": "574f55f5-0f27-44d3-d0e5-34e024259d29" + }, + "outputs": [], + "source": [ + "!python /content/mjlab/src/mjlab/scripts/train.py Mjlab-Cartpole --agent.max-iterations 201 --agent.save-interval 20" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xCaqPznGrx8H" + }, + "source": [ + "### **šŸ“ Locate Training Checkpoints**\n", + "\n", + "After training, checkpoints are saved locally." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "uPnmHYu8r0uY", + "outputId": "67cbaf4e-85eb-4f44-8fdd-5d6c8307859b" + }, + "outputs": [], + "source": [ + "import os\n", + "import re\n", + "from pathlib import Path\n", + "\n", + "# Find the most recent training run\n", + "log_dir = Path(\"/content/mjlab/logs/rsl_rl/cartpole\")\n", + "if log_dir.exists():\n", + " runs = sorted(log_dir.glob(\"*\"), key=os.path.getmtime, reverse=True)\n", + " if runs:\n", + " latest_run = runs[0]\n", + " print(f\"āœ“ Latest training run: {latest_run.name}\\n\")\n", + "\n", + " # List checkpoints - sorted by iteration number\n", + " checkpoints = list(latest_run.glob(\"model_*.pt\"))\n", + " if checkpoints:\n", + " # Extract iteration number and sort numerically\n", + " def get_iteration(ckpt):\n", + " match = re.search(r'model_(\\d+)\\.pt', ckpt.name)\n", + " return int(match.group(1)) if match else 0\n", + "\n", + " checkpoints = sorted(checkpoints, key=get_iteration)\n", + "\n", + " print(f\"Found {len(checkpoints)} checkpoints:\")\n", + " for ckpt in checkpoints[-5:]: # Show last 5\n", + " size_mb = ckpt.stat().st_size / (1024 * 1024)\n", + " iteration = get_iteration(ckpt)\n", + " print(f\" • {ckpt.name} (iteration {iteration}, {size_mb:.2f} MB)\")\n", + "\n", + " # Store the last checkpoint path\n", + " last_checkpoint = str(checkpoints[-1])\n", + " print(f\"\\nšŸ’¾ Last checkpoint: {last_checkpoint}\")\n", + " else:\n", + " print(\"⚠ No checkpoints found yet\")\n", + " else:\n", + " print(\"⚠ No training runs found\")\n", + "else:\n", + " print(\"⚠ Log directory not found. Have you run training yet?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eWFS9Pw7r2uH" + }, + "source": [ + "---\n", + "\n", + "## **šŸŽ® Step 4: Visualize the Trained Policy**\n", + "\n", + "Let's see the trained policy in action!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "78PgHtpfr5sb" + }, + "source": [ + "### **🌐 Launch Viser API**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "_9tGiFyBr2bW", + "outputId": "79db7c02-f301-4005-c5d4-f2726b2b514d" + }, + "outputs": [], + "source": [ + "import subprocess\n", + "import sys\n", + "import re\n", + "\n", + "process = subprocess.Popen(\n", + " [\n", + " \"python\",\n", + " \"/content/mjlab/src/mjlab/scripts/play.py\",\n", + " \"Mjlab-Cartpole\",\n", + " \"--checkpoint_file\",\n", + " last_checkpoint,\n", + " \"--num_envs\",\n", + " \"4\",\n", + " ],\n", + " stdout=subprocess.PIPE,\n", + " stderr=subprocess.STDOUT,\n", + " universal_newlines=True,\n", + " bufsize=1,\n", + ")\n", + "\n", + "detected_port = None\n", + "\n", + "for line in process.stdout:\n", + " print(line, end=\"\")\n", + " sys.stdout.flush()\n", + "\n", + " # Extract port number from viser output\n", + " port_match = re.search(r':(\\d{4})', line)\n", + " if port_match and 'viser' in line.lower():\n", + " detected_port = int(port_match.group(1))\n", + " print(\"\\n\" + \"=\" * 34)\n", + " print(f\"āœ… Server is running on port {detected_port}!\")\n", + " print(\"=\" * 34)\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XgzJXyBXssZS" + }, + "source": [ + "### **šŸ–„ļø Embed Client as iframe**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 721 + }, + "id": "ll89QnuSuUxx", + "outputId": "067a7ace-dd6e-4de0-ce7a-99d5ac478fb4" + }, + "outputs": [], + "source": [ + "from google.colab import output\n", + "\n", + "port = detected_port if detected_port else 8081\n", + "output.serve_kernel_port_as_iframe(\n", + " port=port,\n", + " height=700\n", + ")" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "include_colab_link": true, + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} From fc50d751ae4ccd831cc2fd1800a571548e54993f Mon Sep 17 00:00:00 2001 From: Tatsuki Tsujimoto <55564973+ttktjmt@users.noreply.github.com> Date: Wed, 10 Dec 2025 13:44:17 +0900 Subject: [PATCH 2/2] Update the notebook --- notebooks/create_new_task.ipynb | 278 ++++++++++++-------------------- 1 file changed, 102 insertions(+), 176 deletions(-) diff --git a/notebooks/create_new_task.ipynb b/notebooks/create_new_task.ipynb index 9d18ea99..90ce0112 100644 --- a/notebooks/create_new_task.ipynb +++ b/notebooks/create_new_task.ipynb @@ -3,11 +3,11 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", - "id": "view-in-github" + "id": "view-in-github", + "colab_type": "text" }, "source": [ - "\"Open" + "\"Open" ] }, { @@ -41,21 +41,20 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, "collapsed": true, - "id": "dtLMJHzy3Nee", - "outputId": "4ff1cea6-db01-4d78-f074-c655bbf4dccd" + "id": "dtLMJHzy3Nee" }, "outputs": [], "source": [ + "# Install mujoco-warp\n", + "!pip install git+https://github.com/google-deepmind/mujoco_warp@9fc294d86955a303619a254cefae809a41adb274 -q\n", + "\n", "# Clone the mjlab repository\n", "!if [ ! -d \"mjlab\" ]; then git clone -q https://github.com/mujocolab/mjlab.git; fi\n", "%cd /content/mjlab\n", "\n", "# Install mjlab in editable mode\n", - "!uv pip install --system -e . -q\n", + "!pip install -e . -q\n", "\n", "print(\"āœ“ Installation complete!\")" ] @@ -66,7 +65,7 @@ "id": "SSf2943z3b0s" }, "source": [ - "### **šŸ”‘ WandB Setup (Optional)**\n", + "### **šŸ”‘ WandB Setup**\n", "\n", "Configure Weights & Biases for experiment tracking. Add your WandB API key to Colab Secrets:\n", "- `WANDB_API_KEY`: from [wandb.ai/authorize](https://wandb.ai/authorize)\n", @@ -77,28 +76,25 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "KC9ywCnm3dGg", - "outputId": "e481f303-f938-49ae-f9c3-aff522c5bdf5" + "id": "KC9ywCnm3dGg" }, "outputs": [], "source": [ "import os\n", + "\n", "from google.colab import userdata\n", "\n", "try:\n", - " # Set this to disable wandb logger\n", - " # os.environ['WANDB_MODE'] = 'disabled'\n", - "\n", - " # Set this to use wandb logger\n", - " os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')\n", - " os.environ['WANDB_ENTITY'] = userdata.get('WANDB_ENTITY')\n", + " # Set this to use wandb logger\n", + " os.environ[\"WANDB_API_KEY\"] = userdata.get(\"WANDB_API_KEY\")\n", + " os.environ[\"WANDB_ENTITY\"] = userdata.get(\"WANDB_ENTITY\")\n", "\n", - " print(\"āœ“ WandB configured successfully!\")\n", + " print(\"āœ“ WandB configured successfully!\")\n", "except (AttributeError, KeyError):\n", - " print(\"⚠ WandB secrets not found. Training will proceed without WandB logging.\")" + " # Set this to disable wandb logger\n", + " os.environ['WANDB_MODE'] = 'disabled'\n", + "\n", + " print(\"⚠ WandB secrets not found. Training will proceed without WandB logging.\")" ] }, { @@ -130,11 +126,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "OP-yET-R3ofN", - "outputId": "660793a3-be51-449f-ba5a-72c9b975a4aa" + "id": "OP-yET-R3ofN" }, "outputs": [], "source": [ @@ -164,11 +156,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "gWGyFX5V3yWc", - "outputId": "3e4e196e-0cf4-4f4d-8d28-db41f974b3c6" + "id": "gWGyFX5V3yWc" }, "outputs": [], "source": [ @@ -205,11 +193,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "HDhiyDTn4AVa", - "outputId": "7f4a9347-a820-4de3-801e-dc12cd2ccad6" + "id": "HDhiyDTn4AVa" }, "outputs": [], "source": [ @@ -252,11 +236,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "-WSaDod04FwN", - "outputId": "ac085f19-e3c0-4d8c-ec77-a4327305c2c3" + "id": "-WSaDod04FwN" }, "outputs": [], "source": [ @@ -269,21 +249,17 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "W1tiBPfp_oVP", - "outputId": "4de0534e-81c8-43a6-f3eb-d2e4f6dd1217" + "id": "W1tiBPfp_oVP" }, "outputs": [], "source": [ "import sys\n", "\n", "# Append src dir to python path\n", - "mjlab_src = '/content/mjlab/src'\n", + "mjlab_src = \"/content/mjlab/src\"\n", "if mjlab_src not in sys.path:\n", - " sys.path.insert(0, mjlab_src)\n", - " print(f\"āœ“ Added {mjlab_src} to Python path\")" + " sys.path.insert(0, mjlab_src)\n", + " print(f\"āœ“ Added {mjlab_src} to Python path\")" ] }, { @@ -301,17 +277,14 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "5tVsvqzQ4J9h", - "outputId": "b27372be-85d9-4150-f69c-6ae0c784f8bd" + "id": "5tVsvqzQ4J9h" }, "outputs": [], "source": [ - "from mjlab.entity import Entity\n", "from mjlab.asset_zoo.robots.cartpole.cartpole_constants import get_cartpole_robot_cfg\n", "\n", + "from mjlab.entity import Entity\n", + "\n", "# Load the robot\n", "robot = Entity(get_cartpole_robot_cfg())\n", "model = robot.spec.compile()\n", @@ -339,19 +312,15 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "8qDIF__lHPcb", - "outputId": "b328c3e5-784f-482c-85c2-32c641dfb09a" + "id": "8qDIF__lHPcb" }, "outputs": [], "source": [ "# Add CartPole import to robots __init__.py\n", - "with open('/content/mjlab/src/mjlab/asset_zoo/robots/__init__.py', 'a') as f:\n", - " f.write('\\n# CartPole robot\\n')\n", - " f.write('from mjlab.asset_zoo.robots.cartpole.cartpole_constants import ')\n", - " f.write('get_cartpole_robot_cfg as get_cartpole_robot_cfg\\n')\n", + "with open(\"/content/mjlab/src/mjlab/asset_zoo/robots/__init__.py\", \"a\") as f:\n", + " f.write(\"\\n# CartPole robot\\n\")\n", + " f.write(\"from mjlab.asset_zoo.robots.cartpole.cartpole_constants import \")\n", + " f.write(\"get_cartpole_robot_cfg as get_cartpole_robot_cfg\\n\")\n", "\n", "print(\"āœ“ CartPole robot registered in asset zoo\")" ] @@ -387,11 +356,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "nWBqdkziHc2G", - "outputId": "77fb7eb2-2b09-4f46-c678-bed897e46ef4" + "id": "nWBqdkziHc2G" }, "outputs": [], "source": [ @@ -421,11 +386,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "javx9XDIHkFI", - "outputId": "9ecf2cf3-1666-4562-99fd-165753289b8a" + "id": "javx9XDIHkFI" }, "outputs": [], "source": [ @@ -628,11 +589,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "C81zZm6mSj_X", - "outputId": "a248d4d7-5c86-402e-9994-eeceff962a23" + "id": "C81zZm6mSj_X" }, "outputs": [], "source": [ @@ -693,11 +650,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "YitUGUBRHxD4", - "outputId": "cfc2413f-37c4-4043-a574-d38ab1d1776c" + "id": "YitUGUBRHxD4" }, "outputs": [], "source": [ @@ -729,30 +682,18 @@ "\n", "## **šŸš€ Step 3: Train the Agent**\n", "\n", - "Now let's train a PPO policy to balance the CartPole!\n", - "\n", - "**Training Configuration:**\n", - "- Algorithm: PPO (Proximal Policy Optimization)\n", - "- Parallel Environments: 64\n", - "- Episode Length: 10 seconds (500 steps @ 50Hz)\n", - "- Total Steps: ~5-10 million (adjust as needed)\n", - "\n", - "**āš ļø You may need to create a project named \"mjlab\" on wandb UI manually when google colab doesn't have permission to create a new project.**" + "Now let's train a PPO policy to balance the CartPole!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Hht_hF4trqP2", - "outputId": "574f55f5-0f27-44d3-d0e5-34e024259d29" + "id": "Hht_hF4trqP2" }, "outputs": [], "source": [ - "!python /content/mjlab/src/mjlab/scripts/train.py Mjlab-Cartpole --agent.max-iterations 201 --agent.save-interval 20" + "!python -m mjlab.scripts.train Mjlab-Cartpole --agent.max-iterations 201 --agent.save-interval 20" ] }, { @@ -770,11 +711,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "uPnmHYu8r0uY", - "outputId": "67cbaf4e-85eb-4f44-8fdd-5d6c8307859b" + "id": "uPnmHYu8r0uY" }, "outputs": [], "source": [ @@ -785,36 +722,36 @@ "# Find the most recent training run\n", "log_dir = Path(\"/content/mjlab/logs/rsl_rl/cartpole\")\n", "if log_dir.exists():\n", - " runs = sorted(log_dir.glob(\"*\"), key=os.path.getmtime, reverse=True)\n", - " if runs:\n", - " latest_run = runs[0]\n", - " print(f\"āœ“ Latest training run: {latest_run.name}\\n\")\n", - "\n", - " # List checkpoints - sorted by iteration number\n", - " checkpoints = list(latest_run.glob(\"model_*.pt\"))\n", - " if checkpoints:\n", - " # Extract iteration number and sort numerically\n", - " def get_iteration(ckpt):\n", - " match = re.search(r'model_(\\d+)\\.pt', ckpt.name)\n", - " return int(match.group(1)) if match else 0\n", - "\n", - " checkpoints = sorted(checkpoints, key=get_iteration)\n", - "\n", - " print(f\"Found {len(checkpoints)} checkpoints:\")\n", - " for ckpt in checkpoints[-5:]: # Show last 5\n", - " size_mb = ckpt.stat().st_size / (1024 * 1024)\n", - " iteration = get_iteration(ckpt)\n", - " print(f\" • {ckpt.name} (iteration {iteration}, {size_mb:.2f} MB)\")\n", - "\n", - " # Store the last checkpoint path\n", - " last_checkpoint = str(checkpoints[-1])\n", - " print(f\"\\nšŸ’¾ Last checkpoint: {last_checkpoint}\")\n", - " else:\n", - " print(\"⚠ No checkpoints found yet\")\n", + " runs = sorted(log_dir.glob(\"*\"), key=os.path.getmtime, reverse=True)\n", + " if runs:\n", + " latest_run = runs[0]\n", + " print(f\"āœ“ Latest training run: {latest_run.name}\\n\")\n", + "\n", + " # List checkpoints - sorted by iteration number\n", + " checkpoints = list(latest_run.glob(\"model_*.pt\"))\n", + " if checkpoints:\n", + " # Extract iteration number and sort numerically\n", + " def get_iteration(ckpt):\n", + " match = re.search(r\"model_(\\d+)\\.pt\", ckpt.name)\n", + " return int(match.group(1)) if match else 0\n", + "\n", + " checkpoints = sorted(checkpoints, key=get_iteration)\n", + "\n", + " print(f\"Found {len(checkpoints)} checkpoints:\")\n", + " for ckpt in checkpoints[-5:]: # Show last 5\n", + " size_mb = ckpt.stat().st_size / (1024 * 1024)\n", + " iteration = get_iteration(ckpt)\n", + " print(f\" • {ckpt.name} (iteration {iteration}, {size_mb:.2f} MB)\")\n", + "\n", + " # Store the last checkpoint path\n", + " last_checkpoint = str(checkpoints[-1])\n", + " print(f\"\\nšŸ’¾ Last checkpoint: {last_checkpoint}\")\n", " else:\n", - " print(\"⚠ No training runs found\")\n", + " print(\"⚠ No checkpoints found yet\")\n", + " else:\n", + " print(\"⚠ No training runs found\")\n", "else:\n", - " print(\"⚠ Log directory not found. Have you run training yet?\")" + " print(\"⚠ Log directory not found. Have you run training yet?\")" ] }, { @@ -843,48 +780,44 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "_9tGiFyBr2bW", - "outputId": "79db7c02-f301-4005-c5d4-f2726b2b514d" + "id": "_9tGiFyBr2bW" }, "outputs": [], "source": [ "import subprocess\n", "import sys\n", - "import re\n", "\n", "process = subprocess.Popen(\n", - " [\n", - " \"python\",\n", - " \"/content/mjlab/src/mjlab/scripts/play.py\",\n", - " \"Mjlab-Cartpole\",\n", - " \"--checkpoint_file\",\n", - " last_checkpoint,\n", - " \"--num_envs\",\n", - " \"4\",\n", - " ],\n", - " stdout=subprocess.PIPE,\n", - " stderr=subprocess.STDOUT,\n", - " universal_newlines=True,\n", - " bufsize=1,\n", + " [\n", + " \"python\",\n", + " \"-m\",\n", + " \"mjlab.scripts.play\",\n", + " \"Mjlab-Cartpole\",\n", + " \"--checkpoint_file\",\n", + " last_checkpoint,\n", + " \"--num_envs\",\n", + " \"4\",\n", + " ],\n", + " stdout=subprocess.PIPE,\n", + " stderr=subprocess.STDOUT,\n", + " universal_newlines=True,\n", + " bufsize=1,\n", ")\n", "\n", "detected_port = None\n", "\n", "for line in process.stdout:\n", - " print(line, end=\"\")\n", - " sys.stdout.flush()\n", + " print(line, end=\"\")\n", + " sys.stdout.flush()\n", "\n", - " # Extract port number from viser output\n", - " port_match = re.search(r':(\\d{4})', line)\n", - " if port_match and 'viser' in line.lower():\n", - " detected_port = int(port_match.group(1))\n", - " print(\"\\n\" + \"=\" * 34)\n", - " print(f\"āœ… Server is running on port {detected_port}!\")\n", - " print(\"=\" * 34)\n", - " break" + " # Extract port number from viser output\n", + " port_match = re.search(r\":(\\d{4})\", line)\n", + " if port_match and \"viser\" in line.lower():\n", + " detected_port = int(port_match.group(1))\n", + " print(\"\\n\" + \"=\" * 34)\n", + " print(f\"āœ… Server is running on port {detected_port}!\")\n", + " print(\"=\" * 34)\n", + " break" ] }, { @@ -900,22 +833,14 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 721 - }, - "id": "ll89QnuSuUxx", - "outputId": "067a7ace-dd6e-4de0-ce7a-99d5ac478fb4" + "id": "ll89QnuSuUxx" }, "outputs": [], "source": [ "from google.colab import output\n", "\n", "port = detected_port if detected_port else 8081\n", - "output.serve_kernel_port_as_iframe(\n", - " port=port,\n", - " height=700\n", - ")" + "output.serve_kernel_port_as_iframe(port=port, height=700)" ] } ], @@ -923,8 +848,9 @@ "accelerator": "GPU", "colab": { "gpuType": "T4", - "include_colab_link": true, - "provenance": [] + "provenance": [], + "toc_visible": true, + "include_colab_link": true }, "kernelspec": { "display_name": "Python 3", @@ -936,4 +862,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file