diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml
index e53f87f73e..2eeb11a3cd 100644
--- a/src/MaxText/configs/base.yml
+++ b/src/MaxText/configs/base.yml
@@ -47,6 +47,7 @@ enable_checkpointing: True
 save_checkpoint_on_completion: True
 async_checkpointing: True
 checkpoint_period: 10_000
+max_num_checkpoints_to_keep: None
 # enables one replica to read the ckpt then broadcast to the rest
 enable_single_replica_ckpt_restoring: False
 
diff --git a/src/MaxText/configs/rl.yml b/src/MaxText/configs/rl.yml
index 38108dc82b..79081e644d 100644
--- a/src/MaxText/configs/rl.yml
+++ b/src/MaxText/configs/rl.yml
@@ -12,66 +12,190 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+# RL Configuration
+# This config consolidates common parameters for RL training across different model sizes
+
 base_config: "base.yml"
 
-logical_axis_rules: [
-                      ['prefill_activation_length', ['data']],
-                      ['prefill_activation_norm_length', ['data']],
-                      ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
-                      ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
-                      ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
-                      ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
-                      ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
-                      ['activation_length', ['context_autoregressive', 'sequence']],
-                      ['activation_length', ['context_autoregressive']],
-                      ['activation_q_length', ['context_autoregressive']],
-                      ['activation_kv_length', ['context_autoregressive']],
-                      ['activation_norm_length', ['tensor_sequence', 'sequence']],
-                      ['activation_embed', ['tensor_transpose']],
-                      ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
-                      ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
-                      ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
-                      ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']],
-                      ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']],
-                      ['activation_vocab', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']],
-                      ['activation_vocab', ['tensor', 'tensor_transpose']],
-                      ['activation_vocab', 'tensor_sequence'],
-                      ['activation_vocab', ['sequence', 'context_autoregressive']],
-                      ['activation_stage', 'stage'],
-                      ['activation_exp', ['expert', 'context_autoregressive']],
-                      ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']],
-                      ['decode_length', []],
-                      ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
-                      ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context_autoregressive']],
-                      ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
-                      ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
-                      ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
-                      ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'expert']],
-                      ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']],
-                      ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']],
-                      ['embed', ['fsdp', 'sequence', 'expert']],
-                      ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']],
-                      ['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']],
-                      ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']],
-                      ['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive']],
-                      ['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']],
-                      ['layers', 'stage'],
-                      ['kv', []],
-                      ['kv_head_dim', []],
-                      ['cache_batch_prefill', []],
-                      ['cache_batch', ['context_autoregressive']],
-                      ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
-                      ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
-                      ['cache_kv', []],
-                      ['cache_sequence', ['context_autoregressive']],
-                      ['cache_scale_sequence', ['context_autoregressive']],
-                      ['exp', ['expert', 'context_autoregressive']],
-                      ['paged_kv_heads', []],
-                      ['num_pages', ['tensor']],
-                      ['tokens_per_page', []],
-                      ['paged_kv_head_dim_size', []],
-                    ]
-# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
-data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
-
-return_log_prob: True
\ No newline at end of file
+# ====== Hardware =====
+trainer_devices_fraction: 0.5
+sampler_devices_fraction: 0.5
+chips_per_vm: 4  # depends on hardware, for v5p this is 4
+
+# ====== Debug ======
+debug: True
+
+# ====== Reproducibility ======
+data_shuffle_seed: 42
+
+# ====== Checkpoint saving ======
+save_interval_steps: 500
+max_to_keep: 4
+
+# ====== GRPO ======
+
+# The number of times the policy generates multiple responses for a given prompt
+# within a single training step. This corresponds to `G` in Algorithm 1 in the
+# paper. The "group" in GRPO comes from here.
+num_generations: 2
+
+# === other GRPO configs ===
+# The number of iterations per batch (π in GRPO algo 1).
+num_iterations: 1
+
+# The coefficient for the KL divergence penalty (π½) in the GRPO loss function.
+# Important to keep a high enough value for this, otherwise, the KL divergence
+# can increase unchecked.
+grpo_beta: 0.08
+# Epsilon value for clipping (π in GRPO loss in paper). Similar to PPO, for
+# stable updates.
+grpo_epsilon: 0.2
+loss_algo: 'grpo' # grpo or gspo-token
+
+
+# ====== Models ======
+# for MaxText
+# Model and Tokenizer Configuration
+# Override these via CLI:
+# model_name, tokenizer_path, load_parameters_path
+# Model-Specific Overrides (examples)
+# For Llama3.1-8B:
+#   model_name: llama3.1-8b
+#   tokenizer_path: meta-llama/Llama-3.1-8B-Instruct
+#
+# For Llama3.1-70B with Pathways:
+#   model_name: llama3.1-70b
+#   tokenizer_path: meta-llama/Llama-3.1-70B-Instruct
+
+async_checkpointing: 'false'
+checkpoint_period: 5
+skip_jax_distributed_system: True
+weight_dtype: 'bfloat16'
+attention: 'dot_product'
+remat_policy: 'custom'
+decoder_layer_input: 'offload'
+query_proj: 'offload'
+key_proj: 'offload'
+value_proj: 'offload'
+# for vLLM
+hf_model_name: 'meta-llama/Llama-3.1-70B-Instruct'
+
+# ====== Training ======
+batch_size: 1
+# Increase `batch_size` and `MAX_STEPS` for better results.
+# num_batches = 3738
+num_batches = 4  # 200
+# Keep `num_test_batches` low so that evaluation runs quickly. It can be
+# increased to a max. of 330 (if batch size is 4).
+num_test_batches = 5  # 200
+train_fraction: 1.0
+
+eval_interval: 10  # this doesn't matter if `TRAIN_FRACTION = 1.0`.
+
+num_epochs: 1  # can potentially train for more epochs
+
+gradient_clipping_threshold: 0.1
+
+
+# greedy
+eval_temperature: 0.01
+eval_top_k: 1
+eval_top_p: 1.0
+# # some randomness
+# 'standard': {'eval_temperature': 0.7, 'eval_top_k': 50, 'eval_top_p': 0.95},
+# # liberal
+# # 'liberal': {'eval_temperature': 0.85, 'eval_top_k': 2000, 'eval_top_p': 1.0},
+
+
+# ====== Inference ======
+# Important to keep a high-ish temperature for varied, diverse responses during
+# training.
+decode_sampling_temperature: 0.9
+decode_sampling_top_k: 50
+decode_sampling_nucleus_p: 1.0
+
+# for vLLM
+# === Generation during GRPO training ===
+# max Lengths for prompt and completion
+max_prefill_predict_length: 256
+max_target_length: 768
+kv_cache_buffer: 256
+hbm_utilization_vllm: 0.72
+swap_space_vllm_gb: 2
+# Generation Configuration During Training
+decode_sampling_temperature: 0.9
+decode_sampling_top_p: 1.0
+decode_sampling_top_k: 50
+
+# ====== Checkpoint Configuration ======
+enable_checkpointing: True
+async_checkpointing: True
+checkpoint_period: 50
+max_num_checkpoints_to_keep: 10
+
+# ====== Reward ======
+
+reward_exact_format_match: 3.0
+reward_white_space_format_match: 1.5
+reward_partial_format_match: 0.5
+reward_ratio_guess_to_answer_high:  0.5
+reward_ratio_guess_to_answer_low: 0.25
+penalty_incorrect_format: -0.5
+penalty_incorrect_answer: -1.0
+
+# ====== Special tokens for GSM8K reasoning ======
+reasoning_start_token: ''
+reasoning_end_token: ''
+solution_start_token: ''
+solution_end_token: ''
+
+# ====== System prompt and Templates ======
+
+system_prompt: |
+  You are given a problem. Think about the problem and provide your reasoning. Place it between {reasoning_start_token} and {reasoning_end_token}. Then, provide the final answer (i.e., just one numerical value) between {solution_start_token} and {solution_end_token}.
+
+template: |
+  user
+  {system_prompt}
+  
+  {question}
+  model
+
+
+# TODO: fix this
+# Dataset Configuration
+dataset_type: hf  # Huggingface input pipeline
+hf_path: 'gsm8k'
+hf_data_split: 'main'
+hf_data_files: 'train'
+
+
+# Pathways Inference Configuration
+# For multi-host/multi-slice setups
+use_pathways_reshard: False
+inference_devices_per_replica: 4
+inference_replicas: 1
+
+# Tokenizer Settings
+add_bos: False
+add_eos: False
+return_log_prob: True
+
+# Performance and Memory
+weight_dtype: bfloat16
+dtype: bfloat16
+
+# Splash Attention Block Sizes
+# Tuned for GRPO workloads
+sa_block_q: 128
+sa_block_kv: 128
+sa_block_kv_compute: 128
+sa_block_q_dkv: 128
+sa_block_kv_dkv: 128
+sa_block_kv_dkv_compute: 128
+sa_block_q_dq: 128
+sa_block_kv_dq: 128
+sa_use_fused_bwd_kernel: False
+sa_q_layout: "HEAD_DIM_MINOR"
+sa_k_layout: "HEAD_DIM_MINOR"
+sa_v_layout: "HEAD_DIM_MINOR"
\ No newline at end of file
diff --git a/src/MaxText/evaluate_rl.py b/src/MaxText/evaluate_rl.py
new file mode 100644
index 0000000000..99657c3460
--- /dev/null
+++ b/src/MaxText/evaluate_rl.py
@@ -0,0 +1,275 @@
+# Copyright 2023β2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=bare-except, consider-using-generator
+
+import functools
+import os
+from pprint import pprint
+import re
+import sys
+
+from datetime import datetime
+from flax import nnx
+from flax.linen import partitioning as nn_partitioning
+import grain
+import humanize
+
+
+import jax
+from jax.sharding import Mesh
+import optax
+from orbax import checkpoint as ocp
+import tensorflow_datasets as tfds
+from tqdm.auto import tqdm
+from tunix.rl import rl_cluster as rl_cluster_lib
+from tunix.rl.rollout import base_rollout
+from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner
+from tunix.sft import metrics_logger
+
+
+from transformers import AutoTokenizer
+
+from flax import linen as nn
+import numpy as np
+from etils import epath
+
+from tunix.rl.rollout.base_rollout import RolloutConfig
+
+from MaxText.globals import MAXTEXT_ASSETS_ROOT
+from MaxText import rl_utils
+
+# ## Evaluate
+#
+#
+# Before we train the model, let's evaluate the model on the test set so we can
+# see the improvement post training.
+#
+# We evaluate it in two ways:
+#
+# **Quantitative**
+#
+# * **Answer Accuracy**: percentage of samples for which the model predicts the
+# correct final numerical answer
+# * **Answer (Partial) Accuracy**: percentage of samples for which the model
+# predicts a final numerical answer such that the \`model answer / answer\`
+# ratio lies between 0.9 and 1.1.
+# * **Format Accuracy**: percentage of samples for which the model outputs the
+# correct format, i.e., reasoning between the reasoning special tokens, and the
+# final answer between the \`\\`, \`\\` tokens.
+#
+# **Qualitative**
+#
+# We'll also print outputs for a few given questions so that we can compare the generated output later.
+#
+
+
+def generate_responses(
+    mt_config,
+    prompts,
+    rl_cluster,
+    num_passes=1,
+    temperature=0.7,
+    top_k=50,
+    top_p=0.95,
+):
+  """
+  Generate responses for a batch of prompts across multiple passes.
+
+  Args:
+      prompts: List of prompts to generate responses for
+      rl_cluster: Model cluster for generation
+      num_passes: Number of generation passes
+      temperature: Sampling temperature
+      top_k: Top-k sampling parameter
+      top_p: Top-p sampling parameter
+
+  Returns:
+      List of lists containing responses for each prompt across passes
+  """
+  multiple_call_responses = [[] for _ in range(len(prompts))]
+
+  for p in range(num_passes):
+    responses = rl_cluster.rollout.generate(
+        prompts,
+        rollout_config=RolloutConfig(
+            max_tokens_to_generate=mt_config.max_target_length,
+            temperature=mt_config.eval_temperature,
+            top_k=mt_config.eval_top_k,
+            top_p=mt_config.eval_top_p,
+        ),
+    )
+    responses = responses.text
+
+    if mt_config.debug:
+      print(f"Pass {p+1}/{num_passes}, responses: {responses}")
+
+    for idx, response in enumerate(responses):
+      multiple_call_responses[idx].append(response)
+
+  return multiple_call_responses
+
+
+def score_responses(mt_config, question, responses, answer):
+  """
+  Score a set of responses for a single question.
+
+  Args:
+      question: The evaluation question
+      responses: List of generated responses for this question
+      answer: The correct answer
+
+  Returns:
+      Tuple of (is_correct, is_partially_correct, has_correct_format)
+  """
+  match_format = rl_utils.get_match_format_regex(mt_config)
+  match_numbers = rl_utils.get_match_numbers_regex(mt_config)
+
+  if DEBUG:
+    print("========================================")
+    print(f"Evaluation Question: {question}")
+    print(f"Evaluation Answer: {answer}")
+    print(f"Evaluation Responses: {responses}")
+    print("========================================")
+
+  is_correct = False
+  is_partially_correct = False
+  has_correct_format = False
+
+  for response in responses:
+    # Extract numerical response
+    extracted_response = guess.group(1) if (guess := match_numbers.search(response)) is not None else "-1000000"
+
+    if DEBUG:
+      print(f"Evaluation extracted_response: {extracted_response}")
+
+    # Check exact correctness
+    try:
+      if float(extracted_response.strip()) == float(answer.strip()):
+        is_correct = True
+
+      # Check partial correctness (within 10%)
+      ratio = float(extracted_response.strip()) / float(answer.strip())
+      if 0.9 <= ratio <= 1.1:
+        is_partially_correct = True
+    except Exception as e:
+      if DEBUG:
+        print(f"Evaluation Exception: {e}")
+        print("SKIPPED")
+
+    # Check format correctness
+    if match_format.search(response) is not None:
+      has_correct_format = True
+
+    # Early exit if all criteria are met
+    if is_correct and is_partially_correct and has_correct_format:
+      break
+
+  return is_correct, is_partially_correct, has_correct_format
+
+
+def evaluate(
+    mt_config,
+    dataset,
+    rl_cluster,
+    temperature=0.7,
+    top_k=50,
+    top_p=0.95,
+    num_passes=1,
+    corr_lst=False,
+    make_lst=False,
+):
+  """
+  Computes accuracy and percentage of outputs matching the format.
+
+  Args:
+      dataset: The evaluation dataset
+      rl_cluster: Model cluster for generation
+      temperature: Sampling temperature
+      top_k: Top-k sampling parameter
+      top_p: Top-p sampling parameter
+      num_passes: Number of generation passes
+      corr_lst: If True, only include correct responses in the list
+      make_lst: If True, return a list of (question, answer, responses)
+
+  Returns:
+      Tuple of statistics and optionally the response list
+  """
+  response_lst = []
+  corr = 0
+  partially_corr = 0
+  corr_format = 0
+  total = 0
+
+  for batch in tqdm(dataset):
+    answers = batch["answer"]
+    questions = batch["question"]
+    prompts = batch["prompts"]
+
+    # Generate responses for all prompts in the batch
+    multiple_call_responses = generate_responses(
+        mt_config=mt_config,
+        prompts=prompts,
+        rl_cluster=rl_cluster,
+        num_passes=num_passes,
+        temperature=temperature,
+        top_k=top_k,
+        top_p=top_p,
+    )
+
+    # Score each question-answer pair
+    for question, responses, answer in zip(questions, multiple_call_responses, answers):
+      is_correct, is_partially_correct, has_correct_format = score_responses(
+          mt_config=mt_config,
+          question=question,
+          responses=responses,
+          answer=answer,
+      )
+
+      # Update counters
+      if is_correct:
+        corr += 1
+        if corr_lst and make_lst:
+          response_lst.append((question, answer, responses))
+      else:
+        if not corr_lst and make_lst:
+          response_lst.append((question, answer, responses))
+
+      if is_partially_correct:
+        partially_corr += 1
+
+      if has_correct_format:
+        corr_format += 1
+
+      total += 1
+
+      # Print progress every 10 items
+      if total % 10 == 0:
+        print(
+            f"===> {corr=}, {total=}, {corr / total * 100=}, "
+            f"{partially_corr / total * 100=}, {corr_format / total * 100=}"
+        )
+
+  # Prepare return values
+  to_return = (
+      corr,
+      total,
+      corr / total * 100,
+      partially_corr / total * 100,
+      corr_format / total * 100,
+  )
+
+  if make_lst:
+    return to_return, response_lst
+  return to_return
diff --git a/src/MaxText/examples/README.md b/src/MaxText/examples/README_how_to_run_examples.md
similarity index 82%
rename from src/MaxText/examples/README.md
rename to src/MaxText/examples/README_how_to_run_examples.md
index a5b46934e3..a6dc4ae798 100644
--- a/src/MaxText/examples/README.md
+++ b/src/MaxText/examples/README_how_to_run_examples.md
@@ -127,7 +127,38 @@ Use the link for Jupyter Lab as a link for "Connect to a local runtime" in Colla
 
 ### GRPO Training
 
