diff --git a/examples/reinforce/README.md b/examples/reinforce/README.md new file mode 100644 index 000000000..f032b892d --- /dev/null +++ b/examples/reinforce/README.md @@ -0,0 +1,77 @@ +# REINFORCE Algorithm + +Implementation of the REINFORCE policy gradient algorithm ([Williams, 1992](https://doi.org/10.1007/BF00992696)) for solving the CartPole-v1 environment using JAX and Flax NNX. + +## Overview + +This example demonstrates: +- Policy gradient learning with Monte Carlo returns +- Neural network policy using Flax NNX +- Entropy regularization for improved exploration +- JAX transformations for efficient gradient computation + +The agent learns to balance a pole on a cart, achieving an average reward of 490+ (solving threshold: 475) within ~480 episodes. + +## Requirements + +This example requires `gymnax` for the environment and `tqdm` for progress tracking: +```bash +pip install -r requirements.txt +``` + +Required packages: +- `jax>=0.4.13` +- `flax==0.10.6` (NNX API) +- `optax>=0.1.7` +- `gymnax>=0.0.6` +- `tqdm>=4.65.0` +- `orbax-checkpoint>=0.4.0` + +## Implementation Details + +### Algorithm +- **Policy**: 3-layer MLP (128-128-2 units) with leaky ReLU activations +- **Optimization**: Adam optimizer with learning rate decay (1e-3 → decay) +- **Loss**: Policy gradient loss with entropy bonus (coefficient: 0.01) +- **Returns**: Discounted returns with gamma=0.99, normalized per episode +- **Gradient clipping**: Global norm clipping at 1.0 + +### Key Features +- Pure JAX/Flax implementation with JIT compilation +- Efficient episode rollouts using Gymnax +- Xavier weight initialization +- Exponential learning rate decay +- Early stopping when performance threshold is reached + + +## How to run + +There are two implementations available in the CartPole directory: +- `simple_reinforce.ipynb`: Standard REINFORCE implementation + +Each notebook includes: +- Environment setup +- Policy network definition +- Training loop +- Performance monitoring + +## Training Results + +The implementation tracks various metrics during training: + +1. Reward Progress: +![Reward Plot](training_rewards.png) + +2. Trained Agent: + +![CartPole Agent](anim.gif) + +## References + +Williams, R. J. (1992). Simple statistical gradient-following algorithms for connectionist reinforcement learning. *Machine Learning*, 8(3-4), 229-256. [DOI: 10.1007/BF00992696](https://doi.org/10.1007/BF00992696) + +## Additional Resources + +- [Flax NNX Documentation](https://flax.readthedocs.io/en/latest/nnx/index.html) +- [Gymnax Documentation](https://github.com/RobertTLange/gymnax) +- [Policy Gradient Methods (Sutton & Barto)](http://incompleteideas.net/book/RLbook2020.pdf) diff --git a/examples/reinforce/anim.gif b/examples/reinforce/anim.gif new file mode 100644 index 000000000..1c07239f2 Binary files /dev/null and b/examples/reinforce/anim.gif differ diff --git a/examples/reinforce/requirements.txt b/examples/reinforce/requirements.txt new file mode 100644 index 000000000..4bc17913d --- /dev/null +++ b/examples/reinforce/requirements.txt @@ -0,0 +1,6 @@ +jax>=0.4.13 +flax==0.10.6 +optax>=0.1.7 +gymnax>=0.0.6 +tqdm>=4.65.0 +orbax-checkpoint>=0.4.0 \ No newline at end of file diff --git a/examples/reinforce/simple_reinforce.ipynb b/examples/reinforce/simple_reinforce.ipynb new file mode 100644 index 000000000..c94036165 --- /dev/null +++ b/examples/reinforce/simple_reinforce.ipynb @@ -0,0 +1,448 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "outputId": "ebd66a11-5ebd-4b24-ea0b-ee3b5dc48dcb" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: jax>=0.4.13 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from -r requirements.txt (line 1)) (0.8.0)\n", + "Requirement already satisfied: flax==0.10.6 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from -r requirements.txt (line 2)) (0.10.6)\n", + "Requirement already satisfied: optax>=0.1.7 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from -r requirements.txt (line 3)) (0.2.6)\n", + "Requirement already satisfied: gymnax>=0.0.6 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from -r requirements.txt (line 4)) (0.0.9)\n", + "Requirement already satisfied: wandb>=0.15.0 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from -r requirements.txt (line 5)) (0.22.2)\n", + "Requirement already satisfied: tqdm>=4.65.0 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from -r requirements.txt (line 6)) (4.67.1)\n", + "Requirement already satisfied: gymnasium>=0.29.0 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from -r requirements.txt (line 7)) (1.2.1)\n", + "Requirement already satisfied: imageio>=2.31.0 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from -r requirements.txt (line 8)) (2.37.0)\n", + "Requirement already satisfied: orbax-checkpoint>=0.4.0 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from -r requirements.txt (line 9)) (0.11.25)\n", + "Requirement already satisfied: numpy>=1.24.0 in /home/codespace/.local/lib/python3.12/site-packages (from -r requirements.txt (line 10)) (2.3.1)\n", + "Requirement already satisfied: msgpack in /usr/local/python/3.12.1/lib/python3.12/site-packages (from flax==0.10.6->-r requirements.txt (line 2)) (1.1.2)\n", + "Requirement already satisfied: tensorstore in /usr/local/python/3.12.1/lib/python3.12/site-packages (from flax==0.10.6->-r requirements.txt (line 2)) (0.1.78)\n", + "Requirement already satisfied: rich>=11.1 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from flax==0.10.6->-r requirements.txt (line 2)) (14.2.0)\n", + "Requirement already satisfied: typing_extensions>=4.2 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from flax==0.10.6->-r requirements.txt (line 2)) (4.15.0)\n", + "Requirement already satisfied: PyYAML>=5.4.1 in /home/codespace/.local/lib/python3.12/site-packages (from flax==0.10.6->-r requirements.txt (line 2)) (6.0.2)\n", + "Requirement already satisfied: treescope>=0.1.7 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from flax==0.10.6->-r requirements.txt (line 2)) (0.1.10)\n", + "Requirement already satisfied: jaxlib<=0.8.0,>=0.8.0 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from jax>=0.4.13->-r requirements.txt (line 1)) (0.8.0)\n", + "Requirement already satisfied: ml_dtypes>=0.5.0 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from jax>=0.4.13->-r requirements.txt (line 1)) (0.5.3)\n", + "Requirement already satisfied: opt_einsum in /usr/local/python/3.12.1/lib/python3.12/site-packages (from jax>=0.4.13->-r requirements.txt (line 1)) (3.4.0)\n", + "Requirement already satisfied: scipy>=1.13 in /home/codespace/.local/lib/python3.12/site-packages (from jax>=0.4.13->-r requirements.txt (line 1)) (1.16.0)\n", + "Requirement already satisfied: absl-py>=0.7.1 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from optax>=0.1.7->-r requirements.txt (line 3)) (2.3.1)\n", + "Requirement already satisfied: chex>=0.1.87 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from optax>=0.1.7->-r requirements.txt (line 3)) (0.1.91)\n", + "Requirement already satisfied: matplotlib in /home/codespace/.local/lib/python3.12/site-packages (from gymnax>=0.0.6->-r requirements.txt (line 4)) (3.10.3)\n", + "Requirement already satisfied: seaborn in /home/codespace/.local/lib/python3.12/site-packages (from gymnax>=0.0.6->-r requirements.txt (line 4)) (0.13.2)\n", + "Requirement already satisfied: click>=8.0.1 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from wandb>=0.15.0->-r requirements.txt (line 5)) (8.3.0)\n", + "Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /home/codespace/.local/lib/python3.12/site-packages (from wandb>=0.15.0->-r requirements.txt (line 5)) (3.1.44)\n", + "Requirement already satisfied: packaging in /home/codespace/.local/lib/python3.12/site-packages (from wandb>=0.15.0->-r requirements.txt (line 5)) (25.0)\n", + "Requirement already satisfied: platformdirs in /home/codespace/.local/lib/python3.12/site-packages (from wandb>=0.15.0->-r requirements.txt (line 5)) (4.3.8)\n", + "Requirement already satisfied: protobuf!=4.21.0,!=5.28.0,<7,>=3.19.0 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from wandb>=0.15.0->-r requirements.txt (line 5)) (6.33.0)\n", + "Requirement already satisfied: pydantic<3 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from wandb>=0.15.0->-r requirements.txt (line 5)) (2.12.3)\n", + "Requirement already satisfied: requests<3,>=2.0.0 in /home/codespace/.local/lib/python3.12/site-packages (from wandb>=0.15.0->-r requirements.txt (line 5)) (2.32.4)\n", + "Requirement already satisfied: sentry-sdk>=2.0.0 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from wandb>=0.15.0->-r requirements.txt (line 5)) (2.42.0)\n", + "Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from pydantic<3->wandb>=0.15.0->-r requirements.txt (line 5)) (0.7.0)\n", + "Requirement already satisfied: pydantic-core==2.41.4 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from pydantic<3->wandb>=0.15.0->-r requirements.txt (line 5)) (2.41.4)\n", + "Requirement already satisfied: typing-inspection>=0.4.2 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from pydantic<3->wandb>=0.15.0->-r requirements.txt (line 5)) (0.4.2)\n", + "Requirement already satisfied: charset_normalizer<4,>=2 in /home/codespace/.local/lib/python3.12/site-packages (from requests<3,>=2.0.0->wandb>=0.15.0->-r requirements.txt (line 5)) (3.4.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /home/codespace/.local/lib/python3.12/site-packages (from requests<3,>=2.0.0->wandb>=0.15.0->-r requirements.txt (line 5)) (3.10)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/codespace/.local/lib/python3.12/site-packages (from requests<3,>=2.0.0->wandb>=0.15.0->-r requirements.txt (line 5)) (2.5.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /home/codespace/.local/lib/python3.12/site-packages (from requests<3,>=2.0.0->wandb>=0.15.0->-r requirements.txt (line 5)) (2025.7.9)\n", + "Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from gymnasium>=0.29.0->-r requirements.txt (line 7)) (3.1.1)\n", + "Requirement already satisfied: farama-notifications>=0.0.1 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from gymnasium>=0.29.0->-r requirements.txt (line 7)) (0.0.4)\n", + "Requirement already satisfied: pillow>=8.3.2 in /home/codespace/.local/lib/python3.12/site-packages (from imageio>=2.31.0->-r requirements.txt (line 8)) (11.3.0)\n", + "Requirement already satisfied: etils[epath,epy] in /usr/local/python/3.12.1/lib/python3.12/site-packages (from orbax-checkpoint>=0.4.0->-r requirements.txt (line 9)) (1.13.0)\n", + "Requirement already satisfied: nest_asyncio in /home/codespace/.local/lib/python3.12/site-packages (from orbax-checkpoint>=0.4.0->-r requirements.txt (line 9)) (1.6.0)\n", + "Requirement already satisfied: aiofiles in /usr/local/python/3.12.1/lib/python3.12/site-packages (from orbax-checkpoint>=0.4.0->-r requirements.txt (line 9)) (25.1.0)\n", + "Requirement already satisfied: humanize in /usr/local/python/3.12.1/lib/python3.12/site-packages (from orbax-checkpoint>=0.4.0->-r requirements.txt (line 9)) (4.14.0)\n", + "Requirement already satisfied: simplejson>=3.16.0 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from orbax-checkpoint>=0.4.0->-r requirements.txt (line 9)) (3.20.2)\n", + "Requirement already satisfied: toolz>=1.0.0 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from chex>=0.1.87->optax>=0.1.7->-r requirements.txt (line 3)) (1.1.0)\n", + "Requirement already satisfied: gitdb<5,>=4.0.1 in /home/codespace/.local/lib/python3.12/site-packages (from gitpython!=3.1.29,>=1.0.0->wandb>=0.15.0->-r requirements.txt (line 5)) (4.0.12)\n", + "Requirement already satisfied: smmap<6,>=3.0.1 in /home/codespace/.local/lib/python3.12/site-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb>=0.15.0->-r requirements.txt (line 5)) (5.0.2)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from rich>=11.1->flax==0.10.6->-r requirements.txt (line 2)) (4.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/codespace/.local/lib/python3.12/site-packages (from rich>=11.1->flax==0.10.6->-r requirements.txt (line 2)) (2.19.2)\n", + "Requirement already satisfied: mdurl~=0.1 in /usr/local/python/3.12.1/lib/python3.12/site-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax==0.10.6->-r requirements.txt (line 2)) (0.1.2)\n", + "Requirement already satisfied: fsspec in /home/codespace/.local/lib/python3.12/site-packages (from etils[epath,epy]->orbax-checkpoint>=0.4.0->-r requirements.txt (line 9)) (2024.6.1)\n", + "Requirement already satisfied: importlib_resources in /usr/local/python/3.12.1/lib/python3.12/site-packages (from etils[epath,epy]->orbax-checkpoint>=0.4.0->-r requirements.txt (line 9)) (6.5.2)\n", + "Requirement already satisfied: zipp in /usr/local/python/3.12.1/lib/python3.12/site-packages (from etils[epath,epy]->orbax-checkpoint>=0.4.0->-r requirements.txt (line 9)) (3.23.0)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /home/codespace/.local/lib/python3.12/site-packages (from matplotlib->gymnax>=0.0.6->-r requirements.txt (line 4)) (1.3.2)\n", + "Requirement already satisfied: cycler>=0.10 in /home/codespace/.local/lib/python3.12/site-packages (from matplotlib->gymnax>=0.0.6->-r requirements.txt (line 4)) (0.12.1)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /home/codespace/.local/lib/python3.12/site-packages (from matplotlib->gymnax>=0.0.6->-r requirements.txt (line 4)) (4.58.5)\n", + "Requirement already satisfied: kiwisolver>=1.3.1 in /home/codespace/.local/lib/python3.12/site-packages (from matplotlib->gymnax>=0.0.6->-r requirements.txt (line 4)) (1.4.8)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in /home/codespace/.local/lib/python3.12/site-packages (from matplotlib->gymnax>=0.0.6->-r requirements.txt (line 4)) (3.2.3)\n", + "Requirement already satisfied: python-dateutil>=2.7 in /home/codespace/.local/lib/python3.12/site-packages (from matplotlib->gymnax>=0.0.6->-r requirements.txt (line 4)) (2.9.0.post0)\n", + "Requirement already satisfied: six>=1.5 in /home/codespace/.local/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib->gymnax>=0.0.6->-r requirements.txt (line 4)) (1.17.0)\n", + "Requirement already satisfied: pandas>=1.2 in /home/codespace/.local/lib/python3.12/site-packages (from seaborn->gymnax>=0.0.6->-r requirements.txt (line 4)) (2.3.1)\n", + "Requirement already satisfied: pytz>=2020.1 in /home/codespace/.local/lib/python3.12/site-packages (from pandas>=1.2->seaborn->gymnax>=0.0.6->-r requirements.txt (line 4)) (2025.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /home/codespace/.local/lib/python3.12/site-packages (from pandas>=1.2->seaborn->gymnax>=0.0.6->-r requirements.txt (line 4)) (2025.2)\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m25.1.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.2\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install -r requirements.txt" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "vscode": { + "languageId": "markdown" + } + }, + "source": [ + "The **JAX RL Training Workflow**:\n", + "\n", + "1. Run the entire notebook end-to-end and check out the outputs.\n", + " - This will train a policy network on the CartPole environment\n", + " - You'll be able to track training progress using graphs\n", + "2. Experiment with different hyperparameters in the `train()` function:\n", + " - Learning rate (currently 1e-3)\n", + " - Gamma discount factor (currently 0.99)\n", + " - Number of episodes (currently 1000)\n", + "3. Update the code in the Policy class and training loop:\n", + " - Modify the network architecture\n", + " - Add new metrics to track in wandb\n", + " - Try different optimizers or gradient clipping values\n", + " - Experiment with different action selection strategies\n", + "4. The trained model is automatically saved and can be loaded for:\n", + " - Further training\n", + " - Evaluation\n", + " - Deployment in different environments" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import gymnax\n", + "from flax import nnx\n", + "import jax.numpy as jnp\n", + "import matplotlib.pyplot as plt\n", + "import optax\n", + "import collections\n", + "from tqdm import tqdm\n", + "\n", + "key = jax.random.key(0)\n", + "key, key_reset, key_act, key_step = jax.random.split(key, 4)\n", + "\n", + "env, env_params = gymnax.make(\"CartPole-v1\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "outputId": "30e8b755-3e09-4acd-d37c-80bac5b20a0d" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "4" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "env.action_space(env_params).n\n", + "env.observation_space(env_params).shape[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class Policy(nnx.Module):\n", + " \"\"\"A neural network policy for the REINFORCE algorithm.\n", + " \n", + " This policy network maps observations to action probabilities using\n", + " a 3-layer neural network with leaky ReLU activations. It implements\n", + " both forward pass and action selection methods.\n", + " \"\"\"\n", + " def __init__(self, observation_space, action_space, rngs:nnx.Rngs):\n", + " \"\"\"Initialize the policy network.\n", + " \n", + " Args:\n", + " observation_space: The environment's observation space object\n", + " action_space: The environment's action space object\n", + " rngs: Random number generator state for initialization\n", + " \"\"\"\n", + " super().__init__()\n", + " \n", + " # Initialize layers with Xavier normal initialization for better gradient flow\n", + " kernel_init = nnx.initializers.xavier_normal()\n", + " self.layer1 = nnx.Linear(observation_space.shape[0], 128, rngs = rngs, kernel_init=kernel_init)\n", + " self.layer2 = nnx.Linear(128, 128, rngs=rngs, kernel_init=kernel_init)\n", + " self.layer3 = nnx.Linear(128, action_space.n, rngs=rngs, kernel_init=kernel_init)\n", + "\n", + " def __call__(self, x):\n", + " \"\"\"Forward pass through the network.\n", + " \n", + " Args:\n", + " x: Input observation tensor\n", + " Returns:\n", + " Logits for action probabilities\n", + " \"\"\"\n", + " x = jax.nn.leaky_relu(self.layer1(x))\n", + " x = jax.nn.leaky_relu(self.layer2(x))\n", + " return self.layer3(x)\n", + "\n", + " def select_action(self, x, key):\n", + " \"\"\"Sample an action from the policy's probability distribution.\n", + " \n", + " Args:\n", + " x: Input observation tensor\n", + " key: JAX random key for sampling\n", + " Returns:\n", + " Selected action index\n", + " \"\"\"\n", + " logits = self(x)\n", + " return jax.random.categorical(key, logits)\n", + " \n", + "@nnx.jit\n", + "def loss(model, obs, actions, returns):\n", + " \"\"\"Compute the REINFORCE policy gradient loss.\n", + " This implements the policy gradient loss: -E[log(π(a|s)) * G_t]\n", + " where π(a|s) is the policy's probability of taking action a in state s,\n", + " G_t is the discounted return from timestep t, and E[] denotes the\n", + " expectation (implemented as a mean over the batch).\n", + " \n", + " Args:\n", + " model: The policy network\n", + " obs: Batch of observations\n", + " actions: Batch of actions taken\n", + " returns: Batch of discounted returns\n", + " Returns:\n", + " Mean policy gradient loss\n", + " \"\"\"\n", + " log_logits = jax.nn.log_softmax(model(obs))\n", + " # Get log probabilities of the actions that were actually taken\n", + " log_prob_taken = jnp.take_along_axis(log_logits, actions[:, None], axis=1).squeeze()\n", + " # Negative mean for gradient ascent (we want to maximize expected return)\n", + " return -jnp.mean(log_prob_taken * returns)\n", + "\n", + "def compute_returns(rewards, gamma):\n", + " \"\"\"Compute discounted returns for a sequence of rewards.\n", + " \n", + " Args:\n", + " rewards: List of rewards from an episode\n", + " gamma: Discount factor (0 < gamma <= 1)\n", + " Returns:\n", + " Array of discounted returns for each timestep\n", + " \"\"\"\n", + " R = 0\n", + " returns = []\n", + " # Compute returns backwards: R[t] = r[t] + gamma * R[t+1]\n", + " for r in reversed(rewards):\n", + " R = r + gamma * R\n", + " returns.insert(0, R)\n", + " return jnp.array(returns)\n", + "\n", + "# Initialize the policy network with the environment's spaces\n", + "model = Policy(env.observation_space(env_params), env.action_space(env_params), rngs=nnx.Rngs(0))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def train(env, env_params, model, episodes: int = 50, learning_rate=1e-3, gamma=0.99):\n", + " \"\"\"Train the policy using REINFORCE algorithm.\n", + " \n", + " Implements the REINFORCE algorithm to train a policy network through episodic interaction\n", + " with the environment. Uses exponential learning rate decay and gradient clipping for stability.\n", + " \n", + " Args:\n", + " env: The Gymnax environment\n", + " env_params: Environment parameters\n", + " model: The policy network to train\n", + " episodes: Number of training episodes\n", + " learning_rate: Initial learning rate\n", + " gamma: Discount factor for computing returns\n", + " Returns:\n", + " Tuple of (losses, rewards) lists tracking training progress\n", + " \"\"\"\n", + " # Setup learning rate schedule with exponential decay\n", + " # Decays by factor of 0.8 every 100 optimizer steps (not episodes)\n", + " lr_schedule = optax.exponential_decay(\n", + " init_value=learning_rate,\n", + " transition_steps=100,\n", + " decay_rate=0.8,\n", + " )\n", + " \n", + " # Create optimizer with gradient clipping and Adam\n", + " optimizer = nnx.Optimizer(\n", + " model,\n", + " optax.chain(\n", + " optax.clip_by_global_norm(1.0), # Prevent exploding gradients\n", + " optax.adam(learning_rate=lr_schedule),\n", + " ),\n", + " wrt=nnx.Param\n", + " )\n", + " \n", + " grad_func = nnx.value_and_grad(loss)\n", + " key = jax.random.PRNGKey(0)\n", + " # Keep track of last 100 episode rewards for running average\n", + " total_rewards = collections.deque(maxlen=100)\n", + "\n", + " all_losses = []\n", + " all_rewards = []\n", + " with tqdm(range(episodes)) as pbar:\n", + " for i in pbar:\n", + " # Lists to store episode data\n", + " episode_obs = []\n", + " episode_actions = []\n", + " episode_rewards = []\n", + "\n", + " # Run one episode\n", + " done = False\n", + " key, reset_key = jax.random.split(key)\n", + " obs, state = env.reset(reset_key, env_params)\n", + "\n", + " while not done:\n", + " # Sample action from policy and step environment\n", + " key, action_key, step_key = jax.random.split(key, 3)\n", + " action = model.select_action(obs, action_key)\n", + " next_obs, state, reward, done, _ = env.step(step_key, state, action, env_params)\n", + "\n", + " # Store transition data\n", + " episode_obs.append(obs)\n", + " episode_actions.append(action)\n", + " episode_rewards.append(reward)\n", + " obs = next_obs\n", + "\n", + " # Process episode data for training\n", + " total_rewards.append(sum(episode_rewards))\n", + " returns = compute_returns(episode_rewards, gamma)\n", + "\n", + " # Convert lists to arrays for JAX operations\n", + " final_obs = jnp.stack(episode_obs)\n", + " final_actions = jnp.array(episode_actions)\n", + " final_returns = jnp.array(returns)\n", + "\n", + " # Normalize returns (mean=0, std=1) for stable gradient scaling\n", + " # This reduces variance in policy gradient estimates\n", + " final_returns = (final_returns - jnp.mean(final_returns)) / (jnp.std(final_returns) + 1e-8)\n", + "\n", + " # Compute gradients and update policy\n", + " value, grad = grad_func(model, final_obs, final_actions, final_returns)\n", + " optimizer.update(grad)\n", + "\n", + " # Track and display progress\n", + " avg_reward = sum(total_rewards) / len(total_rewards)\n", + " all_losses.append(value.item())\n", + " all_rewards.append(avg_reward)\n", + "\n", + " pbar.set_description(f\"Episode: {i}, Loss: {value.item():.4f}, Reward: {avg_reward:.2f}\")\n", + " # Early stopping if we solve the environment\n", + " if avg_reward >= 495:\n", + " print(\"\\n\")\n", + " print(f\"Score of {avg_reward} reached, stopping training\")\n", + " break\n", + "\n", + " return all_losses, all_rewards" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "outputId": "d5ec74d4-0883-440b-bf25-0766859c46ec" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/1000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Plot only the reward and save it\n", + "plt.figure(figsize=(10, 6))\n", + "\n", + "plt.plot(rewards, linewidth=2)\n", + "plt.title('Average Reward over Episodes', fontsize=16, fontweight='bold')\n", + "plt.xlabel('Episode', fontsize=12)\n", + "plt.ylabel('Average Reward', fontsize=12)\n", + "plt.grid(True, alpha=0.3)\n", + "\n", + "# Save the plot\n", + "plt.savefig('training_rewards.png', dpi=300, bbox_inches='tight')\n", + "print(\"Plot saved as 'training_rewards.png'\")\n", + "\n", + "plt.show()\n", + "plt.close()" + ] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.1" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/examples/reinforce/training_rewards.png b/examples/reinforce/training_rewards.png new file mode 100644 index 000000000..1025ba03a Binary files /dev/null and b/examples/reinforce/training_rewards.png differ