diff --git a/notebooks/create_new_task.ipynb b/notebooks/create_new_task.ipynb new file mode 100644 index 00000000..90ce0112 --- /dev/null +++ b/notebooks/create_new_task.ipynb @@ -0,0 +1,865 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PO76KS1i-MwA" + }, + "source": [ + "# **šŸ¤– CartPole Tutorial with MJLab**\n", + "\n", + "This notebook demonstrates how to create a custom reinforcement learning task using MJLab. We'll build a CartPole environment from scratch, including:\n", + "\n", + "1. **Robot Definition** - Define the CartPole model in MuJoCo XML\n", + "2. **Task Configuration** - Set up observations, actions, rewards, and terminations\n", + "3. **Training** - Train a policy using PPO\n", + "4. **Evaluation** - Visualize the simulation with the trained policy\n", + "\n", + "> **Note**: This tutorial is created based on the official MJLab documentation [\"Create a New Task\"](https://github.com/mujocolab/mjlab/blob/main/docs/create_new_task.md)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3ywZTgfR3C_w" + }, + "source": [ + "## **šŸ“¦ Setup and Installation**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true, + "id": "dtLMJHzy3Nee" + }, + "outputs": [], + "source": [ + "# Install mujoco-warp\n", + "!pip install git+https://github.com/google-deepmind/mujoco_warp@9fc294d86955a303619a254cefae809a41adb274 -q\n", + "\n", + "# Clone the mjlab repository\n", + "!if [ ! -d \"mjlab\" ]; then git clone -q https://github.com/mujocolab/mjlab.git; fi\n", + "%cd /content/mjlab\n", + "\n", + "# Install mjlab in editable mode\n", + "!pip install -e . -q\n", + "\n", + "print(\"āœ“ Installation complete!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SSf2943z3b0s" + }, + "source": [ + "### **šŸ”‘ WandB Setup**\n", + "\n", + "Configure Weights & Biases for experiment tracking. Add your WandB API key to Colab Secrets:\n", + "- `WANDB_API_KEY`: from [wandb.ai/authorize](https://wandb.ai/authorize)\n", + "- `WANDB_ENTITY`: your wandb entity name" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KC9ywCnm3dGg" + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from google.colab import userdata\n", + "\n", + "try:\n", + " # Set this to use wandb logger\n", + " os.environ[\"WANDB_API_KEY\"] = userdata.get(\"WANDB_API_KEY\")\n", + " os.environ[\"WANDB_ENTITY\"] = userdata.get(\"WANDB_ENTITY\")\n", + "\n", + " print(\"āœ“ WandB configured successfully!\")\n", + "except (AttributeError, KeyError):\n", + " # Set this to disable wandb logger\n", + " os.environ['WANDB_MODE'] = 'disabled'\n", + "\n", + " print(\"⚠ WandB secrets not found. Training will proceed without WandB logging.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mispfmy73lmq" + }, + "source": [ + "---\n", + "\n", + "## **šŸ¤– Step 1: Define the Robot**\n", + "\n", + "We'll create a simple CartPole robot with:\n", + "- A sliding cart (1 DOF)\n", + "- A hinged pole (1 DOF)\n", + "- A velocity actuator to control the cart" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-FvJYPWD3scd" + }, + "source": [ + "### **šŸ“ Structure Directories**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OP-yET-R3ofN" + }, + "outputs": [], + "source": [ + "# Create the cartpole robot directory structure\n", + "!mkdir -p /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/\n", + "!mkdir -p /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/xmls\n", + "\n", + "print(\"āœ“ Directory structure created\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MRyN1Pok3u25" + }, + "source": [ + "### **šŸ“ Create MuJoCo XML Model**\n", + "\n", + "This XML defines the CartPole physics:\n", + "- **Ground plane** for visualization\n", + "- **Cart body** with a sliding joint (±2m range)\n", + "- **Pole body** with a hinge joint (±90° range)\n", + "- **Velocity actuator** for cart control" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gWGyFX5V3yWc" + }, + "outputs": [], + "source": [ + "%%writefile /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/xmls/cartpole.xml\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MpYCG9jI31dZ" + }, + "source": [ + "### **āš™ļø Create Robot Configuration**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HDhiyDTn4AVa" + }, + "outputs": [], + "source": [ + "%%writefile /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/cartpole_constants.py\n", + "from pathlib import Path\n", + "import mujoco\n", + "\n", + "from mjlab import MJLAB_SRC_PATH\n", + "from mjlab.entity import Entity, EntityCfg, EntityArticulationInfoCfg\n", + "from mjlab.actuator import XmlVelocityActuatorCfg\n", + "\n", + "CARTPOLE_XML: Path = (\n", + " MJLAB_SRC_PATH / \"asset_zoo\" / \"robots\" / \"cartpole\" / \"xmls\" / \"cartpole.xml\"\n", + ")\n", + "assert CARTPOLE_XML.exists(), f\"XML not found: {CARTPOLE_XML}\"\n", + "\n", + "def get_spec() -> mujoco.MjSpec:\n", + " return mujoco.MjSpec.from_file(str(CARTPOLE_XML))\n", + "\n", + "def get_cartpole_robot_cfg() -> EntityCfg:\n", + " \"\"\"Get a fresh CartPole robot configuration instance.\"\"\"\n", + " actuators = (\n", + " XmlVelocityActuatorCfg(\n", + " joint_names_expr=(\"slide\",),\n", + " ),\n", + " )\n", + " articulation = EntityArticulationInfoCfg(actuators=actuators)\n", + " return EntityCfg(\n", + " spec_fn=get_spec,\n", + " articulation=articulation\n", + " )\n", + "\n", + "# if __name__ == \"__main__\":\n", + "# import mujoco.viewer as viewer\n", + "# robot = Entity(get_cartpole_robot_cfg())\n", + "# viewer.launch(robot.spec.compile())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-WSaDod04FwN" + }, + "outputs": [], + "source": [ + "# Create __init__.py for the cartpole robot package\n", + "%%writefile /content/mjlab/src/mjlab/asset_zoo/robots/cartpole/__init__.py\n", + "# Empty __init__.py to mark the directory as a Python package" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "W1tiBPfp_oVP" + }, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "# Append src dir to python path\n", + "mjlab_src = \"/content/mjlab/src\"\n", + "if mjlab_src not in sys.path:\n", + " sys.path.insert(0, mjlab_src)\n", + " print(f\"āœ“ Added {mjlab_src} to Python path\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ToWF84qC4Hfg" + }, + "source": [ + "### **āœ… Verify Robot Setup**\n", + "\n", + "Let's test that the robot can be loaded correctly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5tVsvqzQ4J9h" + }, + "outputs": [], + "source": [ + "from mjlab.asset_zoo.robots.cartpole.cartpole_constants import get_cartpole_robot_cfg\n", + "\n", + "from mjlab.entity import Entity\n", + "\n", + "# Load the robot\n", + "robot = Entity(get_cartpole_robot_cfg())\n", + "model = robot.spec.compile()\n", + "\n", + "# Display robot information\n", + "print(\"āœ“ CartPole robot loaded successfully!\")\n", + "print(f\" • Degrees of Freedom (DOF): {model.nv}\")\n", + "print(f\" • Number of Actuators: {model.nu}\")\n", + "print(f\" • Bodies: {model.nbody}\")\n", + "print(f\" • Joints: {model.njnt}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "e2_9dixlHON1" + }, + "source": [ + "### **šŸ“‹ Register the Robot**\n", + "\n", + "Add the CartPole robot to the asset zoo registry." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8qDIF__lHPcb" + }, + "outputs": [], + "source": [ + "# Add CartPole import to robots __init__.py\n", + "with open(\"/content/mjlab/src/mjlab/asset_zoo/robots/__init__.py\", \"a\") as f:\n", + " f.write(\"\\n# CartPole robot\\n\")\n", + " f.write(\"from mjlab.asset_zoo.robots.cartpole.cartpole_constants import \")\n", + " f.write(\"get_cartpole_robot_cfg as get_cartpole_robot_cfg\\n\")\n", + "\n", + "print(\"āœ“ CartPole robot registered in asset zoo\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6lVD_L6PHWNm" + }, + "source": [ + "---\n", + "\n", + "## **šŸŽÆ Step 2: Define the Task (MDP)**\n", + "\n", + "Now we'll define the Markov Decision Process:\n", + "- **Observations**: pole angle, angular velocity, cart position, cart velocity\n", + "- **Actions**: cart velocity commands\n", + "- **Rewards**: upright reward + effort penalty\n", + "- **Terminations**: pole tips over or timeout\n", + "- **Events**: random pushes for robustness" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RQxe4TBrHb-I" + }, + "source": [ + "### **šŸ“ Create Task Directory**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "nWBqdkziHc2G" + }, + "outputs": [], + "source": [ + "!mkdir -p /content/mjlab/src/mjlab/tasks/cartpole\n", + "\n", + "print(\"āœ“ Task directory created\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GJfjPpm0Hhj1" + }, + "source": [ + "### **šŸ“ Create Environment Configuration**\n", + "\n", + "This file contains the MDP (Markov Decision Process) components:\n", + "1. **Scene Config**: 64 parallel environments\n", + "2. **Actions**: Joint velocity control with 20.0 scale\n", + "3. **Observations**: Normalized state variables\n", + "4. **Rewards**: Upright reward (5.0) + effort penalty (-0.01)\n", + "5. **Events**: Joint resets + random pushes\n", + "6. **Terminations**: Pole tipped (>30°) or timeout (10s)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "javx9XDIHkFI" + }, + "outputs": [], + "source": [ + "%%writefile /content/mjlab/src/mjlab/tasks/cartpole/env_cfg.py\n", + "\"\"\"CartPole task environment configuration.\"\"\"\n", + "\n", + "import math\n", + "import torch\n", + "\n", + "from mjlab.envs import ManagerBasedRlEnvCfg\n", + "from mjlab.envs.mdp.actions import JointVelocityActionCfg\n", + "from mjlab.managers.manager_term_config import (\n", + " ObservationGroupCfg,\n", + " ObservationTermCfg,\n", + " RewardTermCfg,\n", + " TerminationTermCfg,\n", + " EventTermCfg,\n", + ")\n", + "from mjlab.managers.scene_entity_config import SceneEntityCfg\n", + "from mjlab.scene import SceneCfg\n", + "from mjlab.sim import MujocoCfg, SimulationCfg\n", + "from mjlab.viewer import ViewerConfig\n", + "from mjlab.asset_zoo.robots.cartpole.cartpole_constants import get_cartpole_robot_cfg\n", + "from mjlab.envs import mdp\n", + "\n", + "\n", + "def cartpole_env_cfg(play: bool = False) -> ManagerBasedRlEnvCfg:\n", + " \"\"\"Create CartPole environment configuration.\n", + "\n", + " Args:\n", + " play: If True, disables corruption and extends episode length for evaluation.\n", + " \"\"\"\n", + "\n", + " # ==============================================================================\n", + " # Scene Configuration\n", + " # ==============================================================================\n", + "\n", + " scene_cfg = SceneCfg(\n", + " num_envs=64 if not play else 16, # Fewer envs for play mode\n", + " extent=1.0, # Spacing between environments\n", + " entities={\"robot\": get_cartpole_robot_cfg()},\n", + " )\n", + "\n", + " viewer_cfg = ViewerConfig(\n", + " origin_type=ViewerConfig.OriginType.ASSET_BODY,\n", + " asset_name=\"robot\",\n", + " body_name=\"pole\",\n", + " distance=3.0,\n", + " elevation=10.0,\n", + " azimuth=90.0,\n", + " )\n", + "\n", + " sim_cfg = SimulationCfg(\n", + " mujoco=MujocoCfg(\n", + " timestep=0.02, # 50 Hz control\n", + " iterations=1,\n", + " ),\n", + " )\n", + "\n", + " # ==============================================================================\n", + " # Actions\n", + " # ==============================================================================\n", + "\n", + " actions = {\n", + " \"joint_pos\": JointVelocityActionCfg(\n", + " asset_name=\"robot\",\n", + " actuator_names=(\".*\",),\n", + " scale=20.0,\n", + " use_default_offset=False,\n", + " ),\n", + " }\n", + "\n", + " # ==============================================================================\n", + " # Observations\n", + " # ==============================================================================\n", + "\n", + " policy_terms = {\n", + " \"angle\": ObservationTermCfg(\n", + " func=lambda env: env.sim.data.qpos[:, 1:2] / math.pi\n", + " ),\n", + " \"ang_vel\": ObservationTermCfg(\n", + " func=lambda env: env.sim.data.qvel[:, 1:2] / 5.0\n", + " ),\n", + " \"cart_pos\": ObservationTermCfg(\n", + " func=lambda env: env.sim.data.qpos[:, 0:1] / 2.0\n", + " ),\n", + " \"cart_vel\": ObservationTermCfg(\n", + " func=lambda env: env.sim.data.qvel[:, 0:1] / 20.0\n", + " ),\n", + " }\n", + "\n", + " observations = {\n", + " \"policy\": ObservationGroupCfg(\n", + " terms=policy_terms,\n", + " concatenate_terms=True,\n", + " enable_corruption=not play, # Disable corruption in play mode\n", + " ),\n", + " \"critic\": ObservationGroupCfg(\n", + " terms=policy_terms, # Critic uses same observations\n", + " concatenate_terms=True,\n", + " enable_corruption=False,\n", + " ),\n", + " }\n", + "\n", + " # ==============================================================================\n", + " # Rewards\n", + " # ==============================================================================\n", + "\n", + " def compute_upright_reward(env):\n", + " \"\"\"Reward for keeping pole upright (cosine of angle).\"\"\"\n", + " return env.sim.data.qpos[:, 1].cos()\n", + "\n", + " def compute_effort_penalty(env):\n", + " \"\"\"Penalty for control effort.\"\"\"\n", + " return -0.01 * (env.sim.data.ctrl[:, 0] ** 2)\n", + "\n", + " rewards = {\n", + " \"upright\": RewardTermCfg(func=compute_upright_reward, weight=5.0),\n", + " \"effort\": RewardTermCfg(func=compute_effort_penalty, weight=1.0),\n", + " }\n", + "\n", + " # ==============================================================================\n", + " # Events\n", + " # ==============================================================================\n", + "\n", + " def random_push_cart(env, env_ids, force_range=(-5, 5)):\n", + " \"\"\"Apply random force to cart for robustness training.\"\"\"\n", + " n = len(env_ids)\n", + " random_forces = (\n", + " torch.rand(n, device=env.device) *\n", + " (force_range[1] - force_range[0]) +\n", + " force_range[0]\n", + " )\n", + " env.sim.data.qfrc_applied[env_ids, 0] = random_forces\n", + "\n", + " events = {\n", + " \"reset_robot_joints\": EventTermCfg(\n", + " func=mdp.reset_joints_by_offset,\n", + " mode=\"reset\",\n", + " params={\n", + " \"asset_cfg\": SceneEntityCfg(\"robot\"),\n", + " \"position_range\": (-0.1, 0.1),\n", + " \"velocity_range\": (-0.1, 0.1),\n", + " },\n", + " ),\n", + " }\n", + "\n", + " # Add random pushes only in training mode\n", + " if not play:\n", + " events[\"random_push\"] = EventTermCfg(\n", + " func=random_push_cart,\n", + " mode=\"interval\",\n", + " interval_range_s=(1.0, 2.0),\n", + " params={\"force_range\": (-20.0, 20.0)},\n", + " )\n", + "\n", + " # ==============================================================================\n", + " # Terminations\n", + " # ==============================================================================\n", + "\n", + " def check_pole_tipped(env):\n", + " \"\"\"Check if pole has tipped beyond 30 degrees.\"\"\"\n", + " return env.sim.data.qpos[:, 1].abs() > math.radians(30)\n", + "\n", + " terminations = {\n", + " \"timeout\": TerminationTermCfg(func=mdp.time_out, time_out=True),\n", + " \"tipped\": TerminationTermCfg(func=check_pole_tipped, time_out=False),\n", + " }\n", + "\n", + " # ==============================================================================\n", + " # Environment Configuration\n", + " # ==============================================================================\n", + "\n", + " return ManagerBasedRlEnvCfg(\n", + " scene=scene_cfg,\n", + " observations=observations,\n", + " actions=actions,\n", + " rewards=rewards,\n", + " events=events,\n", + " terminations=terminations,\n", + " sim=sim_cfg,\n", + " viewer=viewer_cfg,\n", + " decimation=1, # No action repeat\n", + " episode_length_s=int(1e9) if play else 10.0, # Infinite for play, 10s for training\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fC5maMjzSj_X" + }, + "source": [ + "### **āš™ļø Create RL Configuration**\n", + "\n", + "This file defines the PPO (Proximal Policy Optimization) training parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "C81zZm6mSj_X" + }, + "outputs": [], + "source": [ + "%%writefile /content/mjlab/src/mjlab/tasks/cartpole/rl_cfg.py\n", + "\"\"\"RL configuration for CartPole task.\"\"\"\n", + "\n", + "from mjlab.rl.config import (\n", + " RslRlOnPolicyRunnerCfg,\n", + " RslRlPpoActorCriticCfg,\n", + " RslRlPpoAlgorithmCfg,\n", + ")\n", + "\n", + "\n", + "def cartpole_ppo_runner_cfg() -> RslRlOnPolicyRunnerCfg:\n", + " \"\"\"Create RL runner configuration for CartPole task.\"\"\"\n", + " return RslRlOnPolicyRunnerCfg(\n", + " policy=RslRlPpoActorCriticCfg(\n", + " init_noise_std=1.0,\n", + " actor_obs_normalization=True,\n", + " critic_obs_normalization=True,\n", + " actor_hidden_dims=(256, 128, 64), # Smaller network for simpler task\n", + " critic_hidden_dims=(256, 128, 64),\n", + " activation=\"elu\",\n", + " ),\n", + " algorithm=RslRlPpoAlgorithmCfg(\n", + " value_loss_coef=1.0,\n", + " use_clipped_value_loss=True,\n", + " clip_param=0.2,\n", + " entropy_coef=0.01,\n", + " num_learning_epochs=5,\n", + " num_mini_batches=4,\n", + " learning_rate=1.0e-3,\n", + " schedule=\"adaptive\",\n", + " gamma=0.99,\n", + " lam=0.95,\n", + " desired_kl=0.01,\n", + " max_grad_norm=1.0,\n", + " ),\n", + " experiment_name=\"cartpole\",\n", + " save_interval=50,\n", + " num_steps_per_env=24,\n", + " max_iterations=5_000, # Fewer iterations for simpler task\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Oc8-AHGcHt78" + }, + "source": [ + "### **šŸ“‹ Register the Task Environment**\n", + "\n", + "Register the CartPole task with mjlab registry." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YitUGUBRHxD4" + }, + "outputs": [], + "source": [ + "%%writefile /content/mjlab/src/mjlab/tasks/cartpole/__init__.py\n", + "\"\"\"CartPole task registration.\"\"\"\n", + "\n", + "from mjlab.tasks.registry import register_mjlab_task\n", + "from mjlab.tasks.velocity.rl import VelocityOnPolicyRunner\n", + "\n", + "from .env_cfg import cartpole_env_cfg\n", + "from .rl_cfg import cartpole_ppo_runner_cfg\n", + "\n", + "register_mjlab_task(\n", + " task_id=\"Mjlab-Cartpole\",\n", + " env_cfg=cartpole_env_cfg(),\n", + " play_env_cfg=cartpole_env_cfg(play=True),\n", + " rl_cfg=cartpole_ppo_runner_cfg(),\n", + " runner_cls=VelocityOnPolicyRunner,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "K7wqLZR1rnGn" + }, + "source": [ + "---\n", + "\n", + "## **šŸš€ Step 3: Train the Agent**\n", + "\n", + "Now let's train a PPO policy to balance the CartPole!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Hht_hF4trqP2" + }, + "outputs": [], + "source": [ + "!python -m mjlab.scripts.train Mjlab-Cartpole --agent.max-iterations 201 --agent.save-interval 20" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xCaqPznGrx8H" + }, + "source": [ + "### **šŸ“ Locate Training Checkpoints**\n", + "\n", + "After training, checkpoints are saved locally." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "uPnmHYu8r0uY" + }, + "outputs": [], + "source": [ + "import os\n", + "import re\n", + "from pathlib import Path\n", + "\n", + "# Find the most recent training run\n", + "log_dir = Path(\"/content/mjlab/logs/rsl_rl/cartpole\")\n", + "if log_dir.exists():\n", + " runs = sorted(log_dir.glob(\"*\"), key=os.path.getmtime, reverse=True)\n", + " if runs:\n", + " latest_run = runs[0]\n", + " print(f\"āœ“ Latest training run: {latest_run.name}\\n\")\n", + "\n", + " # List checkpoints - sorted by iteration number\n", + " checkpoints = list(latest_run.glob(\"model_*.pt\"))\n", + " if checkpoints:\n", + " # Extract iteration number and sort numerically\n", + " def get_iteration(ckpt):\n", + " match = re.search(r\"model_(\\d+)\\.pt\", ckpt.name)\n", + " return int(match.group(1)) if match else 0\n", + "\n", + " checkpoints = sorted(checkpoints, key=get_iteration)\n", + "\n", + " print(f\"Found {len(checkpoints)} checkpoints:\")\n", + " for ckpt in checkpoints[-5:]: # Show last 5\n", + " size_mb = ckpt.stat().st_size / (1024 * 1024)\n", + " iteration = get_iteration(ckpt)\n", + " print(f\" • {ckpt.name} (iteration {iteration}, {size_mb:.2f} MB)\")\n", + "\n", + " # Store the last checkpoint path\n", + " last_checkpoint = str(checkpoints[-1])\n", + " print(f\"\\nšŸ’¾ Last checkpoint: {last_checkpoint}\")\n", + " else:\n", + " print(\"⚠ No checkpoints found yet\")\n", + " else:\n", + " print(\"⚠ No training runs found\")\n", + "else:\n", + " print(\"⚠ Log directory not found. Have you run training yet?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eWFS9Pw7r2uH" + }, + "source": [ + "---\n", + "\n", + "## **šŸŽ® Step 4: Visualize the Trained Policy**\n", + "\n", + "Let's see the trained policy in action!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "78PgHtpfr5sb" + }, + "source": [ + "### **🌐 Launch Viser API**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_9tGiFyBr2bW" + }, + "outputs": [], + "source": [ + "import subprocess\n", + "import sys\n", + "\n", + "process = subprocess.Popen(\n", + " [\n", + " \"python\",\n", + " \"-m\",\n", + " \"mjlab.scripts.play\",\n", + " \"Mjlab-Cartpole\",\n", + " \"--checkpoint_file\",\n", + " last_checkpoint,\n", + " \"--num_envs\",\n", + " \"4\",\n", + " ],\n", + " stdout=subprocess.PIPE,\n", + " stderr=subprocess.STDOUT,\n", + " universal_newlines=True,\n", + " bufsize=1,\n", + ")\n", + "\n", + "detected_port = None\n", + "\n", + "for line in process.stdout:\n", + " print(line, end=\"\")\n", + " sys.stdout.flush()\n", + "\n", + " # Extract port number from viser output\n", + " port_match = re.search(r\":(\\d{4})\", line)\n", + " if port_match and \"viser\" in line.lower():\n", + " detected_port = int(port_match.group(1))\n", + " print(\"\\n\" + \"=\" * 34)\n", + " print(f\"āœ… Server is running on port {detected_port}!\")\n", + " print(\"=\" * 34)\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XgzJXyBXssZS" + }, + "source": [ + "### **šŸ–„ļø Embed Client as iframe**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ll89QnuSuUxx" + }, + "outputs": [], + "source": [ + "from google.colab import output\n", + "\n", + "port = detected_port if detected_port else 8081\n", + "output.serve_kernel_port_as_iframe(port=port, height=700)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [], + "toc_visible": true, + "include_colab_link": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file