-- **`grpo_llama3_demo.ipynb`** β GRPO training on math dataset
+- **`grpo_llama3_1_8b_demo.ipynb`** β GRPO training on math dataset (Colab/notebook)
+- **`grpo_runner.py`** β Unified CLI for GRPO training (any model)
+
+#### GRPO Colab Usage
+
+For interactive GRPO training in Google Colab or Jupyter:
+
+1. **Open** `src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb`
+2. **Enable TPU runtime** (Runtime β Change runtime type β TPU)
+3. **Run cells** to train Llama3.1-8B with GRPO on GSM8K dataset
+
+#### GRPO Python Script Usage
+
+```bash
+# Llama3.1-8B
+python3 src/MaxText/examples/grpo_runner.py \
+  --model_name=llama3.1-8b \
+  --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
+  --load_parameters_path=gs://path/to/checkpoint \
+  --base_output_directory=/tmp/grpo_output \
+  --hf_access_token=$HF_TOKEN \
+  --steps=100
+
+# Qwen2.5-7B
+python3 src/MaxText/examples/grpo_runner.py \
+  --model_name=qwen2.5-7b \
+  --tokenizer_path=Qwen/Qwen2.5-7B-Instruct \
+  --load_parameters_path=gs://path/to/checkpoint \
+  --base_output_directory=/tmp/grpo_output \
+  --hf_access_token=$HF_TOKEN \
+  --steps=100
+```
 
 ## Common Pitfalls & Debugging
 
diff --git a/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py b/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py
index def23ec943..33d7fcd33c 100644
--- a/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py
+++ b/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py
@@ -14,6 +14,11 @@
 
 # pylint: disable=bare-except, consider-using-generator
 """ 
+DEPRECATED: This file is deprecated and kept for reference only.
+Please use the new unified CLI interface: rl_trainer.py
+
+See GRPO_README.md for migration guide and usage examples.
+
 This tutorial demonstrates training the Llama3.1 70B-IT model on
  the GSM8K math reasoning benchmark using Group Relative Policy Optimization (GRPO).
    GRPO can enhance your model's problem-solving skills on mathematical word problems,
@@ -34,7 +39,7 @@
 # We use Tunix as the library for GRPO.
 # And we use vLLM as the library for efficient model inference and generation.
 #
-# In this tutorial we use a single host TPUVM such as `v6e-8/v5p-8`. Let's get started!
+# In this tutorial we use `v5p-256` or `v5p-128`. Let's get started!
 
 
 # ## Install necessary libraries
diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb b/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb
index 3e82bb1376..41fbeccfcb 100644
--- a/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb
+++ b/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb
@@ -4,36 +4,33 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "# GRPO Llama3.1-8B-Instruct Demo: Group Relative Policy Optimization\n",
+    "# GRPO Llama3.1-8B Demo: Direct Function Call\n",
     "\n",
-    "This tutorial demonstrates training the Llama3.1 8B-Instruct model on the GSM8K math reasoning benchmark using Group Relative Policy Optimization (GRPO). GRPO can enhance your model's problem-solving skills on mathematical word problems, coding problems, etc.\n",
+    "This notebook demonstrates GRPO training by directly calling the `rl_train` function from `rl_trainer.py`.\n",
     "\n",
     "## What is GRPO?\n",
     "\n",
-    "GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by:\n",
+    "GRPO (Group Relative Policy Optimization) is an RL algorithm that enhances reasoning abilities of LLMs by:\n",
+    "1. Generating multiple responses for each prompt\n",
+    "2. Evaluating responses using reward models  \n",
+    "3. Calculating relative advantages to update the policy\n",
     "\n",
-    "1. Generating multiple responses for a given prompt\n",
-    "2. Evaluating these responses using a reward model\n",
-    "3. Calculating a relative advantage based on the group's performance to update the policy\n",
     "\n",
-    "## Libraries Used\n",
-    "\n",
-    "- **Tunix**: Library for GRPO implementation\n",
-    "- **vLLM**: Library for efficient model inference and generation\n",
-    "- **MaxText**: For model creation and training infrastructure\n",
+    "This notebook imports and calls the `rl_train` function \n",
     "\n",
     "## Hardware Requirements\n",
     "\n",
-    "This tutorial uses a single host TPUVM such as `v6e-8/v5p-8`.\n"
+    "- Single host TPUVM (v6e-8/v5p-8) or multi-host with Pathways\n",
+    "- Sufficient memory for Llama3.1-8B model"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## Install Necessary Libraries\n",
+    "## Setup\n",
     "\n",
-    "First, let's install the required dependencies:\n"
+    "Install dependencies and set up the environment:"
    ]
   },
   {
@@ -42,12 +39,8 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "### (Optional) Run this if you just have this file and nothing else\n",
-    "\n",
-    "# 1. Clone the MaxText repository (from AIβHypercomputer)\n",
+    "# Clone MaxText repository\n",
     "!git clone https://github.com/AI-Hypercomputer/maxtext.git\n",
-    "\n",
-    "# 2. Navigate into the cloned directory\n",
     "%cd maxtext"
    ]
   },
@@ -57,35 +50,28 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "### (Optional) Do not run this if you already installed the dependencies\n",
-    "\n",
-    "# 3. Ensure setup.sh is executable\n",
+    "# Install dependencies\n",
     "!chmod +x setup.sh\n",
-    "\n",
-    "# 4. Execute the setup script\n",
     "!./setup.sh\n",
     "\n",
-    "# Install vllm requirements\n",
+    "# Install GRPO-specific dependencies\n",
     "!./src/MaxText/examples/install_tunix_vllm_requirement.sh\n",
     "\n",
-    "# force numpy version\n",
-    "!pip install --force-reinstall numpy==2.1.2\n",
-    "# install nest_asyncio\n",
-    "!pip install nest_asyncio\n",
+    "# Install additional requirements\n",
+    "%pip install --force-reinstall numpy==2.1.2\n",
+    "%pip install nest_asyncio\n",
     "\n",
     "import nest_asyncio\n",
-    "\n",
-    "nest_asyncio.apply()\n",
-    "# To fix \"This event loop is already running\" error in Colab"
+    "nest_asyncio.apply()  # Fix for Colab event loop"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## Imports\n",
+    "## Configuration\n",
     "\n",
-    "Import all necessary libraries for GRPO training:\n"
+    "Set up the training parameters:"
    ]
   },
   {
@@ -94,50 +80,22 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "# Set up the variables for the script\n",
+    "# Configuration for GRPO training\n",
     "import os\n",
-    "import sys\n",
     "\n",
-    "# Set the MaxText home directory (where you cloned the maxtext repo)\n",
+    "# Set up paths\n",
     "MAXTEXT_REPO_ROOT = os.path.expanduser(\"~\") + \"/maxtext\"\n",
-    "print(f\"MaxText Home directory (from Python): {MAXTEXT_REPO_ROOT}\")\n",
-    "\n",
-    "DEBUG = False  # set to True to run in debug mode, for more print statements\n",
-    "# set this to the path of the checkpoint you want to load, gs:// supported\n",
-    "MODEL_CHECKPOINT_PATH = \"/path/to/scanned/model/ckpt_load_dir/\""
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "import functools\n",
-    "from pprint import pprint\n",
-    "import re\n",
-    "from flax import nnx\n",
-    "from flax.linen import partitioning as nn_partitioning\n",
-    "import grain\n",
-    "import humanize\n",
-    "import jax\n",
-    "import optax\n",
-    "from orbax import checkpoint as ocp\n",
-    "import tensorflow_datasets as tfds\n",
-    "from tqdm.auto import tqdm\n",
-    "from tunix.rl import rl_cluster as rl_cluster_lib\n",
-    "from tunix.rl.rollout import base_rollout\n",
-    "from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner\n",
-    "from tunix.sft import metrics_logger\n",
+    "print(f\"MaxText Home directory: {MAXTEXT_REPO_ROOT}\")\n",
     "\n",
-    "from transformers import AutoTokenizer\n",
+    "# Training configuration\n",
+    "MODEL_CHECKPOINT_PATH = \"gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items\"\n",
+    "OUTPUT_DIRECTORY = \"/tmp/grpo_output\"\n",
+    "STEPS = 10  # Reduced for demo purposes\n",
+    "HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"your_hf_token_here\")\n",
     "\n",
-    "from flax import linen as nn\n",
-    "from tunix.models.llama3 import model as llama3_lib\n",
-    "import numpy as np\n",
-    "from etils import epath\n",
-    "\n",
-    "from tunix.rl.rollout.base_rollout import RolloutConfig"
+    "print(f\"Model checkpoint: {MODEL_CHECKPOINT_PATH}\")\n",
+    "print(f\"Output directory: {OUTPUT_DIRECTORY}\")\n",
+    "print(f\"Training steps: {STEPS}\")"
    ]
   },
   {
@@ -146,669 +104,24 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "# Import GRPO training function directly\n",
+    "import sys\n",
+    "import os\n",
     "from pathlib import Path\n",
-    "from typing import Optional, Dict, Any\n",
-    "\n",
-    "maxtext_path = Path(f\"{MAXTEXT_REPO_ROOT}\") / \"src\" / \"MaxText\"\n",
     "\n",
-    "# Change working directory to MaxText project root\n",
-    "os.chdir(maxtext_path)\n",
+    "# Add MaxText to Python path\n",
+    "maxtext_path = Path(MAXTEXT_REPO_ROOT) / \"src\" / \"MaxText\"\n",
     "sys.path.insert(0, str(maxtext_path))\n",
     "\n",
-    "from MaxText import model_creation_utils, pyconfig\n",
-    "from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter\n",
-    "\n",
-    "print(f\"β Changed working directory to: {os.getcwd()}\")\n",
-    "print(f\"β MaxText project root: {maxtext_path}\")\n",
-    "print(f\"β Added to Python path: {maxtext_path}\")\n",
-    "\n",
-    "if not jax.distributed.is_initialized():\n",
-    "  jax.distributed.initialize()\n",
-    "jax.devices()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# Hugging Face Authentication Setup\n",
-    "import os\n",
-    "from huggingface_hub import login\n",
-    "\n",
-    "# Use your Hugging Face token (recommended)\n",
-    "# Get your token from: https://huggingface.co/settings/tokens\n",
-    "\n",
-    "os.environ[\"HF_TOKEN\"] = \"hf_your_token_here\"\n",
-    "login(token=os.environ[\"HF_TOKEN\"])"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Hyperparameters\n",
-    "\n",
-    "Let's define the configuration we are going to use. Note that this is by no means a \"perfect\" set of hyperparameters. To get good results, you might have to train the model for longer.\n",
-    "\n",
-    "### Data Configuration\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# ====== Data ======\n",
-    "TRAIN_DATA_DIR = f\"{MAXTEXT_REPO_ROOT}/data/train\"\n",
-    "TEST_DATA_DIR = f\"{MAXTEXT_REPO_ROOT}/data/test\"\n",
-    "if not os.path.exists(TRAIN_DATA_DIR):\n",
-    "  os.makedirs(TRAIN_DATA_DIR)\n",
-    "if not os.path.exists(TEST_DATA_DIR):\n",
-    "  os.makedirs(TEST_DATA_DIR)\n",
-    "TRAIN_FRACTION = 1.0"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "### Checkpoint and Logging Configuration\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# ====== Checkpoint directory =====\n",
-    "LOG_DIR = f\"{MAXTEXT_REPO_ROOT}/content/tensorboard/grpo/logs_llama3/\"\n",
-    "if not os.path.exists(LOG_DIR):\n",
-    "  os.makedirs(LOG_DIR)\n",
-    "\n",
-    "# ===== Profiling =====\n",
-    "PROFILE_DIR = f\"{MAXTEXT_REPO_ROOT}/content/jax_traces/grpo/profiles_llama3/\"\n",
-    "if not os.path.exists(PROFILE_DIR):\n",
-    "  os.makedirs(PROFILE_DIR)\n",
-    "\n",
-    "# ====== Checkpoint saving ======\n",
-    "CKPT_DIR = f\"{MAXTEXT_REPO_ROOT}/content/ckpts_llama3/\"\n",
-    "\n",
-    "if not os.path.exists(CKPT_DIR):\n",
-    "  os.makedirs(CKPT_DIR)\n",
-    "\n",
-    "SAVE_INTERVAL_STEPS = 500\n",
-    "MAX_TO_KEEP = 4\n",
-    "\n",
-    "# ====== Reproducibility ======\n",
-    "SEED = 42"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "### GRPO Configuration\n",
-    "\n",
-    "GRPO-specific hyperparameters for generation and training:\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# ====== GRPO ======\n",
-    "# === Generation during GRPO training ===\n",
-    "MAX_PROMPT_LENGTH = 256\n",
-    "TOTAL_GENERATION_STEPS = 768\n",
-    "# Important to keep a high-ish temperature for varied, diverse responses during\n",
-    "# training.\n",
-    "TEMPERATURE = 0.9\n",
-    "TOP_P = 1.0\n",
-    "TOP_K = 50\n",
-    "# The number of times the policy generates multiple responses for a given prompt\n",
-    "# within a single training step. This corresponds to `G` in Algorithm 1 in the\n",
-    "# paper. The \"group\" in GRPO comes from here.\n",
-    "NUM_GENERATIONS = 2\n",
-    "\n",
-    "# === other GRPO configs ===\n",
-    "# The number of iterations per batch (π in GRPO algo 1).\n",
-    "NUM_ITERATIONS = 1\n",
-    "# The coefficient for the KL divergence penalty (π½) in the GRPO loss function.\n",
-    "# Important to keep a high enough value for this, otherwise, the KL divergence\n",
-    "# can increase unchecked.\n",
-    "BETA = 0.08\n",
-    "# Epsilon value for clipping (π in GRPO loss in paper). Similar to PPO, for\n",
-    "# stable updates.\n",
-    "EPSILON = 0.2"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "### Training Configuration\n",
-    "\n",
-    "Training hyperparameters including batch size, learning rate, and optimization settings:\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# ====== Training ======\n",
-    "BATCH_SIZE = 1\n",
-    "# Increase `NUM_BATCHES` and `MAX_STEPS` for better results.\n",
-    "NUM_BATCHES = 4  # 200\n",
-    "# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be\n",
-    "# increased to a max. of 330 (if batch size is 4).\n",
-    "NUM_TEST_BATCHES = 5  # 200\n",
-    "\n",
-    "SEQUENCE_LENGTH = 1024\n",
-    "\n",
-    "EVAL_EVERY_N_STEPS = 10  # this doesn't matter if `TRAIN_FRACTION = 1.0`.\n",
-    "NUM_EPOCHS = 1  # can potentially train for more epochs\n",
-    "\n",
-    "# Number of training steps.\n",
-    "MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)\n",
-    "\n",
-    "# === AdamW, warmup, cosine scheduler ===\n",
-    "LEARNING_RATE = 3e-6\n",
-    "B1 = 0.9\n",
-    "B2 = 0.99\n",
-    "WEIGHT_DECAY = 0.1\n",
-    "# == Cosine decay with warmup scheduler ==\n",
-    "# Linearly increase learning rate from 0. to 5e-6 in the first 10% training\n",
-    "# steps, and then gradually decrease the learning rate to 0 using cosine\n",
-    "# scheduler.\n",
-    "WARMUP_STEPS = int(0.1 * MAX_STEPS)\n",
-    "# == Grad clipping ==\n",
-    "# Grad clipping to prevent large gradients. Found this\n",
-    "# important to keep KL divergence in check.\n",
-    "MAX_GRAD_NORM = 0.1"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "### Inference and Reward Configuration\n",
-    "\n",
-    "Configuration for model inference and reward function parameters:\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# ====== Inference ======\n",
-    "GENERATION_CONFIGS = {\n",
-    "    # greedy search\n",
-    "    \"greedy\": {\"temperature\": 0.01, \"top_k\": 1, \"top_p\": 1.0},\n",
-    "    # some randomness\n",
-    "    \"standard\": {\"temperature\": 0.7, \"top_k\": 50, \"top_p\": 0.95},\n",
-    "    # liberal\n",
-    "    \"liberal\": {\"temperature\": 0.85, \"top_k\": 2000, \"top_p\": 1.0},\n",
-    "}\n",
-    "\n",
-    "# ====== Reward ======\n",
-    "REWARD_EXACT_FORMAT_MATCH = 3.0\n",
-    "REWARD_WHITE_SPACE_FORMAT_MATCH = 1.5\n",
-    "REWARD_PARTIAL_FORMAT_MATCH = 0.5\n",
-    "REWARD_RATIO_GUESS_TO_ANSWER_HIGH = 0.5\n",
-    "REWARD_RATIO_GUESS_TO_ANSWER_LOW = 0.25\n",
-    "PENALTY_INCORRECT_FORMAT = -0.5\n",
-    "PENALTY_INCORRECT_ANSWER = -1.0"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Utility Functions\n",
-    "\n",
-    "Helper functions for monitoring memory usage and other utilities:\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def show_hbm_usage():\n",
-    "  \"\"\"Displays memory usage per device.\"\"\"\n",
-    "  fmt_size = functools.partial(humanize.naturalsize, binary=True)\n",
-    "\n",
-    "  for d in jax.local_devices():\n",
-    "    stats = d.memory_stats()\n",
-    "    used = stats[\"bytes_in_use\"]\n",
-    "    limit = stats[\"bytes_limit\"]\n",
-    "    print(f\"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}\")"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Data Preprocessing\n",
-    "\n",
-    "First, let's define some special tokens. We instruct the model to first reason between the `` and `` tokens. After reasoning, we expect it to provide the answer between the `` and `` tokens.\n",
-    "\n",
-    "We use OpenAI's GSM8K dataset. GSM8K comprises grade school math word problems.\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "model_tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Llama-3.1-8B-Instruct\")\n",
-    "\n",
-    "\n",
-    "reasoning_start = \"\"\n",
-    "reasoning_end = \"\"\n",
-    "solution_start = \"\"\n",
-    "solution_end = \"\"\n",
-    "\n",
-    "\n",
-    "SYSTEM_PROMPT = f\"\"\"You are given a problem. Think about the problem and \\\n",
-    "provide your reasoning. Place it between {reasoning_start} and \\\n",
-    "{reasoning_end}. Then, provide the final answer (i.e., just one numerical \\\n",
-    "value) between {solution_start} and {solution_end}.\"\"\"\n",
-    "\n",
-    "TEMPLATE = \"\"\"user\n",
-    "{system_prompt}\n",
-    "\n",
-    "{question}\n",
-    "model\"\"\""
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def extract_hash_answer(text: str) -> str | None:\n",
-    "  if DEBUG:\n",
-    "    print(f\"Extracting answer from: {text}\")\n",
-    "  if \"####\" not in text:\n",
-    "    return None\n",
-    "  return text.split(\"####\")[1].strip()\n",
-    "\n",
-    "\n",
-    "def get_dataset(data_dir, split=\"train\") -> grain.MapDataset:\n",
-    "  # Download data\n",
-    "  if not os.path.exists(data_dir):\n",
-    "    os.makedirs(data_dir)\n",
-    "\n",
-    "  data = tfds.data_source(\n",
-    "      \"gsm8k\",\n",
-    "      split=split,\n",
-    "      data_dir=data_dir,\n",
-    "      builder_kwargs={\"file_format\": tfds.core.FileFormat.ARRAY_RECORD},\n",
-    "      download=True,\n",
-    "  )\n",
-    "\n",
-    "  loaded_dataset = (\n",
-    "      grain.MapDataset.source(data)\n",
-    "      .shuffle(seed=SEED)\n",
-    "      .map(\n",
-    "          lambda x: {\n",
-    "              # passed to model forward pass\n",
-    "              \"prompts\": model_tokenizer.apply_chat_template(\n",
-    "                  [\n",
-    "                      {\n",
-    "                          \"role\": \"user\",\n",
-    "                          \"content\": TEMPLATE.format(\n",
-    "                              system_prompt=SYSTEM_PROMPT,\n",
-    "                              question=x[\"question\"].decode(\"utf-8\"),\n",
-    "                          ),\n",
-    "                      },\n",
-    "                  ],\n",
-    "                  tokenize=False,\n",
-    "                  add_generation_prompt=True,\n",
-    "              ),\n",
-    "              # passed to reward functions\n",
-    "              \"question\": x[\"question\"].decode(\"utf-8\"),\n",
-    "              # passed to reward functions\n",
-    "              \"answer\": extract_hash_answer(x[\"answer\"].decode(\"utf-8\")),\n",
-    "          }\n",
-    "      )\n",
-    "  )\n",
-    "  return loaded_dataset"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "dataset = get_dataset(TRAIN_DATA_DIR, \"train\").batch(BATCH_SIZE)[:NUM_BATCHES]\n",
-    "\n",
-    "if TRAIN_FRACTION == 1.0:\n",
-    "  train_dataset = dataset.repeat(NUM_EPOCHS)\n",
-    "  val_dataset = None\n",
-    "else:\n",
-    "  train_dataset = dataset[: int(len(dataset) * TRAIN_FRACTION)]\n",
-    "  train_dataset = train_dataset.repeat(NUM_EPOCHS)\n",
-    "\n",
-    "  val_dataset = dataset[int(len(dataset) * TRAIN_FRACTION) :].repeat(NUM_EPOCHS)\n",
-    "\n",
-    "test_dataset = get_dataset(TEST_DATA_DIR, \"test\").batch(BATCH_SIZE)[:NUM_TEST_BATCHES]\n",
-    "\n",
-    "\n",
-    "# Let's see how one batch of the dataset looks like!\n",
-    "\n",
-    "\n",
-    "if DEBUG:\n",
-    "  for ele in train_dataset[:1]:\n",
-    "    pprint(ele)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Load the Policy Model and the Reference Model\n",
-    "\n",
-    "The policy model is the model which is actually trained and whose weights are updated. The reference model is the model with which we compute KL divergence. This is to ensure that the policy updates are not huge and that it does not deviate too much from the reference model.\n",
-    "\n",
-    "Typically, the reference model is the base model, and the policy model is the same base model, but with potentially LoRA parameters where only the LoRA parameters are updated. This script is not using LoRA, so both the reference and policy models are the same.\n",
-    "\n",
-    "Note: We perform full precision (fp32) training. You can, however, leverage Qwix for QAT.\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "print(\"HBM usage before loading model:\")\n",
-    "show_hbm_usage()"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "### Load MaxText Model\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def get_ref_maxtext_model(config):\n",
-    "\n",
-    "  model, mesh = model_creation_utils.create_nnx_model(config)\n",
-    "  with mesh:\n",
-    "    tunix_model = TunixMaxTextAdapter(\n",
-    "        base_model=model,\n",
-    "    )\n",
-    "\n",
-    "    model_config = llama3_lib.ModelConfig.llama3_1_8b()\n",
-    "    tunix_model.config = model_config\n",
-    "\n",
-    "  return tunix_model, mesh\n",
-    "\n",
-    "\n",
-    "model_config = llama3_lib.ModelConfig.llama3_1_8b()"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "### Load Reference Model\n",
-    "\n",
-    "Note: pass the path to your scanned checkpoint for \"load_parameters_path\". To create a scanned checkpoint, you can use `/maxtext/src/MaxText/utils/ckpt_conversion/to_maxtext.py`\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# Load the reference model\n",
-    "config_ref = pyconfig.initialize(\n",
-    "    [\n",
-    "        \"\",\n",
-    "        f\"{MAXTEXT_REPO_ROOT}/src/MaxText/configs/base.yml\",\n",
-    "    ],\n",
-    "    base_output_directory=\"dummy\",  # This is not used in Tunix.\n",
-    "    run_name=\"test-tunix-maxtext-llama3.1-8b\",\n",
-    "    tokenizer_type=\"tiktoken\",\n",
-    "    tokenizer_path=\"assets/tokenizer_llama3.tiktoken\",\n",
-    "    load_parameters_path=f\"{MODEL_CHECKPOINT_PATH}\",\n",
-    "    max_target_length=SEQUENCE_LENGTH,\n",
-    "    async_checkpointing=\"false\",\n",
-    "    model_name=\"llama3.1-8b\",\n",
-    "    skip_jax_distributed_system=\"true\",\n",
-    "    weight_dtype=\"bfloat16\",\n",
-    "    attention=\"dot_product\",\n",
-    "    remat_policy=\"custom\",\n",
-    "    decoder_layer_input=\"offload\",\n",
-    "    query_proj=\"offload\",\n",
-    "    key_proj=\"offload\",\n",
-    "    value_proj=\"offload\",\n",
-    ")\n",
-    "\n",
-    "llama3_1_8b, mesh = get_ref_maxtext_model(config_ref)\n",
-    "\n",
-    "llama3_1_8b.config = model_config\n",
-    "\n",
-    "nnx.display(llama3_1_8b)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "if DEBUG:\n",
-    "  print(\"Model initialized successfully\")\n",
-    "  print(f\"Model mesh shape: {mesh.shape}\")\n",
-    "  print(f\"Model config: {model_config}\")\n",
-    "\n",
-    "  # Sanity check that weights are loaded correctly\n",
-    "  _maxtext_state_flatten = nnx.state(llama3_1_8b).flat_state()\n",
-    "  maxtext_state_flatten = {\".\".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten}\n",
-    "  print(\n",
-    "      f\"maxtext_state_flatten[base.token_embedder.embedding].value={maxtext_state_flatten['base.token_embedder.embedding'].value}\"\n",
-    "  )\n",
-    "\n",
-    "\n",
-    "# See the memory use after loading the reference model:\n",
-    "print(\"HBM usage after loading ref model:\")\n",
-    "show_hbm_usage()"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "### Load Policy Model\n",
-    "\n",
-    "Note: pass the path to your scanned checkpoint for \"load_parameters_path\". To create a scanned checkpoint, you can use `/maxtext/src/MaxText/utils/ckpt_conversion/to_maxtext.py`\n",
-    "\n",
-    "TODO: @mazumdera: change this to use lora\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "config_policy = pyconfig.initialize(\n",
-    "    [\n",
-    "        \"\",\n",
-    "        f\"{MAXTEXT_REPO_ROOT}/src/MaxText/configs/base.yml\",\n",
-    "    ],\n",
-    "    base_output_directory=\"dummy\",  # This is not used in Tunix.\n",
-    "    run_name=\"test-tunix-maxtext-llama3.1-8b\",  # This is not used in Tunix.\n",
-    "    tokenizer_type=\"tiktoken\",\n",
-    "    tokenizer_path=\"assets/tokenizer_llama3.tiktoken\",\n",
-    "    load_parameters_path=f\"{MODEL_CHECKPOINT_PATH}\",\n",
-    "    max_target_length=SEQUENCE_LENGTH,\n",
-    "    async_checkpointing=\"false\",\n",
-    "    model_name=\"llama3.1-8b\",\n",
-    "    skip_jax_distributed_system=\"true\",\n",
-    "    weight_dtype=\"bfloat16\",\n",
-    "    attention=\"dot_product\",\n",
-    "    remat_policy=\"custom\",\n",
-    "    decoder_layer_input=\"offload\",\n",
-    "    query_proj=\"offload\",\n",
-    "    key_proj=\"offload\",\n",
-    "    value_proj=\"offload\",\n",
-    ")\n",
-    "llama3_1_8b_policy, mesh_policy = get_ref_maxtext_model(config_policy)\n",
-    "\n",
-    "llama3_1_8b_policy.config = model_config\n",
-    "\n",
-    "nnx.display(llama3_1_8b_policy)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "if DEBUG:\n",
-    "  print(\"Model initialized successfully\")\n",
-    "  print(f\"Model mesh shape: {mesh_policy.shape}\")\n",
-    "  print(f\"Model config: {model_config}\")\n",
-    "\n",
-    "  # Sanity check that weights are loaded correctly\n",
-    "  _maxtext_state_flatten = nnx.state(llama3_1_8b_policy).flat_state()\n",
-    "  maxtext_state_flatten = {\".\".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten}\n",
-    "  print(\n",
-    "      f\"maxtext_state_flatten[base.token_embedder.embedding].value={maxtext_state_flatten['base.token_embedder.embedding'].value}\"\n",
-    "  )\n",
-    "\n",
-    "# See memory usage after loading the policy model:\n",
-    "print(\"HBM usage after loading policy model:\")\n",
-    "show_hbm_usage()"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Define Reward Functions\n",
-    "\n",
-    "We define four reward functions:\n",
-    "\n",
-    "1. **Format Matching**: Reward if the format of the output exactly matches the instruction given in `TEMPLATE`\n",
-    "2. **Approximate Format Matching**: Reward if the format of the output approximately matches the instruction given in `TEMPLATE`\n",
-    "3. **Answer Correctness**: Reward if the answer is correct/partially correct\n",
-    "4. **Number Extraction**: Sometimes, the text between ``, `` might not be one number. So, extract the number, and reward the model if the answer is correct.\n",
-    "\n",
-    "The reward functions are inspired from [here](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb).\n",
-    "\n",
-    "First off, let's define a RegEx for checking whether the format matches.\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "match_format = re.compile(\n",
-    "    rf\"^[\\s]{{0,}}\" rf\"{reasoning_start}.+?{reasoning_end}.*?\" rf\"{solution_start}(.+?){solution_end}\" rf\"[\\s]{{0,}}$\",\n",
-    "    flags=re.MULTILINE | re.DOTALL,\n",
-    ")\n",
-    "\n",
-    "match_format.search(\n",
-    "    f\"{reasoning_start}Let me\" f\" think!{reasoning_end}{solution_start}2{solution_end}\",\n",
-    ")"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "### Reward Function 1: Exact Format Matching\n",
-    "\n",
-    "Give the model a reward of 3 points if the format matches exactly.\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def match_format_exactly(prompts, completions, **kargs):\n",
-    "  scores = []\n",
-    "  for completion in completions:\n",
-    "    score = 0\n",
-    "    response = completion\n",
-    "    # Match if format is seen exactly!\n",
-    "    if match_format.search(response) is not None:\n",
-    "      score += REWARD_EXACT_FORMAT_MATCH\n",
-    "    scores.append(score)\n",
-    "  return scores"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "### Reward Function 2: Approximate Format Matching\n",
-    "\n",
-    "We also reward the model if the format of the output matches partially.\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def match_format_approximately(prompts, completions, **kargs):\n",
-    "  scores = []\n",
-    "\n",
-    "  for completion in completions:\n",
-    "    score = 0\n",
-    "    response = completion\n",
-    "    # Count how many keywords are seen - we penalize if too many!\n",
-    "    # If we see 1, then plus some points!\n",
-    "    score += REWARD_PARTIAL_FORMAT_MATCH if response.count(reasoning_start) == 1 else PENALTY_INCORRECT_FORMAT\n",
-    "    score += REWARD_PARTIAL_FORMAT_MATCH if response.count(reasoning_end) == 1 else PENALTY_INCORRECT_FORMAT\n",
-    "    score += REWARD_PARTIAL_FORMAT_MATCH if response.count(solution_start) == 1 else PENALTY_INCORRECT_FORMAT\n",
-    "    score += REWARD_PARTIAL_FORMAT_MATCH if response.count(solution_end) == 1 else PENALTY_INCORRECT_FORMAT\n",
-    "    scores.append(score)\n",
-    "  return scores"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "### Reward Function 3: Answer Correctness\n",
+    "# Import required modules\n",
+    "from MaxText import pyconfig\n",
+    "from MaxText.rl.rl_trainer import rl_train\n",
     "\n",
-    "Reward the model if the answer is correct. A reward is also given if the answer does not match exactly, i.e., based on how close the answer is to the correct value.\n"
+    "print(\"β
 Successfully imported GRPO training function\")\n",
+    "print(f\"π MaxText path: {maxtext_path}\")\n",
+    "print(\"\\n\" + \"=\"*80)\n",
+    "print(\"Starting GRPO Training...\")\n",
+    "print(\"=\"*80)"
    ]
   },
   {
@@ -817,47 +130,32 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "def check_answer(prompts, completions, answer, **kargs):\n",
-    "  responses = completions\n",
-    "\n",
-    "  extracted_responses = [guess.group(1) if (guess := match_format.search(r)) is not None else None for r in responses]\n",
+    "# Build configuration for GRPO training\n",
+    "config_argv = [\n",
+    "    \"\",  # Placeholder for argv[0]\n",
+    "    \"src/MaxText/configs/grpo.yml\",  # Base config\n",
+    "    f\"model_name=llama3.1-8b\",\n",
+    "    f\"tokenizer_path=meta-llama/Llama-3.1-8B-Instruct\",\n",
+    "    f\"load_parameters_path={MODEL_CHECKPOINT_PATH}\",\n",
+    "    f\"base_output_directory={OUTPUT_DIRECTORY}\",\n",
+    "    f\"hf_access_token={HF_TOKEN}\",\n",
+    "    f\"steps={STEPS}\",\n",
+    "    \"per_device_batch_size=1\",\n",
+    "    \"learning_rate=3e-6\",\n",
+    "    \"num_generations=2\",\n",
+    "    \"grpo_beta=0.08\",\n",
+    "    \"grpo_epsilon=0.2\",\n",
+    "    \"chips_per_vm=4\"\n",
+    "]\n",
     "\n",
-    "  scores = []\n",
-    "  for guess, true_answer in zip(extracted_responses, answer):\n",
-    "    score = 0\n",
-    "    if guess is None:\n",
-    "      scores.append(0)\n",
-    "      continue\n",
-    "    # Correct answer gets 3 points!\n",
-    "    if guess == true_answer:\n",
-    "      score += REWARD_EXACT_FORMAT_MATCH\n",
-    "    # Match if spaces are seen\n",
-    "    elif guess.strip() == true_answer.strip():\n",
-    "      score += REWARD_WHITE_SPACE_FORMAT_MATCH\n",
-    "    else:\n",
-    "      # We also reward it if the answer is close via ratios!\n",
-    "      # Ie if the answer is within some range, reward it!\n",
-    "      try:\n",
-    "        ratio = float(guess) / float(true_answer)\n",
-    "        if ratio >= 0.9 and ratio <= 1.1:\n",
-    "          score += REWARD_RATIO_GUESS_TO_ANSWER_HIGH\n",
-    "        elif ratio >= 0.8 and ratio <= 1.2:\n",
-    "          score += REWARD_RATIO_GUESS_TO_ANSWER_LOW\n",
-    "        else:\n",
-    "          score += PENALTY_INCORRECT_ANSWER  # Penalize wrong answers\n",
-    "      except:\n",
-    "        score += PENALTY_INCORRECT_FORMAT  # Penalize\n",
-    "    scores.append(score)\n",
-    "  return scores"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "### Reward Function 4: Number Extraction\n",
+    "# Create configuration object\n",
+    "config = pyconfig.Config()\n",
+    "config.parse_flags(config_argv)\n",
     "\n",
-    "Sometimes, the text between `` and `` might not be one number; it can be a sentence. So, we extract the number and compare the answer.\n"
+    "print(\"β
 Configuration created successfully\")\n",
+    "print(f\"π Training steps: {config.steps}\")\n",
+    "print(f\"π Output directory: {config.base_output_directory}\")\n",
+    "print(f\"π€ Model: {config.model_name}\")"
    ]
   },
   {
@@ -866,490 +164,55 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "match_numbers = re.compile(rf\"{solution_start}.*?([\\d\\.]{{1,}})\", flags=re.MULTILINE | re.DOTALL)\n",
-    "match_numbers.findall(f\"{solution_start}  0.34  {solution_end}\")\n",
-    "\n",
-    "\n",
-    "def check_numbers(prompts, completions, answer, **kargs):\n",
-    "  question = kargs[\"question\"]\n",
-    "  responses = completions\n",
-    "\n",
-    "  extracted_responses = [guess.group(1) if (guess := match_numbers.search(r)) is not None else None for r in responses]\n",
-    "\n",
-    "  scores = []\n",
-    "  if DEBUG:\n",
-    "    print(\"START ============================\")\n",
-    "    print(f\"Question: {question[0]}\")\n",
-    "    print(f\"Answer: {answer[0]}\")\n",
-    "    print(f\"Response: {responses[0]}\")\n",
-    "    print(f\"Extracted: {extracted_responses[0]}\")\n",
-    "    print(\"END ==============================\")\n",
-    "  for guess, true_answer in zip(extracted_responses, answer):\n",
-    "    if guess is None:\n",
-    "      scores.append(0)\n",
-    "      continue\n",
-    "    # Convert to numbers\n",
-    "    try:\n",
-    "      true_answer = float(true_answer.strip())\n",
-    "      guess = float(guess.strip())\n",
-    "      scores.append(1.5 if guess == true_answer else 0.0)\n",
-    "    except:\n",
-    "      scores.append(0)\n",
-    "      continue\n",
-    "  return scores"
+    "# Execute GRPO training directly\n",
+    "try:\n",
+    "    # Call the rl_train function\n",
+    "    grpo_trainer, rl_cluster = rl_train(config)\n",
+    "    \n",
+    "    print(\"\\n\" + \"=\"*80)\n",
+    "    print(\"β
 GRPO Training Completed Successfully!\")\n",
+    "    print(\"=\"*80)\n",
+    "    print(f\"π Checkpoints and logs saved to: {config.base_output_directory}\")\n",
+    "    print(f\"π― Final model ready for inference!\")\n",
+    "    \n",
+    "except Exception as e:\n",
+    "    print(\"\\n\" + \"=\"*80)\n",
+    "    print(\"β GRPO Training Failed!\")\n",
+    "    print(\"=\"*80)\n",
+    "    print(f\"Error: {str(e)}\")\n",
+    "    print(\"\\nPlease check the error message and try again.\")"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## Evaluation Functions\n",
-    "\n",
-    "Before we train the model, let's evaluate the model on the test set so we can see the improvement post training.\n",
-    "\n",
-    "We evaluate it in two ways:\n",
-    "\n",
-    "**Quantitative**\n",
-    "- **Answer Accuracy**: percentage of samples for which the model predicts the correct final numerical answer\n",
-    "- **Answer (Partial) Accuracy**: percentage of samples for which the model predicts a final numerical answer such that the `model answer / answer` ratio lies between 0.9 and 1.1.\n",
-    "- **Format Accuracy**: percentage of samples for which the model outputs the correct format, i.e., reasoning between the reasoning special tokens, and the final answer between the ``, `` tokens.\n",
-    "\n",
-    "**Qualitative**\n",
-    "\n",
-    "We'll also print outputs for a few given questions so that we can compare the generated output later.\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def generate_responses(\n",
-    "    prompts,\n",
-    "    rl_cluster,\n",
-    "    num_passes=1,\n",
-    "    temperature=0.7,\n",
-    "    top_k=50,\n",
-    "    top_p=0.95,\n",
-    "):\n",
-    "  \"\"\"\n",
-    "  Generate responses for a batch of prompts across multiple passes.\n",
-    "\n",
-    "  Args:\n",
-    "      prompts: List of prompts to generate responses for\n",
-    "      rl_cluster: Model cluster for generation\n",
-    "      num_passes: Number of generation passes\n",
-    "      temperature: Sampling temperature\n",
-    "      top_k: Top-k sampling parameter\n",
-    "      top_p: Top-p sampling parameter\n",
-    "\n",
-    "  Returns:\n",
-    "      List of lists containing responses for each prompt across passes\n",
-    "  \"\"\"\n",
-    "  multiple_call_responses = [[] for _ in range(len(prompts))]\n",
-    "\n",
-    "  for p in range(num_passes):\n",
-    "    responses = rl_cluster.rollout.generate(\n",
-    "        prompts,\n",
-    "        rollout_config=RolloutConfig(\n",
-    "            max_tokens_to_generate=TOTAL_GENERATION_STEPS,\n",
-    "            temperature=temperature,\n",
-    "            top_k=top_k,\n",
-    "            top_p=top_p,\n",
-    "        ),\n",
-    "    )\n",
-    "    responses = responses.text\n",
-    "\n",
-    "    if DEBUG:\n",
-    "      print(f\"Pass {p+1}/{num_passes}, responses: {responses}\")\n",
-    "\n",
-    "    for idx, response in enumerate(responses):\n",
-    "      multiple_call_responses[idx].append(response)\n",
-    "\n",
-    "  return multiple_call_responses"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def score_responses(question, responses, answer):\n",
-    "  \"\"\"\n",
-    "  Score a set of responses for a single question.\n",
-    "\n",
-    "  Args:\n",
-    "      question: The evaluation question\n",
-    "      responses: List of generated responses for this question\n",
-    "      answer: The correct answer\n",
-    "\n",
-    "  Returns:\n",
-    "      Tuple of (is_correct, is_partially_correct, has_correct_format)\n",
-    "  \"\"\"\n",
-    "  if DEBUG:\n",
-    "    print(\"========================================\")\n",
-    "    print(f\"Evaluation Question: {question}\")\n",
-    "    print(f\"Evaluation Answer: {answer}\")\n",
-    "    print(f\"Evaluation Responses: {responses}\")\n",
-    "    print(\"========================================\")\n",
-    "\n",
-    "  is_correct = False\n",
-    "  is_partially_correct = False\n",
-    "  has_correct_format = False\n",
-    "\n",
-    "  for response in responses:\n",
-    "    # Extract numerical response\n",
-    "    extracted_response = guess.group(1) if (guess := match_numbers.search(response)) is not None else \"-1000000\"\n",
-    "\n",
-    "    if DEBUG:\n",
-    "      print(f\"Evaluation extracted_response: {extracted_response}\")\n",
-    "\n",
-    "    # Check exact correctness\n",
-    "    try:\n",
-    "      if float(extracted_response.strip()) == float(answer.strip()):\n",
-    "        is_correct = True\n",
-    "\n",
-    "      # Check partial correctness (within 10%)\n",
-    "      ratio = float(extracted_response.strip()) / float(answer.strip())\n",
-    "      if 0.9 <= ratio <= 1.1:\n",
-    "        is_partially_correct = True\n",
-    "    except Exception as e:\n",
-    "      if DEBUG:\n",
-    "        print(f\"Evaluation Exception: {e}\")\n",
-    "        print(\"SKIPPED\")\n",
-    "\n",
-    "    # Check format correctness\n",
-    "    if match_format.search(response) is not None:\n",
-    "      has_correct_format = True\n",
-    "\n",
-    "    # Early exit if all criteria are met\n",
-    "    if is_correct and is_partially_correct and has_correct_format:\n",
-    "      break\n",
-    "\n",
-    "  return is_correct, is_partially_correct, has_correct_format"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def evaluate(\n",
-    "    dataset,\n",
-    "    rl_cluster,\n",
-    "    temperature=0.7,\n",
-    "    top_k=50,\n",
-    "    top_p=0.95,\n",
-    "    num_passes=1,\n",
-    "    corr_lst=False,\n",
-    "    make_lst=False,\n",
-    "):\n",
-    "  \"\"\"\n",
-    "  Computes accuracy and percentage of outputs matching the format.\n",
-    "\n",
-    "  Args:\n",
-    "      dataset: The evaluation dataset\n",
-    "      rl_cluster: Model cluster for generation\n",
-    "      temperature: Sampling temperature\n",
-    "      top_k: Top-k sampling parameter\n",
-    "      top_p: Top-p sampling parameter\n",
-    "      num_passes: Number of generation passes\n",
-    "      corr_lst: If True, only include correct responses in the list\n",
-    "      make_lst: If True, return a list of (question, answer, responses)\n",
-    "\n",
-    "  Returns:\n",
-    "      Tuple of statistics and optionally the response list\n",
-    "  \"\"\"\n",
-    "  response_lst = []\n",
-    "  corr = 0\n",
-    "  partially_corr = 0\n",
-    "  corr_format = 0\n",
-    "  total = 0\n",
-    "\n",
-    "  for batch in tqdm(dataset):\n",
-    "    answers = batch[\"answer\"]\n",
-    "    questions = batch[\"question\"]\n",
-    "    prompts = batch[\"prompts\"]\n",
-    "\n",
-    "    # Generate responses for all prompts in the batch\n",
-    "    multiple_call_responses = generate_responses(\n",
-    "        prompts=prompts,\n",
-    "        rl_cluster=rl_cluster,\n",
-    "        num_passes=num_passes,\n",
-    "        temperature=temperature,\n",
-    "        top_k=top_k,\n",
-    "        top_p=top_p,\n",
-    "    )\n",
-    "\n",
-    "    # Score each question-answer pair\n",
-    "    for question, responses, answer in zip(questions, multiple_call_responses, answers):\n",
-    "      is_correct, is_partially_correct, has_correct_format = score_responses(\n",
-    "          question=question,\n",
-    "          responses=responses,\n",
-    "          answer=answer,\n",
-    "      )\n",
-    "\n",
-    "      # Update counters\n",
-    "      if is_correct:\n",
-    "        corr += 1\n",
-    "        if corr_lst and make_lst:\n",
-    "          response_lst.append((question, answer, responses))\n",
-    "      else:\n",
-    "        if not corr_lst and make_lst:\n",
-    "          response_lst.append((question, answer, responses))\n",
-    "\n",
-    "      if is_partially_correct:\n",
-    "        partially_corr += 1\n",
-    "\n",
-    "      if has_correct_format:\n",
-    "        corr_format += 1\n",
-    "\n",
-    "      total += 1\n",
-    "\n",
-    "      # Print progress every 10 items\n",
-    "      if total % 10 == 0:\n",
-    "        print(\n",
-    "            f\"===> {corr=}, {total=}, {corr / total * 100=}, \"\n",
-    "            f\"{partially_corr / total * 100=}, {corr_format / total * 100=}\"\n",
-    "        )\n",
-    "\n",
-    "  # Prepare return values\n",
-    "  to_return = (\n",
-    "      corr,\n",
-    "      total,\n",
-    "      corr / total * 100,\n",
-    "      partially_corr / total * 100,\n",
-    "      corr_format / total * 100,\n",
-    "  )\n",
-    "\n",
-    "  if make_lst:\n",
-    "    return to_return, response_lst\n",
-    "  return to_return"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Training Setup\n",
-    "\n",
-    "Let's set up all the configs first - checkpointing, metric logging and training. We then train the model.\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# Ckpt saving\n",
-    "checkpointing_options = ocp.CheckpointManagerOptions(save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP)\n",
-    "\n",
-    "# Metrics logger\n",
-    "metrics_logging_options = metrics_logger.MetricsLoggerOptions(log_dir=LOG_DIR, flush_every_n_steps=20)\n",
-    "\n",
-    "\n",
-    "# Logs\n",
-    "print(f\"TensorBoard logs directory: {LOG_DIR}\")\n",
-    "print(f\"tensorboard --logdir {LOG_DIR} --port=8086\")"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# Optimizer, learning rate scheduler, gradient clipping\n",
-    "optimizer = optax.adamw(\n",
-    "    learning_rate=optax.schedules.warmup_cosine_decay_schedule(\n",
-    "        init_value=0.0,\n",
-    "        peak_value=LEARNING_RATE,\n",
-    "        warmup_steps=WARMUP_STEPS,\n",
-    "        decay_steps=MAX_STEPS,\n",
-    "        end_value=0.0,\n",
-    "    ),\n",
-    "    b1=B1,\n",
-    "    b2=B2,\n",
-    "    weight_decay=WEIGHT_DECAY,\n",
-    ")\n",
-    "\n",
-    "if MAX_GRAD_NORM is not None:\n",
-    "  optimizer = optax.chain(\n",
-    "      optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),\n",
-    "      optimizer,\n",
-    "  )"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# RL Cluster config\n",
-    "# Note that we use vLLM as the rollout engine.\n",
-    "# and we are using Tensor Parallelism for rollout\n",
-    "\n",
-    "cluster_config = rl_cluster_lib.ClusterConfig(\n",
-    "    role_to_mesh={\n",
-    "        rl_cluster_lib.Role.ACTOR: mesh,\n",
-    "        rl_cluster_lib.Role.REFERENCE: mesh,\n",
-    "        rl_cluster_lib.Role.ROLLOUT: mesh,\n",
-    "    },\n",
-    "    rollout_engine=\"vllm\",\n",
-    "    offload_to_cpu=False,\n",
-    "    training_config=rl_cluster_lib.RLTrainingConfig(\n",
-    "        actor_optimizer=optimizer,\n",
-    "        eval_every_n_steps=EVAL_EVERY_N_STEPS,\n",
-    "        max_steps=MAX_STEPS,\n",
-    "        gradient_accumulation_steps=1,\n",
-    "        # metrics logging\n",
-    "        metrics_logging_options=metrics_logging_options,\n",
-    "        # checkpoint saving\n",
-    "        checkpoint_root_directory=CKPT_DIR,\n",
-    "        checkpointing_options=checkpointing_options,\n",
-    "    ),\n",
-    "    rollout_config=base_rollout.RolloutConfig(\n",
-    "        max_tokens_to_generate=TOTAL_GENERATION_STEPS,\n",
-    "        max_prompt_length=MAX_PROMPT_LENGTH,\n",
-    "        kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,\n",
-    "        temperature=TEMPERATURE,\n",
-    "        top_p=TOP_P,\n",
-    "        top_k=TOP_K,\n",
-    "    ),\n",
-    "    rollout_vllm_model_version=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
-    "    rollout_vllm_hbm_utilization=0.2,\n",
-    "    rollout_vllm_tpu_backend_type=\"jax\",\n",
-    ")\n",
-    "\n",
-    "grpo_config = GrpoConfig(\n",
-    "    num_generations=NUM_GENERATIONS,\n",
-    "    num_iterations=NUM_ITERATIONS,\n",
-    "    beta=BETA,\n",
-    "    epsilon=EPSILON,\n",
-    ")"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# RL cluster\n",
-    "\n",
-    "rl_cluster = rl_cluster_lib.RLCluster(\n",
-    "    actor=llama3_1_8b_policy,\n",
-    "    reference=llama3_1_8b,\n",
-    "    tokenizer=model_tokenizer,\n",
-    "    cluster_config=cluster_config,\n",
-    ")\n",
-    "\n",
-    "# GRPO Trainer\n",
-    "grpo_trainer = GrpoLearner(\n",
-    "    rl_cluster=rl_cluster,\n",
-    "    reward_fns=[\n",
-    "        match_format_exactly,\n",
-    "        match_format_approximately,\n",
-    "        check_answer,\n",
-    "        check_numbers,\n",
-    "    ],\n",
-    "    grpo_config=grpo_config,\n",
-    ")"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "if DEBUG:\n",
-    "  # verify if vllm sampler works\n",
-    "  output = rl_cluster.rollout.generate(\n",
-    "      [\"The capital of France is\"],\n",
-    "      rollout_config=RolloutConfig(max_tokens_to_generate=64, temperature=0.1),\n",
-    "  )\n",
-    "\n",
-    "  print(f\"Output: {output}\")"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Evaluate Before Training\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(\n",
-    "    test_dataset,\n",
-    "    rl_cluster,\n",
-    "    **GENERATION_CONFIGS[\"greedy\"],\n",
-    ")\n",
-    "print(f\"Pre GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,\" f\" {format_accuracy=}%\")"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Start Training\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "jax.profiler.start_trace(PROFILE_DIR)\n",
-    "with mesh, nn_partitioning.axis_rules(config_policy.logical_axis_rules):\n",
-    "  grpo_trainer.train(dataset)\n",
-    "jax.profiler.stop_trace()\n",
-    "\n",
-    "print(\"HBM usage after training:\")\n",
-    "show_hbm_usage()"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Final Evaluation\n",
-    "\n",
-    "Let's evaluate our model after training!\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(\n",
-    "    test_dataset,\n",
-    "    rl_cluster,\n",
-    "    **GENERATION_CONFIGS[\"greedy\"],\n",
-    ")\n",
-    "print(f\"Post GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,\" f\" {format_accuracy=}%\")"
+    "### π **Learn More**\n",
+    "- See `src/MaxText/examples/grpo_runner.py` for CLI usage\n",
+    "- Check `src/MaxText/configs/grpo.yml` for configuration options\n",
+    "- Read `src/MaxText/examples/README.md` for more examples"
    ]
   }
  ],
  "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
   "language_info": {
-   "name": "python"
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.5"
   }
  },
  "nbformat": 4,
- "nbformat_minor": 2
+ "nbformat_minor": 4
 }
diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo.py b/src/MaxText/examples/grpo_llama3_1_8b_demo.py
index e533228689..3ba2d537d8 100644
--- a/src/MaxText/examples/grpo_llama3_1_8b_demo.py
+++ b/src/MaxText/examples/grpo_llama3_1_8b_demo.py
@@ -14,6 +14,11 @@
 
 # pylint: disable=bare-except, consider-using-generator
 """ 
+DEPRECATED: This file is deprecated and kept for reference only.
+Please use the new unified CLI interface: rl_trainer.py
+
+See GRPO_README.md for migration guide and usage examples.
+
 This tutorial demonstrates training the Llama3.1 8B-IT model on
  the GSM8K math reasoning benchmark using Group Relative Policy Optimization (GRPO).
    GRPO can enhance your model's problem-solving skills on mathematical word problems,
diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo_detailed.ipynb b/src/MaxText/examples/grpo_llama3_1_8b_demo_detailed.ipynb
new file mode 100644
index 0000000000..75394b9506
--- /dev/null
+++ b/src/MaxText/examples/grpo_llama3_1_8b_demo_detailed.ipynb
@@ -0,0 +1,1371 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# GRPO Llama3.1-8B-Instruct Demo: Group Relative Policy Optimization\n",
+    "\n",
+    "This tutorial demonstrates training the Llama3.1 8B-Instruct model on the GSM8K math reasoning benchmark using Group Relative Policy Optimization (GRPO). GRPO can enhance your model's problem-solving skills on mathematical word problems, coding problems, etc.\n",
+    "\n",
+    "## What is GRPO?\n",
+    "\n",
+    "GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by:\n",
+    "\n",
+    "1. Generating multiple responses for a given prompt\n",
+    "2. Evaluating these responses using a reward model\n",
+    "3. Calculating a relative advantage based on the group's performance to update the policy\n",
+    "\n",
+    "## Libraries Used\n",
+    "\n",
+    "- **Tunix**: Library for GRPO implementation\n",
+    "- **vLLM**: Library for efficient model inference and generation\n",
+    "- **MaxText**: For model creation and training infrastructure\n",
+    "\n",
+    "## Hardware Requirements\n",
+    "\n",
+    "This tutorial uses a single host TPUVM such as `v6e-8/v5p-8`.\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Install Necessary Libraries\n",
+    "\n",
+    "First, let's install the required dependencies:\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "### (Optional) Run this if you just have this file and nothing else\n",
+    "\n",
+    "# 1. Clone the MaxText repository (from AIβHypercomputer)\n",
+    "!git clone https://github.com/AI-Hypercomputer/maxtext.git\n",
+    "\n",
+    "# 2. Navigate into the cloned directory\n",
+    "%cd maxtext"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "### (Optional) Do not run this if you already installed the dependencies\n",
+    "\n",
+    "# 3. Ensure setup.sh is executable\n",
+    "!chmod +x setup.sh\n",
+    "\n",
+    "# 4. Execute the setup script\n",
+    "!./setup.sh\n",
+    "\n",
+    "# Install vllm requirements\n",
+    "!./src/MaxText/examples/install_tunix_vllm_requirement.sh\n",
+    "\n",
+    "# force numpy version\n",
+    "!pip install --force-reinstall numpy==2.1.2\n",
+    "# install nest_asyncio\n",
+    "!pip install nest_asyncio\n",
+    "\n",
+    "import nest_asyncio\n",
+    "\n",
+    "nest_asyncio.apply()\n",
+    "# To fix \"This event loop is already running\" error in Colab"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Imports\n",
+    "\n",
+    "Import all necessary libraries for GRPO training:\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Set up the variables for the script\n",
+    "import os\n",
+    "import sys\n",
+    "\n",
+    "# Set the MaxText home directory (where you cloned the maxtext repo)\n",
+    "MAXTEXT_REPO_ROOT = os.path.expanduser(\"~\") + \"/maxtext\"\n",
+    "print(f\"MaxText Home directory (from Python): {MAXTEXT_REPO_ROOT}\")\n",
+    "\n",
+    "DEBUG = False  # set to True to run in debug mode, for more print statements\n",
+    "# set this to the path of the checkpoint you want to load, gs:// supported\n",
+    "MODEL_CHECKPOINT_PATH = \"/path/to/scanned/model/ckpt_load_dir/\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Run GRPO training using the unified script\n",
+    "import subprocess\n",
+    "import sys\n",
+    "\n",
+    "# Build the command\n",
+    "cmd = [\n",
+    "    \"python3\", \"src/MaxText/examples/grpo_runner.py\",\n",
+    "    \"--model_name=llama3.1-8b\",\n",
+    "    \"--tokenizer_path=meta-llama/Llama-3.1-8B-Instruct\", \n",
+    "    f\"--load_parameters_path={MODEL_CHECKPOINT_PATH}\",\n",
+    "    f\"--base_output_directory={OUTPUT_DIRECTORY}\",\n",
+    "    f\"--hf_access_token={HF_TOKEN}\",\n",
+    "    f\"--steps={STEPS}\",\n",
+    "    \"--per_device_batch_size=1\",\n",
+    "    \"--learning_rate=3e-6\",\n",
+    "    \"--num_generations=2\",\n",
+    "    \"--grpo_beta=0.08\",\n",
+    "    \"--grpo_epsilon=0.2\"\n",
+    "]\n",
+    "\n",
+    "print(\"Running GRPO training with the following command:\")\n",
+    "print(\" \".join(cmd))\n",
+    "print(\"\\n\" + \"=\"*80)\n",
+    "print(\"Starting GRPO Training...\")\n",
+    "print(\"=\"*80)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Execute the GRPO training\n",
+    "result = subprocess.run(cmd, capture_output=False, text=True)\n",
+    "\n",
+    "if result.returncode == 0:\n",
+    "    print(\"\\n\" + \"=\"*80)\n",
+    "    print(\"β
 GRPO Training Completed Successfully!\")\n",
+    "    print(\"=\"*80)\n",
+    "    print(f\"π Checkpoints saved to: {OUTPUT_DIRECTORY}\")\n",
+    "    print(f\"π Logs available in: {OUTPUT_DIRECTORY}/logs\")\n",
+    "else:\n",
+    "    print(\"\\n\" + \"=\"*80)\n",
+    "    print(\"β GRPO Training Failed!\")\n",
+    "    print(\"=\"*80)\n",
+    "    print(f\"Exit code: {result.returncode}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "## Summary\n",
+    "\n",
+    "This simplified notebook demonstrates GRPO training using the unified `grpo_runner.py` script. The key benefits are:\n",
+    "\n",
+    "### β
 **Simplified Approach**\n",
+    "- **Single command** - All GRPO logic is consolidated\n",
+    "- **Easy configuration** - Just set parameters and run\n",
+    "- **Consistent interface** - Same as CLI usage\n",
+    "\n",
+    "### β
 **What Happened**\n",
+    "1. **Model Loading** - Llama3.1-8B with Tunix adapter\n",
+    "2. **Dataset Processing** - GSM8K math reasoning dataset\n",
+    "3. **GRPO Training** - Multiple reward functions for math problems\n",
+    "4. **Checkpointing** - Model weights saved for inference\n",
+    "\n",
+    "### β
 **Next Steps**\n",
+    "- **Inference** - Use the trained model for math problem solving\n",
+    "- **Evaluation** - Test on GSM8K test set\n",
+    "- **Customization** - Modify parameters for different models/datasets\n",
+    "\n",
+    "### π **Learn More**\n",
+    "- See `src/MaxText/examples/grpo_runner.py` for CLI usage\n",
+    "- Check `src/MaxText/configs/grpo.yml` for configuration options\n",
+    "- Read `src/MaxText/examples/README.md` for more examples\n",
+    "\n",
+    "# Use your Hugging Face token (recommended)\n",
+    "# Get your token from: https://huggingface.co/settings/tokens\n",
+    "\n",
+    "os.environ[\"HF_TOKEN\"] = \"hf_your_token_here\"\n",
+    "login(token=os.environ[\"HF_TOKEN\"])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Hyperparameters\n",
+    "\n",
+    "Let's define the configuration we are going to use. Note that this is by no means a \"perfect\" set of hyperparameters. To get good results, you might have to train the model for longer.\n",
+    "\n",
+    "### Data Configuration\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# ====== Data ======\n",
+    "TRAIN_DATA_DIR = f\"{MAXTEXT_REPO_ROOT}/data/train\"\n",
+    "TEST_DATA_DIR = f\"{MAXTEXT_REPO_ROOT}/data/test\"\n",
+    "if not os.path.exists(TRAIN_DATA_DIR):\n",
+    "  os.makedirs(TRAIN_DATA_DIR)\n",
+    "if not os.path.exists(TEST_DATA_DIR):\n",
+    "  os.makedirs(TEST_DATA_DIR)\n",
+    "TRAIN_FRACTION = 1.0"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Checkpoint and Logging Configuration\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# ====== Checkpoint directory =====\n",
+    "LOG_DIR = f\"{MAXTEXT_REPO_ROOT}/content/tensorboard/grpo/logs_llama3/\"\n",
+    "if not os.path.exists(LOG_DIR):\n",
+    "  os.makedirs(LOG_DIR)\n",
+    "\n",
+    "# ===== Profiling =====\n",
+    "PROFILE_DIR = f\"{MAXTEXT_REPO_ROOT}/content/jax_traces/grpo/profiles_llama3/\"\n",
+    "if not os.path.exists(PROFILE_DIR):\n",
+    "  os.makedirs(PROFILE_DIR)\n",
+    "\n",
+    "# ====== Checkpoint saving ======\n",
+    "CKPT_DIR = f\"{MAXTEXT_REPO_ROOT}/content/ckpts_llama3/\"\n",
+    "\n",
+    "if not os.path.exists(CKPT_DIR):\n",
+    "  os.makedirs(CKPT_DIR)\n",
+    "\n",
+    "SAVE_INTERVAL_STEPS = 500\n",
+    "MAX_TO_KEEP = 4\n",
+    "\n",
+    "# ====== Reproducibility ======\n",
+    "SEED = 42"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### GRPO Configuration\n",
+    "\n",
+    "GRPO-specific hyperparameters for generation and training:\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# ====== GRPO ======\n",
+    "# === Generation during GRPO training ===\n",
+    "MAX_PROMPT_LENGTH = 256\n",
+    "TOTAL_GENERATION_STEPS = 768\n",
+    "# Important to keep a high-ish temperature for varied, diverse responses during\n",
+    "# training.\n",
+    "TEMPERATURE = 0.9\n",
+    "TOP_P = 1.0\n",
+    "TOP_K = 50\n",
+    "# The number of times the policy generates multiple responses for a given prompt\n",
+    "# within a single training step. This corresponds to `G` in Algorithm 1 in the\n",
+    "# paper. The \"group\" in GRPO comes from here.\n",
+    "NUM_GENERATIONS = 2\n",
+    "\n",
+    "# === other GRPO configs ===\n",
+    "# The number of iterations per batch (π in GRPO algo 1).\n",
+    "NUM_ITERATIONS = 1\n",
+    "# The coefficient for the KL divergence penalty (π½) in the GRPO loss function.\n",
+    "# Important to keep a high enough value for this, otherwise, the KL divergence\n",
+    "# can increase unchecked.\n",
+    "BETA = 0.08\n",
+    "# Epsilon value for clipping (π in GRPO loss in paper). Similar to PPO, for\n",
+    "# stable updates.\n",
+    "EPSILON = 0.2"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Training Configuration\n",
+    "\n",
+    "Training hyperparameters including batch size, learning rate, and optimization settings:\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# ====== Training ======\n",
+    "BATCH_SIZE = 1\n",
+    "# Increase `NUM_BATCHES` and `MAX_STEPS` for better results.\n",
+    "NUM_BATCHES = 4  # 200\n",
+    "# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be\n",
+    "# increased to a max. of 330 (if batch size is 4).\n",
+    "NUM_TEST_BATCHES = 5  # 200\n",
+    "\n",
+    "SEQUENCE_LENGTH = 1024\n",
+    "\n",
+    "EVAL_EVERY_N_STEPS = 10  # this doesn't matter if `TRAIN_FRACTION = 1.0`.\n",
+    "NUM_EPOCHS = 1  # can potentially train for more epochs\n",
+    "\n",
+    "# Number of training steps.\n",
+    "MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)\n",
+    "\n",
+    "# === AdamW, warmup, cosine scheduler ===\n",
+    "LEARNING_RATE = 3e-6\n",
+    "B1 = 0.9\n",
+    "B2 = 0.99\n",
+    "WEIGHT_DECAY = 0.1\n",
+    "# == Cosine decay with warmup scheduler ==\n",
+    "# Linearly increase learning rate from 0. to 5e-6 in the first 10% training\n",
+    "# steps, and then gradually decrease the learning rate to 0 using cosine\n",
+    "# scheduler.\n",
+    "WARMUP_STEPS = int(0.1 * MAX_STEPS)\n",
+    "# == Grad clipping ==\n",
+    "# Grad clipping to prevent large gradients. Found this\n",
+    "# important to keep KL divergence in check.\n",
+    "MAX_GRAD_NORM = 0.1"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Inference and Reward Configuration\n",
+    "\n",
+    "Configuration for model inference and reward function parameters:\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# ====== Inference ======\n",
+    "GENERATION_CONFIGS = {\n",
+    "    # greedy search\n",
+    "    \"greedy\": {\"temperature\": 0.01, \"top_k\": 1, \"top_p\": 1.0},\n",
+    "    # some randomness\n",
+    "    \"standard\": {\"temperature\": 0.7, \"top_k\": 50, \"top_p\": 0.95},\n",
+    "    # liberal\n",
+    "    \"liberal\": {\"temperature\": 0.85, \"top_k\": 2000, \"top_p\": 1.0},\n",
+    "}\n",
+    "\n",
+    "# ====== Reward ======\n",
+    "REWARD_EXACT_FORMAT_MATCH = 3.0\n",
+    "REWARD_WHITE_SPACE_FORMAT_MATCH = 1.5\n",
+    "REWARD_PARTIAL_FORMAT_MATCH = 0.5\n",
+    "REWARD_RATIO_GUESS_TO_ANSWER_HIGH = 0.5\n",
+    "REWARD_RATIO_GUESS_TO_ANSWER_LOW = 0.25\n",
+    "PENALTY_INCORRECT_FORMAT = -0.5\n",
+    "PENALTY_INCORRECT_ANSWER = -1.0"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Utility Functions\n",
+    "\n",
+    "Helper functions for monitoring memory usage and other utilities:\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def show_hbm_usage():\n",
+    "  \"\"\"Displays memory usage per device.\"\"\"\n",
+    "  fmt_size = functools.partial(humanize.naturalsize, binary=True)\n",
+    "\n",
+    "  for d in jax.local_devices():\n",
+    "    stats = d.memory_stats()\n",
+    "    used = stats[\"bytes_in_use\"]\n",
+    "    limit = stats[\"bytes_limit\"]\n",
+    "    print(f\"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Data Preprocessing\n",
+    "\n",
+    "First, let's define some special tokens. We instruct the model to first reason between the `` and `` tokens. After reasoning, we expect it to provide the answer between the `` and `` tokens.\n",
+    "\n",
+    "We use OpenAI's GSM8K dataset. GSM8K comprises grade school math word problems.\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "model_tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Llama-3.1-8B-Instruct\")\n",
+    "\n",
+    "\n",
+    "reasoning_start = \"\"\n",
+    "reasoning_end = \"\"\n",
+    "solution_start = \"\"\n",
+    "solution_end = \"\"\n",
+    "\n",
+    "\n",
+    "SYSTEM_PROMPT = f\"\"\"You are given a problem. Think about the problem and \\\n",
+    "provide your reasoning. Place it between {reasoning_start} and \\\n",
+    "{reasoning_end}. Then, provide the final answer (i.e., just one numerical \\\n",
+    "value) between {solution_start} and {solution_end}.\"\"\"\n",
+    "\n",
+    "TEMPLATE = \"\"\"user\n",
+    "{system_prompt}\n",
+    "\n",
+    "{question}\n",
+    "model\"\"\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def extract_hash_answer(text: str) -> str | None:\n",
+    "  if DEBUG:\n",
+    "    print(f\"Extracting answer from: {text}\")\n",
+    "  if \"####\" not in text:\n",
+    "    return None\n",
+    "  return text.split(\"####\")[1].strip()\n",
+    "\n",
+    "\n",
+    "def get_dataset(data_dir, split=\"train\") -> grain.MapDataset:\n",
+    "  # Download data\n",
+    "  if not os.path.exists(data_dir):\n",
+    "    os.makedirs(data_dir)\n",
+    "\n",
+    "  data = tfds.data_source(\n",
+    "      \"gsm8k\",\n",
+    "      split=split,\n",
+    "      data_dir=data_dir,\n",
+    "      builder_kwargs={\"file_format\": tfds.core.FileFormat.ARRAY_RECORD},\n",
+    "      download=True,\n",
+    "  )\n",
+    "\n",
+    "  loaded_dataset = (\n",
+    "      grain.MapDataset.source(data)\n",
+    "      .shuffle(seed=SEED)\n",
+    "      .map(\n",
+    "          lambda x: {\n",
+    "              # passed to model forward pass\n",
+    "              \"prompts\": model_tokenizer.apply_chat_template(\n",
+    "                  [\n",
+    "                      {\n",
+    "                          \"role\": \"user\",\n",
+    "                          \"content\": TEMPLATE.format(\n",
+    "                              system_prompt=SYSTEM_PROMPT,\n",
+    "                              question=x[\"question\"].decode(\"utf-8\"),\n",
+    "                          ),\n",
+    "                      },\n",
+    "                  ],\n",
+    "                  tokenize=False,\n",
+    "                  add_generation_prompt=True,\n",
+    "              ),\n",
+    "              # passed to reward functions\n",
+    "              \"question\": x[\"question\"].decode(\"utf-8\"),\n",
+    "              # passed to reward functions\n",
+    "              \"answer\": extract_hash_answer(x[\"answer\"].decode(\"utf-8\")),\n",
+    "          }\n",
+    "      )\n",
+    "  )\n",
+    "  return loaded_dataset"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset = get_dataset(TRAIN_DATA_DIR, \"train\").batch(BATCH_SIZE)[:NUM_BATCHES]\n",
+    "\n",
+    "if TRAIN_FRACTION == 1.0:\n",
+    "  train_dataset = dataset.repeat(NUM_EPOCHS)\n",
+    "  val_dataset = None\n",
+    "else:\n",
+    "  train_dataset = dataset[: int(len(dataset) * TRAIN_FRACTION)]\n",
+    "  train_dataset = train_dataset.repeat(NUM_EPOCHS)\n",
+    "\n",
+    "  val_dataset = dataset[int(len(dataset) * TRAIN_FRACTION) :].repeat(NUM_EPOCHS)\n",
+    "\n",
+    "test_dataset = get_dataset(TEST_DATA_DIR, \"test\").batch(BATCH_SIZE)[:NUM_TEST_BATCHES]\n",
+    "\n",
+    "\n",
+    "# Let's see how one batch of the dataset looks like!\n",
+    "\n",
+    "\n",
+    "if DEBUG:\n",
+    "  for ele in train_dataset[:1]:\n",
+    "    pprint(ele)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Load the Policy Model and the Reference Model\n",
+    "\n",
+    "The policy model is the model which is actually trained and whose weights are updated. The reference model is the model with which we compute KL divergence. This is to ensure that the policy updates are not huge and that it does not deviate too much from the reference model.\n",
+    "\n",
+    "Typically, the reference model is the base model, and the policy model is the same base model, but with potentially LoRA parameters where only the LoRA parameters are updated. This script is not using LoRA, so both the reference and policy models are the same.\n",
+    "\n",
+    "Note: We perform full precision (fp32) training. You can, however, leverage Qwix for QAT.\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "print(\"HBM usage before loading model:\")\n",
+    "show_hbm_usage()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Load MaxText Model\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_ref_maxtext_model(config):\n",
+    "\n",
+    "  model, mesh = model_creation_utils.create_nnx_model(config)\n",
+    "  with mesh:\n",
+    "    tunix_model = TunixMaxTextAdapter(\n",
+    "        base_model=model,\n",
+    "    )\n",
+    "\n",
+    "    model_config = llama3_lib.ModelConfig.llama3_1_8b()\n",
+    "    tunix_model.config = model_config\n",
+    "\n",
+    "  return tunix_model, mesh\n",
+    "\n",
+    "\n",
+    "model_config = llama3_lib.ModelConfig.llama3_1_8b()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Load Reference Model\n",
+    "\n",
+    "Note: pass the path to your scanned checkpoint for \"load_parameters_path\". To create a scanned checkpoint, you can use `/maxtext/src/MaxText/utils/ckpt_conversion/to_maxtext.py`\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Load the reference model\n",
+    "config_ref = pyconfig.initialize(\n",
+    "    [\n",
+    "        \"\",\n",
+    "        f\"{MAXTEXT_REPO_ROOT}/src/MaxText/configs/base.yml\",\n",
+    "    ],\n",
+    "    base_output_directory=\"dummy\",  # This is not used in Tunix.\n",
+    "    run_name=\"test-tunix-maxtext-llama3.1-8b\",\n",
+    "    tokenizer_type=\"tiktoken\",\n",
+    "    tokenizer_path=\"assets/tokenizer_llama3.tiktoken\",\n",
+    "    load_parameters_path=f\"{MODEL_CHECKPOINT_PATH}\",\n",
+    "    max_target_length=SEQUENCE_LENGTH,\n",
+    "    async_checkpointing=\"false\",\n",
+    "    model_name=\"llama3.1-8b\",\n",
+    "    skip_jax_distributed_system=\"true\",\n",
+    "    weight_dtype=\"bfloat16\",\n",
+    "    attention=\"dot_product\",\n",
+    "    remat_policy=\"custom\",\n",
+    "    decoder_layer_input=\"offload\",\n",
+    "    query_proj=\"offload\",\n",
+    "    key_proj=\"offload\",\n",
+    "    value_proj=\"offload\",\n",
+    ")\n",
+    "\n",
+    "llama3_1_8b, mesh = get_ref_maxtext_model(config_ref)\n",
+    "\n",
+    "llama3_1_8b.config = model_config\n",
+    "\n",
+    "nnx.display(llama3_1_8b)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "if DEBUG:\n",
+    "  print(\"Model initialized successfully\")\n",
+    "  print(f\"Model mesh shape: {mesh.shape}\")\n",
+    "  print(f\"Model config: {model_config}\")\n",
+    "\n",
+    "  # Sanity check that weights are loaded correctly\n",
+    "  _maxtext_state_flatten = nnx.state(llama3_1_8b).flat_state()\n",
+    "  maxtext_state_flatten = {\".\".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten}\n",
+    "  print(\n",
+    "      f\"maxtext_state_flatten[base.token_embedder.embedding].value={maxtext_state_flatten['base.token_embedder.embedding'].value}\"\n",
+    "  )\n",
+    "\n",
+    "\n",
+    "# See the memory use after loading the reference model:\n",
+    "print(\"HBM usage after loading ref model:\")\n",
+    "show_hbm_usage()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Load Policy Model\n",
+    "\n",
+    "Note: pass the path to your scanned checkpoint for \"load_parameters_path\". To create a scanned checkpoint, you can use `/maxtext/src/MaxText/utils/ckpt_conversion/to_maxtext.py`\n",
+    "\n",
+    "TODO: @mazumdera: change this to use lora\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "config_policy = pyconfig.initialize(\n",
+    "    [\n",
+    "        \"\",\n",
+    "        f\"{MAXTEXT_REPO_ROOT}/src/MaxText/configs/base.yml\",\n",
+    "    ],\n",
+    "    base_output_directory=\"dummy\",  # This is not used in Tunix.\n",
+    "    run_name=\"test-tunix-maxtext-llama3.1-8b\",  # This is not used in Tunix.\n",
+    "    tokenizer_type=\"tiktoken\",\n",
+    "    tokenizer_path=\"assets/tokenizer_llama3.tiktoken\",\n",
+    "    load_parameters_path=f\"{MODEL_CHECKPOINT_PATH}\",\n",
+    "    max_target_length=SEQUENCE_LENGTH,\n",
+    "    async_checkpointing=\"false\",\n",
+    "    model_name=\"llama3.1-8b\",\n",
+    "    skip_jax_distributed_system=\"true\",\n",
+    "    weight_dtype=\"bfloat16\",\n",
+    "    attention=\"dot_product\",\n",
+    "    remat_policy=\"custom\",\n",
+    "    decoder_layer_input=\"offload\",\n",
+    "    query_proj=\"offload\",\n",
+    "    key_proj=\"offload\",\n",
+    "    value_proj=\"offload\",\n",
+    ")\n",
+    "llama3_1_8b_policy, mesh_policy = get_ref_maxtext_model(config_policy)\n",
+    "\n",
+    "llama3_1_8b_policy.config = model_config\n",
+    "\n",
+    "nnx.display(llama3_1_8b_policy)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "if DEBUG:\n",
+    "  print(\"Model initialized successfully\")\n",
+    "  print(f\"Model mesh shape: {mesh_policy.shape}\")\n",
+    "  print(f\"Model config: {model_config}\")\n",
+    "\n",
+    "  # Sanity check that weights are loaded correctly\n",
+    "  _maxtext_state_flatten = nnx.state(llama3_1_8b_policy).flat_state()\n",
+    "  maxtext_state_flatten = {\".\".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten}\n",
+    "  print(\n",
+    "      f\"maxtext_state_flatten[base.token_embedder.embedding].value={maxtext_state_flatten['base.token_embedder.embedding'].value}\"\n",
+    "  )\n",
+    "\n",
+    "# See memory usage after loading the policy model:\n",
+    "print(\"HBM usage after loading policy model:\")\n",
+    "show_hbm_usage()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Define Reward Functions\n",
+    "\n",
+    "We define four reward functions:\n",
+    "\n",
+    "1. **Format Matching**: Reward if the format of the output exactly matches the instruction given in `TEMPLATE`\n",
+    "2. **Approximate Format Matching**: Reward if the format of the output approximately matches the instruction given in `TEMPLATE`\n",
+    "3. **Answer Correctness**: Reward if the answer is correct/partially correct\n",
+    "4. **Number Extraction**: Sometimes, the text between ``, `` might not be one number. So, extract the number, and reward the model if the answer is correct.\n",
+    "\n",
+    "The reward functions are inspired from [here](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb).\n",
+    "\n",
+    "First off, let's define a RegEx for checking whether the format matches.\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "match_format = re.compile(\n",
+    "    rf\"^[\\s]{{0,}}\" rf\"{reasoning_start}.+?{reasoning_end}.*?\" rf\"{solution_start}(.+?){solution_end}\" rf\"[\\s]{{0,}}$\",\n",
+    "    flags=re.MULTILINE | re.DOTALL,\n",
+    ")\n",
+    "\n",
+    "match_format.search(\n",
+    "    f\"{reasoning_start}Let me\" f\" think!{reasoning_end}{solution_start}2{solution_end}\",\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Reward Function 1: Exact Format Matching\n",
+    "\n",
+    "Give the model a reward of 3 points if the format matches exactly.\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def match_format_exactly(prompts, completions, **kargs):\n",
+    "  scores = []\n",
+    "  for completion in completions:\n",
+    "    score = 0\n",
+    "    response = completion\n",
+    "    # Match if format is seen exactly!\n",
+    "    if match_format.search(response) is not None:\n",
+    "      score += REWARD_EXACT_FORMAT_MATCH\n",
+    "    scores.append(score)\n",
+    "  return scores"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Reward Function 2: Approximate Format Matching\n",
+    "\n",
+    "We also reward the model if the format of the output matches partially.\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def match_format_approximately(prompts, completions, **kargs):\n",
+    "  scores = []\n",
+    "\n",
+    "  for completion in completions:\n",
+    "    score = 0\n",
+    "    response = completion\n",
+    "    # Count how many keywords are seen - we penalize if too many!\n",
+    "    # If we see 1, then plus some points!\n",
+    "    score += REWARD_PARTIAL_FORMAT_MATCH if response.count(reasoning_start) == 1 else PENALTY_INCORRECT_FORMAT\n",
+    "    score += REWARD_PARTIAL_FORMAT_MATCH if response.count(reasoning_end) == 1 else PENALTY_INCORRECT_FORMAT\n",
+    "    score += REWARD_PARTIAL_FORMAT_MATCH if response.count(solution_start) == 1 else PENALTY_INCORRECT_FORMAT\n",
+    "    score += REWARD_PARTIAL_FORMAT_MATCH if response.count(solution_end) == 1 else PENALTY_INCORRECT_FORMAT\n",
+    "    scores.append(score)\n",
+    "  return scores"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Reward Function 3: Answer Correctness\n",
+    "\n",
+    "Reward the model if the answer is correct. A reward is also given if the answer does not match exactly, i.e., based on how close the answer is to the correct value.\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def check_answer(prompts, completions, answer, **kargs):\n",
+    "  responses = completions\n",
+    "\n",
+    "  extracted_responses = [guess.group(1) if (guess := match_format.search(r)) is not None else None for r in responses]\n",
+    "\n",
+    "  scores = []\n",
+    "  for guess, true_answer in zip(extracted_responses, answer):\n",
+    "    score = 0\n",
+    "    if guess is None:\n",
+    "      scores.append(0)\n",
+    "      continue\n",
+    "    # Correct answer gets 3 points!\n",
+    "    if guess == true_answer:\n",
+    "      score += REWARD_EXACT_FORMAT_MATCH\n",
+    "    # Match if spaces are seen\n",
+    "    elif guess.strip() == true_answer.strip():\n",
+    "      score += REWARD_WHITE_SPACE_FORMAT_MATCH\n",
+    "    else:\n",
+    "      # We also reward it if the answer is close via ratios!\n",
+    "      # Ie if the answer is within some range, reward it!\n",
+    "      try:\n",
+    "        ratio = float(guess) / float(true_answer)\n",
+    "        if ratio >= 0.9 and ratio <= 1.1:\n",
+    "          score += REWARD_RATIO_GUESS_TO_ANSWER_HIGH\n",
+    "        elif ratio >= 0.8 and ratio <= 1.2:\n",
+    "          score += REWARD_RATIO_GUESS_TO_ANSWER_LOW\n",
+    "        else:\n",
+    "          score += PENALTY_INCORRECT_ANSWER  # Penalize wrong answers\n",
+    "      except:\n",
+    "        score += PENALTY_INCORRECT_FORMAT  # Penalize\n",
+    "    scores.append(score)\n",
+    "  return scores"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Reward Function 4: Number Extraction\n",
+    "\n",
+    "Sometimes, the text between `` and `` might not be one number; it can be a sentence. So, we extract the number and compare the answer.\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "match_numbers = re.compile(rf\"{solution_start}.*?([\\d\\.]{{1,}})\", flags=re.MULTILINE | re.DOTALL)\n",
+    "match_numbers.findall(f\"{solution_start}  0.34  {solution_end}\")\n",
+    "\n",
+    "\n",
+    "def check_numbers(prompts, completions, answer, **kargs):\n",
+    "  question = kargs[\"question\"]\n",
+    "  responses = completions\n",
+    "\n",
+    "  extracted_responses = [guess.group(1) if (guess := match_numbers.search(r)) is not None else None for r in responses]\n",
+    "\n",
+    "  scores = []\n",
+    "  if DEBUG:\n",
+    "    print(\"START ============================\")\n",
+    "    print(f\"Question: {question[0]}\")\n",
+    "    print(f\"Answer: {answer[0]}\")\n",
+    "    print(f\"Response: {responses[0]}\")\n",
+    "    print(f\"Extracted: {extracted_responses[0]}\")\n",
+    "    print(\"END ==============================\")\n",
+    "  for guess, true_answer in zip(extracted_responses, answer):\n",
+    "    if guess is None:\n",
+    "      scores.append(0)\n",
+    "      continue\n",
+    "    # Convert to numbers\n",
+    "    try:\n",
+    "      true_answer = float(true_answer.strip())\n",
+    "      guess = float(guess.strip())\n",
+    "      scores.append(1.5 if guess == true_answer else 0.0)\n",
+    "    except:\n",
+    "      scores.append(0)\n",
+    "      continue\n",
+    "  return scores"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Evaluation Functions\n",
+    "\n",
+    "Before we train the model, let's evaluate the model on the test set so we can see the improvement post training.\n",
+    "\n",
+    "We evaluate it in two ways:\n",
+    "\n",
+    "**Quantitative**\n",
+    "- **Answer Accuracy**: percentage of samples for which the model predicts the correct final numerical answer\n",
+    "- **Answer (Partial) Accuracy**: percentage of samples for which the model predicts a final numerical answer such that the `model answer / answer` ratio lies between 0.9 and 1.1.\n",
+    "- **Format Accuracy**: percentage of samples for which the model outputs the correct format, i.e., reasoning between the reasoning special tokens, and the final answer between the ``, `` tokens.\n",
+    "\n",
+    "**Qualitative**\n",
+    "\n",
+    "We'll also print outputs for a few given questions so that we can compare the generated output later.\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def generate_responses(\n",
+    "    prompts,\n",
+    "    rl_cluster,\n",
+    "    num_passes=1,\n",
+    "    temperature=0.7,\n",
+    "    top_k=50,\n",
+    "    top_p=0.95,\n",
+    "):\n",
+    "  \"\"\"\n",
+    "  Generate responses for a batch of prompts across multiple passes.\n",
+    "\n",
+    "  Args:\n",
+    "      prompts: List of prompts to generate responses for\n",
+    "      rl_cluster: Model cluster for generation\n",
+    "      num_passes: Number of generation passes\n",
+    "      temperature: Sampling temperature\n",
+    "      top_k: Top-k sampling parameter\n",
+    "      top_p: Top-p sampling parameter\n",
+    "\n",
+    "  Returns:\n",
+    "      List of lists containing responses for each prompt across passes\n",
+    "  \"\"\"\n",
+    "  multiple_call_responses = [[] for _ in range(len(prompts))]\n",
+    "\n",
+    "  for p in range(num_passes):\n",
+    "    responses = rl_cluster.rollout.generate(\n",
+    "        prompts,\n",
+    "        rollout_config=RolloutConfig(\n",
+    "            max_tokens_to_generate=TOTAL_GENERATION_STEPS,\n",
+    "            temperature=temperature,\n",
+    "            top_k=top_k,\n",
+    "            top_p=top_p,\n",
+    "        ),\n",
+    "    )\n",
+    "    responses = responses.text\n",
+    "\n",
+    "    if DEBUG:\n",
+    "      print(f\"Pass {p+1}/{num_passes}, responses: {responses}\")\n",
+    "\n",
+    "    for idx, response in enumerate(responses):\n",
+    "      multiple_call_responses[idx].append(response)\n",
+    "\n",
+    "  return multiple_call_responses"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def score_responses(question, responses, answer):\n",
+    "  \"\"\"\n",
+    "  Score a set of responses for a single question.\n",
+    "\n",
+    "  Args:\n",
+    "      question: The evaluation question\n",
+    "      responses: List of generated responses for this question\n",
+    "      answer: The correct answer\n",
+    "\n",
+    "  Returns:\n",
+    "      Tuple of (is_correct, is_partially_correct, has_correct_format)\n",
+    "  \"\"\"\n",
+    "  if DEBUG:\n",
+    "    print(\"========================================\")\n",
+    "    print(f\"Evaluation Question: {question}\")\n",
+    "    print(f\"Evaluation Answer: {answer}\")\n",
+    "    print(f\"Evaluation Responses: {responses}\")\n",
+    "    print(\"========================================\")\n",
+    "\n",
+    "  is_correct = False\n",
+    "  is_partially_correct = False\n",
+    "  has_correct_format = False\n",
+    "\n",
+    "  for response in responses:\n",
+    "    # Extract numerical response\n",
+    "    extracted_response = guess.group(1) if (guess := match_numbers.search(response)) is not None else \"-1000000\"\n",
+    "\n",
+    "    if DEBUG:\n",
+    "      print(f\"Evaluation extracted_response: {extracted_response}\")\n",
+    "\n",
+    "    # Check exact correctness\n",
+    "    try:\n",
+    "      if float(extracted_response.strip()) == float(answer.strip()):\n",
+    "        is_correct = True\n",
+    "\n",
+    "      # Check partial correctness (within 10%)\n",
+    "      ratio = float(extracted_response.strip()) / float(answer.strip())\n",
+    "      if 0.9 <= ratio <= 1.1:\n",
+    "        is_partially_correct = True\n",
+    "    except Exception as e:\n",
+    "      if DEBUG:\n",
+    "        print(f\"Evaluation Exception: {e}\")\n",
+    "        print(\"SKIPPED\")\n",
+    "\n",
+    "    # Check format correctness\n",
+    "    if match_format.search(response) is not None:\n",
+    "      has_correct_format = True\n",
+    "\n",
+    "    # Early exit if all criteria are met\n",
+    "    if is_correct and is_partially_correct and has_correct_format:\n",
+    "      break\n",
+    "\n",
+    "  return is_correct, is_partially_correct, has_correct_format"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def evaluate(\n",
+    "    dataset,\n",
+    "    rl_cluster,\n",
+    "    temperature=0.7,\n",
+    "    top_k=50,\n",
+    "    top_p=0.95,\n",
+    "    num_passes=1,\n",
+    "    corr_lst=False,\n",
+    "    make_lst=False,\n",
+    "):\n",
+    "  \"\"\"\n",
+    "  Computes accuracy and percentage of outputs matching the format.\n",
+    "\n",
+    "  Args:\n",
+    "      dataset: The evaluation dataset\n",
+    "      rl_cluster: Model cluster for generation\n",
+    "      temperature: Sampling temperature\n",
+    "      top_k: Top-k sampling parameter\n",
+    "      top_p: Top-p sampling parameter\n",
+    "      num_passes: Number of generation passes\n",
+    "      corr_lst: If True, only include correct responses in the list\n",
+    "      make_lst: If True, return a list of (question, answer, responses)\n",
+    "\n",
+    "  Returns:\n",
+    "      Tuple of statistics and optionally the response list\n",
+    "  \"\"\"\n",
+    "  response_lst = []\n",
+    "  corr = 0\n",
+    "  partially_corr = 0\n",
+    "  corr_format = 0\n",
+    "  total = 0\n",
+    "\n",
+    "  for batch in tqdm(dataset):\n",
+    "    answers = batch[\"answer\"]\n",
+    "    questions = batch[\"question\"]\n",
+    "    prompts = batch[\"prompts\"]\n",
+    "\n",
+    "    # Generate responses for all prompts in the batch\n",
+    "    multiple_call_responses = generate_responses(\n",
+    "        prompts=prompts,\n",
+    "        rl_cluster=rl_cluster,\n",
+    "        num_passes=num_passes,\n",
+    "        temperature=temperature,\n",
+    "        top_k=top_k,\n",
+    "        top_p=top_p,\n",
+    "    )\n",
+    "\n",
+    "    # Score each question-answer pair\n",
+    "    for question, responses, answer in zip(questions, multiple_call_responses, answers):\n",
+    "      is_correct, is_partially_correct, has_correct_format = score_responses(\n",
+    "          question=question,\n",
+    "          responses=responses,\n",
+    "          answer=answer,\n",
+    "      )\n",
+    "\n",
+    "      # Update counters\n",
+    "      if is_correct:\n",
+    "        corr += 1\n",
+    "        if corr_lst and make_lst:\n",
+    "          response_lst.append((question, answer, responses))\n",
+    "      else:\n",
+    "        if not corr_lst and make_lst:\n",
+    "          response_lst.append((question, answer, responses))\n",
+    "\n",
+    "      if is_partially_correct:\n",
+    "        partially_corr += 1\n",
+    "\n",
+    "      if has_correct_format:\n",
+    "        corr_format += 1\n",
+    "\n",
+    "      total += 1\n",
+    "\n",
+    "      # Print progress every 10 items\n",
+    "      if total % 10 == 0:\n",
+    "        print(\n",
+    "            f\"===> {corr=}, {total=}, {corr / total * 100=}, \"\n",
+    "            f\"{partially_corr / total * 100=}, {corr_format / total * 100=}\"\n",
+    "        )\n",
+    "\n",
+    "  # Prepare return values\n",
+    "  to_return = (\n",
+    "      corr,\n",
+    "      total,\n",
+    "      corr / total * 100,\n",
+    "      partially_corr / total * 100,\n",
+    "      corr_format / total * 100,\n",
+    "  )\n",
+    "\n",
+    "  if make_lst:\n",
+    "    return to_return, response_lst\n",
+    "  return to_return"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Training Setup\n",
+    "\n",
+    "Let's set up all the configs first - checkpointing, metric logging and training. We then train the model.\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Ckpt saving\n",
+    "checkpointing_options = ocp.CheckpointManagerOptions(save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP)\n",
+    "\n",
+    "# Metrics logger\n",
+    "metrics_logging_options = metrics_logger.MetricsLoggerOptions(log_dir=LOG_DIR, flush_every_n_steps=20)\n",
+    "\n",
+    "\n",
+    "# Logs\n",
+    "print(f\"TensorBoard logs directory: {LOG_DIR}\")\n",
+    "print(f\"tensorboard --logdir {LOG_DIR} --port=8086\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Optimizer, learning rate scheduler, gradient clipping\n",
+    "optimizer = optax.adamw(\n",
+    "    learning_rate=optax.schedules.warmup_cosine_decay_schedule(\n",
+    "        init_value=0.0,\n",
+    "        peak_value=LEARNING_RATE,\n",
+    "        warmup_steps=WARMUP_STEPS,\n",
+    "        decay_steps=MAX_STEPS,\n",
+    "        end_value=0.0,\n",
+    "    ),\n",
+    "    b1=B1,\n",
+    "    b2=B2,\n",
+    "    weight_decay=WEIGHT_DECAY,\n",
+    ")\n",
+    "\n",
+    "if MAX_GRAD_NORM is not None:\n",
+    "  optimizer = optax.chain(\n",
+    "      optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),\n",
+    "      optimizer,\n",
+    "  )"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# RL Cluster config\n",
+    "# Note that we use vLLM as the rollout engine.\n",
+    "# and we are using Tensor Parallelism for rollout\n",
+    "\n",
+    "cluster_config = rl_cluster_lib.ClusterConfig(\n",
+    "    role_to_mesh={\n",
+    "        rl_cluster_lib.Role.ACTOR: mesh,\n",
+    "        rl_cluster_lib.Role.REFERENCE: mesh,\n",
+    "        rl_cluster_lib.Role.ROLLOUT: mesh,\n",
+    "    },\n",
+    "    rollout_engine=\"vllm\",\n",
+    "    offload_to_cpu=False,\n",
+    "    training_config=rl_cluster_lib.RLTrainingConfig(\n",
+    "        actor_optimizer=optimizer,\n",
+    "        eval_every_n_steps=EVAL_EVERY_N_STEPS,\n",
+    "        max_steps=MAX_STEPS,\n",
+    "        gradient_accumulation_steps=1,\n",
+    "        # metrics logging\n",
+    "        metrics_logging_options=metrics_logging_options,\n",
+    "        # checkpoint saving\n",
+    "        checkpoint_root_directory=CKPT_DIR,\n",
+    "        checkpointing_options=checkpointing_options,\n",
+    "    ),\n",
+    "    rollout_config=base_rollout.RolloutConfig(\n",
+    "        max_tokens_to_generate=TOTAL_GENERATION_STEPS,\n",
+    "        max_prompt_length=MAX_PROMPT_LENGTH,\n",
+    "        kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,\n",
+    "        temperature=TEMPERATURE,\n",
+    "        top_p=TOP_P,\n",
+    "        top_k=TOP_K,\n",
+    "    ),\n",
+    "    rollout_vllm_model_version=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
+    "    rollout_vllm_hbm_utilization=0.2,\n",
+    "    rollout_vllm_tpu_backend_type=\"jax\",\n",
+    ")\n",
+    "\n",
+    "grpo_config = GrpoConfig(\n",
+    "    num_generations=NUM_GENERATIONS,\n",
+    "    num_iterations=NUM_ITERATIONS,\n",
+    "    beta=BETA,\n",
+    "    epsilon=EPSILON,\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# RL cluster\n",
+    "\n",
+    "rl_cluster = rl_cluster_lib.RLCluster(\n",
+    "    actor=llama3_1_8b_policy,\n",
+    "    reference=llama3_1_8b,\n",
+    "    tokenizer=model_tokenizer,\n",
+    "    cluster_config=cluster_config,\n",
+    ")\n",
+    "\n",
+    "# GRPO Trainer\n",
+    "grpo_trainer = GrpoLearner(\n",
+    "    rl_cluster=rl_cluster,\n",
+    "    reward_fns=[\n",
+    "        match_format_exactly,\n",
+    "        match_format_approximately,\n",
+    "        check_answer,\n",
+    "        check_numbers,\n",
+    "    ],\n",
+    "    grpo_config=grpo_config,\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "if DEBUG:\n",
+    "  # verify if vllm sampler works\n",
+    "  output = rl_cluster.rollout.generate(\n",
+    "      [\"The capital of France is\"],\n",
+    "      rollout_config=RolloutConfig(max_tokens_to_generate=64, temperature=0.1),\n",
+    "  )\n",
+    "\n",
+    "  print(f\"Output: {output}\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Evaluate Before Training\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(\n",
+    "    test_dataset,\n",
+    "    rl_cluster,\n",
+    "    **GENERATION_CONFIGS[\"greedy\"],\n",
+    ")\n",
+    "print(f\"Pre GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,\" f\" {format_accuracy=}%\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Start Training\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "jax.profiler.start_trace(PROFILE_DIR)\n",
+    "with mesh, nn_partitioning.axis_rules(config_policy.logical_axis_rules):\n",
+    "  grpo_trainer.train(dataset)\n",
+    "jax.profiler.stop_trace()\n",
+    "\n",
+    "print(\"HBM usage after training:\")\n",
+    "show_hbm_usage()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Final Evaluation\n",
+    "\n",
+    "Let's evaluate our model after training!\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(\n",
+    "    test_dataset,\n",
+    "    rl_cluster,\n",
+    "    **GENERATION_CONFIGS[\"greedy\"],\n",
+    ")\n",
+    "print(f\"Post GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,\" f\" {format_accuracy=}%\")"
+   ]
+  }
+ ],
+ "metadata": {
+  "language_info": {
+   "name": "python"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py b/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py
index f994e73566..bae80e38ff 100644
--- a/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py
+++ b/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py
@@ -14,6 +14,11 @@
 
 # pylint: disable=bare-except, consider-using-generator
 """ 
+DEPRECATED: This file is deprecated and kept for reference only.
+Please use the new unified CLI interface: rl_trainer.py
+
+See GRPO_README.md for migration guide and usage examples.
+
 This tutorial demonstrates training the Llama3.1 8B-IT model on
  the GSM8K math reasoning benchmark using Group Relative Policy Optimization (GRPO).
    GRPO can enhance your model's problem-solving skills on mathematical word problems,
diff --git a/src/MaxText/experimental/rl/rl.yml b/src/MaxText/experimental/rl/rl.yml
new file mode 100644
index 0000000000..38108dc82b
--- /dev/null
+++ b/src/MaxText/experimental/rl/rl.yml
@@ -0,0 +1,77 @@
+# Copyright 2023β2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+base_config: "base.yml"
+
+logical_axis_rules: [
+                      ['prefill_activation_length', ['data']],
+                      ['prefill_activation_norm_length', ['data']],
+                      ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
+                      ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
+                      ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
+                      ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
+                      ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
+                      ['activation_length', ['context_autoregressive', 'sequence']],
+                      ['activation_length', ['context_autoregressive']],
+                      ['activation_q_length', ['context_autoregressive']],
+                      ['activation_kv_length', ['context_autoregressive']],
+                      ['activation_norm_length', ['tensor_sequence', 'sequence']],
+                      ['activation_embed', ['tensor_transpose']],
+                      ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
+                      ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
+                      ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
+                      ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']],
+                      ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']],
+                      ['activation_vocab', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']],
+                      ['activation_vocab', ['tensor', 'tensor_transpose']],
+                      ['activation_vocab', 'tensor_sequence'],
+                      ['activation_vocab', ['sequence', 'context_autoregressive']],
+                      ['activation_stage', 'stage'],
+                      ['activation_exp', ['expert', 'context_autoregressive']],
+                      ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']],
+                      ['decode_length', []],
+                      ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
+                      ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context_autoregressive']],
+                      ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
+                      ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
+                      ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
+                      ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'expert']],
+                      ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']],
+                      ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']],
+                      ['embed', ['fsdp', 'sequence', 'expert']],
+                      ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']],
+                      ['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']],
+                      ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']],
+                      ['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive']],
+                      ['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']],
+                      ['layers', 'stage'],
+                      ['kv', []],
+                      ['kv_head_dim', []],
+                      ['cache_batch_prefill', []],
+                      ['cache_batch', ['context_autoregressive']],
+                      ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
+                      ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
+                      ['cache_kv', []],
+                      ['cache_sequence', ['context_autoregressive']],
+                      ['cache_scale_sequence', ['context_autoregressive']],
+                      ['exp', ['expert', 'context_autoregressive']],
+                      ['paged_kv_heads', []],
+                      ['num_pages', ['tensor']],
+                      ['tokens_per_page', []],
+                      ['paged_kv_head_dim_size', []],
+                    ]
+# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
+data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
+
+return_log_prob: True
\ No newline at end of file
diff --git a/src/MaxText/rl_utils.py b/src/MaxText/rl_utils.py
new file mode 100644
index 0000000000..48aa58a31b
--- /dev/null
+++ b/src/MaxText/rl_utils.py
@@ -0,0 +1,219 @@
+# Copyright 2023β2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=bare-except, consider-using-generator
+
+import functools
+import os
+from pprint import pprint
+import re
+import sys
+
+from datetime import datetime
+from flax import nnx
+from flax.linen import partitioning as nn_partitioning
+import grain
+import humanize
+
+
+import jax
+from jax.sharding import Mesh
+import optax
+from orbax import checkpoint as ocp
+import tensorflow_datasets as tfds
+from tqdm.auto import tqdm
+from tunix.rl import rl_cluster as rl_cluster_lib
+from tunix.rl.rollout import base_rollout
+from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner
+from tunix.sft import metrics_logger
+
+
+from transformers import AutoTokenizer
+
+from flax import linen as nn
+import numpy as np
+from etils import epath
+
+from tunix.rl.rollout.base_rollout import RolloutConfig
+
+from MaxText.globals import MAXTEXT_ASSETS_ROOT
+import pathwaysutils
+
+# Let's define a RegEx for checking whether the format matches.
+#
+def get_match_format_regex(mt_config):
+  """Returns a compiled regex to extract the answer from a completion."""
+  match_format = re.compile(
+      (
+          r"^[\s]{0,}}"
+          rf"{mt_config.reasoning_start_token}.+?{mt_config.reasoning_end_token}.*?"
+          rf"{mt_config.solution_start_token}(.+?){mt_config.solution_end_token}"
+          r"[\s]{0,}$"
+      ),
+      flags=re.MULTILINE | re.DOTALL,
+  )
+  if mt_config.debug:
+    match_format.search(
+    f"{mt_config.reasoning_start_token}Let me" f" think!{mt_config.reasoning_end_token}{mt_config.solution_start_token}2{mt_config.solution_end_token}",
+  )
+  return match_format
+
+
+def match_format_exactly(prompts, completions, mt_config, **kargs):
+  """
+  Give the model a reward of mt_config.reward_exact_format_match points if the format matches exactly.
+  """
+  scores = []
+  match_format = get_match_format_regex(mt_config)
+  for completion in completions:
+    score = 0
+    response = completion
+    # Match if format is seen exactly!
+    if match_format.search(response) is not None:
+      score += mt_config.reward_exact_format_match
+    scores.append(score)
+  return scores
+
+
+def match_format_approximately(prompts, completions, mt_config, **kargs):
+  """
+  We also reward the model if the format of the output matches partially.
+  """
+  scores = []
+
+  for completion in completions:
+    score = 0
+    # Count how many keywords are seen - we penalize if too many!
+    # If we see 1, then plus some points!
+    score += (
+        mt_config.reward_partial_format_match
+        if completion.count(mt_config.reasoning_start_token) == 1
+        else mt_config.penalty_incorrect_format
+    )
+    score += (
+        mt_config.reward_partial_format_match
+        if completion.count(mt_config.reasoning_end_token) == 1
+        else mt_config.penalty_incorrect_format
+    )
+    score += (
+        mt_config.reward_partial_format_match
+        if completion.count(mt_config.solution_start_token) == 1
+        else mt_config.penalty_incorrect_format
+    )
+    score += (
+        mt_config.reward_partial_format_match
+        if completion.count(mt_config.solution_end_token) == 1
+        else mt_config.penalty_incorrect_format
+    )
+    scores.append(score)
+  return scores
+
+
+def check_answer(prompts, completions, answer, mt_config, **kargs):
+  """
+  Reward the model if the answer is correct. A reward is also given if the answer
+  does not match exactly, i.e., based on how close the answer is to the correct
+  value.
+  """
+  match_format = get_match_format_regex(mt_config)
+  extracted_responses = [
+      guess.group(1) if (guess := match_format.search(c)) is not None else None
+      for c in completions
+  ]
+
+  extracted_responses = [guess.group(1) if (guess := match_format.search(c)) is not None else None for c in completions]
+
+  scores = []
+  for guess, true_answer in zip(extracted_responses, answer):
+    score = 0
+    if guess is None:
+      scores.append(0)
+      continue
+    # Correct answer gets 3 points!
+    if guess == true_answer:
+      score += mt_config.reward_exact_format_match
+    # Match if spaces are seen
+    elif guess.strip() == true_answer.strip():
+      score += mt_config.reward_white_space_format_match
+    else:
+      # We also reward it if the answer is close via ratios!
+      # Ie if the answer is within some range, reward it!
+      try:
+        ratio = float(guess) / float(true_answer)
+        if ratio >= 0.9 and ratio <= 1.1:
+          score += mt_config.reward_ratio_guess_to_answer_high
+        elif ratio >= 0.8 and ratio <= 1.2:
+          score += mt_config.reward_ratio_guess_to_answer_low
+        else:
+          score += mt_config.penalty_incorrect_answer  # Penalize wrong answers
+      except:
+        score += mt_config.penalty_incorrect_format  # Penalize
+    scores.append(score)
+  return scores
+
+
+# Sometimes, the text between `` and `` might not be one
+# number; it can be a sentence. So, we extract the number and compare the answer.
+
+def get_match_numbers_regex(mt_config):
+  """Returns a compiled regex to extract the answer from a completion."""
+  match_numbers = re.compile(
+      rf"{mt_config.solution_start_token}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL
+  )
+  if mt_config.debug:
+    match_numbers.findall(f"{mt_config.solution_start_token}  0.34  {mt_config.solution_end_token}")
+  return match_numbers
+
+def check_numbers(prompts, completions, answer, mt_config, **kargs):
+  """
+  Reward the model if the answer is correct.
+  """
+  question = kargs["question"]
+  
+  match_numbers = get_match_numbers_regex(mt_config)
+  extracted_responses = [
+      guess.group(1) if (guess := match_numbers.search(c)) is not None else None
+      for c in completions
+  ]
+
+  scores = []
+  if mt_config.debug:
+    print("START ============================")
+    print(f"Question: {question[0]}")
+    print(f"Answer: {answer[0]}")
+    print(f"Response: {completions[0]}")
+    print(f"Extracted: {extracted_responses[0]}")
+    print("END ==============================")
+  for guess, true_answer in zip(extracted_responses, answer):
+    if guess is None:
+      scores.append(0)
+      continue
+    # Convert to numbers
+    try:
+      true_answer = float(true_answer.strip())
+      guess = float(guess.strip())
+      scores.append(1.5 if guess == true_answer else 0.0)
+    except:
+      scores.append(0)
+      continue
+  return scores
+
+def extract_hash_answer(text: str, debug: bool = False) -> str | None:
+  """Function to extract only the answer hash from the text."""
+  if debug:
+    print(f"Extracting answer from: {text}")
+  if "####" not in text:
+    return None
+  return text.split("####")[1].strip()
+
diff --git a/src/MaxText/train_rl.py b/src/MaxText/train_rl.py
new file mode 100644
index 0000000000..59e1a53c2f
--- /dev/null
+++ b/src/MaxText/train_rl.py
@@ -0,0 +1,420 @@
+# Copyright 2023β2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+GRPO Trainer
+
+This module provides a unified `rl_train` function that consolidates the common
+RL training logic. It handles model loading, reward function setup, dataset
+processing, and training orchestration. By default, we run Group Relative Policy Optimization (GRPO) on 
+GSM8K math reasoning benchmark. GRPO can enhance your model's problem-solving skills on mathematical word problems,
+coding problems, etc. 
+
+Usage:
+  Usage Examples:
+
+  # Llama3.1-8B (single host)
+  python3 src/MaxText/examples/rl_trainer.py \\
+    --model_name=llama3.1-8b \\
+    --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \\
+    --load_parameters_path=gs://path/to/checkpoint \\
+    --base_output_directory=/tmp/grpo_output \\
+    --hf_access_token=$HF_TOKEN \\
+    --steps=100
+
+  # Llama3.1-70B with Pathways (multi-host)
+  python3 src/MaxText/examples/rl_trainer.py \\
+    --model_name=llama3.1-70b \\
+    --tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \\
+    --load_parameters_path=gs://path/to/checkpoint \\
+    --base_output_directory=gs://path/to/output \\
+    --hf_access_token=$HF_TOKEN \\
+    --use_pathways=true \\
+    --steps=100
+
+  # Custom dataset
+  python3 src/MaxText/examples/rl_trainer.py \\
+    --model_name=llama3.1-8b \\
+    --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \\
+    --load_parameters_path=gs://path/to/checkpoint \\
+    --base_output_directory=/tmp/grpo_output \\
+    --hf_access_token=$HF_TOKEN \\
+    --hf_path=custom/dataset \\
+    --steps=100
+"""
+
+from pprint import pprint
+from typing import Sequence
+from absl import app
+import os
+import re
+
+import jax
+from jax.sharding import Mesh
+from flax.linen import partitioning as nn_partitioning
+import optax
+from orbax import checkpoint as ocp
+import tensorflow_datasets as tfds
+from transformers import AutoTokenizer
+
+import grain
+
+import pathwaysutils
+
+from tunix.rl import rl_cluster as rl_cluster_lib
+from tunix.rl.rollout import base_rollout
+from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner
+from tunix.sft import metrics_logger
+from tunix.models.llama3 import model as llama3_lib
+
+from MaxText import max_logging, max_utils, maxtext_utils, pyconfig
+from MaxText import model_creation_utils
+from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
+
+from MaxText import rl_utils
+
+
+# We use OpenAI's GSM8K dataset. GSM8K comprises grade school math word problems.
+
+
+
+def get_maxtext_model(config, devices=None):
+  """
+  Load MaxText model with Tunix adapter.
+  # Note: pass the path to your scanned checkpoint for "load_parameters_path". To generate a scanned checkpoint, you can use the `scanned_checkpoint.py` script in MaxText.
+  # To create a scanned checkpoint, you can use /maxtext/MaxText/utils/ckpt_conversion/to_maxtext.py
+  """
+  model, mesh = model_creation_utils.create_nnx_model(config, devices)
+  with mesh:
+    tunix_model = TunixMaxTextAdapter(base_model=model)
+    tunix_model.config = None
+  return tunix_model, mesh
+
+
+def setup_device_allocation(mt_config, use_pathways: bool = False):
+  """Setup device allocation for training and inference."""
+
+  devices = jax.devices()
+  num_vms = len(devices) // mt_config.chips_per_vm
+  trainer_devices = devices
+  sampler_devices = devices
+  if num_vms >= 2 and use_pathways:
+    # Multiple hosts with Pathways - potentially split devices for trainer and sampler
+    # based on trainer_devices_fraction and sampler_devices_fraction
+    print(f"{num_vms} VMs detected, allocating trainer and sampler devices, and using Pathways.")
+    num_devices = len(devices)
+    num_trainer_devices = int(num_devices * mt_config.trainer_devices_fraction)
+    num_sampler_devices = int(num_devices * mt_config.sampler_devices_fraction)
+    trainer_devices = devices[:num_trainer_devices]
+    sampler_devices = devices[num_devices - num_sampler_devices :]
+    if mt_config.trainer_devices_fraction!=1.0:
+      print(f"Using first {len(trainer_devices)} devices as Trainer devices")
+    if mt_config.sampler_devices_fraction != 1.0:
+      print(f"Using last {len(sampler_devices)} devices as Sampler devices")
+  
+  return trainer_devices, sampler_devices, num_vms
+
+def get_dataset(model_tokenizer, mt_config, data_dir, split="train") -> grain.MapDataset:
+  # Download data
+  if not os.path.exists(data_dir):
+    os.makedirs(data_dir)
+
+  data = tfds.data_source(
+      "gsm8k",
+      split=split,
+      data_dir=data_dir,
+      builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD},
+      download=True,
+  )
+
+  loaded_dataset = (
+      grain.MapDataset.source(data)
+      .shuffle(seed=mt_config.data_shuffle_seed)
+      .map(
+          lambda x: {
+              # passed to model forward pass
+              "prompts": model_tokenizer.apply_chat_template(
+                  [
+                      {
+                          "role": "user",
+                          "content": mt_config.template.format(
+                              system_prompt=mt_config.system_prompt,
+                              question=x["question"].decode("utf-8"),
+                          ),
+                      },
+                  ],
+                  tokenize=False,
+                  add_generation_prompt=True,
+              ),
+              # passed to reward functions
+              "question": x["question"].decode("utf-8"),
+              # passed to reward functions
+              "answer": rl_utils.extract_hash_answer(x["answer"].decode("utf-8")),
+          }
+      )
+  )
+  return loaded_dataset
+
+def rl_train(mt_config):
+  """
+  Run RL training with the provided configuration.
+
+  Args:
+    mt_config: MaxText configuration object
+  """
+  # ====== Debug flag for verbose logs ======
+  DEBUG = mt_config.debug
+
+  print("Starting GRPO Training")
+
+  # Number of training steps.
+  max_train_steps = int(mt_config.num_batches * mt_config.num_iterations * mt_config.train_fraction * mt_config.num_epochs)
+
+  # ====== Data ======
+  # Setup data directories
+  home = os.path.expanduser("~") + "/"
+  train_data_dir = f"{home}/data/train"
+  test_data_dir = f"{home}/data/test"
+  if not os.path.exists(train_data_dir):
+    os.makedirs(train_data_dir)
+  if not os.path.exists(test_data_dir):
+    os.makedirs(test_data_dir)
+ 
+  # Create model tokenizer
+  model_tokenizer = AutoTokenizer.from_pretrained(mt_config.hf_model_name)
+
+  # Load datasets
+  dataset = get_dataset(model_tokenizer, mt_config, train_data_dir, "train").batch(mt_config.batch_size)[:mt_config.num_batches]
+
+  if mt_config.train_fraction == 1.0:
+    train_dataset = dataset.repeat(mt_config.num_epochs)
+    val_dataset = None
+  else:
+    train_dataset = dataset[: int(len(dataset) * mt_config.train_fraction)]
+    train_dataset = train_dataset.repeat(mt_config.num_epochs)
+
+    val_dataset = dataset[int(len(dataset) * mt_config.train_fraction) :].repeat(mt_config.num_epochs)
+
+  test_dataset = get_dataset(model_tokenizer, mt_config, test_data_dir, "test").batch(mt_config.batch_size)[:mt_config.num_test_batches]
+
+
+  # Let's see how one batch of the dataset looks like!
+  if mt_config.debug:
+    for ele in train_dataset[:1]:
+      pprint(ele)
+
+
+  
+  # Setup device allocation
+  if jax.extend.backend.get_backend().platform_version == "Pathways":
+    max_logging.log("Pathways backend detected. Disabling setting profile options.")
+    use_pathways = True
+  else:
+    use_pathways = False
+  print(f"Use Pathways: {use_pathways}")
+  trainer_devices, sampler_devices, num_vms = setup_device_allocation(mt_config, use_pathways)
+
+  # Load reference model
+  print("Creating reference model and also meshes for reference and rollout")
+  reference_model, reference_mesh = get_maxtext_model(mt_config, trainer_devices)
+  devices_array = maxtext_utils.create_device_mesh(mt_config, sampler_devices)
+  # if trainer_devices=sampler_devices, then rollout_mesh=reference_mesh
+  # else rollout_mesh uses sampler_devices
+  rollout_mesh = Mesh(devices_array, mt_config.mesh_axes)
+  if mt_config.debug:
+      print("Reference Model initialized successfully")
+      nnx.display(reference_model)
+      print(f"Reference mesh shape: {reference_mesh.shape}")
+
+      # Sanity check that weights are loaded correctly.
+      _maxtext_state_flatten = nnx.state(reference_model).flat_state()
+      maxtext_state_flatten = {".".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten}
+      print(
+          f"maxtext_state_flatten[base.token_embedder.embedding].value={maxtext_state_flatten['base.token_embedder.embedding'].value}"
+      )
+
+  # TODO: @mazumdera: change this to use lora
+  # Load policy model
+  print("Creating policy model with same config as reference model on trainer mesh")
+  policy_model, policy_mesh = get_maxtext_model(mt_config, trainer_devices)
+  actor_mesh = policy_mesh
+
+  if mt_config.debug:
+      print("Policy Model initialized successfully")
+      nnx.display(policy_model)
+      print(f"Policy mesh shape: {policy_mesh.shape}")
+
+  # Setup optimizer
+  optimizer = optax.adamw(
+      learning_rate=optax.schedules.warmup_cosine_decay_schedule(
+          init_value=0.0,
+          peak_value=mt_config.learning_rate,
+          # Linearly increase learning rate from 0. to learning_rate in the first 
+          # warmup_steps_fraction training steps, and then gradually decrease the 
+          # learning rate to 0 using cosine scheduler.
+          warmup_steps=int(mt_config.warmup_steps_fraction*mt_config.max_train_steps),
+          decay_steps=max_train_steps,
+      ),
+      b1=mt_config.adam_b1,
+      b2=mt_config.adam_b2,
+      weight_decay=mt_config.adam_weight_decay,
+  )
+
+  # TODO: @mazumdera: try optimizer offloading with adamw
+  # Add gradient clipping if specified
+  # Grad clipping to prevent large gradients. We find this
+  # important to keep KL divergence in check.
+  if mt_config.gradient_clipping_threshold > 0:
+    optimizer = optax.chain(
+        optax.clip_by_global_norm(max_norm=mt_config.gradient_clipping_threshold),
+        optimizer,
+    )
+
+  # Setup checkpointing
+  ckpt_dir = mt_config.checkpoint_dir
+  checkpointing_options = ocp.CheckpointManagerOptions(
+      save_interval_steps=mt_config.checkpoint_period, mt_config.max_num_checkpoints_to_keep
+  )
+
+  # Setup metrics logging
+  log_dir=mt_config.tensorboard_dir
+  print(f"TensorBoard logs directory: {log_dir}")
+  print(f"tensorboard --logdir {log_dir} --port=8086")
+  metrics_logging_options = metrics_logger.MetricsLoggerOptions(log_dir=log_dir, flush_every_n_steps=20)
+
+  # Profiler configurations
+  # TODO: xfgu@: add profiling
+  profiler_options = None
+
+  # RL Cluster config
+  # Note that we use vLLM as the rollout engine.
+  # and we are using Tensor Parallelism for rollout
+  cluster_config = rl_cluster_lib.ClusterConfig(
+      role_to_mesh={
+          rl_cluster_lib.Role.ACTOR: actor_mesh,
+          rl_cluster_lib.Role.REFERENCE: reference_mesh,
+          rl_cluster_lib.Role.ROLLOUT: rollout_mesh,
+      },
+      rollout_engine="vllm",
+      offload_to_cpu=False,
+      training_config=rl_cluster_lib.RLTrainingConfig(
+          actor_optimizer=optimizer,
+          eval_every_n_steps=mt_config.eval_interval,
+          max_steps=max_train_steps,
+          # metrics logging
+          metrics_logging_options=metrics_logging_options,
+          # profiling
+          profiler_options=profiler_options,
+          # checkpoint saving
+          checkpoint_root_directory=ckpt_dir,
+          checkpointing_options=checkpointing_options,
+      ),
+      rollout_config=base_rollout.RolloutConfig(
+          max_tokens_to_generate=mt_config.max_target_length,
+          max_prompt_length=mt_config.max_prefill_predict_length,
+          kv_cache_size=mt_config.max_prefill_predict_length
+          + mt_config.max_target_length
+          + mt_config.kv_cache_buffer,
+          temperature=mt_config.decode_sampling_temperature,
+          top_p=mt_config.decode_sampling_nucleus_p,
+          top_k=mt_config.decode_sampling_top_k,
+      ),
+      rollout_vllm_model_version=mt_config.hf_model_name,
+      rollout_vllm_hbm_utilization=mt_config.hbm_utilization_vllm,
+      rollout_vllm_tpu_backend_type="jax",
+      rollout_vllm_swap_space_size_gb=mt_config.swap_space_vllm_gb,
+  )
+
+  # Setup GRPO config
+  grpo_config = GrpoConfig(
+      num_generations=mt_config.num_generations,
+      num_iterations=mt_config.num_iterations,
+      beta=mt_config.grpo_beta,
+      epsilon=mt_config.grpo_epsilon,
+      loss_algo=mt_config.loss_algo,
+  )
+
+  # Create RL cluster
+  print("Creating RL cluster...")
+  with nn_partitioning.axis_rules(mt_config.logical_axis_rules):
+    rl_cluster = rl_cluster_lib.RLCluster(
+        actor=policy_model,
+        reference=reference_model,
+        tokenizer=model_tokenizer,
+        cluster_config=cluster_config,
+    )
+
+  # Create GRPO trainer
+  print("Setting up GRPO trainer...")
+  rl_trainer = GrpoLearner(
+      rl_cluster=rl_cluster,
+      reward_fns=[ # type: ignore
+          lambda **kwargs: rl_utils.match_format_exactly(mt_config=mt_config, **kwargs),
+          lambda **kwargs: rl_utils.match_format_approximately(mt_config=mt_config, **kwargs),
+          lambda **kwargs: rl_utils.check_answer(mt_config=mt_config, **kwargs),
+          lambda **kwargs: rl_utils.check_numbers(mt_config=mt_config, **kwargs),
+      ],
+      grpo_config=grpo_config,
+  )
+
+
+
+  if mt_config.debug:
+    # verify if vllm sampler works
+    output = rl_cluster.rollout.generate(
+        ["The capital of France is"],
+        rollout_config=base_rollout.RolloutConfig(max_tokens_to_generate=64, temperature=0.1),
+    )
+
+    print(f"Output: {output}")
+
+  # Start training
+  print("Starting GRPO training...")
+  with policy_mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
+    rl_trainer.train(train_dataset)
+  
+  profile_dir = mt_config.tensorboard_dir
+  max_logging.log(f"Saving profiles to {profile_dir}")
+
+  jax.profiler.start_trace(profile_dir)
+  with reference_mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
+    rl_trainer.train(train_dataset)
+  jax.profiler.stop_trace()
+
+  print("GRPO Training Completed Successfully!")
+
+  return rl_trainer, rl_cluster
+
+def main(argv: Sequence[str]) -> None:
+  """Main function to run SFT training.
+
+  Args:
+    argv: Command-line arguments.
+  """
+  pathwaysutils.initialize()
+  jax.config.update("jax_default_prng_impl", "unsafe_rbg")
+  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
+  if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""):
+    os.environ["LIBTPU_INIT_ARGS"] = (
+        os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
+    )
+
+  mt_config = pyconfig.initialize(argv)
+  max_utils.print_system_information()
+
+  rl_train(mt_config)
+
+
+if __name__ == "__main__":
+  app.run(main)