From a703dfd6a661c36b502772dc4774a4ea3276ecf2 Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Tue, 21 Oct 2025 21:50:52 +0400 Subject: [PATCH 01/31] grpo refactoring Signed-off-by: Vladimir Suvorov --- src/MaxText/configs/grpo.yml | 113 ++++++ src/MaxText/examples/GRPO_README.md | 226 +++++++++++ src/MaxText/examples/grpo_demo.py | 353 ++++++++++++++++++ .../examples/grpo_llama3_1_70b_demo_pw.py | 5 + src/MaxText/examples/grpo_llama3_1_8b_demo.py | 5 + .../examples/grpo_llama3_1_8b_demo_pw.py | 5 + 6 files changed, 707 insertions(+) create mode 100644 src/MaxText/configs/grpo.yml create mode 100644 src/MaxText/examples/GRPO_README.md create mode 100755 src/MaxText/examples/grpo_demo.py diff --git a/src/MaxText/configs/grpo.yml b/src/MaxText/configs/grpo.yml new file mode 100644 index 000000000..2fa6e6de0 --- /dev/null +++ b/src/MaxText/configs/grpo.yml @@ -0,0 +1,113 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# GRPO Configuration +# This config consolidates common parameters for GRPO training across different model sizes + +base_config: "base.yml" + +use_grpo: True +train_data_columns: 'prompt' + +# Dataset Configuration +dataset_type: hf # Huggingface input pipeline +hf_path: 'gsm8k' +hf_data_split: 'main' +hf_data_files: 'train' + +# Model and Tokenizer Configuration +# Override these via CLI: +# model_name, tokenizer_path, load_parameters_path + +# Sequence Lengths +max_prefill_predict_length: 256 +max_target_length: 768 + +# Training Hyperparameters +learning_rate: 3.e-6 +adam_b1: 0.9 +adam_b2: 0.99 +weight_decay: 0.1 +max_grad_norm: 0.1 + +# Group Relative Policy Optimization (GRPO) Parameters +num_generations: 2 +grpo_beta: 0.08 # KL divergence penalty coefficient +grpo_epsilon: 0.2 # Clipping value for stable updates +inference_rollouts: 1 + +# Generation Configuration During Training +decode_sampling_strategy: "weighted" +decode_sampling_temperature: 0.9 +decode_sampling_top_p: 1.0 +decode_sampling_top_k: 50 + +# Training Loop Configuration +steps: 100 +per_device_batch_size: 1 +eval_interval: 10 +eval_steps: 5 + +# Checkpoint Configuration +enable_checkpointing: True +async_checkpointing: True +checkpoint_period: 50 + +# Pathways Inference Configuration +# For multi-host/multi-slice setups +use_pathways_reshard: False +inference_devices_per_replica: 4 +inference_replicas: 1 + +# Tokenizer Settings +add_bos: False +add_eos: False +return_log_prob: True + +# Performance and Memory +weight_dtype: bfloat16 +dtype: bfloat16 + +# Profiling +profiler: xplane +skip_first_n_steps_for_profiler: 5 +profiler_steps: 3 + +# Splash Attention Block Sizes +# Tuned for GRPO workloads +sa_block_q: 128 +sa_block_kv: 128 +sa_block_kv_compute: 128 +sa_block_q_dkv: 128 +sa_block_kv_dkv: 128 +sa_block_kv_dkv_compute: 128 +sa_block_q_dq: 128 +sa_block_kv_dq: 128 +sa_use_fused_bwd_kernel: False +sa_q_layout: "HEAD_DIM_MINOR" +sa_k_layout: "HEAD_DIM_MINOR" +sa_v_layout: "HEAD_DIM_MINOR" + +# Model-Specific Overrides (examples) +# For Llama3.1-8B: +# model_name: llama3.1-8b +# tokenizer_path: meta-llama/Llama-3.1-8B-Instruct +# ici_fsdp_parallelism: 8 +# +# For Llama3.1-70B with Pathways: +# model_name: llama3.1-70b +# tokenizer_path: meta-llama/Llama-3.1-70B-Instruct +# use_pathways_reshard: True +# ici_fsdp_parallelism: 16 + diff --git a/src/MaxText/examples/GRPO_README.md b/src/MaxText/examples/GRPO_README.md new file mode 100644 index 000000000..380c35459 --- /dev/null +++ b/src/MaxText/examples/GRPO_README.md @@ -0,0 +1,226 @@ +# GRPO Demo - Unified Training Interface + +This directory contains a unified interface for running GRPO (Group Relative Policy Optimization) training demos across different model sizes and configurations. + +## Overview + +Previously, there were separate demo scripts for different model configurations: +- `grpo_llama3_1_8b_demo.py` - Single host 8B model +- `grpo_llama3_1_8b_demo_pw.py` - Pathways-based 8B model +- `grpo_llama3_1_70b_demo_pw.py` - Pathways-based 70B model + +These have been consolidated into a single **unified CLI script** (`grpo_demo.py`) that works with the new **grpo.yml** configuration file. + +## New Structure + +### Configuration File +`src/MaxText/configs/grpo.yml` +- Contains common GRPO parameters +- Can be overridden via CLI arguments +- Consolidates dataset, training, and GRPO-specific settings + +### Unified CLI Script +`src/MaxText/examples/grpo_demo.py` +- Single entry point for all GRPO demos +- Supports both single-host and multi-host (Pathways) setups +- Provides intuitive CLI arguments +- Automatically generates proper config for training and inference + +## Usage Examples + +### Llama3.1-8B (Single Host) + +```bash +python3 src/MaxText/examples/grpo_demo.py \ + --model_name=llama3.1-8b \ + --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ + --load_parameters_path=gs://path/to/checkpoint \ + --base_output_directory=/tmp/grpo_output \ + --hf_access_token=$HF_TOKEN \ + --steps=100 +``` + +### Llama3.1-70B with Pathways (Multi-Host) + +```bash +python3 src/MaxText/examples/grpo_demo.py \ + --model_name=llama3.1-70b \ + --tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \ + --load_parameters_path=gs://path/to/checkpoint \ + --base_output_directory=gs://path/to/output \ + --hf_access_token=$HF_TOKEN \ + --use_pathways \ + --inference_devices_per_replica=4 \ + --inference_replicas=4 \ + --ici_fsdp_parallelism=16 \ + --steps=100 +``` + +### Custom Dataset + +```bash +python3 src/MaxText/examples/grpo_demo.py \ + --model_name=llama3.1-8b \ + --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ + --load_parameters_path=gs://path/to/checkpoint \ + --base_output_directory=/tmp/grpo_output \ + --hf_access_token=$HF_TOKEN \ + --hf_path=custom/dataset \ + --hf_data_split=train \ + --steps=100 +``` + +### With Custom GRPO Parameters + +```bash +python3 src/MaxText/examples/grpo_demo.py \ + --model_name=llama3.1-8b \ + --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ + --load_parameters_path=gs://path/to/checkpoint \ + --base_output_directory=/tmp/grpo_output \ + --hf_access_token=$HF_TOKEN \ + --num_generations=4 \ + --grpo_beta=0.04 \ + --grpo_epsilon=0.15 \ + --learning_rate=5e-6 \ + --steps=200 +``` + +## CLI Arguments + +### Required Arguments + +- `--model_name`: Model name (e.g., llama3.1-8b, llama3.1-70b) +- `--tokenizer_path`: HuggingFace tokenizer path +- `--load_parameters_path`: Path to model checkpoint (local or gs://) +- `--base_output_directory`: Base output directory for logs and checkpoints + +### Dataset Arguments + +- `--hf_access_token`: HuggingFace access token (can use $HF_TOKEN env var) +- `--hf_path`: HuggingFace dataset path (default: gsm8k) +- `--hf_data_split`: Dataset split (default: main) +- `--hf_data_files`: Dataset files (default: train) + +### Training Arguments + +- `--steps`: Number of training steps (default: 100) +- `--per_device_batch_size`: Per device batch size (default: 1) +- `--learning_rate`: Learning rate (default: 3e-6) +- `--run_name`: Custom run name for the experiment + +### GRPO-Specific Arguments + +- `--num_generations`: Number of generations per prompt (default: 2) +- `--grpo_beta`: KL divergence penalty coefficient (default: 0.08) +- `--grpo_epsilon`: Clipping value for stable updates (default: 0.2) + +### Sequence Length Arguments + +- `--max_prefill_predict_length`: Maximum prompt length (default: 256) +- `--max_target_length`: Maximum total sequence length (default: 768) + +### Multi-Host/Pathways Arguments + +- `--use_pathways`: Enable Pathways for multi-host training +- `--inference_devices_per_replica`: Devices per inference replica (default: 4) +- `--inference_replicas`: Number of inference replicas (default: 1) + +### Parallelism Arguments + +- `--ici_fsdp_parallelism`: FSDP parallelism (-1 for auto) +- `--ici_tensor_parallelism`: Tensor parallelism (-1 for auto) + +### Other Arguments + +- `--profiler`: Profiler to use (default: xplane) +- `--checkpoint_period`: Checkpoint saving period (default: 50) +- `--config_file`: Optional custom config file (overrides grpo.yml) + +## Migration Guide + +### From Individual Demo Scripts + +**Old way:** +```python +# Editing grpo_llama3_1_8b_demo.py directly +MODEL_NAME = "llama3.1-8b" +TOKENIZER_PATH = "meta-llama/Llama-3.1-8B-Instruct" +# ... many hardcoded parameters +``` + +**New way:** +```bash +python3 src/MaxText/examples/grpo_demo.py \ + --model_name=llama3.1-8b \ + --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ + # ... all parameters via CLI +``` + +### Benefits + +1. **Single Script**: One script for all model sizes and configurations +2. **No Code Editing**: All parameters configurable via CLI +3. **Better Defaults**: Common parameters in `grpo.yml` +4. **Easier Testing**: Quickly test different configurations +5. **CI/CD Friendly**: Easy to integrate into automated workflows + +## Configuration Files + +### grpo.yml +Main configuration file with sensible defaults for GRPO demos. Override any parameter via CLI. + +Location: `src/MaxText/configs/grpo.yml` + +### grpo.yml and grpo_inference.yml +Low-level configuration files used by the GRPO trainer. Generally, you don't need to modify these directly. + +Location: `src/MaxText/experimental/rl/` + +## Advanced Usage + +### Using a Custom Config File + +If you have a custom configuration: + +```bash +python3 src/MaxText/examples/grpo_demo.py \ + --config_file=/path/to/custom_config.yml \ + --model_name=llama3.1-8b \ + # ... other args +``` + +### Environment Variables + +You can set these environment variables: +- `HF_TOKEN`: HuggingFace access token (alternative to `--hf_access_token`) + +## Troubleshooting + +### Common Issues + +1. **HF_TOKEN not set**: Make sure to either set the environment variable or pass `--hf_access_token` + +2. **Pathways configuration**: For multi-host setups, ensure: + - `--use_pathways` is set + - `--inference_devices_per_replica` and `--inference_replicas` are configured correctly + - The total number of devices is sufficient + +3. **Memory issues**: Try reducing: + - `--per_device_batch_size` + - `--max_target_length` + - `--num_generations` + +## Contributing + +When adding new features or model support: +1. Add sensible defaults to `grpo.yml` +2. Add CLI arguments to `grpo_demo.py` if needed +3. Update this README with examples + +## See Also + +- [GRPO Paper](https://arxiv.org/abs/2402.03300) +- [MaxText Documentation](../../../docs/) +- [Tunix Library](https://github.com/google/tunix) + diff --git a/src/MaxText/examples/grpo_demo.py b/src/MaxText/examples/grpo_demo.py new file mode 100755 index 000000000..c8871cfc0 --- /dev/null +++ b/src/MaxText/examples/grpo_demo.py @@ -0,0 +1,353 @@ +#!/usr/bin/env python3 +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unified GRPO Demo Script + +This script provides a unified CLI interface for running GRPO training demos +across different model sizes and configurations. It consolidates the common +logic from individual demo scripts and uses the grpo.yml config. + +Usage Examples: + +# Llama3.1-8B (single host) +python3 src/MaxText/examples/grpo_demo.py \\ + --model_name=llama3.1-8b \\ + --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \\ + --load_parameters_path=gs://path/to/checkpoint \\ + --base_output_directory=/tmp/grpo_output \\ + --hf_access_token=$HF_TOKEN \\ + --steps=100 + +# Llama3.1-70B with Pathways (multi-host) +python3 src/MaxText/examples/grpo_demo.py \\ + --model_name=llama3.1-70b \\ + --tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \\ + --load_parameters_path=gs://path/to/checkpoint \\ + --base_output_directory=gs://path/to/output \\ + --hf_access_token=$HF_TOKEN \\ + --use_pathways=true \\ + --steps=100 + +# Custom dataset +python3 src/MaxText/examples/grpo_demo.py \\ + --model_name=llama3.1-8b \\ + --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \\ + --load_parameters_path=gs://path/to/checkpoint \\ + --base_output_directory=/tmp/grpo_output \\ + --hf_access_token=$HF_TOKEN \\ + --hf_path=custom/dataset \\ + --steps=100 +""" + +import argparse +import os +import sys +from typing import Optional + +# Add MaxText to path +script_dir = os.path.dirname(os.path.abspath(__file__)) +maxtext_root = os.path.abspath(os.path.join(script_dir, "..", "..")) +sys.path.insert(0, maxtext_root) + +from MaxText import pyconfig +from MaxText.globals import MAXTEXT_PKG_DIR +from MaxText.experimental.rl import grpo_trainer + + +def create_parser(): + """Create argument parser for GRPO demo.""" + parser = argparse.ArgumentParser( + description="Unified GRPO Demo Script", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + # Model Configuration + parser.add_argument( + "--model_name", + type=str, + required=True, + help="Model name (e.g., llama3.1-8b, llama3.1-70b)", + ) + parser.add_argument( + "--tokenizer_path", + type=str, + required=True, + help="HuggingFace tokenizer path (e.g., meta-llama/Llama-3.1-8B-Instruct)", + ) + parser.add_argument( + "--load_parameters_path", + type=str, + required=True, + help="Path to model checkpoint (local or gs://)", + ) + + # Output Configuration + parser.add_argument( + "--base_output_directory", + type=str, + required=True, + help="Base output directory for logs and checkpoints", + ) + parser.add_argument( + "--run_name", + type=str, + default=None, + help="Run name for this experiment", + ) + + # Dataset Configuration + parser.add_argument( + "--hf_access_token", + type=str, + default=os.environ.get("HF_TOKEN"), + help="HuggingFace access token (default: $HF_TOKEN env var)", + ) + parser.add_argument( + "--hf_path", + type=str, + default="gsm8k", + help="HuggingFace dataset path", + ) + parser.add_argument( + "--hf_data_split", + type=str, + default="main", + help="HuggingFace dataset split", + ) + parser.add_argument( + "--hf_data_files", + type=str, + default="train", + help="HuggingFace dataset files", + ) + + # Training Configuration + parser.add_argument( + "--steps", + type=int, + default=100, + help="Number of training steps", + ) + parser.add_argument( + "--per_device_batch_size", + type=int, + default=1, + help="Per device batch size", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=3e-6, + help="Learning rate", + ) + + # GRPO-Specific Parameters + parser.add_argument( + "--num_generations", + type=int, + default=2, + help="Number of generations per prompt (group size)", + ) + parser.add_argument( + "--grpo_beta", + type=float, + default=0.08, + help="KL divergence penalty coefficient", + ) + parser.add_argument( + "--grpo_epsilon", + type=float, + default=0.2, + help="Clipping value for stable updates", + ) + + # Sequence Lengths + parser.add_argument( + "--max_prefill_predict_length", + type=int, + default=256, + help="Maximum prompt length", + ) + parser.add_argument( + "--max_target_length", + type=int, + default=768, + help="Maximum total sequence length", + ) + + # Pathways/Multi-Host Configuration + parser.add_argument( + "--use_pathways", + action="store_true", + help="Use Pathways for multi-host training", + ) + parser.add_argument( + "--inference_devices_per_replica", + type=int, + default=4, + help="Number of devices per inference replica", + ) + parser.add_argument( + "--inference_replicas", + type=int, + default=1, + help="Number of inference replicas", + ) + + # Parallelism Configuration + parser.add_argument( + "--ici_fsdp_parallelism", + type=int, + default=-1, + help="FSDP parallelism (-1 for auto)", + ) + parser.add_argument( + "--ici_tensor_parallelism", + type=int, + default=-1, + help="Tensor parallelism (-1 for auto)", + ) + + # Other Configuration + parser.add_argument( + "--profiler", + type=str, + default="xplane", + help="Profiler to use (xplane, none)", + ) + parser.add_argument( + "--checkpoint_period", + type=int, + default=50, + help="Checkpoint saving period", + ) + parser.add_argument( + "--config_file", + type=str, + default=None, + help="Optional custom config file (overrides grpo.yml)", + ) + + return parser + + +def build_config_argv(args): + """Build configuration arguments for MaxText pyconfig.""" + # Use custom config or default grpo.yml + if args.config_file: + base_config = args.config_file + else: + base_config = os.path.join(MAXTEXT_PKG_DIR, "configs", "grpo.yml") + + # Build training config argv + train_config_argv = [ + "", # Placeholder for argv[0] + base_config, + f"model_name={args.model_name}", + f"tokenizer_path={args.tokenizer_path}", + f"load_parameters_path={args.load_parameters_path}", + f"base_output_directory={args.base_output_directory}", + f"steps={args.steps}", + f"per_device_batch_size={args.per_device_batch_size}", + f"learning_rate={args.learning_rate}", + f"num_generations={args.num_generations}", + f"grpo_beta={args.grpo_beta}", + f"grpo_epsilon={args.grpo_epsilon}", + f"max_prefill_predict_length={args.max_prefill_predict_length}", + f"max_target_length={args.max_target_length}", + f"profiler={args.profiler}", + f"checkpoint_period={args.checkpoint_period}", + f"hf_path={args.hf_path}", + f"hf_data_split={args.hf_data_split}", + f"hf_data_files={args.hf_data_files}", + ] + + # Add optional parameters + if args.run_name: + train_config_argv.append(f"run_name={args.run_name}") + + if args.hf_access_token: + train_config_argv.append(f"hf_access_token={args.hf_access_token}") + + if args.use_pathways: + train_config_argv.append("use_pathways_reshard=True") + train_config_argv.append(f"inference_devices_per_replica={args.inference_devices_per_replica}") + train_config_argv.append(f"inference_replicas={args.inference_replicas}") + + if args.ici_fsdp_parallelism > 0: + train_config_argv.append(f"ici_fsdp_parallelism={args.ici_fsdp_parallelism}") + + if args.ici_tensor_parallelism > 0: + train_config_argv.append(f"ici_tensor_parallelism={args.ici_tensor_parallelism}") + + # Build inference config argv + # For GRPO, inference config is similar but with adjusted batch size + inference_config_argv = train_config_argv.copy() + # Replace base config with grpo_inference.yml + inference_config_argv[1] = os.path.join(MAXTEXT_PKG_DIR, "experimental", "rl", "grpo_inference.yml") + + # Adjust batch size for inference (should include num_generations) + inference_batch_size = args.per_device_batch_size * args.num_generations + # Replace the per_device_batch_size entry + for i, arg in enumerate(inference_config_argv): + if arg.startswith("per_device_batch_size="): + inference_config_argv[i] = f"per_device_batch_size={inference_batch_size}" + break + + return [train_config_argv, inference_config_argv] + + +def main(): + """Main entry point for GRPO demo.""" + parser = create_parser() + args = parser.parse_args() + + # Validate required environment/arguments + if not args.hf_access_token: + print("Error: HF_TOKEN is required. Set it as an environment variable or pass --hf_access_token") + sys.exit(1) + + print("=" * 80) + print("GRPO Demo - Unified Training Script") + print("=" * 80) + print(f"Model: {args.model_name}") + print(f"Tokenizer: {args.tokenizer_path}") + print(f"Checkpoint: {args.load_parameters_path}") + print(f"Dataset: {args.hf_path}") + print(f"Output: {args.base_output_directory}") + print(f"Steps: {args.steps}") + print(f"GRPO Beta: {args.grpo_beta}") + print(f"Num Generations: {args.num_generations}") + print(f"Use Pathways: {args.use_pathways}") + print("=" * 80) + + # Build config arguments + config_argv = build_config_argv(args) + + # Convert to the format expected by grpo_trainer.main + sys.argv = ["grpo_demo.py"] + config_argv[0][1:] + config_argv[1][1:] + + # Run GRPO training + grpo_trainer.main(sys.argv) + + print("=" * 80) + print("GRPO Training Completed Successfully!") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py b/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py index def23ec94..107a05274 100644 --- a/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py +++ b/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py @@ -14,6 +14,11 @@ # pylint: disable=bare-except, consider-using-generator """ +DEPRECATED: This file is deprecated and kept for reference only. +Please use the new unified CLI interface: grpo_demo.py + +See GRPO_README.md for migration guide and usage examples. + This tutorial demonstrates training the Llama3.1 70B-IT model on the GSM8K math reasoning benchmark using Group Relative Policy Optimization (GRPO). GRPO can enhance your model's problem-solving skills on mathematical word problems, diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo.py b/src/MaxText/examples/grpo_llama3_1_8b_demo.py index e53322868..d3fc01579 100644 --- a/src/MaxText/examples/grpo_llama3_1_8b_demo.py +++ b/src/MaxText/examples/grpo_llama3_1_8b_demo.py @@ -14,6 +14,11 @@ # pylint: disable=bare-except, consider-using-generator """ +DEPRECATED: This file is deprecated and kept for reference only. +Please use the new unified CLI interface: grpo_demo.py + +See GRPO_README.md for migration guide and usage examples. + This tutorial demonstrates training the Llama3.1 8B-IT model on the GSM8K math reasoning benchmark using Group Relative Policy Optimization (GRPO). GRPO can enhance your model's problem-solving skills on mathematical word problems, diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py b/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py index f994e7356..f7acf722a 100644 --- a/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py +++ b/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py @@ -14,6 +14,11 @@ # pylint: disable=bare-except, consider-using-generator """ +DEPRECATED: This file is deprecated and kept for reference only. +Please use the new unified CLI interface: grpo_demo.py + +See GRPO_README.md for migration guide and usage examples. + This tutorial demonstrates training the Llama3.1 8B-IT model on the GSM8K math reasoning benchmark using Group Relative Policy Optimization (GRPO). GRPO can enhance your model's problem-solving skills on mathematical word problems, From 32a441c7c9578ebd64ddd619c80e11161781e4ce Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Tue, 21 Oct 2025 22:57:13 +0400 Subject: [PATCH 02/31] Fix Signed-off-by: Vladimir Suvorov --- .github/workflows/RunTests.yml | 55 ++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/.github/workflows/RunTests.yml b/.github/workflows/RunTests.yml index 069027889..2896360f5 100644 --- a/.github/workflows/RunTests.yml +++ b/.github/workflows/RunTests.yml @@ -108,6 +108,61 @@ jobs: container_resource_option: "--privileged" is_scheduled_run: ${{ github.event_name == 'schedule' }} + tpu_e2e_grpo_test: + needs: tpu_image + runs-on: linux-x86-ct4p-240-4tpu + container: + image: gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:tpu + env: + XLA_PYTHON_CLIENT_MEM_FRACTION: 0.75 + TF_FORCE_GPU_ALLOW_GROWTH: false + HF_TOKEN: ${{ secrets.HF_TOKEN }} + MODEL_CHECKPOINT_PATH: gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items + STEPS: 10 + options: "--privileged" + steps: + - uses: actions/checkout@v4 + - name: Install Tunix vLLM Requirements + run: | + bash src/MaxText/examples/install_tunix_vllm_requirement.sh + - name: Run GRPO Llama3.1 8B Demo (Unified Script) + run: | + python3 -m pip install -e . --no-dependencies && + python3 src/MaxText/examples/grpo_demo.py \ + --model_name=llama3.1-8b \ + --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ + --load_parameters_path=${MODEL_CHECKPOINT_PATH} \ + --base_output_directory=/tmp/grpo_output \ + --hf_access_token=${HF_TOKEN} \ + --steps=${STEPS} + + tpu_e2e_sft_test: + needs: tpu_image + runs-on: linux-x86-ct4p-240-4tpu + container: + image: gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:tpu + env: + XLA_PYTHON_CLIENT_MEM_FRACTION: 0.75 + TF_FORCE_GPU_ALLOW_GROWTH: false + HF_TOKEN: ${{ secrets.HF_TOKEN }} + MODEL_CHECKPOINT_PATH: gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items + STEPS: 10 + options: "--privileged" + steps: + - uses: actions/checkout@v4 + - name: Install Dependencies + run: | + python3 -m pip install -e . --no-dependencies + - name: Install Tunix vLLM Requirements + run: | + bash src/MaxText/examples/install_tunix_vllm_requirement.sh + + - name: Run SFT Llama3.1 8B Demo + run: | + python3 src/MaxText/examples/sft_llama3_demo.py \ + --skip_checkpoint_download \ + --model_checkpoint_path=${MODEL_CHECKPOINT_PATH} + gpu_unit_tests: needs: gpu_image uses: ./.github/workflows/run_tests_internal.yml From cf4e918a28e57a23baec214360d683118956ae29 Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Tue, 21 Oct 2025 23:14:06 +0400 Subject: [PATCH 03/31] Fix Signed-off-by: Vladimir Suvorov --- .github/workflows/RunTests.yml | 55 ---------------------------------- 1 file changed, 55 deletions(-) diff --git a/.github/workflows/RunTests.yml b/.github/workflows/RunTests.yml index 2896360f5..069027889 100644 --- a/.github/workflows/RunTests.yml +++ b/.github/workflows/RunTests.yml @@ -108,61 +108,6 @@ jobs: container_resource_option: "--privileged" is_scheduled_run: ${{ github.event_name == 'schedule' }} - tpu_e2e_grpo_test: - needs: tpu_image - runs-on: linux-x86-ct4p-240-4tpu - container: - image: gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:tpu - env: - XLA_PYTHON_CLIENT_MEM_FRACTION: 0.75 - TF_FORCE_GPU_ALLOW_GROWTH: false - HF_TOKEN: ${{ secrets.HF_TOKEN }} - MODEL_CHECKPOINT_PATH: gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items - STEPS: 10 - options: "--privileged" - steps: - - uses: actions/checkout@v4 - - name: Install Tunix vLLM Requirements - run: | - bash src/MaxText/examples/install_tunix_vllm_requirement.sh - - name: Run GRPO Llama3.1 8B Demo (Unified Script) - run: | - python3 -m pip install -e . --no-dependencies && - python3 src/MaxText/examples/grpo_demo.py \ - --model_name=llama3.1-8b \ - --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ - --load_parameters_path=${MODEL_CHECKPOINT_PATH} \ - --base_output_directory=/tmp/grpo_output \ - --hf_access_token=${HF_TOKEN} \ - --steps=${STEPS} - - tpu_e2e_sft_test: - needs: tpu_image - runs-on: linux-x86-ct4p-240-4tpu - container: - image: gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:tpu - env: - XLA_PYTHON_CLIENT_MEM_FRACTION: 0.75 - TF_FORCE_GPU_ALLOW_GROWTH: false - HF_TOKEN: ${{ secrets.HF_TOKEN }} - MODEL_CHECKPOINT_PATH: gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items - STEPS: 10 - options: "--privileged" - steps: - - uses: actions/checkout@v4 - - name: Install Dependencies - run: | - python3 -m pip install -e . --no-dependencies - - name: Install Tunix vLLM Requirements - run: | - bash src/MaxText/examples/install_tunix_vllm_requirement.sh - - - name: Run SFT Llama3.1 8B Demo - run: | - python3 src/MaxText/examples/sft_llama3_demo.py \ - --skip_checkpoint_download \ - --model_checkpoint_path=${MODEL_CHECKPOINT_PATH} - gpu_unit_tests: needs: gpu_image uses: ./.github/workflows/run_tests_internal.yml From 115bf846b5d2a9785be9f44357781fa5162b807d Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Fri, 24 Oct 2025 21:15:23 +0400 Subject: [PATCH 04/31] grpo refactor Signed-off-by: Vladimir Suvorov --- src/MaxText/examples/GRPO_README.md | 268 +++++------ src/MaxText/examples/grpo_demo.py | 51 +- .../experimental/rl/grpo_demo_trainer.py | 445 ++++++++++++++++++ 3 files changed, 583 insertions(+), 181 deletions(-) create mode 100644 src/MaxText/experimental/rl/grpo_demo_trainer.py diff --git a/src/MaxText/examples/GRPO_README.md b/src/MaxText/examples/GRPO_README.md index 380c35459..c1751b234 100644 --- a/src/MaxText/examples/GRPO_README.md +++ b/src/MaxText/examples/GRPO_README.md @@ -1,36 +1,19 @@ -# GRPO Demo - Unified Training Interface +# GRPO Demo - Unified Interface -This directory contains a unified interface for running GRPO (Group Relative Policy Optimization) training demos across different model sizes and configurations. +This directory contains the unified GRPO (Group Relative Policy Optimization) demo interface that consolidates the common logic from individual demo scripts. The interface is **model-agnostic** and supports any model (Llama, Qwen, etc.). -## Overview +## Structure -Previously, there were separate demo scripts for different model configurations: -- `grpo_llama3_1_8b_demo.py` - Single host 8B model -- `grpo_llama3_1_8b_demo_pw.py` - Pathways-based 8B model -- `grpo_llama3_1_70b_demo_pw.py` - Pathways-based 70B model +- **`grpo_demo.py`** - Simple CLI interface for running GRPO training +- **`grpo_demo_trainer.py`** - Core GRPO training logic (in `experimental/rl/`) +- **`grpo.yml`** - Unified model-agnostic configuration file (in `configs/`) -These have been consolidated into a single **unified CLI script** (`grpo_demo.py`) that works with the new **grpo.yml** configuration file. +## Usage -## New Structure - -### Configuration File -`src/MaxText/configs/grpo.yml` -- Contains common GRPO parameters -- Can be overridden via CLI arguments -- Consolidates dataset, training, and GRPO-specific settings - -### Unified CLI Script -`src/MaxText/examples/grpo_demo.py` -- Single entry point for all GRPO demos -- Supports both single-host and multi-host (Pathways) setups -- Provides intuitive CLI arguments -- Automatically generates proper config for training and inference - -## Usage Examples - -### Llama3.1-8B (Single Host) +### Llama Models ```bash +# Llama3.1-8B (single host) python3 src/MaxText/examples/grpo_demo.py \ --model_name=llama3.1-8b \ --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ @@ -38,189 +21,178 @@ python3 src/MaxText/examples/grpo_demo.py \ --base_output_directory=/tmp/grpo_output \ --hf_access_token=$HF_TOKEN \ --steps=100 -``` - -### Llama3.1-70B with Pathways (Multi-Host) -```bash +# Llama3.1-70B with Pathways python3 src/MaxText/examples/grpo_demo.py \ --model_name=llama3.1-70b \ --tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \ --load_parameters_path=gs://path/to/checkpoint \ --base_output_directory=gs://path/to/output \ --hf_access_token=$HF_TOKEN \ - --use_pathways \ - --inference_devices_per_replica=4 \ - --inference_replicas=4 \ - --ici_fsdp_parallelism=16 \ + --use_pathways=true \ --steps=100 ``` -### Custom Dataset +### Qwen Models ```bash +# Qwen2.5-7B (single host) python3 src/MaxText/examples/grpo_demo.py \ - --model_name=llama3.1-8b \ - --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ + --model_name=qwen2.5-7b \ + --tokenizer_path=Qwen/Qwen2.5-7B-Instruct \ --load_parameters_path=gs://path/to/checkpoint \ --base_output_directory=/tmp/grpo_output \ --hf_access_token=$HF_TOKEN \ - --hf_path=custom/dataset \ - --hf_data_split=train \ + --steps=100 + +# Qwen2.5-72B with Pathways +python3 src/MaxText/examples/grpo_demo.py \ + --model_name=qwen2.5-72b \ + --tokenizer_path=Qwen/Qwen2.5-72B-Instruct \ + --load_parameters_path=gs://path/to/checkpoint \ + --base_output_directory=gs://path/to/output \ + --hf_access_token=$HF_TOKEN \ + --use_pathways=true \ --steps=100 ``` -### With Custom GRPO Parameters +### Custom Dataset ```bash +# Any model with custom HuggingFace dataset python3 src/MaxText/examples/grpo_demo.py \ - --model_name=llama3.1-8b \ - --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ + --model_name=your-model \ + --tokenizer_path=your/tokenizer \ --load_parameters_path=gs://path/to/checkpoint \ --base_output_directory=/tmp/grpo_output \ --hf_access_token=$HF_TOKEN \ - --num_generations=4 \ - --grpo_beta=0.04 \ - --grpo_epsilon=0.15 \ - --learning_rate=5e-6 \ - --steps=200 + --hf_path=custom/dataset \ + --steps=100 ``` -## CLI Arguments +## Key Features -### Required Arguments +### GRPO-Specific Components -- `--model_name`: Model name (e.g., llama3.1-8b, llama3.1-70b) -- `--tokenizer_path`: HuggingFace tokenizer path -- `--load_parameters_path`: Path to model checkpoint (local or gs://) -- `--base_output_directory`: Base output directory for logs and checkpoints +The unified interface includes all essential GRPO components: -### Dataset Arguments +1. **Reward Functions**: + - `match_format_exactly` - Rewards exact format matching + - `match_format_approximately` - Rewards partial format matching + - `check_answer` - Rewards correct answers + - `check_numbers` - Rewards correct numerical answers -- `--hf_access_token`: HuggingFace access token (can use $HF_TOKEN env var) -- `--hf_path`: HuggingFace dataset path (default: gsm8k) -- `--hf_data_split`: Dataset split (default: main) -- `--hf_data_files`: Dataset files (default: train) +2. **Model Loading**: + - Reference model (for KL divergence) + - Policy model (for training) + - Proper device allocation for multi-host setups -### Training Arguments +3. **Dataset Processing**: + - GSM8K math reasoning dataset + - Special token formatting for reasoning tasks + - Batch processing for training and evaluation -- `--steps`: Number of training steps (default: 100) -- `--per_device_batch_size`: Per device batch size (default: 1) -- `--learning_rate`: Learning rate (default: 3e-6) -- `--run_name`: Custom run name for the experiment +4. **Training Configuration**: + - GRPO-specific hyperparameters (beta, epsilon, num_generations) + - Optimizer setup with warmup and cosine decay + - Checkpointing and metrics logging -### GRPO-Specific Arguments +### Device Allocation -- `--num_generations`: Number of generations per prompt (default: 2) -- `--grpo_beta`: KL divergence penalty coefficient (default: 0.08) -- `--grpo_epsilon`: Clipping value for stable updates (default: 0.2) +The system automatically handles device allocation: -### Sequence Length Arguments +- **Single Host**: Uses all available devices +- **Multi-Host**: Splits devices between training and inference +- **Pathways**: Full multi-host support with proper mesh setup -- `--max_prefill_predict_length`: Maximum prompt length (default: 256) -- `--max_target_length`: Maximum total sequence length (default: 768) +### Configuration -### Multi-Host/Pathways Arguments +The `grpo.yml` config file provides sensible defaults for: -- `--use_pathways`: Enable Pathways for multi-host training -- `--inference_devices_per_replica`: Devices per inference replica (default: 4) -- `--inference_replicas`: Number of inference replicas (default: 1) +- GRPO hyperparameters +- Training loop configuration +- Dataset processing +- Checkpointing settings +- Performance optimizations -### Parallelism Arguments +## Migration from Individual Demos -- `--ici_fsdp_parallelism`: FSDP parallelism (-1 for auto) -- `--ici_tensor_parallelism`: Tensor parallelism (-1 for auto) +The old individual demo files (`grpo_llama3_1_8b_demo.py`, etc.) are now deprecated. To migrate: -### Other Arguments +1. **Replace model-specific scripts** with the unified `grpo_demo.py` +2. **Use CLI arguments** instead of hardcoded parameters +3. **Leverage `grpo.yml`** for common configuration +4. **Customize via CLI** for model-specific needs -- `--profiler`: Profiler to use (default: xplane) -- `--checkpoint_period`: Checkpoint saving period (default: 50) -- `--config_file`: Optional custom config file (overrides grpo.yml) +## Examples -## Migration Guide +### Llama3.1-8B Training -### From Individual Demo Scripts - -**Old way:** -```python -# Editing grpo_llama3_1_8b_demo.py directly -MODEL_NAME = "llama3.1-8b" -TOKENIZER_PATH = "meta-llama/Llama-3.1-8B-Instruct" -# ... many hardcoded parameters -``` - -**New way:** ```bash python3 src/MaxText/examples/grpo_demo.py \ --model_name=llama3.1-8b \ --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ - # ... all parameters via CLI + --load_parameters_path=gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items \ + --base_output_directory=/tmp/grpo_output \ + --hf_access_token=$HF_TOKEN \ + --steps=10 ``` -### Benefits - -1. **Single Script**: One script for all model sizes and configurations -2. **No Code Editing**: All parameters configurable via CLI -3. **Better Defaults**: Common parameters in `grpo.yml` -4. **Easier Testing**: Quickly test different configurations -5. **CI/CD Friendly**: Easy to integrate into automated workflows - -## Configuration Files - -### grpo.yml -Main configuration file with sensible defaults for GRPO demos. Override any parameter via CLI. - -Location: `src/MaxText/configs/grpo.yml` - -### grpo.yml and grpo_inference.yml -Low-level configuration files used by the GRPO trainer. Generally, you don't need to modify these directly. - -Location: `src/MaxText/experimental/rl/` - -## Advanced Usage - -### Using a Custom Config File - -If you have a custom configuration: +### Qwen2.5-7B Training ```bash python3 src/MaxText/examples/grpo_demo.py \ - --config_file=/path/to/custom_config.yml \ - --model_name=llama3.1-8b \ - # ... other args + --model_name=qwen2.5-7b \ + --tokenizer_path=Qwen/Qwen2.5-7B-Instruct \ + --load_parameters_path=gs://path/to/qwen/checkpoint \ + --base_output_directory=/tmp/grpo_output \ + --hf_access_token=$HF_TOKEN \ + --steps=100 ``` -### Environment Variables - -You can set these environment variables: -- `HF_TOKEN`: HuggingFace access token (alternative to `--hf_access_token`) - -## Troubleshooting - -### Common Issues +### Large Models with Pathways -1. **HF_TOKEN not set**: Make sure to either set the environment variable or pass `--hf_access_token` - -2. **Pathways configuration**: For multi-host setups, ensure: - - `--use_pathways` is set - - `--inference_devices_per_replica` and `--inference_replicas` are configured correctly - - The total number of devices is sufficient +```bash +# Llama3.1-70B with Pathways +python3 src/MaxText/examples/grpo_demo.py \ + --model_name=llama3.1-70b \ + --tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \ + --load_parameters_path=gs://path/to/70b/checkpoint \ + --base_output_directory=gs://path/to/output \ + --hf_access_token=$HF_TOKEN \ + --use_pathways=true \ + --inference_devices_per_replica=8 \ + --inference_replicas=2 \ + --steps=100 -3. **Memory issues**: Try reducing: - - `--per_device_batch_size` - - `--max_target_length` - - `--num_generations` +# Qwen2.5-72B with Pathways +python3 src/MaxText/examples/grpo_demo.py \ + --model_name=qwen2.5-72b \ + --tokenizer_path=Qwen/Qwen2.5-72B-Instruct \ + --load_parameters_path=gs://path/to/qwen72b/checkpoint \ + --base_output_directory=gs://path/to/output \ + --hf_access_token=$HF_TOKEN \ + --use_pathways=true \ + --inference_devices_per_replica=8 \ + --inference_replicas=2 \ + --steps=100 +``` ## Contributing -When adding new features or model support: -1. Add sensible defaults to `grpo.yml` -2. Add CLI arguments to `grpo_demo.py` if needed -3. Update this README with examples +When adding new features: + +1. **Add CLI arguments** to `grpo_demo.py` +2. **Update `grpo.yml`** with new configuration options +3. **Extend `grpo_demo_trainer.py`** with new logic +4. **Update this README** with usage examples -## See Also +## Dependencies -- [GRPO Paper](https://arxiv.org/abs/2402.03300) -- [MaxText Documentation](../../../docs/) -- [Tunix Library](https://github.com/google/tunix) +The GRPO demo requires: +- MaxText core dependencies +- Tunix library for RL +- vLLM for efficient inference +- HuggingFace datasets and tokenizers +- JAX/Flax for model training \ No newline at end of file diff --git a/src/MaxText/examples/grpo_demo.py b/src/MaxText/examples/grpo_demo.py index c8871cfc0..aba94e93e 100755 --- a/src/MaxText/examples/grpo_demo.py +++ b/src/MaxText/examples/grpo_demo.py @@ -17,8 +17,8 @@ Unified GRPO Demo Script This script provides a unified CLI interface for running GRPO training demos -across different model sizes and configurations. It consolidates the common -logic from individual demo scripts and uses the grpo.yml config. +across different model sizes and configurations. It uses the grpo_train function +from grpo_demo_trainer.py which consolidates all the GRPO-specific logic. Usage Examples: @@ -55,7 +55,6 @@ import argparse import os import sys -from typing import Optional # Add MaxText to path script_dir = os.path.dirname(os.path.abspath(__file__)) @@ -64,7 +63,7 @@ from MaxText import pyconfig from MaxText.globals import MAXTEXT_PKG_DIR -from MaxText.experimental.rl import grpo_trainer +from MaxText.experimental.rl.grpo_demo_trainer import grpo_train def create_parser(): @@ -254,7 +253,7 @@ def build_config_argv(args): base_config = os.path.join(MAXTEXT_PKG_DIR, "configs", "grpo.yml") # Build training config argv - train_config_argv = [ + config_argv = [ "", # Placeholder for argv[0] base_config, f"model_name={args.model_name}", @@ -278,37 +277,23 @@ def build_config_argv(args): # Add optional parameters if args.run_name: - train_config_argv.append(f"run_name={args.run_name}") - + config_argv.append(f"run_name={args.run_name}") + if args.hf_access_token: - train_config_argv.append(f"hf_access_token={args.hf_access_token}") + config_argv.append(f"hf_access_token={args.hf_access_token}") if args.use_pathways: - train_config_argv.append("use_pathways_reshard=True") - train_config_argv.append(f"inference_devices_per_replica={args.inference_devices_per_replica}") - train_config_argv.append(f"inference_replicas={args.inference_replicas}") + config_argv.append("use_pathways_reshard=True") + config_argv.append(f"inference_devices_per_replica={args.inference_devices_per_replica}") + config_argv.append(f"inference_replicas={args.inference_replicas}") if args.ici_fsdp_parallelism > 0: - train_config_argv.append(f"ici_fsdp_parallelism={args.ici_fsdp_parallelism}") + config_argv.append(f"ici_fsdp_parallelism={args.ici_fsdp_parallelism}") if args.ici_tensor_parallelism > 0: - train_config_argv.append(f"ici_tensor_parallelism={args.ici_tensor_parallelism}") - - # Build inference config argv - # For GRPO, inference config is similar but with adjusted batch size - inference_config_argv = train_config_argv.copy() - # Replace base config with grpo_inference.yml - inference_config_argv[1] = os.path.join(MAXTEXT_PKG_DIR, "experimental", "rl", "grpo_inference.yml") - - # Adjust batch size for inference (should include num_generations) - inference_batch_size = args.per_device_batch_size * args.num_generations - # Replace the per_device_batch_size entry - for i, arg in enumerate(inference_config_argv): - if arg.startswith("per_device_batch_size="): - inference_config_argv[i] = f"per_device_batch_size={inference_batch_size}" - break + config_argv.append(f"ici_tensor_parallelism={args.ici_tensor_parallelism}") - return [train_config_argv, inference_config_argv] + return config_argv def main(): @@ -338,11 +323,11 @@ def main(): # Build config arguments config_argv = build_config_argv(args) - # Convert to the format expected by grpo_trainer.main - sys.argv = ["grpo_demo.py"] + config_argv[0][1:] + config_argv[1][1:] + # Initialize configuration + config = pyconfig.initialize(config_argv) - # Run GRPO training - grpo_trainer.main(sys.argv) + # Run GRPO training using the unified trainer + grpo_train(config) print("=" * 80) print("GRPO Training Completed Successfully!") @@ -350,4 +335,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/src/MaxText/experimental/rl/grpo_demo_trainer.py b/src/MaxText/experimental/rl/grpo_demo_trainer.py new file mode 100644 index 000000000..ac0200a5d --- /dev/null +++ b/src/MaxText/experimental/rl/grpo_demo_trainer.py @@ -0,0 +1,445 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +GRPO Demo Trainer + +This module provides a unified `grpo_train` function that consolidates the common +GRPO training logic from the individual demo scripts. It handles model loading, +reward function setup, dataset processing, and training orchestration. + +Usage: + from MaxText.experimental.rl.grpo_demo_trainer import grpo_train + + # Train with GRPO + grpo_train(config, goodput_recorder=None) +""" + +import os +import re +from typing import Optional, List, Tuple, Dict, Any +from pprint import pprint + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh +from flax import nnx +from flax.linen import partitioning as nn_partitioning +import optax +from orbax import checkpoint as ocp +import tensorflow_datasets as tfds +from tqdm.auto import tqdm +from transformers import AutoTokenizer + +from tunix.rl import rl_cluster as rl_cluster_lib +from tunix.rl.rollout import base_rollout +from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner +from tunix.sft import metrics_logger +from tunix.models.llama3 import model as llama3_lib + +from MaxText import maxtext_utils +from MaxText import model_creation_utils +from MaxText import pyconfig +from MaxText.globals import MAXTEXT_ASSETS_ROOT +from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter +from MaxText.utils.goodput_utils import GoodputRecorder + + +# GRPO-specific constants +REWARD_EXACT_FORMAT_MATCH = 3.0 +REWARD_WHITE_SPACE_FORMAT_MATCH = 1.5 +REWARD_PARTIAL_FORMAT_MATCH = 0.5 +REWARD_RATIO_GUESS_TO_ANSWER_HIGH = 0.5 +REWARD_RATIO_GUESS_TO_ANSWER_LOW = 0.25 +PENALTY_INCORRECT_FORMAT = -0.5 +PENALTY_INCORRECT_ANSWER = -1.0 + +# Special tokens for GSM8K reasoning +REASONING_START = "" +REASONING_END = "" +SOLUTION_START = "" +SOLUTION_END = "" + +SYSTEM_PROMPT = f"""You are given a problem. Think about the problem and \ +provide your reasoning. Place it between {REASONING_START} and \ +{REASONING_END}. Then, provide the final answer (i.e., just one numerical \ +value) between {SOLUTION_START} and {SOLUTION_END}.""" + +TEMPLATE = """user +{system_prompt} + +{question} +model""" + + +def extract_hash_answer(text: str) -> str | None: + """Extract the numerical answer from GSM8K format.""" + if "####" not in text: + return None + return text.split("####")[1].strip() + + +def get_gsm8k_dataset(data_dir: str, split: str = "train", batch_size: int = 1, + num_batches: int = 4, seed: int = 42): + """Load and process GSM8K dataset for GRPO training.""" + import grain + + # Download data + if not os.path.exists(data_dir): + os.makedirs(data_dir) + + data = tfds.data_source( + "gsm8k", + split=split, + data_dir=data_dir, + builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD}, + download=True, + ) + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") + + loaded_dataset = ( + grain.MapDataset.source(data) + .shuffle(seed=seed) + .map( + lambda x: { + # passed to model forward pass + "prompts": tokenizer.apply_chat_template( + [ + { + "role": "user", + "content": TEMPLATE.format( + system_prompt=SYSTEM_PROMPT, + question=x["question"].decode("utf-8"), + ), + }, + ], + tokenize=False, + add_generation_prompt=True, + ), + # passed to reward functions + "question": x["question"].decode("utf-8"), + # passed to reward functions + "answer": extract_hash_answer(x["answer"].decode("utf-8")), + } + ) + ) + + return loaded_dataset.batch(batch_size)[:num_batches], tokenizer + + +def get_maxtext_model(config, devices=None): + """Load MaxText model with Tunix adapter.""" + model, mesh = model_creation_utils.create_nnx_model(config, devices) + with mesh: + tunix_model = TunixMaxTextAdapter(base_model=model) + model_config = llama3_lib.ModelConfig.llama3_1_8b() + tunix_model.config = model_config + return tunix_model, mesh + + +def setup_device_allocation(config, use_pathways: bool = False): + """Setup device allocation for training and inference.""" + devices = jax.devices() + + if use_pathways: + # Multi-host setup with Pathways + import pathwaysutils + pathwaysutils.initialize() + + # For Pathways, use all devices for both training and inference + trainer_devices = devices + sampler_devices = devices + num_vms = len(devices) // 4 # Assuming 4 chips per VM + else: + # Single host setup + num_vms = len(devices) // 4 # Assuming 4 chips per VM + if num_vms >= 2: + # Multi-VM single host setup + num_devices = len(devices) + num_trainer_devices = int(num_devices * 0.5) # 50% for training + num_sampler_devices = int(num_devices * 0.5) # 50% for sampling + trainer_devices = devices[:num_trainer_devices] + sampler_devices = devices[num_devices - num_sampler_devices:] + else: + # Single VM setup + trainer_devices = devices + sampler_devices = devices + + return trainer_devices, sampler_devices, num_vms + + +# Reward Functions +def match_format_exactly(prompts, completions, **kwargs): + """Reward exact format matching.""" + scores = [] + match_format = re.compile( + rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?" + rf"{SOLUTION_START}(.+?){SOLUTION_END}" rf"[\s]{{0,}}$", + flags=re.MULTILINE | re.DOTALL, + ) + + for completion in completions: + score = 0 + if match_format.search(completion) is not None: + score += REWARD_EXACT_FORMAT_MATCH + scores.append(score) + return scores + + +def match_format_approximately(prompts, completions, **kwargs): + """Reward approximate format matching.""" + scores = [] + for completion in completions: + score = 0 + score += REWARD_PARTIAL_FORMAT_MATCH if completion.count(REASONING_START) == 1 else PENALTY_INCORRECT_FORMAT + score += REWARD_PARTIAL_FORMAT_MATCH if completion.count(REASONING_END) == 1 else PENALTY_INCORRECT_FORMAT + score += REWARD_PARTIAL_FORMAT_MATCH if completion.count(SOLUTION_START) == 1 else PENALTY_INCORRECT_FORMAT + score += REWARD_PARTIAL_FORMAT_MATCH if completion.count(SOLUTION_END) == 1 else PENALTY_INCORRECT_FORMAT + scores.append(score) + return scores + + +def check_answer(prompts, completions, answer, **kwargs): + """Reward correct answers.""" + match_format = re.compile( + rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?" + rf"{SOLUTION_START}(.+?){SOLUTION_END}" rf"[\s]{{0,}}$", + flags=re.MULTILINE | re.DOTALL, + ) + + extracted_responses = [guess.group(1) if (guess := match_format.search(r)) is not None else None for r in completions] + + scores = [] + for guess, true_answer in zip(extracted_responses, answer): + score = 0 + if guess is None: + scores.append(0) + continue + + if guess == true_answer: + score += REWARD_EXACT_FORMAT_MATCH + elif guess.strip() == true_answer.strip(): + score += REWARD_WHITE_SPACE_FORMAT_MATCH + else: + try: + ratio = float(guess) / float(true_answer) + if ratio >= 0.9 and ratio <= 1.1: + score += REWARD_RATIO_GUESS_TO_ANSWER_HIGH + elif ratio >= 0.8 and ratio <= 1.2: + score += REWARD_RATIO_GUESS_TO_ANSWER_LOW + else: + score += PENALTY_INCORRECT_ANSWER + except: + score += PENALTY_INCORRECT_FORMAT + scores.append(score) + return scores + + +def check_numbers(prompts, completions, answer, **kwargs): + """Reward correct numerical answers.""" + match_numbers = re.compile(rf"{SOLUTION_START}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL) + extracted_responses = [guess.group(1) if (guess := match_numbers.search(r)) is not None else None for r in completions] + + scores = [] + for guess, true_answer in zip(extracted_responses, answer): + if guess is None: + scores.append(0) + continue + try: + true_answer = float(true_answer.strip()) + guess = float(guess.strip()) + scores.append(1.5 if guess == true_answer else 0.0) + except: + scores.append(0) + continue + return scores + + +def grpo_train(config, goodput_recorder: Optional[GoodputRecorder] = None): + """ + Run GRPO training with the provided configuration. + + Args: + config: MaxText configuration object + goodput_recorder: Optional goodput recorder for performance monitoring + """ + print("=" * 80) + print("Starting GRPO Training") + print("=" * 80) + + # Setup device allocation + use_pathways = getattr(config, 'use_pathways_reshard', False) + trainer_devices, sampler_devices, num_vms = setup_device_allocation(config, use_pathways) + + print(f"Device allocation: {len(trainer_devices)} trainer, {len(sampler_devices)} sampler") + print(f"Use Pathways: {use_pathways}") + + # Setup data directories + home = os.path.expanduser("~") + "/" + train_data_dir = f"{home}/data/train" + test_data_dir = f"{home}/data/test" + + # Load datasets + print("Loading GSM8K dataset...") + train_dataset, tokenizer = get_gsm8k_dataset( + train_data_dir, + split="train", + batch_size=config.per_device_batch_size, + num_batches=getattr(config, 'num_batches', 4) + ) + + test_dataset, _ = get_gsm8k_dataset( + test_data_dir, + split="test", + batch_size=config.per_device_batch_size, + num_batches=getattr(config, 'num_test_batches', 5) + ) + + # Load reference model + print("Loading reference model...") + reference_model, reference_mesh = get_maxtext_model(config, trainer_devices) + reference_model.config = None + + # Load policy model + print("Loading policy model...") + policy_model, policy_mesh = get_maxtext_model(config, trainer_devices) + policy_model.config = None + + # Setup meshes + if num_vms >= 2 and not use_pathways: + actor_mesh = policy_mesh + rollout_mesh = Mesh(maxtext_utils.create_device_mesh(config, sampler_devices), config.mesh_axes) + else: + actor_mesh = policy_mesh + rollout_mesh = policy_mesh + + # Setup optimizer + learning_rate = getattr(config, 'learning_rate', 3e-6) + max_steps = getattr(config, 'steps', 100) + warmup_steps = int(0.1 * max_steps) + + optimizer = optax.adamw( + learning_rate=optax.schedules.warmup_cosine_decay_schedule( + init_value=0.0, + peak_value=learning_rate, + warmup_steps=warmup_steps, + decay_steps=max_steps, + end_value=0.0, + ), + b1=0.9, + b2=0.99, + weight_decay=0.1, + ) + + # Add gradient clipping if specified + max_grad_norm = getattr(config, 'max_grad_norm', 0.1) + if max_grad_norm is not None: + optimizer = optax.chain( + optax.clip_by_global_norm(max_norm=max_grad_norm), + optimizer, + ) + + # Setup checkpointing + ckpt_dir = f"{config.base_output_directory}/checkpoints" + os.makedirs(ckpt_dir, exist_ok=True) + + checkpointing_options = ocp.CheckpointManagerOptions( + save_interval_steps=getattr(config, 'checkpoint_period', 50), + max_to_keep=4 + ) + + # Setup metrics logging + log_dir = f"{config.base_output_directory}/logs" + os.makedirs(log_dir, exist_ok=True) + + metrics_logging_options = metrics_logger.MetricsLoggerOptions( + log_dir=log_dir, + flush_every_n_steps=20 + ) + + # Setup RL cluster config + cluster_config = rl_cluster_lib.ClusterConfig( + role_to_mesh={ + rl_cluster_lib.Role.ACTOR: actor_mesh, + rl_cluster_lib.Role.REFERENCE: reference_mesh, + rl_cluster_lib.Role.ROLLOUT: rollout_mesh, + }, + rollout_engine="vllm", + offload_to_cpu=False, + training_config=rl_cluster_lib.RLTrainingConfig( + actor_optimizer=optimizer, + eval_every_n_steps=getattr(config, 'eval_interval', 10), + max_steps=max_steps, + metrics_logging_options=metrics_logging_options, + profiler_options=None, + checkpoint_root_directory=ckpt_dir, + checkpointing_options=checkpointing_options, + ), + rollout_config=base_rollout.RolloutConfig( + max_tokens_to_generate=getattr(config, 'max_target_length', 768), + max_prompt_length=getattr(config, 'max_prefill_predict_length', 256), + kv_cache_size=getattr(config, 'max_prefill_predict_length', 256) + + getattr(config, 'max_target_length', 768) + 256, + temperature=getattr(config, 'decode_sampling_temperature', 0.9), + top_p=getattr(config, 'decode_sampling_top_p', 1.0), + top_k=getattr(config, 'decode_sampling_top_k', 50), + ), + rollout_vllm_model_version="meta-llama/Meta-Llama-3.1-8B-Instruct", + rollout_vllm_hbm_utilization=0.2, + rollout_vllm_tpu_backend_type="jax", + ) + + # Setup GRPO config + grpo_config = GrpoConfig( + num_generations=getattr(config, 'num_generations', 2), + num_iterations=1, + beta=getattr(config, 'grpo_beta', 0.08), + epsilon=getattr(config, 'grpo_epsilon', 0.2), + ) + + # Create RL cluster + print("Creating RL cluster...") + with nn_partitioning.axis_rules(config.logical_axis_rules): + rl_cluster = rl_cluster_lib.RLCluster( + actor=policy_model, + reference=reference_model, + tokenizer=tokenizer, + cluster_config=cluster_config, + ) + + # Create GRPO trainer + print("Setting up GRPO trainer...") + grpo_trainer = GrpoLearner( + rl_cluster=rl_cluster, + reward_fns=[ + match_format_exactly, + match_format_approximately, + check_answer, + check_numbers, + ], + grpo_config=grpo_config, + ) + + # Start training + print("Starting GRPO training...") + with policy_mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + grpo_trainer.train(train_dataset) + + print("=" * 80) + print("GRPO Training Completed Successfully!") + print("=" * 80) + + return grpo_trainer, rl_cluster From e1708e6511715d5838d53c7920d9d9866ffd661f Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Fri, 24 Oct 2025 21:23:01 +0400 Subject: [PATCH 05/31] grpo refactor Signed-off-by: Vladimir Suvorov --- src/MaxText/examples/GRPO_README.md | 198 ------------------ src/MaxText/examples/README.md | 23 ++ src/MaxText/examples/grpo_demo.py | 12 +- ..._demo_trainer.py => grpo_tunix_trainer.py} | 8 +- 4 files changed, 33 insertions(+), 208 deletions(-) delete mode 100644 src/MaxText/examples/GRPO_README.md rename src/MaxText/experimental/rl/{grpo_demo_trainer.py => grpo_tunix_trainer.py} (98%) diff --git a/src/MaxText/examples/GRPO_README.md b/src/MaxText/examples/GRPO_README.md deleted file mode 100644 index c1751b234..000000000 --- a/src/MaxText/examples/GRPO_README.md +++ /dev/null @@ -1,198 +0,0 @@ -# GRPO Demo - Unified Interface - -This directory contains the unified GRPO (Group Relative Policy Optimization) demo interface that consolidates the common logic from individual demo scripts. The interface is **model-agnostic** and supports any model (Llama, Qwen, etc.). - -## Structure - -- **`grpo_demo.py`** - Simple CLI interface for running GRPO training -- **`grpo_demo_trainer.py`** - Core GRPO training logic (in `experimental/rl/`) -- **`grpo.yml`** - Unified model-agnostic configuration file (in `configs/`) - -## Usage - -### Llama Models - -```bash -# Llama3.1-8B (single host) -python3 src/MaxText/examples/grpo_demo.py \ - --model_name=llama3.1-8b \ - --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ - --load_parameters_path=gs://path/to/checkpoint \ - --base_output_directory=/tmp/grpo_output \ - --hf_access_token=$HF_TOKEN \ - --steps=100 - -# Llama3.1-70B with Pathways -python3 src/MaxText/examples/grpo_demo.py \ - --model_name=llama3.1-70b \ - --tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \ - --load_parameters_path=gs://path/to/checkpoint \ - --base_output_directory=gs://path/to/output \ - --hf_access_token=$HF_TOKEN \ - --use_pathways=true \ - --steps=100 -``` - -### Qwen Models - -```bash -# Qwen2.5-7B (single host) -python3 src/MaxText/examples/grpo_demo.py \ - --model_name=qwen2.5-7b \ - --tokenizer_path=Qwen/Qwen2.5-7B-Instruct \ - --load_parameters_path=gs://path/to/checkpoint \ - --base_output_directory=/tmp/grpo_output \ - --hf_access_token=$HF_TOKEN \ - --steps=100 - -# Qwen2.5-72B with Pathways -python3 src/MaxText/examples/grpo_demo.py \ - --model_name=qwen2.5-72b \ - --tokenizer_path=Qwen/Qwen2.5-72B-Instruct \ - --load_parameters_path=gs://path/to/checkpoint \ - --base_output_directory=gs://path/to/output \ - --hf_access_token=$HF_TOKEN \ - --use_pathways=true \ - --steps=100 -``` - -### Custom Dataset - -```bash -# Any model with custom HuggingFace dataset -python3 src/MaxText/examples/grpo_demo.py \ - --model_name=your-model \ - --tokenizer_path=your/tokenizer \ - --load_parameters_path=gs://path/to/checkpoint \ - --base_output_directory=/tmp/grpo_output \ - --hf_access_token=$HF_TOKEN \ - --hf_path=custom/dataset \ - --steps=100 -``` - -## Key Features - -### GRPO-Specific Components - -The unified interface includes all essential GRPO components: - -1. **Reward Functions**: - - `match_format_exactly` - Rewards exact format matching - - `match_format_approximately` - Rewards partial format matching - - `check_answer` - Rewards correct answers - - `check_numbers` - Rewards correct numerical answers - -2. **Model Loading**: - - Reference model (for KL divergence) - - Policy model (for training) - - Proper device allocation for multi-host setups - -3. **Dataset Processing**: - - GSM8K math reasoning dataset - - Special token formatting for reasoning tasks - - Batch processing for training and evaluation - -4. **Training Configuration**: - - GRPO-specific hyperparameters (beta, epsilon, num_generations) - - Optimizer setup with warmup and cosine decay - - Checkpointing and metrics logging - -### Device Allocation - -The system automatically handles device allocation: - -- **Single Host**: Uses all available devices -- **Multi-Host**: Splits devices between training and inference -- **Pathways**: Full multi-host support with proper mesh setup - -### Configuration - -The `grpo.yml` config file provides sensible defaults for: - -- GRPO hyperparameters -- Training loop configuration -- Dataset processing -- Checkpointing settings -- Performance optimizations - -## Migration from Individual Demos - -The old individual demo files (`grpo_llama3_1_8b_demo.py`, etc.) are now deprecated. To migrate: - -1. **Replace model-specific scripts** with the unified `grpo_demo.py` -2. **Use CLI arguments** instead of hardcoded parameters -3. **Leverage `grpo.yml`** for common configuration -4. **Customize via CLI** for model-specific needs - -## Examples - -### Llama3.1-8B Training - -```bash -python3 src/MaxText/examples/grpo_demo.py \ - --model_name=llama3.1-8b \ - --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ - --load_parameters_path=gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items \ - --base_output_directory=/tmp/grpo_output \ - --hf_access_token=$HF_TOKEN \ - --steps=10 -``` - -### Qwen2.5-7B Training - -```bash -python3 src/MaxText/examples/grpo_demo.py \ - --model_name=qwen2.5-7b \ - --tokenizer_path=Qwen/Qwen2.5-7B-Instruct \ - --load_parameters_path=gs://path/to/qwen/checkpoint \ - --base_output_directory=/tmp/grpo_output \ - --hf_access_token=$HF_TOKEN \ - --steps=100 -``` - -### Large Models with Pathways - -```bash -# Llama3.1-70B with Pathways -python3 src/MaxText/examples/grpo_demo.py \ - --model_name=llama3.1-70b \ - --tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \ - --load_parameters_path=gs://path/to/70b/checkpoint \ - --base_output_directory=gs://path/to/output \ - --hf_access_token=$HF_TOKEN \ - --use_pathways=true \ - --inference_devices_per_replica=8 \ - --inference_replicas=2 \ - --steps=100 - -# Qwen2.5-72B with Pathways -python3 src/MaxText/examples/grpo_demo.py \ - --model_name=qwen2.5-72b \ - --tokenizer_path=Qwen/Qwen2.5-72B-Instruct \ - --load_parameters_path=gs://path/to/qwen72b/checkpoint \ - --base_output_directory=gs://path/to/output \ - --hf_access_token=$HF_TOKEN \ - --use_pathways=true \ - --inference_devices_per_replica=8 \ - --inference_replicas=2 \ - --steps=100 -``` - -## Contributing - -When adding new features: - -1. **Add CLI arguments** to `grpo_demo.py` -2. **Update `grpo.yml`** with new configuration options -3. **Extend `grpo_demo_trainer.py`** with new logic -4. **Update this README** with usage examples - -## Dependencies - -The GRPO demo requires: - -- MaxText core dependencies -- Tunix library for RL -- vLLM for efficient inference -- HuggingFace datasets and tokenizers -- JAX/Flax for model training \ No newline at end of file diff --git a/src/MaxText/examples/README.md b/src/MaxText/examples/README.md index a5b46934e..bfdc6dca6 100644 --- a/src/MaxText/examples/README.md +++ b/src/MaxText/examples/README.md @@ -128,6 +128,29 @@ Use the link for Jupyter Lab as a link for "Connect to a local runtime" in Colla ### GRPO Training - **`grpo_llama3_demo.ipynb`** → GRPO training on math dataset +- **`grpo_demo.py`** → Unified CLI for GRPO training (any model) + +#### GRPO Usage + +```bash +# Llama3.1-8B +python3 src/MaxText/examples/grpo_demo.py \ + --model_name=llama3.1-8b \ + --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ + --load_parameters_path=gs://path/to/checkpoint \ + --base_output_directory=/tmp/grpo_output \ + --hf_access_token=$HF_TOKEN \ + --steps=100 + +# Qwen2.5-7B +python3 src/MaxText/examples/grpo_demo.py \ + --model_name=qwen2.5-7b \ + --tokenizer_path=Qwen/Qwen2.5-7B-Instruct \ + --load_parameters_path=gs://path/to/checkpoint \ + --base_output_directory=/tmp/grpo_output \ + --hf_access_token=$HF_TOKEN \ + --steps=100 +``` ## Common Pitfalls & Debugging diff --git a/src/MaxText/examples/grpo_demo.py b/src/MaxText/examples/grpo_demo.py index aba94e93e..bb5f92be1 100755 --- a/src/MaxText/examples/grpo_demo.py +++ b/src/MaxText/examples/grpo_demo.py @@ -14,11 +14,11 @@ # limitations under the License. """ -Unified GRPO Demo Script +Unified GRPO Script -This script provides a unified CLI interface for running GRPO training demos +This script provides a unified CLI interface for running GRPO training across different model sizes and configurations. It uses the grpo_train function -from grpo_demo_trainer.py which consolidates all the GRPO-specific logic. +from grpo_tunix_trainer.py which consolidates all the GRPO-specific logic. Usage Examples: @@ -63,13 +63,13 @@ from MaxText import pyconfig from MaxText.globals import MAXTEXT_PKG_DIR -from MaxText.experimental.rl.grpo_demo_trainer import grpo_train +from MaxText.experimental.rl.grpo_tunix_trainer import grpo_train def create_parser(): """Create argument parser for GRPO demo.""" parser = argparse.ArgumentParser( - description="Unified GRPO Demo Script", + description="Unified GRPO Script", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=__doc__, ) @@ -307,7 +307,7 @@ def main(): sys.exit(1) print("=" * 80) - print("GRPO Demo - Unified Training Script") + print("GRPO - Unified Training Script") print("=" * 80) print(f"Model: {args.model_name}") print(f"Tokenizer: {args.tokenizer_path}") diff --git a/src/MaxText/experimental/rl/grpo_demo_trainer.py b/src/MaxText/experimental/rl/grpo_tunix_trainer.py similarity index 98% rename from src/MaxText/experimental/rl/grpo_demo_trainer.py rename to src/MaxText/experimental/rl/grpo_tunix_trainer.py index ac0200a5d..adf9aad44 100644 --- a/src/MaxText/experimental/rl/grpo_demo_trainer.py +++ b/src/MaxText/experimental/rl/grpo_tunix_trainer.py @@ -13,14 +13,14 @@ # limitations under the License. """ -GRPO Demo Trainer +GRPO Tunix Trainer This module provides a unified `grpo_train` function that consolidates the common -GRPO training logic from the individual demo scripts. It handles model loading, -reward function setup, dataset processing, and training orchestration. +GRPO training logic. It handles model loading, reward function setup, dataset +processing, and training orchestration. Usage: - from MaxText.experimental.rl.grpo_demo_trainer import grpo_train + from MaxText.experimental.rl.grpo_tunix_trainer import grpo_train # Train with GRPO grpo_train(config, goodput_recorder=None) From d5bcf7a01f487b529e3c52767d2f1bfd45ab8cce Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Sat, 25 Oct 2025 06:53:33 +0400 Subject: [PATCH 06/31] Fix Signed-off-by: Vladimir Suvorov --- src/MaxText/examples/README.md | 12 +- src/MaxText/examples/grpo_demo.py | 25 +++ .../experimental/rl/grpo_tunix_trainer.py | 149 +++++++++--------- 3 files changed, 107 insertions(+), 79 deletions(-) diff --git a/src/MaxText/examples/README.md b/src/MaxText/examples/README.md index bfdc6dca6..470975de8 100644 --- a/src/MaxText/examples/README.md +++ b/src/MaxText/examples/README.md @@ -127,10 +127,18 @@ Use the link for Jupyter Lab as a link for "Connect to a local runtime" in Colla ### GRPO Training -- **`grpo_llama3_demo.ipynb`** → GRPO training on math dataset +- **`grpo_llama3_1_8b_demo.ipynb`** → GRPO training on math dataset (Colab/notebook) - **`grpo_demo.py`** → Unified CLI for GRPO training (any model) -#### GRPO Usage +#### GRPO Colab Usage + +For interactive GRPO training in Google Colab or Jupyter: + +1. **Open** `src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb` +2. **Enable TPU runtime** (Runtime → Change runtime type → TPU) +3. **Run cells** to train Llama3.1-8B with GRPO on GSM8K dataset + +#### GRPO Python Script Usage ```bash # Llama3.1-8B diff --git a/src/MaxText/examples/grpo_demo.py b/src/MaxText/examples/grpo_demo.py index bb5f92be1..57071c3b8 100755 --- a/src/MaxText/examples/grpo_demo.py +++ b/src/MaxText/examples/grpo_demo.py @@ -221,6 +221,26 @@ def create_parser(): help="Tensor parallelism (-1 for auto)", ) + # Device Allocation Configuration + parser.add_argument( + "--trainer_devices_fraction", + type=float, + default=0.5, + help="Fraction of devices for training (0.0-1.0)", + ) + parser.add_argument( + "--sampler_devices_fraction", + type=float, + default=0.5, + help="Fraction of devices for sampling (0.0-1.0)", + ) + parser.add_argument( + "--chips_per_vm", + type=int, + default=4, + help="Number of chips per VM (hardware dependent)", + ) + # Other Configuration parser.add_argument( "--profiler", @@ -293,6 +313,11 @@ def build_config_argv(args): if args.ici_tensor_parallelism > 0: config_argv.append(f"ici_tensor_parallelism={args.ici_tensor_parallelism}") + # Add device allocation parameters + config_argv.append(f"trainer_devices_fraction={args.trainer_devices_fraction}") + config_argv.append(f"sampler_devices_fraction={args.sampler_devices_fraction}") + config_argv.append(f"chips_per_vm={args.chips_per_vm}") + return config_argv diff --git a/src/MaxText/experimental/rl/grpo_tunix_trainer.py b/src/MaxText/experimental/rl/grpo_tunix_trainer.py index adf9aad44..cbcf6c6c0 100644 --- a/src/MaxText/experimental/rl/grpo_tunix_trainer.py +++ b/src/MaxText/experimental/rl/grpo_tunix_trainer.py @@ -16,30 +16,26 @@ GRPO Tunix Trainer This module provides a unified `grpo_train` function that consolidates the common -GRPO training logic. It handles model loading, reward function setup, dataset +GRPO training logic. It handles model loading, reward function setup, dataset processing, and training orchestration. Usage: from MaxText.experimental.rl.grpo_tunix_trainer import grpo_train - + # Train with GRPO grpo_train(config, goodput_recorder=None) """ import os import re -from typing import Optional, List, Tuple, Dict, Any -from pprint import pprint +from typing import Optional import jax -import jax.numpy as jnp from jax.sharding import Mesh -from flax import nnx from flax.linen import partitioning as nn_partitioning import optax from orbax import checkpoint as ocp import tensorflow_datasets as tfds -from tqdm.auto import tqdm from transformers import AutoTokenizer from tunix.rl import rl_cluster as rl_cluster_lib @@ -50,10 +46,7 @@ from MaxText import maxtext_utils from MaxText import model_creation_utils -from MaxText import pyconfig -from MaxText.globals import MAXTEXT_ASSETS_ROOT from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter -from MaxText.utils.goodput_utils import GoodputRecorder # GRPO-specific constants @@ -90,11 +83,11 @@ def extract_hash_answer(text: str) -> str | None: return text.split("####")[1].strip() -def get_gsm8k_dataset(data_dir: str, split: str = "train", batch_size: int = 1, +def get_gsm8k_dataset(data_dir: str, split: str = "train", batch_size: int = 1, num_batches: int = 4, seed: int = 42): """Load and process GSM8K dataset for GRPO training.""" import grain - + # Download data if not os.path.exists(data_dir): os.makedirs(data_dir) @@ -136,7 +129,7 @@ def get_gsm8k_dataset(data_dir: str, split: str = "train", batch_size: int = 1, } ) ) - + return loaded_dataset.batch(batch_size)[:num_batches], tokenizer @@ -153,31 +146,33 @@ def get_maxtext_model(config, devices=None): def setup_device_allocation(config, use_pathways: bool = False): """Setup device allocation for training and inference.""" devices = jax.devices() - - if use_pathways: - # Multi-host setup with Pathways + + # Get device allocation parameters from config + trainer_devices_fraction = getattr(config, 'trainer_devices_fraction', 0.5) + sampler_devices_fraction = getattr(config, 'sampler_devices_fraction', 0.5) + chips_per_vm = getattr(config, 'chips_per_vm', 4) + + num_vms = len(devices) // chips_per_vm + + if use_pathways and num_vms >= 2: + # Multiple hosts with Pathways - split devices for trainer and sampler import pathwaysutils pathwaysutils.initialize() - - # For Pathways, use all devices for both training and inference + + num_devices = len(devices) + num_trainer_devices = int(num_devices * trainer_devices_fraction) + num_sampler_devices = int(num_devices * sampler_devices_fraction) + trainer_devices = devices[:num_trainer_devices] + sampler_devices = devices[num_devices - num_sampler_devices:] + else: + # Not using Pathways OR single host - use all devices for both + if use_pathways: + import pathwaysutils + pathwaysutils.initialize() + trainer_devices = devices sampler_devices = devices - num_vms = len(devices) // 4 # Assuming 4 chips per VM - else: - # Single host setup - num_vms = len(devices) // 4 # Assuming 4 chips per VM - if num_vms >= 2: - # Multi-VM single host setup - num_devices = len(devices) - num_trainer_devices = int(num_devices * 0.5) # 50% for training - num_sampler_devices = int(num_devices * 0.5) # 50% for sampling - trainer_devices = devices[:num_trainer_devices] - sampler_devices = devices[num_devices - num_sampler_devices:] - else: - # Single VM setup - trainer_devices = devices - sampler_devices = devices - + return trainer_devices, sampler_devices, num_vms @@ -186,11 +181,11 @@ def match_format_exactly(prompts, completions, **kwargs): """Reward exact format matching.""" scores = [] match_format = re.compile( - rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?" + rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?" rf"{SOLUTION_START}(.+?){SOLUTION_END}" rf"[\s]{{0,}}$", flags=re.MULTILINE | re.DOTALL, ) - + for completion in completions: score = 0 if match_format.search(completion) is not None: @@ -215,20 +210,20 @@ def match_format_approximately(prompts, completions, **kwargs): def check_answer(prompts, completions, answer, **kwargs): """Reward correct answers.""" match_format = re.compile( - rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?" + rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?" rf"{SOLUTION_START}(.+?){SOLUTION_END}" rf"[\s]{{0,}}$", flags=re.MULTILINE | re.DOTALL, ) - + extracted_responses = [guess.group(1) if (guess := match_format.search(r)) is not None else None for r in completions] - + scores = [] for guess, true_answer in zip(extracted_responses, answer): score = 0 if guess is None: scores.append(0) continue - + if guess == true_answer: score += REWARD_EXACT_FORMAT_MATCH elif guess.strip() == true_answer.strip(): @@ -236,13 +231,13 @@ def check_answer(prompts, completions, answer, **kwargs): else: try: ratio = float(guess) / float(true_answer) - if ratio >= 0.9 and ratio <= 1.1: + if 0.9 <= ratio <= 1.1: score += REWARD_RATIO_GUESS_TO_ANSWER_HIGH - elif ratio >= 0.8 and ratio <= 1.2: + elif 0.8 <= ratio <= 1.2: score += REWARD_RATIO_GUESS_TO_ANSWER_LOW else: score += PENALTY_INCORRECT_ANSWER - except: + except (ValueError, ZeroDivisionError): score += PENALTY_INCORRECT_FORMAT scores.append(score) return scores @@ -252,7 +247,7 @@ def check_numbers(prompts, completions, answer, **kwargs): """Reward correct numerical answers.""" match_numbers = re.compile(rf"{SOLUTION_START}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL) extracted_responses = [guess.group(1) if (guess := match_numbers.search(r)) is not None else None for r in completions] - + scores = [] for guess, true_answer in zip(extracted_responses, answer): if guess is None: @@ -262,62 +257,62 @@ def check_numbers(prompts, completions, answer, **kwargs): true_answer = float(true_answer.strip()) guess = float(guess.strip()) scores.append(1.5 if guess == true_answer else 0.0) - except: + except (ValueError, TypeError): scores.append(0) continue return scores -def grpo_train(config, goodput_recorder: Optional[GoodputRecorder] = None): +def grpo_train(config): """ Run GRPO training with the provided configuration. - + Args: config: MaxText configuration object - goodput_recorder: Optional goodput recorder for performance monitoring """ print("=" * 80) print("Starting GRPO Training") print("=" * 80) - + # Setup device allocation use_pathways = getattr(config, 'use_pathways_reshard', False) trainer_devices, sampler_devices, num_vms = setup_device_allocation(config, use_pathways) - + print(f"Device allocation: {len(trainer_devices)} trainer, {len(sampler_devices)} sampler") print(f"Use Pathways: {use_pathways}") - + # Setup data directories home = os.path.expanduser("~") + "/" train_data_dir = f"{home}/data/train" test_data_dir = f"{home}/data/test" - + # Load datasets print("Loading GSM8K dataset...") train_dataset, tokenizer = get_gsm8k_dataset( - train_data_dir, - split="train", + train_data_dir, + split="train", batch_size=config.per_device_batch_size, num_batches=getattr(config, 'num_batches', 4) ) - - test_dataset, _ = get_gsm8k_dataset( - test_data_dir, - split="test", + + # Load test dataset for evaluation (currently not used in training loop) + get_gsm8k_dataset( + test_data_dir, + split="test", batch_size=config.per_device_batch_size, num_batches=getattr(config, 'num_test_batches', 5) ) - + # Load reference model print("Loading reference model...") reference_model, reference_mesh = get_maxtext_model(config, trainer_devices) reference_model.config = None - + # Load policy model print("Loading policy model...") policy_model, policy_mesh = get_maxtext_model(config, trainer_devices) policy_model.config = None - + # Setup meshes if num_vms >= 2 and not use_pathways: actor_mesh = policy_mesh @@ -325,12 +320,12 @@ def grpo_train(config, goodput_recorder: Optional[GoodputRecorder] = None): else: actor_mesh = policy_mesh rollout_mesh = policy_mesh - + # Setup optimizer learning_rate = getattr(config, 'learning_rate', 3e-6) max_steps = getattr(config, 'steps', 100) warmup_steps = int(0.1 * max_steps) - + optimizer = optax.adamw( learning_rate=optax.schedules.warmup_cosine_decay_schedule( init_value=0.0, @@ -343,7 +338,7 @@ def grpo_train(config, goodput_recorder: Optional[GoodputRecorder] = None): b2=0.99, weight_decay=0.1, ) - + # Add gradient clipping if specified max_grad_norm = getattr(config, 'max_grad_norm', 0.1) if max_grad_norm is not None: @@ -351,25 +346,25 @@ def grpo_train(config, goodput_recorder: Optional[GoodputRecorder] = None): optax.clip_by_global_norm(max_norm=max_grad_norm), optimizer, ) - + # Setup checkpointing ckpt_dir = f"{config.base_output_directory}/checkpoints" os.makedirs(ckpt_dir, exist_ok=True) - + checkpointing_options = ocp.CheckpointManagerOptions( save_interval_steps=getattr(config, 'checkpoint_period', 50), max_to_keep=4 ) - + # Setup metrics logging log_dir = f"{config.base_output_directory}/logs" os.makedirs(log_dir, exist_ok=True) - + metrics_logging_options = metrics_logger.MetricsLoggerOptions( - log_dir=log_dir, + log_dir=log_dir, flush_every_n_steps=20 ) - + # Setup RL cluster config cluster_config = rl_cluster_lib.ClusterConfig( role_to_mesh={ @@ -391,7 +386,7 @@ def grpo_train(config, goodput_recorder: Optional[GoodputRecorder] = None): rollout_config=base_rollout.RolloutConfig( max_tokens_to_generate=getattr(config, 'max_target_length', 768), max_prompt_length=getattr(config, 'max_prefill_predict_length', 256), - kv_cache_size=getattr(config, 'max_prefill_predict_length', 256) + + kv_cache_size=getattr(config, 'max_prefill_predict_length', 256) + getattr(config, 'max_target_length', 768) + 256, temperature=getattr(config, 'decode_sampling_temperature', 0.9), top_p=getattr(config, 'decode_sampling_top_p', 1.0), @@ -401,7 +396,7 @@ def grpo_train(config, goodput_recorder: Optional[GoodputRecorder] = None): rollout_vllm_hbm_utilization=0.2, rollout_vllm_tpu_backend_type="jax", ) - + # Setup GRPO config grpo_config = GrpoConfig( num_generations=getattr(config, 'num_generations', 2), @@ -409,7 +404,7 @@ def grpo_train(config, goodput_recorder: Optional[GoodputRecorder] = None): beta=getattr(config, 'grpo_beta', 0.08), epsilon=getattr(config, 'grpo_epsilon', 0.2), ) - + # Create RL cluster print("Creating RL cluster...") with nn_partitioning.axis_rules(config.logical_axis_rules): @@ -419,7 +414,7 @@ def grpo_train(config, goodput_recorder: Optional[GoodputRecorder] = None): tokenizer=tokenizer, cluster_config=cluster_config, ) - + # Create GRPO trainer print("Setting up GRPO trainer...") grpo_trainer = GrpoLearner( @@ -432,14 +427,14 @@ def grpo_train(config, goodput_recorder: Optional[GoodputRecorder] = None): ], grpo_config=grpo_config, ) - + # Start training print("Starting GRPO training...") with policy_mesh, nn_partitioning.axis_rules(config.logical_axis_rules): grpo_trainer.train(train_dataset) - + print("=" * 80) print("GRPO Training Completed Successfully!") print("=" * 80) - + return grpo_trainer, rl_cluster From 81e4c2b64df5642afa448006584aa688450d1cb1 Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Sat, 25 Oct 2025 06:57:16 +0400 Subject: [PATCH 07/31] Fix naming Signed-off-by: Vladimir Suvorov --- src/MaxText/examples/README.md | 6 +++--- src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py | 2 +- src/MaxText/examples/grpo_llama3_1_8b_demo.py | 2 +- src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py | 2 +- src/MaxText/examples/{grpo_demo.py => grpo_runner.py} | 6 +++--- 5 files changed, 9 insertions(+), 9 deletions(-) rename src/MaxText/examples/{grpo_demo.py => grpo_runner.py} (98%) diff --git a/src/MaxText/examples/README.md b/src/MaxText/examples/README.md index 470975de8..a6dc4ae79 100644 --- a/src/MaxText/examples/README.md +++ b/src/MaxText/examples/README.md @@ -128,7 +128,7 @@ Use the link for Jupyter Lab as a link for "Connect to a local runtime" in Colla ### GRPO Training - **`grpo_llama3_1_8b_demo.ipynb`** → GRPO training on math dataset (Colab/notebook) -- **`grpo_demo.py`** → Unified CLI for GRPO training (any model) +- **`grpo_runner.py`** → Unified CLI for GRPO training (any model) #### GRPO Colab Usage @@ -142,7 +142,7 @@ For interactive GRPO training in Google Colab or Jupyter: ```bash # Llama3.1-8B -python3 src/MaxText/examples/grpo_demo.py \ +python3 src/MaxText/examples/grpo_runner.py \ --model_name=llama3.1-8b \ --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ --load_parameters_path=gs://path/to/checkpoint \ @@ -151,7 +151,7 @@ python3 src/MaxText/examples/grpo_demo.py \ --steps=100 # Qwen2.5-7B -python3 src/MaxText/examples/grpo_demo.py \ +python3 src/MaxText/examples/grpo_runner.py \ --model_name=qwen2.5-7b \ --tokenizer_path=Qwen/Qwen2.5-7B-Instruct \ --load_parameters_path=gs://path/to/checkpoint \ diff --git a/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py b/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py index 107a05274..c45e7b174 100644 --- a/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py +++ b/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py @@ -15,7 +15,7 @@ # pylint: disable=bare-except, consider-using-generator """ DEPRECATED: This file is deprecated and kept for reference only. -Please use the new unified CLI interface: grpo_demo.py +Please use the new unified CLI interface: grpo_runner.py See GRPO_README.md for migration guide and usage examples. diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo.py b/src/MaxText/examples/grpo_llama3_1_8b_demo.py index d3fc01579..baf9e53ae 100644 --- a/src/MaxText/examples/grpo_llama3_1_8b_demo.py +++ b/src/MaxText/examples/grpo_llama3_1_8b_demo.py @@ -15,7 +15,7 @@ # pylint: disable=bare-except, consider-using-generator """ DEPRECATED: This file is deprecated and kept for reference only. -Please use the new unified CLI interface: grpo_demo.py +Please use the new unified CLI interface: grpo_runner.py See GRPO_README.md for migration guide and usage examples. diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py b/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py index f7acf722a..2536ff0eb 100644 --- a/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py +++ b/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py @@ -15,7 +15,7 @@ # pylint: disable=bare-except, consider-using-generator """ DEPRECATED: This file is deprecated and kept for reference only. -Please use the new unified CLI interface: grpo_demo.py +Please use the new unified CLI interface: grpo_runner.py See GRPO_README.md for migration guide and usage examples. diff --git a/src/MaxText/examples/grpo_demo.py b/src/MaxText/examples/grpo_runner.py similarity index 98% rename from src/MaxText/examples/grpo_demo.py rename to src/MaxText/examples/grpo_runner.py index 57071c3b8..f92887274 100755 --- a/src/MaxText/examples/grpo_demo.py +++ b/src/MaxText/examples/grpo_runner.py @@ -23,7 +23,7 @@ Usage Examples: # Llama3.1-8B (single host) -python3 src/MaxText/examples/grpo_demo.py \\ +python3 src/MaxText/examples/grpo_runner.py \\ --model_name=llama3.1-8b \\ --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \\ --load_parameters_path=gs://path/to/checkpoint \\ @@ -32,7 +32,7 @@ --steps=100 # Llama3.1-70B with Pathways (multi-host) -python3 src/MaxText/examples/grpo_demo.py \\ +python3 src/MaxText/examples/grpo_runner.py \\ --model_name=llama3.1-70b \\ --tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \\ --load_parameters_path=gs://path/to/checkpoint \\ @@ -42,7 +42,7 @@ --steps=100 # Custom dataset -python3 src/MaxText/examples/grpo_demo.py \\ +python3 src/MaxText/examples/grpo_runner.py \\ --model_name=llama3.1-8b \\ --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \\ --load_parameters_path=gs://path/to/checkpoint \\ From 0b011a4eda40492f47cefc6121298a40b04f15ff Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Sat, 25 Oct 2025 07:07:22 +0400 Subject: [PATCH 08/31] simplification of nb Signed-off-by: Vladimir Suvorov --- .../examples/grpo_llama3_1_8b_demo.ipynb | 1361 ++-------------- .../grpo_llama3_1_8b_demo_backup.ipynb | 1371 +++++++++++++++++ 2 files changed, 1495 insertions(+), 1237 deletions(-) create mode 100644 src/MaxText/examples/grpo_llama3_1_8b_demo_backup.ipynb diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb b/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb index 3e82bb137..322659340 100644 --- a/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb +++ b/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb @@ -4,36 +4,34 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# GRPO Llama3.1-8B-Instruct Demo: Group Relative Policy Optimization\n", + "# GRPO Llama3.1-8B Demo: Direct Function Call\n", "\n", - "This tutorial demonstrates training the Llama3.1 8B-Instruct model on the GSM8K math reasoning benchmark using Group Relative Policy Optimization (GRPO). GRPO can enhance your model's problem-solving skills on mathematical word problems, coding problems, etc.\n", + "This notebook demonstrates GRPO training by directly calling the `grpo_train` function from `grpo_tunix_trainer.py`.\n", "\n", "## What is GRPO?\n", "\n", - "GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by:\n", + "GRPO (Group Relative Policy Optimization) is an RL algorithm that enhances reasoning abilities of LLMs by:\n", + "1. Generating multiple responses for each prompt\n", + "2. Evaluating responses using reward models \n", + "3. Calculating relative advantages to update the policy\n", "\n", - "1. Generating multiple responses for a given prompt\n", - "2. Evaluating these responses using a reward model\n", - "3. Calculating a relative advantage based on the group's performance to update the policy\n", + "## Direct Function Approach\n", "\n", - "## Libraries Used\n", - "\n", - "- **Tunix**: Library for GRPO implementation\n", - "- **vLLM**: Library for efficient model inference and generation\n", - "- **MaxText**: For model creation and training infrastructure\n", + "This notebook imports and calls the `grpo_train` function directly, avoiding subprocess calls and providing better integration with the notebook environment.\n", "\n", "## Hardware Requirements\n", "\n", - "This tutorial uses a single host TPUVM such as `v6e-8/v5p-8`.\n" + "- Single host TPUVM (v6e-8/v5p-8) or multi-host with Pathways\n", + "- Sufficient memory for Llama3.1-8B model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Install Necessary Libraries\n", + "## Setup\n", "\n", - "First, let's install the required dependencies:\n" + "Install dependencies and set up the environment:" ] }, { @@ -42,12 +40,8 @@ "metadata": {}, "outputs": [], "source": [ - "### (Optional) Run this if you just have this file and nothing else\n", - "\n", - "# 1. Clone the MaxText repository (from AI‑Hypercomputer)\n", + "# Clone MaxText repository\n", "!git clone https://github.com/AI-Hypercomputer/maxtext.git\n", - "\n", - "# 2. Navigate into the cloned directory\n", "%cd maxtext" ] }, @@ -57,35 +51,28 @@ "metadata": {}, "outputs": [], "source": [ - "### (Optional) Do not run this if you already installed the dependencies\n", - "\n", - "# 3. Ensure setup.sh is executable\n", + "# Install dependencies\n", "!chmod +x setup.sh\n", - "\n", - "# 4. Execute the setup script\n", "!./setup.sh\n", "\n", - "# Install vllm requirements\n", + "# Install GRPO-specific dependencies\n", "!./src/MaxText/examples/install_tunix_vllm_requirement.sh\n", "\n", - "# force numpy version\n", - "!pip install --force-reinstall numpy==2.1.2\n", - "# install nest_asyncio\n", - "!pip install nest_asyncio\n", + "# Install additional requirements\n", + "%pip install --force-reinstall numpy==2.1.2\n", + "%pip install nest_asyncio\n", "\n", "import nest_asyncio\n", - "\n", - "nest_asyncio.apply()\n", - "# To fix \"This event loop is already running\" error in Colab" + "nest_asyncio.apply() # Fix for Colab event loop" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Imports\n", + "## Configuration\n", "\n", - "Import all necessary libraries for GRPO training:\n" + "Set up the training parameters:" ] }, { @@ -94,50 +81,22 @@ "metadata": {}, "outputs": [], "source": [ - "# Set up the variables for the script\n", + "# Configuration for GRPO training\n", "import os\n", - "import sys\n", "\n", - "# Set the MaxText home directory (where you cloned the maxtext repo)\n", + "# Set up paths\n", "MAXTEXT_REPO_ROOT = os.path.expanduser(\"~\") + \"/maxtext\"\n", - "print(f\"MaxText Home directory (from Python): {MAXTEXT_REPO_ROOT}\")\n", - "\n", - "DEBUG = False # set to True to run in debug mode, for more print statements\n", - "# set this to the path of the checkpoint you want to load, gs:// supported\n", - "MODEL_CHECKPOINT_PATH = \"/path/to/scanned/model/ckpt_load_dir/\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import functools\n", - "from pprint import pprint\n", - "import re\n", - "from flax import nnx\n", - "from flax.linen import partitioning as nn_partitioning\n", - "import grain\n", - "import humanize\n", - "import jax\n", - "import optax\n", - "from orbax import checkpoint as ocp\n", - "import tensorflow_datasets as tfds\n", - "from tqdm.auto import tqdm\n", - "from tunix.rl import rl_cluster as rl_cluster_lib\n", - "from tunix.rl.rollout import base_rollout\n", - "from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner\n", - "from tunix.sft import metrics_logger\n", + "print(f\"MaxText Home directory: {MAXTEXT_REPO_ROOT}\")\n", "\n", - "from transformers import AutoTokenizer\n", + "# Training configuration\n", + "MODEL_CHECKPOINT_PATH = \"gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items\"\n", + "OUTPUT_DIRECTORY = \"/tmp/grpo_output\"\n", + "STEPS = 10 # Reduced for demo purposes\n", + "HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"your_hf_token_here\")\n", "\n", - "from flax import linen as nn\n", - "from tunix.models.llama3 import model as llama3_lib\n", - "import numpy as np\n", - "from etils import epath\n", - "\n", - "from tunix.rl.rollout.base_rollout import RolloutConfig" + "print(f\"Model checkpoint: {MODEL_CHECKPOINT_PATH}\")\n", + "print(f\"Output directory: {OUTPUT_DIRECTORY}\")\n", + "print(f\"Training steps: {STEPS}\")" ] }, { @@ -146,25 +105,24 @@ "metadata": {}, "outputs": [], "source": [ + "# Import GRPO training function directly\n", + "import sys\n", + "import os\n", "from pathlib import Path\n", - "from typing import Optional, Dict, Any\n", - "\n", - "maxtext_path = Path(f\"{MAXTEXT_REPO_ROOT}\") / \"src\" / \"MaxText\"\n", "\n", - "# Change working directory to MaxText project root\n", - "os.chdir(maxtext_path)\n", + "# Add MaxText to Python path\n", + "maxtext_path = Path(MAXTEXT_REPO_ROOT) / \"src\" / \"MaxText\"\n", "sys.path.insert(0, str(maxtext_path))\n", "\n", - "from MaxText import model_creation_utils, pyconfig\n", - "from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter\n", + "# Import required modules\n", + "from MaxText import pyconfig\n", + "from MaxText.experimental.rl.grpo_tunix_trainer import grpo_train\n", "\n", - "print(f\"✓ Changed working directory to: {os.getcwd()}\")\n", - "print(f\"✓ MaxText project root: {maxtext_path}\")\n", - "print(f\"✓ Added to Python path: {maxtext_path}\")\n", - "\n", - "if not jax.distributed.is_initialized():\n", - " jax.distributed.initialize()\n", - "jax.devices()" + "print(\"✅ Successfully imported GRPO training function\")\n", + "print(f\"📁 MaxText path: {maxtext_path}\")\n", + "print(\"\\n\" + \"=\"*80)\n", + "print(\"Starting GRPO Training...\")\n", + "print(\"=\"*80)" ] }, { @@ -173,26 +131,34 @@ "metadata": {}, "outputs": [], "source": [ - "# Hugging Face Authentication Setup\n", - "import os\n", - "from huggingface_hub import login\n", - "\n", - "# Use your Hugging Face token (recommended)\n", - "# Get your token from: https://huggingface.co/settings/tokens\n", - "\n", - "os.environ[\"HF_TOKEN\"] = \"hf_your_token_here\"\n", - "login(token=os.environ[\"HF_TOKEN\"])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Hyperparameters\n", + "# Build configuration for GRPO training\n", + "config_argv = [\n", + " \"\", # Placeholder for argv[0]\n", + " \"src/MaxText/configs/grpo.yml\", # Base config\n", + " f\"model_name=llama3.1-8b\",\n", + " f\"tokenizer_path=meta-llama/Llama-3.1-8B-Instruct\",\n", + " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}\",\n", + " f\"base_output_directory={OUTPUT_DIRECTORY}\",\n", + " f\"hf_access_token={HF_TOKEN}\",\n", + " f\"steps={STEPS}\",\n", + " \"per_device_batch_size=1\",\n", + " \"learning_rate=3e-6\",\n", + " \"num_generations=2\",\n", + " \"grpo_beta=0.08\",\n", + " \"grpo_epsilon=0.2\",\n", + " \"trainer_devices_fraction=0.5\",\n", + " \"sampler_devices_fraction=0.5\",\n", + " \"chips_per_vm=4\"\n", + "]\n", "\n", - "Let's define the configuration we are going to use. Note that this is by no means a \"perfect\" set of hyperparameters. To get good results, you might have to train the model for longer.\n", + "# Create configuration object\n", + "config = pyconfig.Config()\n", + "config.parse_flags(config_argv)\n", "\n", - "### Data Configuration\n" + "print(\"✅ Configuration created successfully\")\n", + "print(f\"📊 Training steps: {config.steps}\")\n", + "print(f\"📁 Output directory: {config.base_output_directory}\")\n", + "print(f\"🤖 Model: {config.model_name}\")" ] }, { @@ -201,1155 +167,76 @@ "metadata": {}, "outputs": [], "source": [ - "# ====== Data ======\n", - "TRAIN_DATA_DIR = f\"{MAXTEXT_REPO_ROOT}/data/train\"\n", - "TEST_DATA_DIR = f\"{MAXTEXT_REPO_ROOT}/data/test\"\n", - "if not os.path.exists(TRAIN_DATA_DIR):\n", - " os.makedirs(TRAIN_DATA_DIR)\n", - "if not os.path.exists(TEST_DATA_DIR):\n", - " os.makedirs(TEST_DATA_DIR)\n", - "TRAIN_FRACTION = 1.0" + "# Execute GRPO training directly\n", + "try:\n", + " # Call the grpo_train function\n", + " grpo_trainer, rl_cluster = grpo_train(config)\n", + " \n", + " print(\"\\n\" + \"=\"*80)\n", + " print(\"✅ GRPO Training Completed Successfully!\")\n", + " print(\"=\"*80)\n", + " print(f\"📁 Checkpoints saved to: {config.base_output_directory}/checkpoints\")\n", + " print(f\"📊 Logs available in: {config.base_output_directory}/logs\")\n", + " print(f\"🎯 Final model ready for inference!\")\n", + " \n", + "except Exception as e:\n", + " print(\"\\n\" + \"=\"*80)\n", + " print(\"❌ GRPO Training Failed!\")\n", + " print(\"=\"*80)\n", + " print(f\"Error: {str(e)}\")\n", + " print(\"\\nPlease check the error message and try again.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Checkpoint and Logging Configuration\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# ====== Checkpoint directory =====\n", - "LOG_DIR = f\"{MAXTEXT_REPO_ROOT}/content/tensorboard/grpo/logs_llama3/\"\n", - "if not os.path.exists(LOG_DIR):\n", - " os.makedirs(LOG_DIR)\n", + "## Summary\n", "\n", - "# ===== Profiling =====\n", - "PROFILE_DIR = f\"{MAXTEXT_REPO_ROOT}/content/jax_traces/grpo/profiles_llama3/\"\n", - "if not os.path.exists(PROFILE_DIR):\n", - " os.makedirs(PROFILE_DIR)\n", + "This notebook demonstrates GRPO training by directly calling the `grpo_train` function. The key benefits are:\n", "\n", - "# ====== Checkpoint saving ======\n", - "CKPT_DIR = f\"{MAXTEXT_REPO_ROOT}/content/ckpts_llama3/\"\n", + "### ✅ **Direct Function Approach**\n", + "- **No subprocess calls** - Direct Python function execution\n", + "- **Better integration** - Seamless notebook environment\n", + "- **Easier debugging** - Direct access to variables and state\n", "\n", - "if not os.path.exists(CKPT_DIR):\n", - " os.makedirs(CKPT_DIR)\n", + "### ✅ **What Happened**\n", + "1. **Model Loading** - Llama3.1-8B with Tunix adapter\n", + "2. **Dataset Processing** - GSM8K math reasoning dataset\n", + "3. **GRPO Training** - Multiple reward functions for math problems\n", + "4. **Checkpointing** - Model weights saved for inference\n", "\n", - "SAVE_INTERVAL_STEPS = 500\n", - "MAX_TO_KEEP = 4\n", - "\n", - "# ====== Reproducibility ======\n", - "SEED = 42" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### GRPO Configuration\n", + "### ✅ **Next Steps**\n", + "- **Inference** - Use the trained model for math problem solving\n", + "- **Evaluation** - Test on GSM8K test set\n", + "- **Customization** - Modify parameters for different models/datasets\n", "\n", - "GRPO-specific hyperparameters for generation and training:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# ====== GRPO ======\n", - "# === Generation during GRPO training ===\n", - "MAX_PROMPT_LENGTH = 256\n", - "TOTAL_GENERATION_STEPS = 768\n", - "# Important to keep a high-ish temperature for varied, diverse responses during\n", - "# training.\n", - "TEMPERATURE = 0.9\n", - "TOP_P = 1.0\n", - "TOP_K = 50\n", - "# The number of times the policy generates multiple responses for a given prompt\n", - "# within a single training step. This corresponds to `G` in Algorithm 1 in the\n", - "# paper. The \"group\" in GRPO comes from here.\n", - "NUM_GENERATIONS = 2\n", - "\n", - "# === other GRPO configs ===\n", - "# The number of iterations per batch (𝜇 in GRPO algo 1).\n", - "NUM_ITERATIONS = 1\n", - "# The coefficient for the KL divergence penalty (𝛽) in the GRPO loss function.\n", - "# Important to keep a high enough value for this, otherwise, the KL divergence\n", - "# can increase unchecked.\n", - "BETA = 0.08\n", - "# Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, for\n", - "# stable updates.\n", - "EPSILON = 0.2" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Training Configuration\n", - "\n", - "Training hyperparameters including batch size, learning rate, and optimization settings:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# ====== Training ======\n", - "BATCH_SIZE = 1\n", - "# Increase `NUM_BATCHES` and `MAX_STEPS` for better results.\n", - "NUM_BATCHES = 4 # 200\n", - "# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be\n", - "# increased to a max. of 330 (if batch size is 4).\n", - "NUM_TEST_BATCHES = 5 # 200\n", - "\n", - "SEQUENCE_LENGTH = 1024\n", - "\n", - "EVAL_EVERY_N_STEPS = 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`.\n", - "NUM_EPOCHS = 1 # can potentially train for more epochs\n", - "\n", - "# Number of training steps.\n", - "MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)\n", - "\n", - "# === AdamW, warmup, cosine scheduler ===\n", - "LEARNING_RATE = 3e-6\n", - "B1 = 0.9\n", - "B2 = 0.99\n", - "WEIGHT_DECAY = 0.1\n", - "# == Cosine decay with warmup scheduler ==\n", - "# Linearly increase learning rate from 0. to 5e-6 in the first 10% training\n", - "# steps, and then gradually decrease the learning rate to 0 using cosine\n", - "# scheduler.\n", - "WARMUP_STEPS = int(0.1 * MAX_STEPS)\n", - "# == Grad clipping ==\n", - "# Grad clipping to prevent large gradients. Found this\n", - "# important to keep KL divergence in check.\n", - "MAX_GRAD_NORM = 0.1" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Inference and Reward Configuration\n", - "\n", - "Configuration for model inference and reward function parameters:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# ====== Inference ======\n", - "GENERATION_CONFIGS = {\n", - " # greedy search\n", - " \"greedy\": {\"temperature\": 0.01, \"top_k\": 1, \"top_p\": 1.0},\n", - " # some randomness\n", - " \"standard\": {\"temperature\": 0.7, \"top_k\": 50, \"top_p\": 0.95},\n", - " # liberal\n", - " \"liberal\": {\"temperature\": 0.85, \"top_k\": 2000, \"top_p\": 1.0},\n", - "}\n", - "\n", - "# ====== Reward ======\n", - "REWARD_EXACT_FORMAT_MATCH = 3.0\n", - "REWARD_WHITE_SPACE_FORMAT_MATCH = 1.5\n", - "REWARD_PARTIAL_FORMAT_MATCH = 0.5\n", - "REWARD_RATIO_GUESS_TO_ANSWER_HIGH = 0.5\n", - "REWARD_RATIO_GUESS_TO_ANSWER_LOW = 0.25\n", - "PENALTY_INCORRECT_FORMAT = -0.5\n", - "PENALTY_INCORRECT_ANSWER = -1.0" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Utility Functions\n", - "\n", - "Helper functions for monitoring memory usage and other utilities:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def show_hbm_usage():\n", - " \"\"\"Displays memory usage per device.\"\"\"\n", - " fmt_size = functools.partial(humanize.naturalsize, binary=True)\n", - "\n", - " for d in jax.local_devices():\n", - " stats = d.memory_stats()\n", - " used = stats[\"bytes_in_use\"]\n", - " limit = stats[\"bytes_limit\"]\n", - " print(f\"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Data Preprocessing\n", - "\n", - "First, let's define some special tokens. We instruct the model to first reason between the `` and `` tokens. After reasoning, we expect it to provide the answer between the `` and `` tokens.\n", - "\n", - "We use OpenAI's GSM8K dataset. GSM8K comprises grade school math word problems.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model_tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Llama-3.1-8B-Instruct\")\n", - "\n", - "\n", - "reasoning_start = \"\"\n", - "reasoning_end = \"\"\n", - "solution_start = \"\"\n", - "solution_end = \"\"\n", - "\n", - "\n", - "SYSTEM_PROMPT = f\"\"\"You are given a problem. Think about the problem and \\\n", - "provide your reasoning. Place it between {reasoning_start} and \\\n", - "{reasoning_end}. Then, provide the final answer (i.e., just one numerical \\\n", - "value) between {solution_start} and {solution_end}.\"\"\"\n", - "\n", - "TEMPLATE = \"\"\"user\n", - "{system_prompt}\n", - "\n", - "{question}\n", - "model\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def extract_hash_answer(text: str) -> str | None:\n", - " if DEBUG:\n", - " print(f\"Extracting answer from: {text}\")\n", - " if \"####\" not in text:\n", - " return None\n", - " return text.split(\"####\")[1].strip()\n", - "\n", - "\n", - "def get_dataset(data_dir, split=\"train\") -> grain.MapDataset:\n", - " # Download data\n", - " if not os.path.exists(data_dir):\n", - " os.makedirs(data_dir)\n", - "\n", - " data = tfds.data_source(\n", - " \"gsm8k\",\n", - " split=split,\n", - " data_dir=data_dir,\n", - " builder_kwargs={\"file_format\": tfds.core.FileFormat.ARRAY_RECORD},\n", - " download=True,\n", - " )\n", - "\n", - " loaded_dataset = (\n", - " grain.MapDataset.source(data)\n", - " .shuffle(seed=SEED)\n", - " .map(\n", - " lambda x: {\n", - " # passed to model forward pass\n", - " \"prompts\": model_tokenizer.apply_chat_template(\n", - " [\n", - " {\n", - " \"role\": \"user\",\n", - " \"content\": TEMPLATE.format(\n", - " system_prompt=SYSTEM_PROMPT,\n", - " question=x[\"question\"].decode(\"utf-8\"),\n", - " ),\n", - " },\n", - " ],\n", - " tokenize=False,\n", - " add_generation_prompt=True,\n", - " ),\n", - " # passed to reward functions\n", - " \"question\": x[\"question\"].decode(\"utf-8\"),\n", - " # passed to reward functions\n", - " \"answer\": extract_hash_answer(x[\"answer\"].decode(\"utf-8\")),\n", - " }\n", - " )\n", - " )\n", - " return loaded_dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dataset = get_dataset(TRAIN_DATA_DIR, \"train\").batch(BATCH_SIZE)[:NUM_BATCHES]\n", - "\n", - "if TRAIN_FRACTION == 1.0:\n", - " train_dataset = dataset.repeat(NUM_EPOCHS)\n", - " val_dataset = None\n", - "else:\n", - " train_dataset = dataset[: int(len(dataset) * TRAIN_FRACTION)]\n", - " train_dataset = train_dataset.repeat(NUM_EPOCHS)\n", - "\n", - " val_dataset = dataset[int(len(dataset) * TRAIN_FRACTION) :].repeat(NUM_EPOCHS)\n", - "\n", - "test_dataset = get_dataset(TEST_DATA_DIR, \"test\").batch(BATCH_SIZE)[:NUM_TEST_BATCHES]\n", - "\n", - "\n", - "# Let's see how one batch of the dataset looks like!\n", - "\n", - "\n", - "if DEBUG:\n", - " for ele in train_dataset[:1]:\n", - " pprint(ele)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Load the Policy Model and the Reference Model\n", - "\n", - "The policy model is the model which is actually trained and whose weights are updated. The reference model is the model with which we compute KL divergence. This is to ensure that the policy updates are not huge and that it does not deviate too much from the reference model.\n", - "\n", - "Typically, the reference model is the base model, and the policy model is the same base model, but with potentially LoRA parameters where only the LoRA parameters are updated. This script is not using LoRA, so both the reference and policy models are the same.\n", - "\n", - "Note: We perform full precision (fp32) training. You can, however, leverage Qwix for QAT.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"HBM usage before loading model:\")\n", - "show_hbm_usage()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Load MaxText Model\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def get_ref_maxtext_model(config):\n", - "\n", - " model, mesh = model_creation_utils.create_nnx_model(config)\n", - " with mesh:\n", - " tunix_model = TunixMaxTextAdapter(\n", - " base_model=model,\n", - " )\n", - "\n", - " model_config = llama3_lib.ModelConfig.llama3_1_8b()\n", - " tunix_model.config = model_config\n", - "\n", - " return tunix_model, mesh\n", - "\n", - "\n", - "model_config = llama3_lib.ModelConfig.llama3_1_8b()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Load Reference Model\n", - "\n", - "Note: pass the path to your scanned checkpoint for \"load_parameters_path\". To create a scanned checkpoint, you can use `/maxtext/src/MaxText/utils/ckpt_conversion/to_maxtext.py`\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Load the reference model\n", - "config_ref = pyconfig.initialize(\n", - " [\n", - " \"\",\n", - " f\"{MAXTEXT_REPO_ROOT}/src/MaxText/configs/base.yml\",\n", - " ],\n", - " base_output_directory=\"dummy\", # This is not used in Tunix.\n", - " run_name=\"test-tunix-maxtext-llama3.1-8b\",\n", - " tokenizer_type=\"tiktoken\",\n", - " tokenizer_path=\"assets/tokenizer_llama3.tiktoken\",\n", - " load_parameters_path=f\"{MODEL_CHECKPOINT_PATH}\",\n", - " max_target_length=SEQUENCE_LENGTH,\n", - " async_checkpointing=\"false\",\n", - " model_name=\"llama3.1-8b\",\n", - " skip_jax_distributed_system=\"true\",\n", - " weight_dtype=\"bfloat16\",\n", - " attention=\"dot_product\",\n", - " remat_policy=\"custom\",\n", - " decoder_layer_input=\"offload\",\n", - " query_proj=\"offload\",\n", - " key_proj=\"offload\",\n", - " value_proj=\"offload\",\n", - ")\n", - "\n", - "llama3_1_8b, mesh = get_ref_maxtext_model(config_ref)\n", - "\n", - "llama3_1_8b.config = model_config\n", - "\n", - "nnx.display(llama3_1_8b)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "if DEBUG:\n", - " print(\"Model initialized successfully\")\n", - " print(f\"Model mesh shape: {mesh.shape}\")\n", - " print(f\"Model config: {model_config}\")\n", - "\n", - " # Sanity check that weights are loaded correctly\n", - " _maxtext_state_flatten = nnx.state(llama3_1_8b).flat_state()\n", - " maxtext_state_flatten = {\".\".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten}\n", - " print(\n", - " f\"maxtext_state_flatten[base.token_embedder.embedding].value={maxtext_state_flatten['base.token_embedder.embedding'].value}\"\n", - " )\n", - "\n", - "\n", - "# See the memory use after loading the reference model:\n", - "print(\"HBM usage after loading ref model:\")\n", - "show_hbm_usage()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Load Policy Model\n", - "\n", - "Note: pass the path to your scanned checkpoint for \"load_parameters_path\". To create a scanned checkpoint, you can use `/maxtext/src/MaxText/utils/ckpt_conversion/to_maxtext.py`\n", - "\n", - "TODO: @mazumdera: change this to use lora\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "config_policy = pyconfig.initialize(\n", - " [\n", - " \"\",\n", - " f\"{MAXTEXT_REPO_ROOT}/src/MaxText/configs/base.yml\",\n", - " ],\n", - " base_output_directory=\"dummy\", # This is not used in Tunix.\n", - " run_name=\"test-tunix-maxtext-llama3.1-8b\", # This is not used in Tunix.\n", - " tokenizer_type=\"tiktoken\",\n", - " tokenizer_path=\"assets/tokenizer_llama3.tiktoken\",\n", - " load_parameters_path=f\"{MODEL_CHECKPOINT_PATH}\",\n", - " max_target_length=SEQUENCE_LENGTH,\n", - " async_checkpointing=\"false\",\n", - " model_name=\"llama3.1-8b\",\n", - " skip_jax_distributed_system=\"true\",\n", - " weight_dtype=\"bfloat16\",\n", - " attention=\"dot_product\",\n", - " remat_policy=\"custom\",\n", - " decoder_layer_input=\"offload\",\n", - " query_proj=\"offload\",\n", - " key_proj=\"offload\",\n", - " value_proj=\"offload\",\n", - ")\n", - "llama3_1_8b_policy, mesh_policy = get_ref_maxtext_model(config_policy)\n", - "\n", - "llama3_1_8b_policy.config = model_config\n", - "\n", - "nnx.display(llama3_1_8b_policy)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "if DEBUG:\n", - " print(\"Model initialized successfully\")\n", - " print(f\"Model mesh shape: {mesh_policy.shape}\")\n", - " print(f\"Model config: {model_config}\")\n", - "\n", - " # Sanity check that weights are loaded correctly\n", - " _maxtext_state_flatten = nnx.state(llama3_1_8b_policy).flat_state()\n", - " maxtext_state_flatten = {\".\".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten}\n", - " print(\n", - " f\"maxtext_state_flatten[base.token_embedder.embedding].value={maxtext_state_flatten['base.token_embedder.embedding'].value}\"\n", - " )\n", - "\n", - "# See memory usage after loading the policy model:\n", - "print(\"HBM usage after loading policy model:\")\n", - "show_hbm_usage()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Define Reward Functions\n", - "\n", - "We define four reward functions:\n", - "\n", - "1. **Format Matching**: Reward if the format of the output exactly matches the instruction given in `TEMPLATE`\n", - "2. **Approximate Format Matching**: Reward if the format of the output approximately matches the instruction given in `TEMPLATE`\n", - "3. **Answer Correctness**: Reward if the answer is correct/partially correct\n", - "4. **Number Extraction**: Sometimes, the text between ``, `` might not be one number. So, extract the number, and reward the model if the answer is correct.\n", - "\n", - "The reward functions are inspired from [here](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb).\n", - "\n", - "First off, let's define a RegEx for checking whether the format matches.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "match_format = re.compile(\n", - " rf\"^[\\s]{{0,}}\" rf\"{reasoning_start}.+?{reasoning_end}.*?\" rf\"{solution_start}(.+?){solution_end}\" rf\"[\\s]{{0,}}$\",\n", - " flags=re.MULTILINE | re.DOTALL,\n", - ")\n", - "\n", - "match_format.search(\n", - " f\"{reasoning_start}Let me\" f\" think!{reasoning_end}{solution_start}2{solution_end}\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Reward Function 1: Exact Format Matching\n", - "\n", - "Give the model a reward of 3 points if the format matches exactly.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def match_format_exactly(prompts, completions, **kargs):\n", - " scores = []\n", - " for completion in completions:\n", - " score = 0\n", - " response = completion\n", - " # Match if format is seen exactly!\n", - " if match_format.search(response) is not None:\n", - " score += REWARD_EXACT_FORMAT_MATCH\n", - " scores.append(score)\n", - " return scores" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Reward Function 2: Approximate Format Matching\n", - "\n", - "We also reward the model if the format of the output matches partially.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def match_format_approximately(prompts, completions, **kargs):\n", - " scores = []\n", - "\n", - " for completion in completions:\n", - " score = 0\n", - " response = completion\n", - " # Count how many keywords are seen - we penalize if too many!\n", - " # If we see 1, then plus some points!\n", - " score += REWARD_PARTIAL_FORMAT_MATCH if response.count(reasoning_start) == 1 else PENALTY_INCORRECT_FORMAT\n", - " score += REWARD_PARTIAL_FORMAT_MATCH if response.count(reasoning_end) == 1 else PENALTY_INCORRECT_FORMAT\n", - " score += REWARD_PARTIAL_FORMAT_MATCH if response.count(solution_start) == 1 else PENALTY_INCORRECT_FORMAT\n", - " score += REWARD_PARTIAL_FORMAT_MATCH if response.count(solution_end) == 1 else PENALTY_INCORRECT_FORMAT\n", - " scores.append(score)\n", - " return scores" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Reward Function 3: Answer Correctness\n", - "\n", - "Reward the model if the answer is correct. A reward is also given if the answer does not match exactly, i.e., based on how close the answer is to the correct value.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def check_answer(prompts, completions, answer, **kargs):\n", - " responses = completions\n", - "\n", - " extracted_responses = [guess.group(1) if (guess := match_format.search(r)) is not None else None for r in responses]\n", - "\n", - " scores = []\n", - " for guess, true_answer in zip(extracted_responses, answer):\n", - " score = 0\n", - " if guess is None:\n", - " scores.append(0)\n", - " continue\n", - " # Correct answer gets 3 points!\n", - " if guess == true_answer:\n", - " score += REWARD_EXACT_FORMAT_MATCH\n", - " # Match if spaces are seen\n", - " elif guess.strip() == true_answer.strip():\n", - " score += REWARD_WHITE_SPACE_FORMAT_MATCH\n", - " else:\n", - " # We also reward it if the answer is close via ratios!\n", - " # Ie if the answer is within some range, reward it!\n", - " try:\n", - " ratio = float(guess) / float(true_answer)\n", - " if ratio >= 0.9 and ratio <= 1.1:\n", - " score += REWARD_RATIO_GUESS_TO_ANSWER_HIGH\n", - " elif ratio >= 0.8 and ratio <= 1.2:\n", - " score += REWARD_RATIO_GUESS_TO_ANSWER_LOW\n", - " else:\n", - " score += PENALTY_INCORRECT_ANSWER # Penalize wrong answers\n", - " except:\n", - " score += PENALTY_INCORRECT_FORMAT # Penalize\n", - " scores.append(score)\n", - " return scores" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Reward Function 4: Number Extraction\n", - "\n", - "Sometimes, the text between `` and `` might not be one number; it can be a sentence. So, we extract the number and compare the answer.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "match_numbers = re.compile(rf\"{solution_start}.*?([\\d\\.]{{1,}})\", flags=re.MULTILINE | re.DOTALL)\n", - "match_numbers.findall(f\"{solution_start} 0.34 {solution_end}\")\n", - "\n", - "\n", - "def check_numbers(prompts, completions, answer, **kargs):\n", - " question = kargs[\"question\"]\n", - " responses = completions\n", - "\n", - " extracted_responses = [guess.group(1) if (guess := match_numbers.search(r)) is not None else None for r in responses]\n", - "\n", - " scores = []\n", - " if DEBUG:\n", - " print(\"START ============================\")\n", - " print(f\"Question: {question[0]}\")\n", - " print(f\"Answer: {answer[0]}\")\n", - " print(f\"Response: {responses[0]}\")\n", - " print(f\"Extracted: {extracted_responses[0]}\")\n", - " print(\"END ==============================\")\n", - " for guess, true_answer in zip(extracted_responses, answer):\n", - " if guess is None:\n", - " scores.append(0)\n", - " continue\n", - " # Convert to numbers\n", - " try:\n", - " true_answer = float(true_answer.strip())\n", - " guess = float(guess.strip())\n", - " scores.append(1.5 if guess == true_answer else 0.0)\n", - " except:\n", - " scores.append(0)\n", - " continue\n", - " return scores" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Evaluation Functions\n", - "\n", - "Before we train the model, let's evaluate the model on the test set so we can see the improvement post training.\n", - "\n", - "We evaluate it in two ways:\n", - "\n", - "**Quantitative**\n", - "- **Answer Accuracy**: percentage of samples for which the model predicts the correct final numerical answer\n", - "- **Answer (Partial) Accuracy**: percentage of samples for which the model predicts a final numerical answer such that the `model answer / answer` ratio lies between 0.9 and 1.1.\n", - "- **Format Accuracy**: percentage of samples for which the model outputs the correct format, i.e., reasoning between the reasoning special tokens, and the final answer between the ``, `` tokens.\n", - "\n", - "**Qualitative**\n", - "\n", - "We'll also print outputs for a few given questions so that we can compare the generated output later.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def generate_responses(\n", - " prompts,\n", - " rl_cluster,\n", - " num_passes=1,\n", - " temperature=0.7,\n", - " top_k=50,\n", - " top_p=0.95,\n", - "):\n", - " \"\"\"\n", - " Generate responses for a batch of prompts across multiple passes.\n", - "\n", - " Args:\n", - " prompts: List of prompts to generate responses for\n", - " rl_cluster: Model cluster for generation\n", - " num_passes: Number of generation passes\n", - " temperature: Sampling temperature\n", - " top_k: Top-k sampling parameter\n", - " top_p: Top-p sampling parameter\n", - "\n", - " Returns:\n", - " List of lists containing responses for each prompt across passes\n", - " \"\"\"\n", - " multiple_call_responses = [[] for _ in range(len(prompts))]\n", - "\n", - " for p in range(num_passes):\n", - " responses = rl_cluster.rollout.generate(\n", - " prompts,\n", - " rollout_config=RolloutConfig(\n", - " max_tokens_to_generate=TOTAL_GENERATION_STEPS,\n", - " temperature=temperature,\n", - " top_k=top_k,\n", - " top_p=top_p,\n", - " ),\n", - " )\n", - " responses = responses.text\n", - "\n", - " if DEBUG:\n", - " print(f\"Pass {p+1}/{num_passes}, responses: {responses}\")\n", - "\n", - " for idx, response in enumerate(responses):\n", - " multiple_call_responses[idx].append(response)\n", - "\n", - " return multiple_call_responses" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def score_responses(question, responses, answer):\n", - " \"\"\"\n", - " Score a set of responses for a single question.\n", - "\n", - " Args:\n", - " question: The evaluation question\n", - " responses: List of generated responses for this question\n", - " answer: The correct answer\n", - "\n", - " Returns:\n", - " Tuple of (is_correct, is_partially_correct, has_correct_format)\n", - " \"\"\"\n", - " if DEBUG:\n", - " print(\"========================================\")\n", - " print(f\"Evaluation Question: {question}\")\n", - " print(f\"Evaluation Answer: {answer}\")\n", - " print(f\"Evaluation Responses: {responses}\")\n", - " print(\"========================================\")\n", - "\n", - " is_correct = False\n", - " is_partially_correct = False\n", - " has_correct_format = False\n", - "\n", - " for response in responses:\n", - " # Extract numerical response\n", - " extracted_response = guess.group(1) if (guess := match_numbers.search(response)) is not None else \"-1000000\"\n", - "\n", - " if DEBUG:\n", - " print(f\"Evaluation extracted_response: {extracted_response}\")\n", - "\n", - " # Check exact correctness\n", - " try:\n", - " if float(extracted_response.strip()) == float(answer.strip()):\n", - " is_correct = True\n", - "\n", - " # Check partial correctness (within 10%)\n", - " ratio = float(extracted_response.strip()) / float(answer.strip())\n", - " if 0.9 <= ratio <= 1.1:\n", - " is_partially_correct = True\n", - " except Exception as e:\n", - " if DEBUG:\n", - " print(f\"Evaluation Exception: {e}\")\n", - " print(\"SKIPPED\")\n", - "\n", - " # Check format correctness\n", - " if match_format.search(response) is not None:\n", - " has_correct_format = True\n", - "\n", - " # Early exit if all criteria are met\n", - " if is_correct and is_partially_correct and has_correct_format:\n", - " break\n", - "\n", - " return is_correct, is_partially_correct, has_correct_format" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def evaluate(\n", - " dataset,\n", - " rl_cluster,\n", - " temperature=0.7,\n", - " top_k=50,\n", - " top_p=0.95,\n", - " num_passes=1,\n", - " corr_lst=False,\n", - " make_lst=False,\n", - "):\n", - " \"\"\"\n", - " Computes accuracy and percentage of outputs matching the format.\n", - "\n", - " Args:\n", - " dataset: The evaluation dataset\n", - " rl_cluster: Model cluster for generation\n", - " temperature: Sampling temperature\n", - " top_k: Top-k sampling parameter\n", - " top_p: Top-p sampling parameter\n", - " num_passes: Number of generation passes\n", - " corr_lst: If True, only include correct responses in the list\n", - " make_lst: If True, return a list of (question, answer, responses)\n", - "\n", - " Returns:\n", - " Tuple of statistics and optionally the response list\n", - " \"\"\"\n", - " response_lst = []\n", - " corr = 0\n", - " partially_corr = 0\n", - " corr_format = 0\n", - " total = 0\n", - "\n", - " for batch in tqdm(dataset):\n", - " answers = batch[\"answer\"]\n", - " questions = batch[\"question\"]\n", - " prompts = batch[\"prompts\"]\n", - "\n", - " # Generate responses for all prompts in the batch\n", - " multiple_call_responses = generate_responses(\n", - " prompts=prompts,\n", - " rl_cluster=rl_cluster,\n", - " num_passes=num_passes,\n", - " temperature=temperature,\n", - " top_k=top_k,\n", - " top_p=top_p,\n", - " )\n", - "\n", - " # Score each question-answer pair\n", - " for question, responses, answer in zip(questions, multiple_call_responses, answers):\n", - " is_correct, is_partially_correct, has_correct_format = score_responses(\n", - " question=question,\n", - " responses=responses,\n", - " answer=answer,\n", - " )\n", - "\n", - " # Update counters\n", - " if is_correct:\n", - " corr += 1\n", - " if corr_lst and make_lst:\n", - " response_lst.append((question, answer, responses))\n", - " else:\n", - " if not corr_lst and make_lst:\n", - " response_lst.append((question, answer, responses))\n", - "\n", - " if is_partially_correct:\n", - " partially_corr += 1\n", - "\n", - " if has_correct_format:\n", - " corr_format += 1\n", - "\n", - " total += 1\n", - "\n", - " # Print progress every 10 items\n", - " if total % 10 == 0:\n", - " print(\n", - " f\"===> {corr=}, {total=}, {corr / total * 100=}, \"\n", - " f\"{partially_corr / total * 100=}, {corr_format / total * 100=}\"\n", - " )\n", - "\n", - " # Prepare return values\n", - " to_return = (\n", - " corr,\n", - " total,\n", - " corr / total * 100,\n", - " partially_corr / total * 100,\n", - " corr_format / total * 100,\n", - " )\n", - "\n", - " if make_lst:\n", - " return to_return, response_lst\n", - " return to_return" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Training Setup\n", - "\n", - "Let's set up all the configs first - checkpointing, metric logging and training. We then train the model.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Ckpt saving\n", - "checkpointing_options = ocp.CheckpointManagerOptions(save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP)\n", - "\n", - "# Metrics logger\n", - "metrics_logging_options = metrics_logger.MetricsLoggerOptions(log_dir=LOG_DIR, flush_every_n_steps=20)\n", - "\n", - "\n", - "# Logs\n", - "print(f\"TensorBoard logs directory: {LOG_DIR}\")\n", - "print(f\"tensorboard --logdir {LOG_DIR} --port=8086\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Optimizer, learning rate scheduler, gradient clipping\n", - "optimizer = optax.adamw(\n", - " learning_rate=optax.schedules.warmup_cosine_decay_schedule(\n", - " init_value=0.0,\n", - " peak_value=LEARNING_RATE,\n", - " warmup_steps=WARMUP_STEPS,\n", - " decay_steps=MAX_STEPS,\n", - " end_value=0.0,\n", - " ),\n", - " b1=B1,\n", - " b2=B2,\n", - " weight_decay=WEIGHT_DECAY,\n", - ")\n", - "\n", - "if MAX_GRAD_NORM is not None:\n", - " optimizer = optax.chain(\n", - " optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),\n", - " optimizer,\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# RL Cluster config\n", - "# Note that we use vLLM as the rollout engine.\n", - "# and we are using Tensor Parallelism for rollout\n", - "\n", - "cluster_config = rl_cluster_lib.ClusterConfig(\n", - " role_to_mesh={\n", - " rl_cluster_lib.Role.ACTOR: mesh,\n", - " rl_cluster_lib.Role.REFERENCE: mesh,\n", - " rl_cluster_lib.Role.ROLLOUT: mesh,\n", - " },\n", - " rollout_engine=\"vllm\",\n", - " offload_to_cpu=False,\n", - " training_config=rl_cluster_lib.RLTrainingConfig(\n", - " actor_optimizer=optimizer,\n", - " eval_every_n_steps=EVAL_EVERY_N_STEPS,\n", - " max_steps=MAX_STEPS,\n", - " gradient_accumulation_steps=1,\n", - " # metrics logging\n", - " metrics_logging_options=metrics_logging_options,\n", - " # checkpoint saving\n", - " checkpoint_root_directory=CKPT_DIR,\n", - " checkpointing_options=checkpointing_options,\n", - " ),\n", - " rollout_config=base_rollout.RolloutConfig(\n", - " max_tokens_to_generate=TOTAL_GENERATION_STEPS,\n", - " max_prompt_length=MAX_PROMPT_LENGTH,\n", - " kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,\n", - " temperature=TEMPERATURE,\n", - " top_p=TOP_P,\n", - " top_k=TOP_K,\n", - " ),\n", - " rollout_vllm_model_version=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", - " rollout_vllm_hbm_utilization=0.2,\n", - " rollout_vllm_tpu_backend_type=\"jax\",\n", - ")\n", - "\n", - "grpo_config = GrpoConfig(\n", - " num_generations=NUM_GENERATIONS,\n", - " num_iterations=NUM_ITERATIONS,\n", - " beta=BETA,\n", - " epsilon=EPSILON,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# RL cluster\n", - "\n", - "rl_cluster = rl_cluster_lib.RLCluster(\n", - " actor=llama3_1_8b_policy,\n", - " reference=llama3_1_8b,\n", - " tokenizer=model_tokenizer,\n", - " cluster_config=cluster_config,\n", - ")\n", - "\n", - "# GRPO Trainer\n", - "grpo_trainer = GrpoLearner(\n", - " rl_cluster=rl_cluster,\n", - " reward_fns=[\n", - " match_format_exactly,\n", - " match_format_approximately,\n", - " check_answer,\n", - " check_numbers,\n", - " ],\n", - " grpo_config=grpo_config,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "if DEBUG:\n", - " # verify if vllm sampler works\n", - " output = rl_cluster.rollout.generate(\n", - " [\"The capital of France is\"],\n", - " rollout_config=RolloutConfig(max_tokens_to_generate=64, temperature=0.1),\n", - " )\n", - "\n", - " print(f\"Output: {output}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Evaluate Before Training\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(\n", - " test_dataset,\n", - " rl_cluster,\n", - " **GENERATION_CONFIGS[\"greedy\"],\n", - ")\n", - "print(f\"Pre GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,\" f\" {format_accuracy=}%\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Start Training\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "jax.profiler.start_trace(PROFILE_DIR)\n", - "with mesh, nn_partitioning.axis_rules(config_policy.logical_axis_rules):\n", - " grpo_trainer.train(dataset)\n", - "jax.profiler.stop_trace()\n", - "\n", - "print(\"HBM usage after training:\")\n", - "show_hbm_usage()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Final Evaluation\n", - "\n", - "Let's evaluate our model after training!\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(\n", - " test_dataset,\n", - " rl_cluster,\n", - " **GENERATION_CONFIGS[\"greedy\"],\n", - ")\n", - "print(f\"Post GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,\" f\" {format_accuracy=}%\")" + "### 📚 **Learn More**\n", + "- See `src/MaxText/examples/grpo_runner.py` for CLI usage\n", + "- Check `src/MaxText/configs/grpo.yml` for configuration options\n", + "- Read `src/MaxText/examples/README.md` for more examples" ] } ], "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo_backup.ipynb b/src/MaxText/examples/grpo_llama3_1_8b_demo_backup.ipynb new file mode 100644 index 000000000..75394b950 --- /dev/null +++ b/src/MaxText/examples/grpo_llama3_1_8b_demo_backup.ipynb @@ -0,0 +1,1371 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# GRPO Llama3.1-8B-Instruct Demo: Group Relative Policy Optimization\n", + "\n", + "This tutorial demonstrates training the Llama3.1 8B-Instruct model on the GSM8K math reasoning benchmark using Group Relative Policy Optimization (GRPO). GRPO can enhance your model's problem-solving skills on mathematical word problems, coding problems, etc.\n", + "\n", + "## What is GRPO?\n", + "\n", + "GRPO is an RL algorithm designed to enhance the reasoning abilities of LLMs. It is a variant of Proximal Policy Optimization (PPO) that reduces memory usage by eliminating the need for a separate value function model. GRPO works by:\n", + "\n", + "1. Generating multiple responses for a given prompt\n", + "2. Evaluating these responses using a reward model\n", + "3. Calculating a relative advantage based on the group's performance to update the policy\n", + "\n", + "## Libraries Used\n", + "\n", + "- **Tunix**: Library for GRPO implementation\n", + "- **vLLM**: Library for efficient model inference and generation\n", + "- **MaxText**: For model creation and training infrastructure\n", + "\n", + "## Hardware Requirements\n", + "\n", + "This tutorial uses a single host TPUVM such as `v6e-8/v5p-8`.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Install Necessary Libraries\n", + "\n", + "First, let's install the required dependencies:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "### (Optional) Run this if you just have this file and nothing else\n", + "\n", + "# 1. Clone the MaxText repository (from AI‑Hypercomputer)\n", + "!git clone https://github.com/AI-Hypercomputer/maxtext.git\n", + "\n", + "# 2. Navigate into the cloned directory\n", + "%cd maxtext" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "### (Optional) Do not run this if you already installed the dependencies\n", + "\n", + "# 3. Ensure setup.sh is executable\n", + "!chmod +x setup.sh\n", + "\n", + "# 4. Execute the setup script\n", + "!./setup.sh\n", + "\n", + "# Install vllm requirements\n", + "!./src/MaxText/examples/install_tunix_vllm_requirement.sh\n", + "\n", + "# force numpy version\n", + "!pip install --force-reinstall numpy==2.1.2\n", + "# install nest_asyncio\n", + "!pip install nest_asyncio\n", + "\n", + "import nest_asyncio\n", + "\n", + "nest_asyncio.apply()\n", + "# To fix \"This event loop is already running\" error in Colab" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports\n", + "\n", + "Import all necessary libraries for GRPO training:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set up the variables for the script\n", + "import os\n", + "import sys\n", + "\n", + "# Set the MaxText home directory (where you cloned the maxtext repo)\n", + "MAXTEXT_REPO_ROOT = os.path.expanduser(\"~\") + \"/maxtext\"\n", + "print(f\"MaxText Home directory (from Python): {MAXTEXT_REPO_ROOT}\")\n", + "\n", + "DEBUG = False # set to True to run in debug mode, for more print statements\n", + "# set this to the path of the checkpoint you want to load, gs:// supported\n", + "MODEL_CHECKPOINT_PATH = \"/path/to/scanned/model/ckpt_load_dir/\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run GRPO training using the unified script\n", + "import subprocess\n", + "import sys\n", + "\n", + "# Build the command\n", + "cmd = [\n", + " \"python3\", \"src/MaxText/examples/grpo_runner.py\",\n", + " \"--model_name=llama3.1-8b\",\n", + " \"--tokenizer_path=meta-llama/Llama-3.1-8B-Instruct\", \n", + " f\"--load_parameters_path={MODEL_CHECKPOINT_PATH}\",\n", + " f\"--base_output_directory={OUTPUT_DIRECTORY}\",\n", + " f\"--hf_access_token={HF_TOKEN}\",\n", + " f\"--steps={STEPS}\",\n", + " \"--per_device_batch_size=1\",\n", + " \"--learning_rate=3e-6\",\n", + " \"--num_generations=2\",\n", + " \"--grpo_beta=0.08\",\n", + " \"--grpo_epsilon=0.2\"\n", + "]\n", + "\n", + "print(\"Running GRPO training with the following command:\")\n", + "print(\" \".join(cmd))\n", + "print(\"\\n\" + \"=\"*80)\n", + "print(\"Starting GRPO Training...\")\n", + "print(\"=\"*80)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Execute the GRPO training\n", + "result = subprocess.run(cmd, capture_output=False, text=True)\n", + "\n", + "if result.returncode == 0:\n", + " print(\"\\n\" + \"=\"*80)\n", + " print(\"✅ GRPO Training Completed Successfully!\")\n", + " print(\"=\"*80)\n", + " print(f\"📁 Checkpoints saved to: {OUTPUT_DIRECTORY}\")\n", + " print(f\"📊 Logs available in: {OUTPUT_DIRECTORY}/logs\")\n", + "else:\n", + " print(\"\\n\" + \"=\"*80)\n", + " print(\"❌ GRPO Training Failed!\")\n", + " print(\"=\"*80)\n", + " print(f\"Exit code: {result.returncode}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## Summary\n", + "\n", + "This simplified notebook demonstrates GRPO training using the unified `grpo_runner.py` script. The key benefits are:\n", + "\n", + "### ✅ **Simplified Approach**\n", + "- **Single command** - All GRPO logic is consolidated\n", + "- **Easy configuration** - Just set parameters and run\n", + "- **Consistent interface** - Same as CLI usage\n", + "\n", + "### ✅ **What Happened**\n", + "1. **Model Loading** - Llama3.1-8B with Tunix adapter\n", + "2. **Dataset Processing** - GSM8K math reasoning dataset\n", + "3. **GRPO Training** - Multiple reward functions for math problems\n", + "4. **Checkpointing** - Model weights saved for inference\n", + "\n", + "### ✅ **Next Steps**\n", + "- **Inference** - Use the trained model for math problem solving\n", + "- **Evaluation** - Test on GSM8K test set\n", + "- **Customization** - Modify parameters for different models/datasets\n", + "\n", + "### 📚 **Learn More**\n", + "- See `src/MaxText/examples/grpo_runner.py` for CLI usage\n", + "- Check `src/MaxText/configs/grpo.yml` for configuration options\n", + "- Read `src/MaxText/examples/README.md` for more examples\n", + "\n", + "# Use your Hugging Face token (recommended)\n", + "# Get your token from: https://huggingface.co/settings/tokens\n", + "\n", + "os.environ[\"HF_TOKEN\"] = \"hf_your_token_here\"\n", + "login(token=os.environ[\"HF_TOKEN\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Hyperparameters\n", + "\n", + "Let's define the configuration we are going to use. Note that this is by no means a \"perfect\" set of hyperparameters. To get good results, you might have to train the model for longer.\n", + "\n", + "### Data Configuration\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ====== Data ======\n", + "TRAIN_DATA_DIR = f\"{MAXTEXT_REPO_ROOT}/data/train\"\n", + "TEST_DATA_DIR = f\"{MAXTEXT_REPO_ROOT}/data/test\"\n", + "if not os.path.exists(TRAIN_DATA_DIR):\n", + " os.makedirs(TRAIN_DATA_DIR)\n", + "if not os.path.exists(TEST_DATA_DIR):\n", + " os.makedirs(TEST_DATA_DIR)\n", + "TRAIN_FRACTION = 1.0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Checkpoint and Logging Configuration\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ====== Checkpoint directory =====\n", + "LOG_DIR = f\"{MAXTEXT_REPO_ROOT}/content/tensorboard/grpo/logs_llama3/\"\n", + "if not os.path.exists(LOG_DIR):\n", + " os.makedirs(LOG_DIR)\n", + "\n", + "# ===== Profiling =====\n", + "PROFILE_DIR = f\"{MAXTEXT_REPO_ROOT}/content/jax_traces/grpo/profiles_llama3/\"\n", + "if not os.path.exists(PROFILE_DIR):\n", + " os.makedirs(PROFILE_DIR)\n", + "\n", + "# ====== Checkpoint saving ======\n", + "CKPT_DIR = f\"{MAXTEXT_REPO_ROOT}/content/ckpts_llama3/\"\n", + "\n", + "if not os.path.exists(CKPT_DIR):\n", + " os.makedirs(CKPT_DIR)\n", + "\n", + "SAVE_INTERVAL_STEPS = 500\n", + "MAX_TO_KEEP = 4\n", + "\n", + "# ====== Reproducibility ======\n", + "SEED = 42" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### GRPO Configuration\n", + "\n", + "GRPO-specific hyperparameters for generation and training:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ====== GRPO ======\n", + "# === Generation during GRPO training ===\n", + "MAX_PROMPT_LENGTH = 256\n", + "TOTAL_GENERATION_STEPS = 768\n", + "# Important to keep a high-ish temperature for varied, diverse responses during\n", + "# training.\n", + "TEMPERATURE = 0.9\n", + "TOP_P = 1.0\n", + "TOP_K = 50\n", + "# The number of times the policy generates multiple responses for a given prompt\n", + "# within a single training step. This corresponds to `G` in Algorithm 1 in the\n", + "# paper. The \"group\" in GRPO comes from here.\n", + "NUM_GENERATIONS = 2\n", + "\n", + "# === other GRPO configs ===\n", + "# The number of iterations per batch (𝜇 in GRPO algo 1).\n", + "NUM_ITERATIONS = 1\n", + "# The coefficient for the KL divergence penalty (𝛽) in the GRPO loss function.\n", + "# Important to keep a high enough value for this, otherwise, the KL divergence\n", + "# can increase unchecked.\n", + "BETA = 0.08\n", + "# Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, for\n", + "# stable updates.\n", + "EPSILON = 0.2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training Configuration\n", + "\n", + "Training hyperparameters including batch size, learning rate, and optimization settings:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ====== Training ======\n", + "BATCH_SIZE = 1\n", + "# Increase `NUM_BATCHES` and `MAX_STEPS` for better results.\n", + "NUM_BATCHES = 4 # 200\n", + "# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be\n", + "# increased to a max. of 330 (if batch size is 4).\n", + "NUM_TEST_BATCHES = 5 # 200\n", + "\n", + "SEQUENCE_LENGTH = 1024\n", + "\n", + "EVAL_EVERY_N_STEPS = 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`.\n", + "NUM_EPOCHS = 1 # can potentially train for more epochs\n", + "\n", + "# Number of training steps.\n", + "MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)\n", + "\n", + "# === AdamW, warmup, cosine scheduler ===\n", + "LEARNING_RATE = 3e-6\n", + "B1 = 0.9\n", + "B2 = 0.99\n", + "WEIGHT_DECAY = 0.1\n", + "# == Cosine decay with warmup scheduler ==\n", + "# Linearly increase learning rate from 0. to 5e-6 in the first 10% training\n", + "# steps, and then gradually decrease the learning rate to 0 using cosine\n", + "# scheduler.\n", + "WARMUP_STEPS = int(0.1 * MAX_STEPS)\n", + "# == Grad clipping ==\n", + "# Grad clipping to prevent large gradients. Found this\n", + "# important to keep KL divergence in check.\n", + "MAX_GRAD_NORM = 0.1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Inference and Reward Configuration\n", + "\n", + "Configuration for model inference and reward function parameters:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ====== Inference ======\n", + "GENERATION_CONFIGS = {\n", + " # greedy search\n", + " \"greedy\": {\"temperature\": 0.01, \"top_k\": 1, \"top_p\": 1.0},\n", + " # some randomness\n", + " \"standard\": {\"temperature\": 0.7, \"top_k\": 50, \"top_p\": 0.95},\n", + " # liberal\n", + " \"liberal\": {\"temperature\": 0.85, \"top_k\": 2000, \"top_p\": 1.0},\n", + "}\n", + "\n", + "# ====== Reward ======\n", + "REWARD_EXACT_FORMAT_MATCH = 3.0\n", + "REWARD_WHITE_SPACE_FORMAT_MATCH = 1.5\n", + "REWARD_PARTIAL_FORMAT_MATCH = 0.5\n", + "REWARD_RATIO_GUESS_TO_ANSWER_HIGH = 0.5\n", + "REWARD_RATIO_GUESS_TO_ANSWER_LOW = 0.25\n", + "PENALTY_INCORRECT_FORMAT = -0.5\n", + "PENALTY_INCORRECT_ANSWER = -1.0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Utility Functions\n", + "\n", + "Helper functions for monitoring memory usage and other utilities:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def show_hbm_usage():\n", + " \"\"\"Displays memory usage per device.\"\"\"\n", + " fmt_size = functools.partial(humanize.naturalsize, binary=True)\n", + "\n", + " for d in jax.local_devices():\n", + " stats = d.memory_stats()\n", + " used = stats[\"bytes_in_use\"]\n", + " limit = stats[\"bytes_limit\"]\n", + " print(f\"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data Preprocessing\n", + "\n", + "First, let's define some special tokens. We instruct the model to first reason between the `` and `` tokens. After reasoning, we expect it to provide the answer between the `` and `` tokens.\n", + "\n", + "We use OpenAI's GSM8K dataset. GSM8K comprises grade school math word problems.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Llama-3.1-8B-Instruct\")\n", + "\n", + "\n", + "reasoning_start = \"\"\n", + "reasoning_end = \"\"\n", + "solution_start = \"\"\n", + "solution_end = \"\"\n", + "\n", + "\n", + "SYSTEM_PROMPT = f\"\"\"You are given a problem. Think about the problem and \\\n", + "provide your reasoning. Place it between {reasoning_start} and \\\n", + "{reasoning_end}. Then, provide the final answer (i.e., just one numerical \\\n", + "value) between {solution_start} and {solution_end}.\"\"\"\n", + "\n", + "TEMPLATE = \"\"\"user\n", + "{system_prompt}\n", + "\n", + "{question}\n", + "model\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def extract_hash_answer(text: str) -> str | None:\n", + " if DEBUG:\n", + " print(f\"Extracting answer from: {text}\")\n", + " if \"####\" not in text:\n", + " return None\n", + " return text.split(\"####\")[1].strip()\n", + "\n", + "\n", + "def get_dataset(data_dir, split=\"train\") -> grain.MapDataset:\n", + " # Download data\n", + " if not os.path.exists(data_dir):\n", + " os.makedirs(data_dir)\n", + "\n", + " data = tfds.data_source(\n", + " \"gsm8k\",\n", + " split=split,\n", + " data_dir=data_dir,\n", + " builder_kwargs={\"file_format\": tfds.core.FileFormat.ARRAY_RECORD},\n", + " download=True,\n", + " )\n", + "\n", + " loaded_dataset = (\n", + " grain.MapDataset.source(data)\n", + " .shuffle(seed=SEED)\n", + " .map(\n", + " lambda x: {\n", + " # passed to model forward pass\n", + " \"prompts\": model_tokenizer.apply_chat_template(\n", + " [\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": TEMPLATE.format(\n", + " system_prompt=SYSTEM_PROMPT,\n", + " question=x[\"question\"].decode(\"utf-8\"),\n", + " ),\n", + " },\n", + " ],\n", + " tokenize=False,\n", + " add_generation_prompt=True,\n", + " ),\n", + " # passed to reward functions\n", + " \"question\": x[\"question\"].decode(\"utf-8\"),\n", + " # passed to reward functions\n", + " \"answer\": extract_hash_answer(x[\"answer\"].decode(\"utf-8\")),\n", + " }\n", + " )\n", + " )\n", + " return loaded_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = get_dataset(TRAIN_DATA_DIR, \"train\").batch(BATCH_SIZE)[:NUM_BATCHES]\n", + "\n", + "if TRAIN_FRACTION == 1.0:\n", + " train_dataset = dataset.repeat(NUM_EPOCHS)\n", + " val_dataset = None\n", + "else:\n", + " train_dataset = dataset[: int(len(dataset) * TRAIN_FRACTION)]\n", + " train_dataset = train_dataset.repeat(NUM_EPOCHS)\n", + "\n", + " val_dataset = dataset[int(len(dataset) * TRAIN_FRACTION) :].repeat(NUM_EPOCHS)\n", + "\n", + "test_dataset = get_dataset(TEST_DATA_DIR, \"test\").batch(BATCH_SIZE)[:NUM_TEST_BATCHES]\n", + "\n", + "\n", + "# Let's see how one batch of the dataset looks like!\n", + "\n", + "\n", + "if DEBUG:\n", + " for ele in train_dataset[:1]:\n", + " pprint(ele)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load the Policy Model and the Reference Model\n", + "\n", + "The policy model is the model which is actually trained and whose weights are updated. The reference model is the model with which we compute KL divergence. This is to ensure that the policy updates are not huge and that it does not deviate too much from the reference model.\n", + "\n", + "Typically, the reference model is the base model, and the policy model is the same base model, but with potentially LoRA parameters where only the LoRA parameters are updated. This script is not using LoRA, so both the reference and policy models are the same.\n", + "\n", + "Note: We perform full precision (fp32) training. You can, however, leverage Qwix for QAT.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"HBM usage before loading model:\")\n", + "show_hbm_usage()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load MaxText Model\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_ref_maxtext_model(config):\n", + "\n", + " model, mesh = model_creation_utils.create_nnx_model(config)\n", + " with mesh:\n", + " tunix_model = TunixMaxTextAdapter(\n", + " base_model=model,\n", + " )\n", + "\n", + " model_config = llama3_lib.ModelConfig.llama3_1_8b()\n", + " tunix_model.config = model_config\n", + "\n", + " return tunix_model, mesh\n", + "\n", + "\n", + "model_config = llama3_lib.ModelConfig.llama3_1_8b()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load Reference Model\n", + "\n", + "Note: pass the path to your scanned checkpoint for \"load_parameters_path\". To create a scanned checkpoint, you can use `/maxtext/src/MaxText/utils/ckpt_conversion/to_maxtext.py`\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load the reference model\n", + "config_ref = pyconfig.initialize(\n", + " [\n", + " \"\",\n", + " f\"{MAXTEXT_REPO_ROOT}/src/MaxText/configs/base.yml\",\n", + " ],\n", + " base_output_directory=\"dummy\", # This is not used in Tunix.\n", + " run_name=\"test-tunix-maxtext-llama3.1-8b\",\n", + " tokenizer_type=\"tiktoken\",\n", + " tokenizer_path=\"assets/tokenizer_llama3.tiktoken\",\n", + " load_parameters_path=f\"{MODEL_CHECKPOINT_PATH}\",\n", + " max_target_length=SEQUENCE_LENGTH,\n", + " async_checkpointing=\"false\",\n", + " model_name=\"llama3.1-8b\",\n", + " skip_jax_distributed_system=\"true\",\n", + " weight_dtype=\"bfloat16\",\n", + " attention=\"dot_product\",\n", + " remat_policy=\"custom\",\n", + " decoder_layer_input=\"offload\",\n", + " query_proj=\"offload\",\n", + " key_proj=\"offload\",\n", + " value_proj=\"offload\",\n", + ")\n", + "\n", + "llama3_1_8b, mesh = get_ref_maxtext_model(config_ref)\n", + "\n", + "llama3_1_8b.config = model_config\n", + "\n", + "nnx.display(llama3_1_8b)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if DEBUG:\n", + " print(\"Model initialized successfully\")\n", + " print(f\"Model mesh shape: {mesh.shape}\")\n", + " print(f\"Model config: {model_config}\")\n", + "\n", + " # Sanity check that weights are loaded correctly\n", + " _maxtext_state_flatten = nnx.state(llama3_1_8b).flat_state()\n", + " maxtext_state_flatten = {\".\".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten}\n", + " print(\n", + " f\"maxtext_state_flatten[base.token_embedder.embedding].value={maxtext_state_flatten['base.token_embedder.embedding'].value}\"\n", + " )\n", + "\n", + "\n", + "# See the memory use after loading the reference model:\n", + "print(\"HBM usage after loading ref model:\")\n", + "show_hbm_usage()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load Policy Model\n", + "\n", + "Note: pass the path to your scanned checkpoint for \"load_parameters_path\". To create a scanned checkpoint, you can use `/maxtext/src/MaxText/utils/ckpt_conversion/to_maxtext.py`\n", + "\n", + "TODO: @mazumdera: change this to use lora\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config_policy = pyconfig.initialize(\n", + " [\n", + " \"\",\n", + " f\"{MAXTEXT_REPO_ROOT}/src/MaxText/configs/base.yml\",\n", + " ],\n", + " base_output_directory=\"dummy\", # This is not used in Tunix.\n", + " run_name=\"test-tunix-maxtext-llama3.1-8b\", # This is not used in Tunix.\n", + " tokenizer_type=\"tiktoken\",\n", + " tokenizer_path=\"assets/tokenizer_llama3.tiktoken\",\n", + " load_parameters_path=f\"{MODEL_CHECKPOINT_PATH}\",\n", + " max_target_length=SEQUENCE_LENGTH,\n", + " async_checkpointing=\"false\",\n", + " model_name=\"llama3.1-8b\",\n", + " skip_jax_distributed_system=\"true\",\n", + " weight_dtype=\"bfloat16\",\n", + " attention=\"dot_product\",\n", + " remat_policy=\"custom\",\n", + " decoder_layer_input=\"offload\",\n", + " query_proj=\"offload\",\n", + " key_proj=\"offload\",\n", + " value_proj=\"offload\",\n", + ")\n", + "llama3_1_8b_policy, mesh_policy = get_ref_maxtext_model(config_policy)\n", + "\n", + "llama3_1_8b_policy.config = model_config\n", + "\n", + "nnx.display(llama3_1_8b_policy)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if DEBUG:\n", + " print(\"Model initialized successfully\")\n", + " print(f\"Model mesh shape: {mesh_policy.shape}\")\n", + " print(f\"Model config: {model_config}\")\n", + "\n", + " # Sanity check that weights are loaded correctly\n", + " _maxtext_state_flatten = nnx.state(llama3_1_8b_policy).flat_state()\n", + " maxtext_state_flatten = {\".\".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten}\n", + " print(\n", + " f\"maxtext_state_flatten[base.token_embedder.embedding].value={maxtext_state_flatten['base.token_embedder.embedding'].value}\"\n", + " )\n", + "\n", + "# See memory usage after loading the policy model:\n", + "print(\"HBM usage after loading policy model:\")\n", + "show_hbm_usage()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define Reward Functions\n", + "\n", + "We define four reward functions:\n", + "\n", + "1. **Format Matching**: Reward if the format of the output exactly matches the instruction given in `TEMPLATE`\n", + "2. **Approximate Format Matching**: Reward if the format of the output approximately matches the instruction given in `TEMPLATE`\n", + "3. **Answer Correctness**: Reward if the answer is correct/partially correct\n", + "4. **Number Extraction**: Sometimes, the text between ``, `` might not be one number. So, extract the number, and reward the model if the answer is correct.\n", + "\n", + "The reward functions are inspired from [here](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb).\n", + "\n", + "First off, let's define a RegEx for checking whether the format matches.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "match_format = re.compile(\n", + " rf\"^[\\s]{{0,}}\" rf\"{reasoning_start}.+?{reasoning_end}.*?\" rf\"{solution_start}(.+?){solution_end}\" rf\"[\\s]{{0,}}$\",\n", + " flags=re.MULTILINE | re.DOTALL,\n", + ")\n", + "\n", + "match_format.search(\n", + " f\"{reasoning_start}Let me\" f\" think!{reasoning_end}{solution_start}2{solution_end}\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Reward Function 1: Exact Format Matching\n", + "\n", + "Give the model a reward of 3 points if the format matches exactly.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def match_format_exactly(prompts, completions, **kargs):\n", + " scores = []\n", + " for completion in completions:\n", + " score = 0\n", + " response = completion\n", + " # Match if format is seen exactly!\n", + " if match_format.search(response) is not None:\n", + " score += REWARD_EXACT_FORMAT_MATCH\n", + " scores.append(score)\n", + " return scores" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Reward Function 2: Approximate Format Matching\n", + "\n", + "We also reward the model if the format of the output matches partially.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def match_format_approximately(prompts, completions, **kargs):\n", + " scores = []\n", + "\n", + " for completion in completions:\n", + " score = 0\n", + " response = completion\n", + " # Count how many keywords are seen - we penalize if too many!\n", + " # If we see 1, then plus some points!\n", + " score += REWARD_PARTIAL_FORMAT_MATCH if response.count(reasoning_start) == 1 else PENALTY_INCORRECT_FORMAT\n", + " score += REWARD_PARTIAL_FORMAT_MATCH if response.count(reasoning_end) == 1 else PENALTY_INCORRECT_FORMAT\n", + " score += REWARD_PARTIAL_FORMAT_MATCH if response.count(solution_start) == 1 else PENALTY_INCORRECT_FORMAT\n", + " score += REWARD_PARTIAL_FORMAT_MATCH if response.count(solution_end) == 1 else PENALTY_INCORRECT_FORMAT\n", + " scores.append(score)\n", + " return scores" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Reward Function 3: Answer Correctness\n", + "\n", + "Reward the model if the answer is correct. A reward is also given if the answer does not match exactly, i.e., based on how close the answer is to the correct value.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def check_answer(prompts, completions, answer, **kargs):\n", + " responses = completions\n", + "\n", + " extracted_responses = [guess.group(1) if (guess := match_format.search(r)) is not None else None for r in responses]\n", + "\n", + " scores = []\n", + " for guess, true_answer in zip(extracted_responses, answer):\n", + " score = 0\n", + " if guess is None:\n", + " scores.append(0)\n", + " continue\n", + " # Correct answer gets 3 points!\n", + " if guess == true_answer:\n", + " score += REWARD_EXACT_FORMAT_MATCH\n", + " # Match if spaces are seen\n", + " elif guess.strip() == true_answer.strip():\n", + " score += REWARD_WHITE_SPACE_FORMAT_MATCH\n", + " else:\n", + " # We also reward it if the answer is close via ratios!\n", + " # Ie if the answer is within some range, reward it!\n", + " try:\n", + " ratio = float(guess) / float(true_answer)\n", + " if ratio >= 0.9 and ratio <= 1.1:\n", + " score += REWARD_RATIO_GUESS_TO_ANSWER_HIGH\n", + " elif ratio >= 0.8 and ratio <= 1.2:\n", + " score += REWARD_RATIO_GUESS_TO_ANSWER_LOW\n", + " else:\n", + " score += PENALTY_INCORRECT_ANSWER # Penalize wrong answers\n", + " except:\n", + " score += PENALTY_INCORRECT_FORMAT # Penalize\n", + " scores.append(score)\n", + " return scores" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Reward Function 4: Number Extraction\n", + "\n", + "Sometimes, the text between `` and `` might not be one number; it can be a sentence. So, we extract the number and compare the answer.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "match_numbers = re.compile(rf\"{solution_start}.*?([\\d\\.]{{1,}})\", flags=re.MULTILINE | re.DOTALL)\n", + "match_numbers.findall(f\"{solution_start} 0.34 {solution_end}\")\n", + "\n", + "\n", + "def check_numbers(prompts, completions, answer, **kargs):\n", + " question = kargs[\"question\"]\n", + " responses = completions\n", + "\n", + " extracted_responses = [guess.group(1) if (guess := match_numbers.search(r)) is not None else None for r in responses]\n", + "\n", + " scores = []\n", + " if DEBUG:\n", + " print(\"START ============================\")\n", + " print(f\"Question: {question[0]}\")\n", + " print(f\"Answer: {answer[0]}\")\n", + " print(f\"Response: {responses[0]}\")\n", + " print(f\"Extracted: {extracted_responses[0]}\")\n", + " print(\"END ==============================\")\n", + " for guess, true_answer in zip(extracted_responses, answer):\n", + " if guess is None:\n", + " scores.append(0)\n", + " continue\n", + " # Convert to numbers\n", + " try:\n", + " true_answer = float(true_answer.strip())\n", + " guess = float(guess.strip())\n", + " scores.append(1.5 if guess == true_answer else 0.0)\n", + " except:\n", + " scores.append(0)\n", + " continue\n", + " return scores" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluation Functions\n", + "\n", + "Before we train the model, let's evaluate the model on the test set so we can see the improvement post training.\n", + "\n", + "We evaluate it in two ways:\n", + "\n", + "**Quantitative**\n", + "- **Answer Accuracy**: percentage of samples for which the model predicts the correct final numerical answer\n", + "- **Answer (Partial) Accuracy**: percentage of samples for which the model predicts a final numerical answer such that the `model answer / answer` ratio lies between 0.9 and 1.1.\n", + "- **Format Accuracy**: percentage of samples for which the model outputs the correct format, i.e., reasoning between the reasoning special tokens, and the final answer between the ``, `` tokens.\n", + "\n", + "**Qualitative**\n", + "\n", + "We'll also print outputs for a few given questions so that we can compare the generated output later.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def generate_responses(\n", + " prompts,\n", + " rl_cluster,\n", + " num_passes=1,\n", + " temperature=0.7,\n", + " top_k=50,\n", + " top_p=0.95,\n", + "):\n", + " \"\"\"\n", + " Generate responses for a batch of prompts across multiple passes.\n", + "\n", + " Args:\n", + " prompts: List of prompts to generate responses for\n", + " rl_cluster: Model cluster for generation\n", + " num_passes: Number of generation passes\n", + " temperature: Sampling temperature\n", + " top_k: Top-k sampling parameter\n", + " top_p: Top-p sampling parameter\n", + "\n", + " Returns:\n", + " List of lists containing responses for each prompt across passes\n", + " \"\"\"\n", + " multiple_call_responses = [[] for _ in range(len(prompts))]\n", + "\n", + " for p in range(num_passes):\n", + " responses = rl_cluster.rollout.generate(\n", + " prompts,\n", + " rollout_config=RolloutConfig(\n", + " max_tokens_to_generate=TOTAL_GENERATION_STEPS,\n", + " temperature=temperature,\n", + " top_k=top_k,\n", + " top_p=top_p,\n", + " ),\n", + " )\n", + " responses = responses.text\n", + "\n", + " if DEBUG:\n", + " print(f\"Pass {p+1}/{num_passes}, responses: {responses}\")\n", + "\n", + " for idx, response in enumerate(responses):\n", + " multiple_call_responses[idx].append(response)\n", + "\n", + " return multiple_call_responses" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def score_responses(question, responses, answer):\n", + " \"\"\"\n", + " Score a set of responses for a single question.\n", + "\n", + " Args:\n", + " question: The evaluation question\n", + " responses: List of generated responses for this question\n", + " answer: The correct answer\n", + "\n", + " Returns:\n", + " Tuple of (is_correct, is_partially_correct, has_correct_format)\n", + " \"\"\"\n", + " if DEBUG:\n", + " print(\"========================================\")\n", + " print(f\"Evaluation Question: {question}\")\n", + " print(f\"Evaluation Answer: {answer}\")\n", + " print(f\"Evaluation Responses: {responses}\")\n", + " print(\"========================================\")\n", + "\n", + " is_correct = False\n", + " is_partially_correct = False\n", + " has_correct_format = False\n", + "\n", + " for response in responses:\n", + " # Extract numerical response\n", + " extracted_response = guess.group(1) if (guess := match_numbers.search(response)) is not None else \"-1000000\"\n", + "\n", + " if DEBUG:\n", + " print(f\"Evaluation extracted_response: {extracted_response}\")\n", + "\n", + " # Check exact correctness\n", + " try:\n", + " if float(extracted_response.strip()) == float(answer.strip()):\n", + " is_correct = True\n", + "\n", + " # Check partial correctness (within 10%)\n", + " ratio = float(extracted_response.strip()) / float(answer.strip())\n", + " if 0.9 <= ratio <= 1.1:\n", + " is_partially_correct = True\n", + " except Exception as e:\n", + " if DEBUG:\n", + " print(f\"Evaluation Exception: {e}\")\n", + " print(\"SKIPPED\")\n", + "\n", + " # Check format correctness\n", + " if match_format.search(response) is not None:\n", + " has_correct_format = True\n", + "\n", + " # Early exit if all criteria are met\n", + " if is_correct and is_partially_correct and has_correct_format:\n", + " break\n", + "\n", + " return is_correct, is_partially_correct, has_correct_format" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate(\n", + " dataset,\n", + " rl_cluster,\n", + " temperature=0.7,\n", + " top_k=50,\n", + " top_p=0.95,\n", + " num_passes=1,\n", + " corr_lst=False,\n", + " make_lst=False,\n", + "):\n", + " \"\"\"\n", + " Computes accuracy and percentage of outputs matching the format.\n", + "\n", + " Args:\n", + " dataset: The evaluation dataset\n", + " rl_cluster: Model cluster for generation\n", + " temperature: Sampling temperature\n", + " top_k: Top-k sampling parameter\n", + " top_p: Top-p sampling parameter\n", + " num_passes: Number of generation passes\n", + " corr_lst: If True, only include correct responses in the list\n", + " make_lst: If True, return a list of (question, answer, responses)\n", + "\n", + " Returns:\n", + " Tuple of statistics and optionally the response list\n", + " \"\"\"\n", + " response_lst = []\n", + " corr = 0\n", + " partially_corr = 0\n", + " corr_format = 0\n", + " total = 0\n", + "\n", + " for batch in tqdm(dataset):\n", + " answers = batch[\"answer\"]\n", + " questions = batch[\"question\"]\n", + " prompts = batch[\"prompts\"]\n", + "\n", + " # Generate responses for all prompts in the batch\n", + " multiple_call_responses = generate_responses(\n", + " prompts=prompts,\n", + " rl_cluster=rl_cluster,\n", + " num_passes=num_passes,\n", + " temperature=temperature,\n", + " top_k=top_k,\n", + " top_p=top_p,\n", + " )\n", + "\n", + " # Score each question-answer pair\n", + " for question, responses, answer in zip(questions, multiple_call_responses, answers):\n", + " is_correct, is_partially_correct, has_correct_format = score_responses(\n", + " question=question,\n", + " responses=responses,\n", + " answer=answer,\n", + " )\n", + "\n", + " # Update counters\n", + " if is_correct:\n", + " corr += 1\n", + " if corr_lst and make_lst:\n", + " response_lst.append((question, answer, responses))\n", + " else:\n", + " if not corr_lst and make_lst:\n", + " response_lst.append((question, answer, responses))\n", + "\n", + " if is_partially_correct:\n", + " partially_corr += 1\n", + "\n", + " if has_correct_format:\n", + " corr_format += 1\n", + "\n", + " total += 1\n", + "\n", + " # Print progress every 10 items\n", + " if total % 10 == 0:\n", + " print(\n", + " f\"===> {corr=}, {total=}, {corr / total * 100=}, \"\n", + " f\"{partially_corr / total * 100=}, {corr_format / total * 100=}\"\n", + " )\n", + "\n", + " # Prepare return values\n", + " to_return = (\n", + " corr,\n", + " total,\n", + " corr / total * 100,\n", + " partially_corr / total * 100,\n", + " corr_format / total * 100,\n", + " )\n", + "\n", + " if make_lst:\n", + " return to_return, response_lst\n", + " return to_return" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training Setup\n", + "\n", + "Let's set up all the configs first - checkpointing, metric logging and training. We then train the model.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Ckpt saving\n", + "checkpointing_options = ocp.CheckpointManagerOptions(save_interval_steps=SAVE_INTERVAL_STEPS, max_to_keep=MAX_TO_KEEP)\n", + "\n", + "# Metrics logger\n", + "metrics_logging_options = metrics_logger.MetricsLoggerOptions(log_dir=LOG_DIR, flush_every_n_steps=20)\n", + "\n", + "\n", + "# Logs\n", + "print(f\"TensorBoard logs directory: {LOG_DIR}\")\n", + "print(f\"tensorboard --logdir {LOG_DIR} --port=8086\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Optimizer, learning rate scheduler, gradient clipping\n", + "optimizer = optax.adamw(\n", + " learning_rate=optax.schedules.warmup_cosine_decay_schedule(\n", + " init_value=0.0,\n", + " peak_value=LEARNING_RATE,\n", + " warmup_steps=WARMUP_STEPS,\n", + " decay_steps=MAX_STEPS,\n", + " end_value=0.0,\n", + " ),\n", + " b1=B1,\n", + " b2=B2,\n", + " weight_decay=WEIGHT_DECAY,\n", + ")\n", + "\n", + "if MAX_GRAD_NORM is not None:\n", + " optimizer = optax.chain(\n", + " optax.clip_by_global_norm(max_norm=MAX_GRAD_NORM),\n", + " optimizer,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# RL Cluster config\n", + "# Note that we use vLLM as the rollout engine.\n", + "# and we are using Tensor Parallelism for rollout\n", + "\n", + "cluster_config = rl_cluster_lib.ClusterConfig(\n", + " role_to_mesh={\n", + " rl_cluster_lib.Role.ACTOR: mesh,\n", + " rl_cluster_lib.Role.REFERENCE: mesh,\n", + " rl_cluster_lib.Role.ROLLOUT: mesh,\n", + " },\n", + " rollout_engine=\"vllm\",\n", + " offload_to_cpu=False,\n", + " training_config=rl_cluster_lib.RLTrainingConfig(\n", + " actor_optimizer=optimizer,\n", + " eval_every_n_steps=EVAL_EVERY_N_STEPS,\n", + " max_steps=MAX_STEPS,\n", + " gradient_accumulation_steps=1,\n", + " # metrics logging\n", + " metrics_logging_options=metrics_logging_options,\n", + " # checkpoint saving\n", + " checkpoint_root_directory=CKPT_DIR,\n", + " checkpointing_options=checkpointing_options,\n", + " ),\n", + " rollout_config=base_rollout.RolloutConfig(\n", + " max_tokens_to_generate=TOTAL_GENERATION_STEPS,\n", + " max_prompt_length=MAX_PROMPT_LENGTH,\n", + " kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,\n", + " temperature=TEMPERATURE,\n", + " top_p=TOP_P,\n", + " top_k=TOP_K,\n", + " ),\n", + " rollout_vllm_model_version=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", + " rollout_vllm_hbm_utilization=0.2,\n", + " rollout_vllm_tpu_backend_type=\"jax\",\n", + ")\n", + "\n", + "grpo_config = GrpoConfig(\n", + " num_generations=NUM_GENERATIONS,\n", + " num_iterations=NUM_ITERATIONS,\n", + " beta=BETA,\n", + " epsilon=EPSILON,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# RL cluster\n", + "\n", + "rl_cluster = rl_cluster_lib.RLCluster(\n", + " actor=llama3_1_8b_policy,\n", + " reference=llama3_1_8b,\n", + " tokenizer=model_tokenizer,\n", + " cluster_config=cluster_config,\n", + ")\n", + "\n", + "# GRPO Trainer\n", + "grpo_trainer = GrpoLearner(\n", + " rl_cluster=rl_cluster,\n", + " reward_fns=[\n", + " match_format_exactly,\n", + " match_format_approximately,\n", + " check_answer,\n", + " check_numbers,\n", + " ],\n", + " grpo_config=grpo_config,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if DEBUG:\n", + " # verify if vllm sampler works\n", + " output = rl_cluster.rollout.generate(\n", + " [\"The capital of France is\"],\n", + " rollout_config=RolloutConfig(max_tokens_to_generate=64, temperature=0.1),\n", + " )\n", + "\n", + " print(f\"Output: {output}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluate Before Training\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(\n", + " test_dataset,\n", + " rl_cluster,\n", + " **GENERATION_CONFIGS[\"greedy\"],\n", + ")\n", + "print(f\"Pre GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,\" f\" {format_accuracy=}%\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Start Training\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "jax.profiler.start_trace(PROFILE_DIR)\n", + "with mesh, nn_partitioning.axis_rules(config_policy.logical_axis_rules):\n", + " grpo_trainer.train(dataset)\n", + "jax.profiler.stop_trace()\n", + "\n", + "print(\"HBM usage after training:\")\n", + "show_hbm_usage()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Final Evaluation\n", + "\n", + "Let's evaluate our model after training!\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "(corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate(\n", + " test_dataset,\n", + " rl_cluster,\n", + " **GENERATION_CONFIGS[\"greedy\"],\n", + ")\n", + "print(f\"Post GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%,\" f\" {format_accuracy=}%\")" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 28fb1929457b9f84ab1d8d7fe016530ec14642fd Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Sat, 25 Oct 2025 07:07:51 +0400 Subject: [PATCH 09/31] simplification of nb Signed-off-by: Vladimir Suvorov --- ..._8b_demo_backup.ipynb => grpo_llama3_1_8b_demo_detailed.ipynb} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/MaxText/examples/{grpo_llama3_1_8b_demo_backup.ipynb => grpo_llama3_1_8b_demo_detailed.ipynb} (100%) diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo_backup.ipynb b/src/MaxText/examples/grpo_llama3_1_8b_demo_detailed.ipynb similarity index 100% rename from src/MaxText/examples/grpo_llama3_1_8b_demo_backup.ipynb rename to src/MaxText/examples/grpo_llama3_1_8b_demo_detailed.ipynb From cebf7c95e92da601184d589b47da0a47cb122a18 Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Sat, 25 Oct 2025 07:11:28 +0400 Subject: [PATCH 10/31] fix Signed-off-by: Vladimir Suvorov --- .../examples/grpo_llama3_1_8b_demo.ipynb | 23 +--- .../experimental/rl/grpo_tunix_trainer.py | 115 ++++++++++-------- 2 files changed, 65 insertions(+), 73 deletions(-) diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb b/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb index 322659340..173c34c36 100644 --- a/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb +++ b/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb @@ -15,9 +15,8 @@ "2. Evaluating responses using reward models \n", "3. Calculating relative advantages to update the policy\n", "\n", - "## Direct Function Approach\n", "\n", - "This notebook imports and calls the `grpo_train` function directly, avoiding subprocess calls and providing better integration with the notebook environment.\n", + "This notebook imports and calls the `grpo_train` function \n", "\n", "## Hardware Requirements\n", "\n", @@ -191,26 +190,6 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Summary\n", - "\n", - "This notebook demonstrates GRPO training by directly calling the `grpo_train` function. The key benefits are:\n", - "\n", - "### ✅ **Direct Function Approach**\n", - "- **No subprocess calls** - Direct Python function execution\n", - "- **Better integration** - Seamless notebook environment\n", - "- **Easier debugging** - Direct access to variables and state\n", - "\n", - "### ✅ **What Happened**\n", - "1. **Model Loading** - Llama3.1-8B with Tunix adapter\n", - "2. **Dataset Processing** - GSM8K math reasoning dataset\n", - "3. **GRPO Training** - Multiple reward functions for math problems\n", - "4. **Checkpointing** - Model weights saved for inference\n", - "\n", - "### ✅ **Next Steps**\n", - "- **Inference** - Use the trained model for math problem solving\n", - "- **Evaluation** - Test on GSM8K test set\n", - "- **Customization** - Modify parameters for different models/datasets\n", - "\n", "### 📚 **Learn More**\n", "- See `src/MaxText/examples/grpo_runner.py` for CLI usage\n", "- Check `src/MaxText/configs/grpo.yml` for configuration options\n", diff --git a/src/MaxText/experimental/rl/grpo_tunix_trainer.py b/src/MaxText/experimental/rl/grpo_tunix_trainer.py index cbcf6c6c0..3cfd9d954 100644 --- a/src/MaxText/experimental/rl/grpo_tunix_trainer.py +++ b/src/MaxText/experimental/rl/grpo_tunix_trainer.py @@ -16,19 +16,18 @@ GRPO Tunix Trainer This module provides a unified `grpo_train` function that consolidates the common -GRPO training logic. It handles model loading, reward function setup, dataset +GRPO training logic. It handles model loading, reward function setup, dataset processing, and training orchestration. Usage: from MaxText.experimental.rl.grpo_tunix_trainer import grpo_train - + # Train with GRPO grpo_train(config, goodput_recorder=None) """ import os import re -from typing import Optional import jax from jax.sharding import Mesh @@ -38,6 +37,17 @@ import tensorflow_datasets as tfds from transformers import AutoTokenizer +# Conditional imports for optional dependencies +try: + import grain +except ImportError: + grain = None + +try: + import pathwaysutils +except ImportError: + pathwaysutils = None + from tunix.rl import rl_cluster as rl_cluster_lib from tunix.rl.rollout import base_rollout from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner @@ -83,11 +93,12 @@ def extract_hash_answer(text: str) -> str | None: return text.split("####")[1].strip() -def get_gsm8k_dataset(data_dir: str, split: str = "train", batch_size: int = 1, +def get_gsm8k_dataset(data_dir: str, split: str = "train", batch_size: int = 1, num_batches: int = 4, seed: int = 42): """Load and process GSM8K dataset for GRPO training.""" - import grain - + if grain is None: + raise ImportError("grain is required for dataset processing. Please install it.") + # Download data if not os.path.exists(data_dir): os.makedirs(data_dir) @@ -129,7 +140,7 @@ def get_gsm8k_dataset(data_dir: str, split: str = "train", batch_size: int = 1, } ) ) - + return loaded_dataset.batch(batch_size)[:num_batches], tokenizer @@ -146,19 +157,20 @@ def get_maxtext_model(config, devices=None): def setup_device_allocation(config, use_pathways: bool = False): """Setup device allocation for training and inference.""" devices = jax.devices() - + # Get device allocation parameters from config trainer_devices_fraction = getattr(config, 'trainer_devices_fraction', 0.5) sampler_devices_fraction = getattr(config, 'sampler_devices_fraction', 0.5) chips_per_vm = getattr(config, 'chips_per_vm', 4) - + num_vms = len(devices) // chips_per_vm - + if use_pathways and num_vms >= 2: # Multiple hosts with Pathways - split devices for trainer and sampler - import pathwaysutils + if pathwaysutils is None: + raise ImportError("pathwaysutils is required for Pathways support. Please install it.") pathwaysutils.initialize() - + num_devices = len(devices) num_trainer_devices = int(num_devices * trainer_devices_fraction) num_sampler_devices = int(num_devices * sampler_devices_fraction) @@ -167,12 +179,13 @@ def setup_device_allocation(config, use_pathways: bool = False): else: # Not using Pathways OR single host - use all devices for both if use_pathways: - import pathwaysutils + if pathwaysutils is None: + raise ImportError("pathwaysutils is required for Pathways support. Please install it.") pathwaysutils.initialize() - + trainer_devices = devices sampler_devices = devices - + return trainer_devices, sampler_devices, num_vms @@ -181,11 +194,11 @@ def match_format_exactly(prompts, completions, **kwargs): """Reward exact format matching.""" scores = [] match_format = re.compile( - rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?" + rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?" rf"{SOLUTION_START}(.+?){SOLUTION_END}" rf"[\s]{{0,}}$", flags=re.MULTILINE | re.DOTALL, ) - + for completion in completions: score = 0 if match_format.search(completion) is not None: @@ -210,20 +223,20 @@ def match_format_approximately(prompts, completions, **kwargs): def check_answer(prompts, completions, answer, **kwargs): """Reward correct answers.""" match_format = re.compile( - rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?" + rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?" rf"{SOLUTION_START}(.+?){SOLUTION_END}" rf"[\s]{{0,}}$", flags=re.MULTILINE | re.DOTALL, ) - + extracted_responses = [guess.group(1) if (guess := match_format.search(r)) is not None else None for r in completions] - + scores = [] for guess, true_answer in zip(extracted_responses, answer): score = 0 if guess is None: scores.append(0) continue - + if guess == true_answer: score += REWARD_EXACT_FORMAT_MATCH elif guess.strip() == true_answer.strip(): @@ -247,7 +260,7 @@ def check_numbers(prompts, completions, answer, **kwargs): """Reward correct numerical answers.""" match_numbers = re.compile(rf"{SOLUTION_START}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL) extracted_responses = [guess.group(1) if (guess := match_numbers.search(r)) is not None else None for r in completions] - + scores = [] for guess, true_answer in zip(extracted_responses, answer): if guess is None: @@ -266,53 +279,53 @@ def check_numbers(prompts, completions, answer, **kwargs): def grpo_train(config): """ Run GRPO training with the provided configuration. - + Args: config: MaxText configuration object """ print("=" * 80) print("Starting GRPO Training") print("=" * 80) - + # Setup device allocation use_pathways = getattr(config, 'use_pathways_reshard', False) trainer_devices, sampler_devices, num_vms = setup_device_allocation(config, use_pathways) - + print(f"Device allocation: {len(trainer_devices)} trainer, {len(sampler_devices)} sampler") print(f"Use Pathways: {use_pathways}") - + # Setup data directories home = os.path.expanduser("~") + "/" train_data_dir = f"{home}/data/train" test_data_dir = f"{home}/data/test" - + # Load datasets print("Loading GSM8K dataset...") train_dataset, tokenizer = get_gsm8k_dataset( - train_data_dir, - split="train", + train_data_dir, + split="train", batch_size=config.per_device_batch_size, num_batches=getattr(config, 'num_batches', 4) ) - + # Load test dataset for evaluation (currently not used in training loop) get_gsm8k_dataset( - test_data_dir, - split="test", + test_data_dir, + split="test", batch_size=config.per_device_batch_size, num_batches=getattr(config, 'num_test_batches', 5) ) - + # Load reference model print("Loading reference model...") reference_model, reference_mesh = get_maxtext_model(config, trainer_devices) reference_model.config = None - + # Load policy model print("Loading policy model...") policy_model, policy_mesh = get_maxtext_model(config, trainer_devices) policy_model.config = None - + # Setup meshes if num_vms >= 2 and not use_pathways: actor_mesh = policy_mesh @@ -320,12 +333,12 @@ def grpo_train(config): else: actor_mesh = policy_mesh rollout_mesh = policy_mesh - + # Setup optimizer learning_rate = getattr(config, 'learning_rate', 3e-6) max_steps = getattr(config, 'steps', 100) warmup_steps = int(0.1 * max_steps) - + optimizer = optax.adamw( learning_rate=optax.schedules.warmup_cosine_decay_schedule( init_value=0.0, @@ -338,7 +351,7 @@ def grpo_train(config): b2=0.99, weight_decay=0.1, ) - + # Add gradient clipping if specified max_grad_norm = getattr(config, 'max_grad_norm', 0.1) if max_grad_norm is not None: @@ -346,25 +359,25 @@ def grpo_train(config): optax.clip_by_global_norm(max_norm=max_grad_norm), optimizer, ) - + # Setup checkpointing ckpt_dir = f"{config.base_output_directory}/checkpoints" os.makedirs(ckpt_dir, exist_ok=True) - + checkpointing_options = ocp.CheckpointManagerOptions( save_interval_steps=getattr(config, 'checkpoint_period', 50), max_to_keep=4 ) - + # Setup metrics logging log_dir = f"{config.base_output_directory}/logs" os.makedirs(log_dir, exist_ok=True) - + metrics_logging_options = metrics_logger.MetricsLoggerOptions( - log_dir=log_dir, + log_dir=log_dir, flush_every_n_steps=20 ) - + # Setup RL cluster config cluster_config = rl_cluster_lib.ClusterConfig( role_to_mesh={ @@ -386,7 +399,7 @@ def grpo_train(config): rollout_config=base_rollout.RolloutConfig( max_tokens_to_generate=getattr(config, 'max_target_length', 768), max_prompt_length=getattr(config, 'max_prefill_predict_length', 256), - kv_cache_size=getattr(config, 'max_prefill_predict_length', 256) + + kv_cache_size=getattr(config, 'max_prefill_predict_length', 256) + getattr(config, 'max_target_length', 768) + 256, temperature=getattr(config, 'decode_sampling_temperature', 0.9), top_p=getattr(config, 'decode_sampling_top_p', 1.0), @@ -396,7 +409,7 @@ def grpo_train(config): rollout_vllm_hbm_utilization=0.2, rollout_vllm_tpu_backend_type="jax", ) - + # Setup GRPO config grpo_config = GrpoConfig( num_generations=getattr(config, 'num_generations', 2), @@ -404,7 +417,7 @@ def grpo_train(config): beta=getattr(config, 'grpo_beta', 0.08), epsilon=getattr(config, 'grpo_epsilon', 0.2), ) - + # Create RL cluster print("Creating RL cluster...") with nn_partitioning.axis_rules(config.logical_axis_rules): @@ -414,7 +427,7 @@ def grpo_train(config): tokenizer=tokenizer, cluster_config=cluster_config, ) - + # Create GRPO trainer print("Setting up GRPO trainer...") grpo_trainer = GrpoLearner( @@ -427,14 +440,14 @@ def grpo_train(config): ], grpo_config=grpo_config, ) - + # Start training print("Starting GRPO training...") with policy_mesh, nn_partitioning.axis_rules(config.logical_axis_rules): grpo_trainer.train(train_dataset) - + print("=" * 80) print("GRPO Training Completed Successfully!") print("=" * 80) - + return grpo_trainer, rl_cluster From 280be97f88f9bef0379c67d667362d605a2704b0 Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Sat, 25 Oct 2025 07:14:18 +0400 Subject: [PATCH 11/31] fix Signed-off-by: Vladimir Suvorov --- .../experimental/rl/grpo_tunix_trainer.py | 102 +++++++++--------- 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/src/MaxText/experimental/rl/grpo_tunix_trainer.py b/src/MaxText/experimental/rl/grpo_tunix_trainer.py index 3cfd9d954..3266a58e9 100644 --- a/src/MaxText/experimental/rl/grpo_tunix_trainer.py +++ b/src/MaxText/experimental/rl/grpo_tunix_trainer.py @@ -16,12 +16,12 @@ GRPO Tunix Trainer This module provides a unified `grpo_train` function that consolidates the common -GRPO training logic. It handles model loading, reward function setup, dataset +GRPO training logic. It handles model loading, reward function setup, dataset processing, and training orchestration. Usage: from MaxText.experimental.rl.grpo_tunix_trainer import grpo_train - + # Train with GRPO grpo_train(config, goodput_recorder=None) """ @@ -39,14 +39,14 @@ # Conditional imports for optional dependencies try: - import grain + import grain except ImportError: - grain = None + grain = None try: - import pathwaysutils + import pathwaysutils except ImportError: - pathwaysutils = None + pathwaysutils = None from tunix.rl import rl_cluster as rl_cluster_lib from tunix.rl.rollout import base_rollout @@ -93,12 +93,12 @@ def extract_hash_answer(text: str) -> str | None: return text.split("####")[1].strip() -def get_gsm8k_dataset(data_dir: str, split: str = "train", batch_size: int = 1, +def get_gsm8k_dataset(data_dir: str, split: str = "train", batch_size: int = 1, num_batches: int = 4, seed: int = 42): """Load and process GSM8K dataset for GRPO training.""" if grain is None: raise ImportError("grain is required for dataset processing. Please install it.") - + # Download data if not os.path.exists(data_dir): os.makedirs(data_dir) @@ -140,7 +140,7 @@ def get_gsm8k_dataset(data_dir: str, split: str = "train", batch_size: int = 1, } ) ) - + return loaded_dataset.batch(batch_size)[:num_batches], tokenizer @@ -157,20 +157,20 @@ def get_maxtext_model(config, devices=None): def setup_device_allocation(config, use_pathways: bool = False): """Setup device allocation for training and inference.""" devices = jax.devices() - + # Get device allocation parameters from config trainer_devices_fraction = getattr(config, 'trainer_devices_fraction', 0.5) sampler_devices_fraction = getattr(config, 'sampler_devices_fraction', 0.5) chips_per_vm = getattr(config, 'chips_per_vm', 4) - + num_vms = len(devices) // chips_per_vm - + if use_pathways and num_vms >= 2: # Multiple hosts with Pathways - split devices for trainer and sampler if pathwaysutils is None: raise ImportError("pathwaysutils is required for Pathways support. Please install it.") pathwaysutils.initialize() - + num_devices = len(devices) num_trainer_devices = int(num_devices * trainer_devices_fraction) num_sampler_devices = int(num_devices * sampler_devices_fraction) @@ -182,10 +182,10 @@ def setup_device_allocation(config, use_pathways: bool = False): if pathwaysutils is None: raise ImportError("pathwaysutils is required for Pathways support. Please install it.") pathwaysutils.initialize() - + trainer_devices = devices sampler_devices = devices - + return trainer_devices, sampler_devices, num_vms @@ -194,11 +194,11 @@ def match_format_exactly(prompts, completions, **kwargs): """Reward exact format matching.""" scores = [] match_format = re.compile( - rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?" + rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?" rf"{SOLUTION_START}(.+?){SOLUTION_END}" rf"[\s]{{0,}}$", flags=re.MULTILINE | re.DOTALL, ) - + for completion in completions: score = 0 if match_format.search(completion) is not None: @@ -223,20 +223,20 @@ def match_format_approximately(prompts, completions, **kwargs): def check_answer(prompts, completions, answer, **kwargs): """Reward correct answers.""" match_format = re.compile( - rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?" + rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?" rf"{SOLUTION_START}(.+?){SOLUTION_END}" rf"[\s]{{0,}}$", flags=re.MULTILINE | re.DOTALL, ) - + extracted_responses = [guess.group(1) if (guess := match_format.search(r)) is not None else None for r in completions] - + scores = [] for guess, true_answer in zip(extracted_responses, answer): score = 0 if guess is None: scores.append(0) continue - + if guess == true_answer: score += REWARD_EXACT_FORMAT_MATCH elif guess.strip() == true_answer.strip(): @@ -260,7 +260,7 @@ def check_numbers(prompts, completions, answer, **kwargs): """Reward correct numerical answers.""" match_numbers = re.compile(rf"{SOLUTION_START}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL) extracted_responses = [guess.group(1) if (guess := match_numbers.search(r)) is not None else None for r in completions] - + scores = [] for guess, true_answer in zip(extracted_responses, answer): if guess is None: @@ -279,53 +279,53 @@ def check_numbers(prompts, completions, answer, **kwargs): def grpo_train(config): """ Run GRPO training with the provided configuration. - + Args: config: MaxText configuration object """ print("=" * 80) print("Starting GRPO Training") print("=" * 80) - + # Setup device allocation use_pathways = getattr(config, 'use_pathways_reshard', False) trainer_devices, sampler_devices, num_vms = setup_device_allocation(config, use_pathways) - + print(f"Device allocation: {len(trainer_devices)} trainer, {len(sampler_devices)} sampler") print(f"Use Pathways: {use_pathways}") - + # Setup data directories home = os.path.expanduser("~") + "/" train_data_dir = f"{home}/data/train" test_data_dir = f"{home}/data/test" - + # Load datasets print("Loading GSM8K dataset...") train_dataset, tokenizer = get_gsm8k_dataset( - train_data_dir, - split="train", + train_data_dir, + split="train", batch_size=config.per_device_batch_size, num_batches=getattr(config, 'num_batches', 4) ) - + # Load test dataset for evaluation (currently not used in training loop) get_gsm8k_dataset( - test_data_dir, - split="test", + test_data_dir, + split="test", batch_size=config.per_device_batch_size, num_batches=getattr(config, 'num_test_batches', 5) ) - + # Load reference model print("Loading reference model...") reference_model, reference_mesh = get_maxtext_model(config, trainer_devices) reference_model.config = None - + # Load policy model print("Loading policy model...") policy_model, policy_mesh = get_maxtext_model(config, trainer_devices) policy_model.config = None - + # Setup meshes if num_vms >= 2 and not use_pathways: actor_mesh = policy_mesh @@ -333,12 +333,12 @@ def grpo_train(config): else: actor_mesh = policy_mesh rollout_mesh = policy_mesh - + # Setup optimizer learning_rate = getattr(config, 'learning_rate', 3e-6) max_steps = getattr(config, 'steps', 100) warmup_steps = int(0.1 * max_steps) - + optimizer = optax.adamw( learning_rate=optax.schedules.warmup_cosine_decay_schedule( init_value=0.0, @@ -351,7 +351,7 @@ def grpo_train(config): b2=0.99, weight_decay=0.1, ) - + # Add gradient clipping if specified max_grad_norm = getattr(config, 'max_grad_norm', 0.1) if max_grad_norm is not None: @@ -359,25 +359,25 @@ def grpo_train(config): optax.clip_by_global_norm(max_norm=max_grad_norm), optimizer, ) - + # Setup checkpointing ckpt_dir = f"{config.base_output_directory}/checkpoints" os.makedirs(ckpt_dir, exist_ok=True) - + checkpointing_options = ocp.CheckpointManagerOptions( save_interval_steps=getattr(config, 'checkpoint_period', 50), max_to_keep=4 ) - + # Setup metrics logging log_dir = f"{config.base_output_directory}/logs" os.makedirs(log_dir, exist_ok=True) - + metrics_logging_options = metrics_logger.MetricsLoggerOptions( - log_dir=log_dir, + log_dir=log_dir, flush_every_n_steps=20 ) - + # Setup RL cluster config cluster_config = rl_cluster_lib.ClusterConfig( role_to_mesh={ @@ -399,7 +399,7 @@ def grpo_train(config): rollout_config=base_rollout.RolloutConfig( max_tokens_to_generate=getattr(config, 'max_target_length', 768), max_prompt_length=getattr(config, 'max_prefill_predict_length', 256), - kv_cache_size=getattr(config, 'max_prefill_predict_length', 256) + + kv_cache_size=getattr(config, 'max_prefill_predict_length', 256) + getattr(config, 'max_target_length', 768) + 256, temperature=getattr(config, 'decode_sampling_temperature', 0.9), top_p=getattr(config, 'decode_sampling_top_p', 1.0), @@ -409,7 +409,7 @@ def grpo_train(config): rollout_vllm_hbm_utilization=0.2, rollout_vllm_tpu_backend_type="jax", ) - + # Setup GRPO config grpo_config = GrpoConfig( num_generations=getattr(config, 'num_generations', 2), @@ -417,7 +417,7 @@ def grpo_train(config): beta=getattr(config, 'grpo_beta', 0.08), epsilon=getattr(config, 'grpo_epsilon', 0.2), ) - + # Create RL cluster print("Creating RL cluster...") with nn_partitioning.axis_rules(config.logical_axis_rules): @@ -427,7 +427,7 @@ def grpo_train(config): tokenizer=tokenizer, cluster_config=cluster_config, ) - + # Create GRPO trainer print("Setting up GRPO trainer...") grpo_trainer = GrpoLearner( @@ -440,14 +440,14 @@ def grpo_train(config): ], grpo_config=grpo_config, ) - + # Start training print("Starting GRPO training...") with policy_mesh, nn_partitioning.axis_rules(config.logical_axis_rules): grpo_trainer.train(train_dataset) - + print("=" * 80) print("GRPO Training Completed Successfully!") print("=" * 80) - + return grpo_trainer, rl_cluster From 129cf5743d7ed7e9d7f15887ac406dcaacd57bc8 Mon Sep 17 00:00:00 2001 From: Vladimir Suvorov Date: Sat, 25 Oct 2025 07:17:45 +0400 Subject: [PATCH 12/31] Fix Signed-off-by: Vladimir Suvorov --- src/MaxText/examples/grpo_runner.py | 4 +- .../experimental/rl/grpo_tunix_trainer.py | 60 +++++++++---------- 2 files changed, 29 insertions(+), 35 deletions(-) diff --git a/src/MaxText/examples/grpo_runner.py b/src/MaxText/examples/grpo_runner.py index f92887274..f96767954 100755 --- a/src/MaxText/examples/grpo_runner.py +++ b/src/MaxText/examples/grpo_runner.py @@ -298,7 +298,7 @@ def build_config_argv(args): # Add optional parameters if args.run_name: config_argv.append(f"run_name={args.run_name}") - + if args.hf_access_token: config_argv.append(f"hf_access_token={args.hf_access_token}") @@ -360,4 +360,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/MaxText/experimental/rl/grpo_tunix_trainer.py b/src/MaxText/experimental/rl/grpo_tunix_trainer.py index 3266a58e9..f2aef351b 100644 --- a/src/MaxText/experimental/rl/grpo_tunix_trainer.py +++ b/src/MaxText/experimental/rl/grpo_tunix_trainer.py @@ -93,8 +93,7 @@ def extract_hash_answer(text: str) -> str | None: return text.split("####")[1].strip() -def get_gsm8k_dataset(data_dir: str, split: str = "train", batch_size: int = 1, - num_batches: int = 4, seed: int = 42): +def get_gsm8k_dataset(data_dir: str, split: str = "train", batch_size: int = 1, num_batches: int = 4, seed: int = 42): """Load and process GSM8K dataset for GRPO training.""" if grain is None: raise ImportError("grain is required for dataset processing. Please install it.") @@ -159,9 +158,9 @@ def setup_device_allocation(config, use_pathways: bool = False): devices = jax.devices() # Get device allocation parameters from config - trainer_devices_fraction = getattr(config, 'trainer_devices_fraction', 0.5) - sampler_devices_fraction = getattr(config, 'sampler_devices_fraction', 0.5) - chips_per_vm = getattr(config, 'chips_per_vm', 4) + trainer_devices_fraction = getattr(config, "trainer_devices_fraction", 0.5) + sampler_devices_fraction = getattr(config, "sampler_devices_fraction", 0.5) + chips_per_vm = getattr(config, "chips_per_vm", 4) num_vms = len(devices) // chips_per_vm @@ -175,7 +174,7 @@ def setup_device_allocation(config, use_pathways: bool = False): num_trainer_devices = int(num_devices * trainer_devices_fraction) num_sampler_devices = int(num_devices * sampler_devices_fraction) trainer_devices = devices[:num_trainer_devices] - sampler_devices = devices[num_devices - num_sampler_devices:] + sampler_devices = devices[num_devices - num_sampler_devices :] else: # Not using Pathways OR single host - use all devices for both if use_pathways: @@ -194,8 +193,7 @@ def match_format_exactly(prompts, completions, **kwargs): """Reward exact format matching.""" scores = [] match_format = re.compile( - rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?" - rf"{SOLUTION_START}(.+?){SOLUTION_END}" rf"[\s]{{0,}}$", + rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?" rf"{SOLUTION_START}(.+?){SOLUTION_END}" rf"[\s]{{0,}}$", flags=re.MULTILINE | re.DOTALL, ) @@ -223,8 +221,7 @@ def match_format_approximately(prompts, completions, **kwargs): def check_answer(prompts, completions, answer, **kwargs): """Reward correct answers.""" match_format = re.compile( - rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?" - rf"{SOLUTION_START}(.+?){SOLUTION_END}" rf"[\s]{{0,}}$", + rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?" rf"{SOLUTION_START}(.+?){SOLUTION_END}" rf"[\s]{{0,}}$", flags=re.MULTILINE | re.DOTALL, ) @@ -288,7 +285,7 @@ def grpo_train(config): print("=" * 80) # Setup device allocation - use_pathways = getattr(config, 'use_pathways_reshard', False) + use_pathways = getattr(config, "use_pathways_reshard", False) trainer_devices, sampler_devices, num_vms = setup_device_allocation(config, use_pathways) print(f"Device allocation: {len(trainer_devices)} trainer, {len(sampler_devices)} sampler") @@ -305,7 +302,7 @@ def grpo_train(config): train_data_dir, split="train", batch_size=config.per_device_batch_size, - num_batches=getattr(config, 'num_batches', 4) + num_batches=getattr(config, "num_batches", 4), ) # Load test dataset for evaluation (currently not used in training loop) @@ -313,7 +310,7 @@ def grpo_train(config): test_data_dir, split="test", batch_size=config.per_device_batch_size, - num_batches=getattr(config, 'num_test_batches', 5) + num_batches=getattr(config, "num_test_batches", 5), ) # Load reference model @@ -335,8 +332,8 @@ def grpo_train(config): rollout_mesh = policy_mesh # Setup optimizer - learning_rate = getattr(config, 'learning_rate', 3e-6) - max_steps = getattr(config, 'steps', 100) + learning_rate = getattr(config, "learning_rate", 3e-6) + max_steps = getattr(config, "steps", 100) warmup_steps = int(0.1 * max_steps) optimizer = optax.adamw( @@ -353,7 +350,7 @@ def grpo_train(config): ) # Add gradient clipping if specified - max_grad_norm = getattr(config, 'max_grad_norm', 0.1) + max_grad_norm = getattr(config, "max_grad_norm", 0.1) if max_grad_norm is not None: optimizer = optax.chain( optax.clip_by_global_norm(max_norm=max_grad_norm), @@ -365,18 +362,14 @@ def grpo_train(config): os.makedirs(ckpt_dir, exist_ok=True) checkpointing_options = ocp.CheckpointManagerOptions( - save_interval_steps=getattr(config, 'checkpoint_period', 50), - max_to_keep=4 + save_interval_steps=getattr(config, "checkpoint_period", 50), max_to_keep=4 ) # Setup metrics logging log_dir = f"{config.base_output_directory}/logs" os.makedirs(log_dir, exist_ok=True) - metrics_logging_options = metrics_logger.MetricsLoggerOptions( - log_dir=log_dir, - flush_every_n_steps=20 - ) + metrics_logging_options = metrics_logger.MetricsLoggerOptions(log_dir=log_dir, flush_every_n_steps=20) # Setup RL cluster config cluster_config = rl_cluster_lib.ClusterConfig( @@ -389,7 +382,7 @@ def grpo_train(config): offload_to_cpu=False, training_config=rl_cluster_lib.RLTrainingConfig( actor_optimizer=optimizer, - eval_every_n_steps=getattr(config, 'eval_interval', 10), + eval_every_n_steps=getattr(config, "eval_interval", 10), max_steps=max_steps, metrics_logging_options=metrics_logging_options, profiler_options=None, @@ -397,13 +390,14 @@ def grpo_train(config): checkpointing_options=checkpointing_options, ), rollout_config=base_rollout.RolloutConfig( - max_tokens_to_generate=getattr(config, 'max_target_length', 768), - max_prompt_length=getattr(config, 'max_prefill_predict_length', 256), - kv_cache_size=getattr(config, 'max_prefill_predict_length', 256) + - getattr(config, 'max_target_length', 768) + 256, - temperature=getattr(config, 'decode_sampling_temperature', 0.9), - top_p=getattr(config, 'decode_sampling_top_p', 1.0), - top_k=getattr(config, 'decode_sampling_top_k', 50), + max_tokens_to_generate=getattr(config, "max_target_length", 768), + max_prompt_length=getattr(config, "max_prefill_predict_length", 256), + kv_cache_size=getattr(config, "max_prefill_predict_length", 256) + + getattr(config, "max_target_length", 768) + + 256, + temperature=getattr(config, "decode_sampling_temperature", 0.9), + top_p=getattr(config, "decode_sampling_top_p", 1.0), + top_k=getattr(config, "decode_sampling_top_k", 50), ), rollout_vllm_model_version="meta-llama/Meta-Llama-3.1-8B-Instruct", rollout_vllm_hbm_utilization=0.2, @@ -412,10 +406,10 @@ def grpo_train(config): # Setup GRPO config grpo_config = GrpoConfig( - num_generations=getattr(config, 'num_generations', 2), + num_generations=getattr(config, "num_generations", 2), num_iterations=1, - beta=getattr(config, 'grpo_beta', 0.08), - epsilon=getattr(config, 'grpo_epsilon', 0.2), + beta=getattr(config, "grpo_beta", 0.08), + epsilon=getattr(config, "grpo_epsilon", 0.2), ) # Create RL cluster From 6dfab38b2616dd9e499faa46c737e1907c089ace Mon Sep 17 00:00:00 2001 From: A9isha Date: Tue, 28 Oct 2025 17:18:45 +0000 Subject: [PATCH 13/31] nit changes --- .../examples/{README.md => README_how_to_run_examples.md} | 0 src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename src/MaxText/examples/{README.md => README_how_to_run_examples.md} (100%) diff --git a/src/MaxText/examples/README.md b/src/MaxText/examples/README_how_to_run_examples.md similarity index 100% rename from src/MaxText/examples/README.md rename to src/MaxText/examples/README_how_to_run_examples.md diff --git a/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py b/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py index c45e7b174..bfe54b41c 100644 --- a/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py +++ b/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py @@ -39,7 +39,7 @@ # We use Tunix as the library for GRPO. # And we use vLLM as the library for efficient model inference and generation. # -# In this tutorial we use a single host TPUVM such as `v6e-8/v5p-8`. Let's get started! +# In this tutorial we use `v5p-256` or `v5p-128`. Let's get started! # ## Install necessary libraries From d514567d3ddf6dece30e89f5517f2caa9252aa6a Mon Sep 17 00:00:00 2001 From: A9isha Date: Wed, 29 Oct 2025 01:07:51 +0000 Subject: [PATCH 14/31] WIP src/MaxText/rl/rl_trainer.py, delete grpo_runner --- src/MaxText/configs/grpo.yml | 105 ++++- .../examples/grpo_llama3_1_70b_demo_pw.py | 2 +- .../examples/grpo_llama3_1_8b_demo.ipynb | 15 +- src/MaxText/examples/grpo_llama3_1_8b_demo.py | 2 +- .../examples/grpo_llama3_1_8b_demo_pw.py | 2 +- src/MaxText/examples/grpo_runner.py | 363 ------------------ .../rl_trainer.py} | 236 +++++++----- 7 files changed, 258 insertions(+), 467 deletions(-) delete mode 100755 src/MaxText/examples/grpo_runner.py rename src/MaxText/{experimental/rl/grpo_tunix_trainer.py => rl/rl_trainer.py} (64%) diff --git a/src/MaxText/configs/grpo.yml b/src/MaxText/configs/grpo.yml index 2fa6e6de0..4d4b43116 100644 --- a/src/MaxText/configs/grpo.yml +++ b/src/MaxText/configs/grpo.yml @@ -12,14 +12,109 @@ # See the License for the specific language governing permissions and # limitations under the License. -# GRPO Configuration -# This config consolidates common parameters for GRPO training across different model sizes +# RL Configuration +# This config consolidates common parameters for RL training across different model sizes base_config: "base.yml" -use_grpo: True -train_data_columns: 'prompt' +# ====== Hardware ===== +trainer_devices_fraction: 0.5 +sampler_devices_fraction: 0.5 +chips_per_vm: 4 # depends on hardware, for v5p this is 4 +# ====== Debug ====== +debug: True + +# ====== Reproducibility ====== +data_shuffle_seed: 42 +loss_algo: 'grpo' # grpo or gspo-token + +# ====== Checkpoint saving ====== +save_interval_steps: 500 +max_to_keep: 4 + +# ====== GRPO ====== +# === Generation during GRPO training === +max_prompt_length: 256 +total_generation_steps: 768 + +# The number of times the policy generates multiple responses for a given prompt +# within a single training step. This corresponds to `G` in Algorithm 1 in the +# paper. The "group" in GRPO comes from here. +num_generations: 2 + +# === other GRPO configs === +# The number of iterations per batch (𝜇 in GRPO algo 1). +num_iterations: 1 + +# The coefficient for the KL divergence penalty (𝛽) in the GRPO loss function. +# Important to keep a high enough value for this, otherwise, the KL divergence +# can increase unchecked. +beta: 0.08 +# Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, for +# stable updates. +epsilon: 0.2 + +# ====== Training ====== + +batch_size: 1 +# Increase `batch_size` and `MAX_STEPS` for better results. +# NUM_BATCHES = 3738 +NUM_BATCHES = 4 # 200 +# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be +# increased to a max. of 330 (if batch size is 4). +NUM_TEST_BATCHES = 5 # 200 + +EVAL_EVERY_N_STEPS = 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`. +NUM_EPOCHS = 1 # can potentially train for more epochs + + +# === AdamW, warmup, cosine scheduler === +LEARNING_RATE = 3e-6 +B1 = 0.9 +B2 = 0.99 +WEIGHT_DECAY = 0.1 +# == Cosine decay with warmup scheduler == +# Linearly increase learning rate from 0. to 5e-6 in the first 10% training +# steps, and then gradually decrease the learning rate to 0 using cosine +# scheduler. +WARMUP_STEPS = int(0.1 * MAX_STEPS) +# == Grad clipping == +# Grad clipping to prevent large gradients. Found this +# important to keep KL divergence in check. +MAX_GRAD_NORM = 0.1 + + +# ====== Inference ====== +# Important to keep a high-ish temperature for varied, diverse responses during +# training. +# greedy search +temperature: 0.01 +top_p: 1.0 +top_k: 1 + +# # some randomness +# "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95}, +# # liberal +# "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0}, + +TRAINER_DEVICES_FRACTION = 0.5 +SAMPLER_DEVICES_FRACTION = 0.5 +HBM_UTILIZATION_VLLM = 0.72 +SWAP_SPACE_VLLM_GB = 2 + + +# ====== Reward ====== +REWARD_EXACT_FORMAT_MATCH = 3.0 +REWARD_WHITE_SPACE_FORMAT_MATCH = 1.5 +REWARD_PARTIAL_FORMAT_MATCH = 0.5 +REWARD_RATIO_GUESS_TO_ANSWER_HIGH = 0.5 +REWARD_RATIO_GUESS_TO_ANSWER_LOW = 0.25 +PENALTY_INCORRECT_FORMAT = -0.5 +PENALTY_INCORRECT_ANSWER = -1.0 + + +# TODO: fix this # Dataset Configuration dataset_type: hf # Huggingface input pipeline hf_path: 'gsm8k' @@ -56,8 +151,6 @@ decode_sampling_top_k: 50 # Training Loop Configuration steps: 100 per_device_batch_size: 1 -eval_interval: 10 -eval_steps: 5 # Checkpoint Configuration enable_checkpointing: True diff --git a/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py b/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py index bfe54b41c..33d7fcd33 100644 --- a/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py +++ b/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py @@ -15,7 +15,7 @@ # pylint: disable=bare-except, consider-using-generator """ DEPRECATED: This file is deprecated and kept for reference only. -Please use the new unified CLI interface: grpo_runner.py +Please use the new unified CLI interface: rl_trainer.py See GRPO_README.md for migration guide and usage examples. diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb b/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb index 173c34c36..41fbeccfc 100644 --- a/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb +++ b/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb @@ -6,7 +6,7 @@ "source": [ "# GRPO Llama3.1-8B Demo: Direct Function Call\n", "\n", - "This notebook demonstrates GRPO training by directly calling the `grpo_train` function from `grpo_tunix_trainer.py`.\n", + "This notebook demonstrates GRPO training by directly calling the `rl_train` function from `rl_trainer.py`.\n", "\n", "## What is GRPO?\n", "\n", @@ -16,7 +16,7 @@ "3. Calculating relative advantages to update the policy\n", "\n", "\n", - "This notebook imports and calls the `grpo_train` function \n", + "This notebook imports and calls the `rl_train` function \n", "\n", "## Hardware Requirements\n", "\n", @@ -115,7 +115,7 @@ "\n", "# Import required modules\n", "from MaxText import pyconfig\n", - "from MaxText.experimental.rl.grpo_tunix_trainer import grpo_train\n", + "from MaxText.rl.rl_trainer import rl_train\n", "\n", "print(\"✅ Successfully imported GRPO training function\")\n", "print(f\"📁 MaxText path: {maxtext_path}\")\n", @@ -145,8 +145,6 @@ " \"num_generations=2\",\n", " \"grpo_beta=0.08\",\n", " \"grpo_epsilon=0.2\",\n", - " \"trainer_devices_fraction=0.5\",\n", - " \"sampler_devices_fraction=0.5\",\n", " \"chips_per_vm=4\"\n", "]\n", "\n", @@ -168,14 +166,13 @@ "source": [ "# Execute GRPO training directly\n", "try:\n", - " # Call the grpo_train function\n", - " grpo_trainer, rl_cluster = grpo_train(config)\n", + " # Call the rl_train function\n", + " grpo_trainer, rl_cluster = rl_train(config)\n", " \n", " print(\"\\n\" + \"=\"*80)\n", " print(\"✅ GRPO Training Completed Successfully!\")\n", " print(\"=\"*80)\n", - " print(f\"📁 Checkpoints saved to: {config.base_output_directory}/checkpoints\")\n", - " print(f\"📊 Logs available in: {config.base_output_directory}/logs\")\n", + " print(f\"📁 Checkpoints and logs saved to: {config.base_output_directory}\")\n", " print(f\"🎯 Final model ready for inference!\")\n", " \n", "except Exception as e:\n", diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo.py b/src/MaxText/examples/grpo_llama3_1_8b_demo.py index baf9e53ae..3ba2d537d 100644 --- a/src/MaxText/examples/grpo_llama3_1_8b_demo.py +++ b/src/MaxText/examples/grpo_llama3_1_8b_demo.py @@ -15,7 +15,7 @@ # pylint: disable=bare-except, consider-using-generator """ DEPRECATED: This file is deprecated and kept for reference only. -Please use the new unified CLI interface: grpo_runner.py +Please use the new unified CLI interface: rl_trainer.py See GRPO_README.md for migration guide and usage examples. diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py b/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py index 2536ff0eb..bae80e38f 100644 --- a/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py +++ b/src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py @@ -15,7 +15,7 @@ # pylint: disable=bare-except, consider-using-generator """ DEPRECATED: This file is deprecated and kept for reference only. -Please use the new unified CLI interface: grpo_runner.py +Please use the new unified CLI interface: rl_trainer.py See GRPO_README.md for migration guide and usage examples. diff --git a/src/MaxText/examples/grpo_runner.py b/src/MaxText/examples/grpo_runner.py deleted file mode 100755 index f96767954..000000000 --- a/src/MaxText/examples/grpo_runner.py +++ /dev/null @@ -1,363 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023–2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Unified GRPO Script - -This script provides a unified CLI interface for running GRPO training -across different model sizes and configurations. It uses the grpo_train function -from grpo_tunix_trainer.py which consolidates all the GRPO-specific logic. - -Usage Examples: - -# Llama3.1-8B (single host) -python3 src/MaxText/examples/grpo_runner.py \\ - --model_name=llama3.1-8b \\ - --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \\ - --load_parameters_path=gs://path/to/checkpoint \\ - --base_output_directory=/tmp/grpo_output \\ - --hf_access_token=$HF_TOKEN \\ - --steps=100 - -# Llama3.1-70B with Pathways (multi-host) -python3 src/MaxText/examples/grpo_runner.py \\ - --model_name=llama3.1-70b \\ - --tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \\ - --load_parameters_path=gs://path/to/checkpoint \\ - --base_output_directory=gs://path/to/output \\ - --hf_access_token=$HF_TOKEN \\ - --use_pathways=true \\ - --steps=100 - -# Custom dataset -python3 src/MaxText/examples/grpo_runner.py \\ - --model_name=llama3.1-8b \\ - --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \\ - --load_parameters_path=gs://path/to/checkpoint \\ - --base_output_directory=/tmp/grpo_output \\ - --hf_access_token=$HF_TOKEN \\ - --hf_path=custom/dataset \\ - --steps=100 -""" - -import argparse -import os -import sys - -# Add MaxText to path -script_dir = os.path.dirname(os.path.abspath(__file__)) -maxtext_root = os.path.abspath(os.path.join(script_dir, "..", "..")) -sys.path.insert(0, maxtext_root) - -from MaxText import pyconfig -from MaxText.globals import MAXTEXT_PKG_DIR -from MaxText.experimental.rl.grpo_tunix_trainer import grpo_train - - -def create_parser(): - """Create argument parser for GRPO demo.""" - parser = argparse.ArgumentParser( - description="Unified GRPO Script", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=__doc__, - ) - - # Model Configuration - parser.add_argument( - "--model_name", - type=str, - required=True, - help="Model name (e.g., llama3.1-8b, llama3.1-70b)", - ) - parser.add_argument( - "--tokenizer_path", - type=str, - required=True, - help="HuggingFace tokenizer path (e.g., meta-llama/Llama-3.1-8B-Instruct)", - ) - parser.add_argument( - "--load_parameters_path", - type=str, - required=True, - help="Path to model checkpoint (local or gs://)", - ) - - # Output Configuration - parser.add_argument( - "--base_output_directory", - type=str, - required=True, - help="Base output directory for logs and checkpoints", - ) - parser.add_argument( - "--run_name", - type=str, - default=None, - help="Run name for this experiment", - ) - - # Dataset Configuration - parser.add_argument( - "--hf_access_token", - type=str, - default=os.environ.get("HF_TOKEN"), - help="HuggingFace access token (default: $HF_TOKEN env var)", - ) - parser.add_argument( - "--hf_path", - type=str, - default="gsm8k", - help="HuggingFace dataset path", - ) - parser.add_argument( - "--hf_data_split", - type=str, - default="main", - help="HuggingFace dataset split", - ) - parser.add_argument( - "--hf_data_files", - type=str, - default="train", - help="HuggingFace dataset files", - ) - - # Training Configuration - parser.add_argument( - "--steps", - type=int, - default=100, - help="Number of training steps", - ) - parser.add_argument( - "--per_device_batch_size", - type=int, - default=1, - help="Per device batch size", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=3e-6, - help="Learning rate", - ) - - # GRPO-Specific Parameters - parser.add_argument( - "--num_generations", - type=int, - default=2, - help="Number of generations per prompt (group size)", - ) - parser.add_argument( - "--grpo_beta", - type=float, - default=0.08, - help="KL divergence penalty coefficient", - ) - parser.add_argument( - "--grpo_epsilon", - type=float, - default=0.2, - help="Clipping value for stable updates", - ) - - # Sequence Lengths - parser.add_argument( - "--max_prefill_predict_length", - type=int, - default=256, - help="Maximum prompt length", - ) - parser.add_argument( - "--max_target_length", - type=int, - default=768, - help="Maximum total sequence length", - ) - - # Pathways/Multi-Host Configuration - parser.add_argument( - "--use_pathways", - action="store_true", - help="Use Pathways for multi-host training", - ) - parser.add_argument( - "--inference_devices_per_replica", - type=int, - default=4, - help="Number of devices per inference replica", - ) - parser.add_argument( - "--inference_replicas", - type=int, - default=1, - help="Number of inference replicas", - ) - - # Parallelism Configuration - parser.add_argument( - "--ici_fsdp_parallelism", - type=int, - default=-1, - help="FSDP parallelism (-1 for auto)", - ) - parser.add_argument( - "--ici_tensor_parallelism", - type=int, - default=-1, - help="Tensor parallelism (-1 for auto)", - ) - - # Device Allocation Configuration - parser.add_argument( - "--trainer_devices_fraction", - type=float, - default=0.5, - help="Fraction of devices for training (0.0-1.0)", - ) - parser.add_argument( - "--sampler_devices_fraction", - type=float, - default=0.5, - help="Fraction of devices for sampling (0.0-1.0)", - ) - parser.add_argument( - "--chips_per_vm", - type=int, - default=4, - help="Number of chips per VM (hardware dependent)", - ) - - # Other Configuration - parser.add_argument( - "--profiler", - type=str, - default="xplane", - help="Profiler to use (xplane, none)", - ) - parser.add_argument( - "--checkpoint_period", - type=int, - default=50, - help="Checkpoint saving period", - ) - parser.add_argument( - "--config_file", - type=str, - default=None, - help="Optional custom config file (overrides grpo.yml)", - ) - - return parser - - -def build_config_argv(args): - """Build configuration arguments for MaxText pyconfig.""" - # Use custom config or default grpo.yml - if args.config_file: - base_config = args.config_file - else: - base_config = os.path.join(MAXTEXT_PKG_DIR, "configs", "grpo.yml") - - # Build training config argv - config_argv = [ - "", # Placeholder for argv[0] - base_config, - f"model_name={args.model_name}", - f"tokenizer_path={args.tokenizer_path}", - f"load_parameters_path={args.load_parameters_path}", - f"base_output_directory={args.base_output_directory}", - f"steps={args.steps}", - f"per_device_batch_size={args.per_device_batch_size}", - f"learning_rate={args.learning_rate}", - f"num_generations={args.num_generations}", - f"grpo_beta={args.grpo_beta}", - f"grpo_epsilon={args.grpo_epsilon}", - f"max_prefill_predict_length={args.max_prefill_predict_length}", - f"max_target_length={args.max_target_length}", - f"profiler={args.profiler}", - f"checkpoint_period={args.checkpoint_period}", - f"hf_path={args.hf_path}", - f"hf_data_split={args.hf_data_split}", - f"hf_data_files={args.hf_data_files}", - ] - - # Add optional parameters - if args.run_name: - config_argv.append(f"run_name={args.run_name}") - - if args.hf_access_token: - config_argv.append(f"hf_access_token={args.hf_access_token}") - - if args.use_pathways: - config_argv.append("use_pathways_reshard=True") - config_argv.append(f"inference_devices_per_replica={args.inference_devices_per_replica}") - config_argv.append(f"inference_replicas={args.inference_replicas}") - - if args.ici_fsdp_parallelism > 0: - config_argv.append(f"ici_fsdp_parallelism={args.ici_fsdp_parallelism}") - - if args.ici_tensor_parallelism > 0: - config_argv.append(f"ici_tensor_parallelism={args.ici_tensor_parallelism}") - - # Add device allocation parameters - config_argv.append(f"trainer_devices_fraction={args.trainer_devices_fraction}") - config_argv.append(f"sampler_devices_fraction={args.sampler_devices_fraction}") - config_argv.append(f"chips_per_vm={args.chips_per_vm}") - - return config_argv - - -def main(): - """Main entry point for GRPO demo.""" - parser = create_parser() - args = parser.parse_args() - - # Validate required environment/arguments - if not args.hf_access_token: - print("Error: HF_TOKEN is required. Set it as an environment variable or pass --hf_access_token") - sys.exit(1) - - print("=" * 80) - print("GRPO - Unified Training Script") - print("=" * 80) - print(f"Model: {args.model_name}") - print(f"Tokenizer: {args.tokenizer_path}") - print(f"Checkpoint: {args.load_parameters_path}") - print(f"Dataset: {args.hf_path}") - print(f"Output: {args.base_output_directory}") - print(f"Steps: {args.steps}") - print(f"GRPO Beta: {args.grpo_beta}") - print(f"Num Generations: {args.num_generations}") - print(f"Use Pathways: {args.use_pathways}") - print("=" * 80) - - # Build config arguments - config_argv = build_config_argv(args) - - # Initialize configuration - config = pyconfig.initialize(config_argv) - - # Run GRPO training using the unified trainer - grpo_train(config) - - print("=" * 80) - print("GRPO Training Completed Successfully!") - print("=" * 80) - - -if __name__ == "__main__": - main() diff --git a/src/MaxText/experimental/rl/grpo_tunix_trainer.py b/src/MaxText/rl/rl_trainer.py similarity index 64% rename from src/MaxText/experimental/rl/grpo_tunix_trainer.py rename to src/MaxText/rl/rl_trainer.py index f2aef351b..287e8338a 100644 --- a/src/MaxText/experimental/rl/grpo_tunix_trainer.py +++ b/src/MaxText/rl/rl_trainer.py @@ -13,19 +13,48 @@ # limitations under the License. """ -GRPO Tunix Trainer +GRPO Trainer -This module provides a unified `grpo_train` function that consolidates the common -GRPO training logic. It handles model loading, reward function setup, dataset -processing, and training orchestration. +This module provides a unified `rl_train` function that consolidates the common +RL training logic. It handles model loading, reward function setup, dataset +processing, and training orchestration. By default, we run Group Relative Policy Optimization (GRPO) on +GSM8K math reasoning benchmark. GRPO can enhance your model's problem-solving skills on mathematical word problems, +coding problems, etc. Usage: - from MaxText.experimental.rl.grpo_tunix_trainer import grpo_train - - # Train with GRPO - grpo_train(config, goodput_recorder=None) + Usage Examples: + + # Llama3.1-8B (single host) + python3 src/MaxText/examples/rl_trainer.py \\ + --model_name=llama3.1-8b \\ + --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \\ + --load_parameters_path=gs://path/to/checkpoint \\ + --base_output_directory=/tmp/grpo_output \\ + --hf_access_token=$HF_TOKEN \\ + --steps=100 + + # Llama3.1-70B with Pathways (multi-host) + python3 src/MaxText/examples/rl_trainer.py \\ + --model_name=llama3.1-70b \\ + --tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \\ + --load_parameters_path=gs://path/to/checkpoint \\ + --base_output_directory=gs://path/to/output \\ + --hf_access_token=$HF_TOKEN \\ + --use_pathways=true \\ + --steps=100 + + # Custom dataset + python3 src/MaxText/examples/rl_trainer.py \\ + --model_name=llama3.1-8b \\ + --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \\ + --load_parameters_path=gs://path/to/checkpoint \\ + --base_output_directory=/tmp/grpo_output \\ + --hf_access_token=$HF_TOKEN \\ + --hf_path=custom/dataset \\ + --steps=100 """ +from absl import app import os import re @@ -37,16 +66,9 @@ import tensorflow_datasets as tfds from transformers import AutoTokenizer -# Conditional imports for optional dependencies -try: - import grain -except ImportError: - grain = None +import grain -try: - import pathwaysutils -except ImportError: - pathwaysutils = None +import pathwaysutils from tunix.rl import rl_cluster as rl_cluster_lib from tunix.rl.rollout import base_rollout @@ -54,12 +76,11 @@ from tunix.sft import metrics_logger from tunix.models.llama3 import model as llama3_lib -from MaxText import maxtext_utils +from MaxText import max_logging, max_utils, maxtext_utils, pyconfig from MaxText import model_creation_utils from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter - -# GRPO-specific constants +# ====== Reward-specific constants ====== REWARD_EXACT_FORMAT_MATCH = 3.0 REWARD_WHITE_SPACE_FORMAT_MATCH = 1.5 REWARD_PARTIAL_FORMAT_MATCH = 0.5 @@ -68,12 +89,14 @@ PENALTY_INCORRECT_FORMAT = -0.5 PENALTY_INCORRECT_ANSWER = -1.0 -# Special tokens for GSM8K reasoning +# ====== Special tokens for GSM8K reasoning ====== REASONING_START = "" REASONING_END = "" SOLUTION_START = "" SOLUTION_END = "" +# ====== System prompt and Templates ====== + SYSTEM_PROMPT = f"""You are given a problem. Think about the problem and \ provide your reasoning. Place it between {REASONING_START} and \ {REASONING_END}. Then, provide the final answer (i.e., just one numerical \ @@ -85,19 +108,24 @@ {question} model""" +# ====== Debug flag for verbose logs ====== +DEBUG=False + +# ====== Reproducibility ====== +SEED = 42 + +# We use OpenAI's GSM8K dataset. GSM8K comprises grade school math word problems. + def extract_hash_answer(text: str) -> str | None: - """Extract the numerical answer from GSM8K format.""" + if DEBUG: + print(f"Extracting answer from: {text}") if "####" not in text: return None return text.split("####")[1].strip() -def get_gsm8k_dataset(data_dir: str, split: str = "train", batch_size: int = 1, num_batches: int = 4, seed: int = 42): - """Load and process GSM8K dataset for GRPO training.""" - if grain is None: - raise ImportError("grain is required for dataset processing. Please install it.") - +def get_dataset(data_dir, split="train") -> grain.MapDataset: # Download data if not os.path.exists(data_dir): os.makedirs(data_dir) @@ -110,16 +138,13 @@ def get_gsm8k_dataset(data_dir: str, split: str = "train", batch_size: int = 1, download=True, ) - # Load tokenizer - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") - loaded_dataset = ( grain.MapDataset.source(data) - .shuffle(seed=seed) + .shuffle(seed=SEED) .map( lambda x: { # passed to model forward pass - "prompts": tokenizer.apply_chat_template( + "prompts": model_tokenizer.apply_chat_template( [ { "role": "user", @@ -139,8 +164,7 @@ def get_gsm8k_dataset(data_dir: str, split: str = "train", batch_size: int = 1, } ) ) - - return loaded_dataset.batch(batch_size)[:num_batches], tokenizer + return loaded_dataset def get_maxtext_model(config, devices=None): @@ -164,27 +188,23 @@ def setup_device_allocation(config, use_pathways: bool = False): num_vms = len(devices) // chips_per_vm - if use_pathways and num_vms >= 2: + trainer_devices = devices + sampler_devices = devices + if num_vms >= 2 and use_pathways: # Multiple hosts with Pathways - split devices for trainer and sampler - if pathwaysutils is None: - raise ImportError("pathwaysutils is required for Pathways support. Please install it.") - pathwaysutils.initialize() - + print(f"{num_vms} VMs detected, allocating trainer and sampler devices, and using Pathways.") num_devices = len(devices) num_trainer_devices = int(num_devices * trainer_devices_fraction) num_sampler_devices = int(num_devices * sampler_devices_fraction) trainer_devices = devices[:num_trainer_devices] sampler_devices = devices[num_devices - num_sampler_devices :] - else: - # Not using Pathways OR single host - use all devices for both - if use_pathways: - if pathwaysutils is None: - raise ImportError("pathwaysutils is required for Pathways support. Please install it.") - pathwaysutils.initialize() - - trainer_devices = devices - sampler_devices = devices + print("Creating reference model and also meshes for reference and rollout") + llama3_1_70b, reference_mesh = get_maxtext_model(config_ref, trainer_devices) + devices_array = maxtext_utils.create_device_mesh(config_ref, sampler_devices) + rollout_mesh = Mesh(devices_array, config_ref.mesh_axes) + mesh = reference_mesh + return trainer_devices, sampler_devices, num_vms @@ -273,67 +293,80 @@ def check_numbers(prompts, completions, answer, **kwargs): return scores -def grpo_train(config): +def rl_train(mt_config): """ - Run GRPO training with the provided configuration. + Run RL training with the provided configuration. Args: - config: MaxText configuration object + mt_config: MaxText configuration object """ - print("=" * 80) print("Starting GRPO Training") - print("=" * 80) + + # Number of training steps. + max_steps = int(mt_config.num_batches * mt_config.num_iterations * TRAIN_FRACTION * NUM_EPOCHS) # Setup device allocation - use_pathways = getattr(config, "use_pathways_reshard", False) - trainer_devices, sampler_devices, num_vms = setup_device_allocation(config, use_pathways) + if jax.extend.backend.get_backend().platform_version == "Pathways": + max_logging.log("Pathways backend detected. Disabling setting profile options.") + use_pathways = True + else: + use_pathways = False + + trainer_devices, sampler_devices, num_vms = setup_device_allocation(mt_config, use_pathways) print(f"Device allocation: {len(trainer_devices)} trainer, {len(sampler_devices)} sampler") print(f"Use Pathways: {use_pathways}") + # ====== Data ====== # Setup data directories home = os.path.expanduser("~") + "/" train_data_dir = f"{home}/data/train" test_data_dir = f"{home}/data/test" + if not os.path.exists(train_data_dir): + os.makedirs(train_data_dir) + if not os.path.exists(test_data_dir): + os.makedirs(test_data_dir) + TRAIN_FRACTION = 1.0 # Load datasets print("Loading GSM8K dataset...") train_dataset, tokenizer = get_gsm8k_dataset( train_data_dir, split="train", - batch_size=config.per_device_batch_size, - num_batches=getattr(config, "num_batches", 4), + batch_size=mt_config.per_device_batch_size, + num_batches=getattr(mt_config, "num_batches", 4), ) # Load test dataset for evaluation (currently not used in training loop) get_gsm8k_dataset( test_data_dir, split="test", - batch_size=config.per_device_batch_size, - num_batches=getattr(config, "num_test_batches", 5), + batch_size=mt_config.per_device_batch_size, + num_batches=getattr(mt_config, "num_test_batches", 5), ) # Load reference model print("Loading reference model...") - reference_model, reference_mesh = get_maxtext_model(config, trainer_devices) + reference_model, reference_mesh = get_maxtext_model(mt_config, trainer_devices) reference_model.config = None # Load policy model print("Loading policy model...") - policy_model, policy_mesh = get_maxtext_model(config, trainer_devices) + policy_model, policy_mesh = get_maxtext_model(mt_config, trainer_devices) policy_model.config = None + # Setup meshes if num_vms >= 2 and not use_pathways: actor_mesh = policy_mesh - rollout_mesh = Mesh(maxtext_utils.create_device_mesh(config, sampler_devices), config.mesh_axes) + rollout_mesh = Mesh(maxtext_utils.create_device_mesh(mt_config, sampler_devices), mt_config.mesh_axes) else: actor_mesh = policy_mesh rollout_mesh = policy_mesh # Setup optimizer - learning_rate = getattr(config, "learning_rate", 3e-6) - max_steps = getattr(config, "steps", 100) + learning_rate = getattr(mt_config, "learning_rate", 3e-6) + max_steps = getattr(mt_config, "steps", 100) warmup_steps = int(0.1 * max_steps) optimizer = optax.adamw( @@ -350,7 +383,7 @@ def grpo_train(config): ) # Add gradient clipping if specified - max_grad_norm = getattr(config, "max_grad_norm", 0.1) + max_grad_norm = getattr(mt_config, "max_grad_norm", 0.1) if max_grad_norm is not None: optimizer = optax.chain( optax.clip_by_global_norm(max_norm=max_grad_norm), @@ -358,16 +391,15 @@ def grpo_train(config): ) # Setup checkpointing - ckpt_dir = f"{config.base_output_directory}/checkpoints" - os.makedirs(ckpt_dir, exist_ok=True) + ckpt_dir = mt_config.base_output_directory checkpointing_options = ocp.CheckpointManagerOptions( - save_interval_steps=getattr(config, "checkpoint_period", 50), max_to_keep=4 + save_interval_steps=getattr(mt_config, "checkpoint_period", 50), max_to_keep=4 ) # Setup metrics logging - log_dir = f"{config.base_output_directory}/logs" - os.makedirs(log_dir, exist_ok=True) + log_dir = mt_config.base_output_directory + max_logging.log(f"Logging to {log_dir}") metrics_logging_options = metrics_logger.MetricsLoggerOptions(log_dir=log_dir, flush_every_n_steps=20) @@ -382,7 +414,7 @@ def grpo_train(config): offload_to_cpu=False, training_config=rl_cluster_lib.RLTrainingConfig( actor_optimizer=optimizer, - eval_every_n_steps=getattr(config, "eval_interval", 10), + eval_every_n_steps=getattr(mt_config, "eval_interval", 10), max_steps=max_steps, metrics_logging_options=metrics_logging_options, profiler_options=None, @@ -390,14 +422,14 @@ def grpo_train(config): checkpointing_options=checkpointing_options, ), rollout_config=base_rollout.RolloutConfig( - max_tokens_to_generate=getattr(config, "max_target_length", 768), - max_prompt_length=getattr(config, "max_prefill_predict_length", 256), - kv_cache_size=getattr(config, "max_prefill_predict_length", 256) - + getattr(config, "max_target_length", 768) + max_tokens_to_generate=getattr(mt_config, "max_target_length", 768), + max_prompt_length=getattr(mt_config, "max_prefill_predict_length", 256), + kv_cache_size=getattr(mt_config, "max_prefill_predict_length", 256) + + getattr(mt_config, "max_target_length", 768) + 256, - temperature=getattr(config, "decode_sampling_temperature", 0.9), - top_p=getattr(config, "decode_sampling_top_p", 1.0), - top_k=getattr(config, "decode_sampling_top_k", 50), + temperature=getattr(mt_config, "decode_sampling_temperature", 0.9), + top_p=getattr(mt_config, "decode_sampling_top_p", 1.0), + top_k=getattr(mt_config, "decode_sampling_top_k", 50), ), rollout_vllm_model_version="meta-llama/Meta-Llama-3.1-8B-Instruct", rollout_vllm_hbm_utilization=0.2, @@ -406,15 +438,16 @@ def grpo_train(config): # Setup GRPO config grpo_config = GrpoConfig( - num_generations=getattr(config, "num_generations", 2), + num_generations=getattr(mt_config, "num_generations", 2), num_iterations=1, - beta=getattr(config, "grpo_beta", 0.08), - epsilon=getattr(config, "grpo_epsilon", 0.2), + beta=getattr(mt_config, "grpo_beta", 0.08), + epsilon=getattr(mt_config, "grpo_epsilon", 0.2), + loss_algo=mt_config.loss_algo, ) # Create RL cluster print("Creating RL cluster...") - with nn_partitioning.axis_rules(config.logical_axis_rules): + with nn_partitioning.axis_rules(mt_config.logical_axis_rules): rl_cluster = rl_cluster_lib.RLCluster( actor=policy_model, reference=reference_model, @@ -424,7 +457,7 @@ def grpo_train(config): # Create GRPO trainer print("Setting up GRPO trainer...") - grpo_trainer = GrpoLearner( + rl_trainer = GrpoLearner( rl_cluster=rl_cluster, reward_fns=[ match_format_exactly, @@ -437,11 +470,42 @@ def grpo_train(config): # Start training print("Starting GRPO training...") - with policy_mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - grpo_trainer.train(train_dataset) + with policy_mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules): + rl_trainer.train(train_dataset) + + profile_dir = mt_config.tensorboard_dir + max_logging.log(f"Saving profiles to {profile_dir}") + + jax.profiler.start_trace(profile_dir) + with mesh, nn_partitioning.axis_rules(config_policy.logical_axis_rules): + rl_trainer.train(dataset) + jax.profiler.stop_trace() print("=" * 80) print("GRPO Training Completed Successfully!") print("=" * 80) - return grpo_trainer, rl_cluster + return rl_trainer, rl_cluster + +def main(argv: Sequence[str]) -> None: + """Main function to run SFT training. + + Args: + argv: Command-line arguments. + """ + pathwaysutils.initialize() + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""): + os.environ["LIBTPU_INIT_ARGS"] = ( + os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + ) + + mt_config = pyconfig.initialize(argv) + max_utils.print_system_information() + + rl_train(mt_config) + + +if __name__ == "__main__": + app.run(main) From db589d386ef307bea7c8f3930676a7bee86646d7 Mon Sep 17 00:00:00 2001 From: A9isha Date: Wed, 29 Oct 2025 23:04:52 +0000 Subject: [PATCH 15/31] train_rl created, refactoring WIP --- src/MaxText/configs/base.yml | 1 + src/MaxText/configs/grpo.yml | 206 ------------ src/MaxText/configs/rl.yml | 246 ++++++++++---- src/MaxText/experimental/rl/rl.yml | 77 +++++ src/MaxText/rl/{rl_trainer.py => train_rl.py} | 306 ++++++++++-------- 5 files changed, 426 insertions(+), 410 deletions(-) delete mode 100644 src/MaxText/configs/grpo.yml create mode 100644 src/MaxText/experimental/rl/rl.yml rename src/MaxText/rl/{rl_trainer.py => train_rl.py} (59%) diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 391da2918..a51181f6a 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -47,6 +47,7 @@ enable_checkpointing: True save_checkpoint_on_completion: True async_checkpointing: True checkpoint_period: 10_000 +max_num_checkpoints_to_keep: None # enables one replica to read the ckpt then broadcast to the rest enable_single_replica_ckpt_restoring: False diff --git a/src/MaxText/configs/grpo.yml b/src/MaxText/configs/grpo.yml deleted file mode 100644 index 4d4b43116..000000000 --- a/src/MaxText/configs/grpo.yml +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright 2023–2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# RL Configuration -# This config consolidates common parameters for RL training across different model sizes - -base_config: "base.yml" - -# ====== Hardware ===== -trainer_devices_fraction: 0.5 -sampler_devices_fraction: 0.5 -chips_per_vm: 4 # depends on hardware, for v5p this is 4 - -# ====== Debug ====== -debug: True - -# ====== Reproducibility ====== -data_shuffle_seed: 42 -loss_algo: 'grpo' # grpo or gspo-token - -# ====== Checkpoint saving ====== -save_interval_steps: 500 -max_to_keep: 4 - -# ====== GRPO ====== -# === Generation during GRPO training === -max_prompt_length: 256 -total_generation_steps: 768 - -# The number of times the policy generates multiple responses for a given prompt -# within a single training step. This corresponds to `G` in Algorithm 1 in the -# paper. The "group" in GRPO comes from here. -num_generations: 2 - -# === other GRPO configs === -# The number of iterations per batch (𝜇 in GRPO algo 1). -num_iterations: 1 - -# The coefficient for the KL divergence penalty (𝛽) in the GRPO loss function. -# Important to keep a high enough value for this, otherwise, the KL divergence -# can increase unchecked. -beta: 0.08 -# Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, for -# stable updates. -epsilon: 0.2 - -# ====== Training ====== - -batch_size: 1 -# Increase `batch_size` and `MAX_STEPS` for better results. -# NUM_BATCHES = 3738 -NUM_BATCHES = 4 # 200 -# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be -# increased to a max. of 330 (if batch size is 4). -NUM_TEST_BATCHES = 5 # 200 - -EVAL_EVERY_N_STEPS = 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`. -NUM_EPOCHS = 1 # can potentially train for more epochs - - -# === AdamW, warmup, cosine scheduler === -LEARNING_RATE = 3e-6 -B1 = 0.9 -B2 = 0.99 -WEIGHT_DECAY = 0.1 -# == Cosine decay with warmup scheduler == -# Linearly increase learning rate from 0. to 5e-6 in the first 10% training -# steps, and then gradually decrease the learning rate to 0 using cosine -# scheduler. -WARMUP_STEPS = int(0.1 * MAX_STEPS) -# == Grad clipping == -# Grad clipping to prevent large gradients. Found this -# important to keep KL divergence in check. -MAX_GRAD_NORM = 0.1 - - -# ====== Inference ====== -# Important to keep a high-ish temperature for varied, diverse responses during -# training. -# greedy search -temperature: 0.01 -top_p: 1.0 -top_k: 1 - -# # some randomness -# "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95}, -# # liberal -# "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0}, - -TRAINER_DEVICES_FRACTION = 0.5 -SAMPLER_DEVICES_FRACTION = 0.5 -HBM_UTILIZATION_VLLM = 0.72 -SWAP_SPACE_VLLM_GB = 2 - - -# ====== Reward ====== -REWARD_EXACT_FORMAT_MATCH = 3.0 -REWARD_WHITE_SPACE_FORMAT_MATCH = 1.5 -REWARD_PARTIAL_FORMAT_MATCH = 0.5 -REWARD_RATIO_GUESS_TO_ANSWER_HIGH = 0.5 -REWARD_RATIO_GUESS_TO_ANSWER_LOW = 0.25 -PENALTY_INCORRECT_FORMAT = -0.5 -PENALTY_INCORRECT_ANSWER = -1.0 - - -# TODO: fix this -# Dataset Configuration -dataset_type: hf # Huggingface input pipeline -hf_path: 'gsm8k' -hf_data_split: 'main' -hf_data_files: 'train' - -# Model and Tokenizer Configuration -# Override these via CLI: -# model_name, tokenizer_path, load_parameters_path - -# Sequence Lengths -max_prefill_predict_length: 256 -max_target_length: 768 - -# Training Hyperparameters -learning_rate: 3.e-6 -adam_b1: 0.9 -adam_b2: 0.99 -weight_decay: 0.1 -max_grad_norm: 0.1 - -# Group Relative Policy Optimization (GRPO) Parameters -num_generations: 2 -grpo_beta: 0.08 # KL divergence penalty coefficient -grpo_epsilon: 0.2 # Clipping value for stable updates -inference_rollouts: 1 - -# Generation Configuration During Training -decode_sampling_strategy: "weighted" -decode_sampling_temperature: 0.9 -decode_sampling_top_p: 1.0 -decode_sampling_top_k: 50 - -# Training Loop Configuration -steps: 100 -per_device_batch_size: 1 - -# Checkpoint Configuration -enable_checkpointing: True -async_checkpointing: True -checkpoint_period: 50 - -# Pathways Inference Configuration -# For multi-host/multi-slice setups -use_pathways_reshard: False -inference_devices_per_replica: 4 -inference_replicas: 1 - -# Tokenizer Settings -add_bos: False -add_eos: False -return_log_prob: True - -# Performance and Memory -weight_dtype: bfloat16 -dtype: bfloat16 - -# Profiling -profiler: xplane -skip_first_n_steps_for_profiler: 5 -profiler_steps: 3 - -# Splash Attention Block Sizes -# Tuned for GRPO workloads -sa_block_q: 128 -sa_block_kv: 128 -sa_block_kv_compute: 128 -sa_block_q_dkv: 128 -sa_block_kv_dkv: 128 -sa_block_kv_dkv_compute: 128 -sa_block_q_dq: 128 -sa_block_kv_dq: 128 -sa_use_fused_bwd_kernel: False -sa_q_layout: "HEAD_DIM_MINOR" -sa_k_layout: "HEAD_DIM_MINOR" -sa_v_layout: "HEAD_DIM_MINOR" - -# Model-Specific Overrides (examples) -# For Llama3.1-8B: -# model_name: llama3.1-8b -# tokenizer_path: meta-llama/Llama-3.1-8B-Instruct -# ici_fsdp_parallelism: 8 -# -# For Llama3.1-70B with Pathways: -# model_name: llama3.1-70b -# tokenizer_path: meta-llama/Llama-3.1-70B-Instruct -# use_pathways_reshard: True -# ici_fsdp_parallelism: 16 - diff --git a/src/MaxText/configs/rl.yml b/src/MaxText/configs/rl.yml index 38108dc82..79081e644 100644 --- a/src/MaxText/configs/rl.yml +++ b/src/MaxText/configs/rl.yml @@ -12,66 +12,190 @@ # See the License for the specific language governing permissions and # limitations under the License. +# RL Configuration +# This config consolidates common parameters for RL training across different model sizes + base_config: "base.yml" -logical_axis_rules: [ - ['prefill_activation_length', ['data']], - ['prefill_activation_norm_length', ['data']], - ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], - ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], - ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], - ['activation_length', ['context_autoregressive', 'sequence']], - ['activation_length', ['context_autoregressive']], - ['activation_q_length', ['context_autoregressive']], - ['activation_kv_length', ['context_autoregressive']], - ['activation_norm_length', ['tensor_sequence', 'sequence']], - ['activation_embed', ['tensor_transpose']], - ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], - ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], - ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']], - ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], - ['activation_vocab', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']], - ['activation_vocab', ['tensor', 'tensor_transpose']], - ['activation_vocab', 'tensor_sequence'], - ['activation_vocab', ['sequence', 'context_autoregressive']], - ['activation_stage', 'stage'], - ['activation_exp', ['expert', 'context_autoregressive']], - ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']], - ['decode_length', []], - ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], - ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context_autoregressive']], - ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], - ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], - ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], - ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'expert']], - ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']], - ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']], - ['embed', ['fsdp', 'sequence', 'expert']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']], - ['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']], - ['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive']], - ['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']], - ['layers', 'stage'], - ['kv', []], - ['kv_head_dim', []], - ['cache_batch_prefill', []], - ['cache_batch', ['context_autoregressive']], - ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], - ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], - ['cache_kv', []], - ['cache_sequence', ['context_autoregressive']], - ['cache_scale_sequence', ['context_autoregressive']], - ['exp', ['expert', 'context_autoregressive']], - ['paged_kv_heads', []], - ['num_pages', ['tensor']], - ['tokens_per_page', []], - ['paged_kv_head_dim_size', []], - ] -# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details -data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] - -return_log_prob: True \ No newline at end of file +# ====== Hardware ===== +trainer_devices_fraction: 0.5 +sampler_devices_fraction: 0.5 +chips_per_vm: 4 # depends on hardware, for v5p this is 4 + +# ====== Debug ====== +debug: True + +# ====== Reproducibility ====== +data_shuffle_seed: 42 + +# ====== Checkpoint saving ====== +save_interval_steps: 500 +max_to_keep: 4 + +# ====== GRPO ====== + +# The number of times the policy generates multiple responses for a given prompt +# within a single training step. This corresponds to `G` in Algorithm 1 in the +# paper. The "group" in GRPO comes from here. +num_generations: 2 + +# === other GRPO configs === +# The number of iterations per batch (𝜇 in GRPO algo 1). +num_iterations: 1 + +# The coefficient for the KL divergence penalty (𝛽) in the GRPO loss function. +# Important to keep a high enough value for this, otherwise, the KL divergence +# can increase unchecked. +grpo_beta: 0.08 +# Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, for +# stable updates. +grpo_epsilon: 0.2 +loss_algo: 'grpo' # grpo or gspo-token + + +# ====== Models ====== +# for MaxText +# Model and Tokenizer Configuration +# Override these via CLI: +# model_name, tokenizer_path, load_parameters_path +# Model-Specific Overrides (examples) +# For Llama3.1-8B: +# model_name: llama3.1-8b +# tokenizer_path: meta-llama/Llama-3.1-8B-Instruct +# +# For Llama3.1-70B with Pathways: +# model_name: llama3.1-70b +# tokenizer_path: meta-llama/Llama-3.1-70B-Instruct + +async_checkpointing: 'false' +checkpoint_period: 5 +skip_jax_distributed_system: True +weight_dtype: 'bfloat16' +attention: 'dot_product' +remat_policy: 'custom' +decoder_layer_input: 'offload' +query_proj: 'offload' +key_proj: 'offload' +value_proj: 'offload' +# for vLLM +hf_model_name: 'meta-llama/Llama-3.1-70B-Instruct' + +# ====== Training ====== +batch_size: 1 +# Increase `batch_size` and `MAX_STEPS` for better results. +# num_batches = 3738 +num_batches = 4 # 200 +# Keep `num_test_batches` low so that evaluation runs quickly. It can be +# increased to a max. of 330 (if batch size is 4). +num_test_batches = 5 # 200 +train_fraction: 1.0 + +eval_interval: 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`. + +num_epochs: 1 # can potentially train for more epochs + +gradient_clipping_threshold: 0.1 + + +# greedy +eval_temperature: 0.01 +eval_top_k: 1 +eval_top_p: 1.0 +# # some randomness +# 'standard': {'eval_temperature': 0.7, 'eval_top_k': 50, 'eval_top_p': 0.95}, +# # liberal +# # 'liberal': {'eval_temperature': 0.85, 'eval_top_k': 2000, 'eval_top_p': 1.0}, + + +# ====== Inference ====== +# Important to keep a high-ish temperature for varied, diverse responses during +# training. +decode_sampling_temperature: 0.9 +decode_sampling_top_k: 50 +decode_sampling_nucleus_p: 1.0 + +# for vLLM +# === Generation during GRPO training === +# max Lengths for prompt and completion +max_prefill_predict_length: 256 +max_target_length: 768 +kv_cache_buffer: 256 +hbm_utilization_vllm: 0.72 +swap_space_vllm_gb: 2 +# Generation Configuration During Training +decode_sampling_temperature: 0.9 +decode_sampling_top_p: 1.0 +decode_sampling_top_k: 50 + +# ====== Checkpoint Configuration ====== +enable_checkpointing: True +async_checkpointing: True +checkpoint_period: 50 +max_num_checkpoints_to_keep: 10 + +# ====== Reward ====== + +reward_exact_format_match: 3.0 +reward_white_space_format_match: 1.5 +reward_partial_format_match: 0.5 +reward_ratio_guess_to_answer_high: 0.5 +reward_ratio_guess_to_answer_low: 0.25 +penalty_incorrect_format: -0.5 +penalty_incorrect_answer: -1.0 + +# ====== Special tokens for GSM8K reasoning ====== +reasoning_start_token: '' +reasoning_end_token: '' +solution_start_token: '' +solution_end_token: '' + +# ====== System prompt and Templates ====== + +system_prompt: | + You are given a problem. Think about the problem and provide your reasoning. Place it between {reasoning_start_token} and {reasoning_end_token}. Then, provide the final answer (i.e., just one numerical value) between {solution_start_token} and {solution_end_token}. + +template: | + user + {system_prompt} + + {question} + model + + +# TODO: fix this +# Dataset Configuration +dataset_type: hf # Huggingface input pipeline +hf_path: 'gsm8k' +hf_data_split: 'main' +hf_data_files: 'train' + + +# Pathways Inference Configuration +# For multi-host/multi-slice setups +use_pathways_reshard: False +inference_devices_per_replica: 4 +inference_replicas: 1 + +# Tokenizer Settings +add_bos: False +add_eos: False +return_log_prob: True + +# Performance and Memory +weight_dtype: bfloat16 +dtype: bfloat16 + +# Splash Attention Block Sizes +# Tuned for GRPO workloads +sa_block_q: 128 +sa_block_kv: 128 +sa_block_kv_compute: 128 +sa_block_q_dkv: 128 +sa_block_kv_dkv: 128 +sa_block_kv_dkv_compute: 128 +sa_block_q_dq: 128 +sa_block_kv_dq: 128 +sa_use_fused_bwd_kernel: False +sa_q_layout: "HEAD_DIM_MINOR" +sa_k_layout: "HEAD_DIM_MINOR" +sa_v_layout: "HEAD_DIM_MINOR" \ No newline at end of file diff --git a/src/MaxText/experimental/rl/rl.yml b/src/MaxText/experimental/rl/rl.yml new file mode 100644 index 000000000..38108dc82 --- /dev/null +++ b/src/MaxText/experimental/rl/rl.yml @@ -0,0 +1,77 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +base_config: "base.yml" + +logical_axis_rules: [ + ['prefill_activation_length', ['data']], + ['prefill_activation_norm_length', ['data']], + ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], + ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], + ['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']], + ['activation_length', ['context_autoregressive', 'sequence']], + ['activation_length', ['context_autoregressive']], + ['activation_q_length', ['context_autoregressive']], + ['activation_kv_length', ['context_autoregressive']], + ['activation_norm_length', ['tensor_sequence', 'sequence']], + ['activation_embed', ['tensor_transpose']], + ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']], + ['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['activation_vocab', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']], + ['activation_vocab', ['tensor', 'tensor_transpose']], + ['activation_vocab', 'tensor_sequence'], + ['activation_vocab', ['sequence', 'context_autoregressive']], + ['activation_stage', 'stage'], + ['activation_exp', ['expert', 'context_autoregressive']], + ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']], + ['decode_length', []], + ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], + ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context_autoregressive']], + ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], + ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], + ['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], + ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'expert']], + ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']], + ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']], + ['embed', ['fsdp', 'sequence', 'expert']], + ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']], + ['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']], + ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']], + ['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive']], + ['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['layers', 'stage'], + ['kv', []], + ['kv_head_dim', []], + ['cache_batch_prefill', []], + ['cache_batch', ['context_autoregressive']], + ['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']], + ['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']], + ['cache_kv', []], + ['cache_sequence', ['context_autoregressive']], + ['cache_scale_sequence', ['context_autoregressive']], + ['exp', ['expert', 'context_autoregressive']], + ['paged_kv_heads', []], + ['num_pages', ['tensor']], + ['tokens_per_page', []], + ['paged_kv_head_dim_size', []], + ] +# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details +data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] + +return_log_prob: True \ No newline at end of file diff --git a/src/MaxText/rl/rl_trainer.py b/src/MaxText/rl/train_rl.py similarity index 59% rename from src/MaxText/rl/rl_trainer.py rename to src/MaxText/rl/train_rl.py index 287e8338a..037395d00 100644 --- a/src/MaxText/rl/rl_trainer.py +++ b/src/MaxText/rl/train_rl.py @@ -80,36 +80,6 @@ from MaxText import model_creation_utils from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter -# ====== Reward-specific constants ====== -REWARD_EXACT_FORMAT_MATCH = 3.0 -REWARD_WHITE_SPACE_FORMAT_MATCH = 1.5 -REWARD_PARTIAL_FORMAT_MATCH = 0.5 -REWARD_RATIO_GUESS_TO_ANSWER_HIGH = 0.5 -REWARD_RATIO_GUESS_TO_ANSWER_LOW = 0.25 -PENALTY_INCORRECT_FORMAT = -0.5 -PENALTY_INCORRECT_ANSWER = -1.0 - -# ====== Special tokens for GSM8K reasoning ====== -REASONING_START = "" -REASONING_END = "" -SOLUTION_START = "" -SOLUTION_END = "" - -# ====== System prompt and Templates ====== - -SYSTEM_PROMPT = f"""You are given a problem. Think about the problem and \ -provide your reasoning. Place it between {REASONING_START} and \ -{REASONING_END}. Then, provide the final answer (i.e., just one numerical \ -value) between {SOLUTION_START} and {SOLUTION_END}.""" - -TEMPLATE = """user -{system_prompt} - -{question} -model""" - -# ====== Debug flag for verbose logs ====== -DEBUG=False # ====== Reproducibility ====== SEED = 42 @@ -117,8 +87,8 @@ # We use OpenAI's GSM8K dataset. GSM8K comprises grade school math word problems. -def extract_hash_answer(text: str) -> str | None: - if DEBUG: +def extract_hash_answer(text: str, debug: bool = False) -> str | None: + if debug: print(f"Extracting answer from: {text}") if "####" not in text: return None @@ -160,7 +130,7 @@ def get_dataset(data_dir, split="train") -> grain.MapDataset: # passed to reward functions "question": x["question"].decode("utf-8"), # passed to reward functions - "answer": extract_hash_answer(x["answer"].decode("utf-8")), + "answer": extract_hash_answer(x["answer"].decode("utf-8"), debug=mt_config.debug), } ) ) @@ -168,84 +138,108 @@ def get_dataset(data_dir, split="train") -> grain.MapDataset: def get_maxtext_model(config, devices=None): - """Load MaxText model with Tunix adapter.""" + """ + Load MaxText model with Tunix adapter. + # Note: pass the path to your scanned checkpoint for "load_parameters_path". To generate a scanned checkpoint, you can use the `scanned_checkpoint.py` script in MaxText. + # To create a scanned checkpoint, you can use /maxtext/MaxText/utils/ckpt_conversion/to_maxtext.py + """ model, mesh = model_creation_utils.create_nnx_model(config, devices) with mesh: tunix_model = TunixMaxTextAdapter(base_model=model) - model_config = llama3_lib.ModelConfig.llama3_1_8b() - tunix_model.config = model_config + tunix_model.config = None return tunix_model, mesh -def setup_device_allocation(config, use_pathways: bool = False): +def setup_device_allocation(mt_config, use_pathways: bool = False): """Setup device allocation for training and inference.""" - devices = jax.devices() - - # Get device allocation parameters from config - trainer_devices_fraction = getattr(config, "trainer_devices_fraction", 0.5) - sampler_devices_fraction = getattr(config, "sampler_devices_fraction", 0.5) - chips_per_vm = getattr(config, "chips_per_vm", 4) - - num_vms = len(devices) // chips_per_vm + devices = jax.devices() + num_vms = len(devices) // mt_config.chips_per_vm trainer_devices = devices sampler_devices = devices if num_vms >= 2 and use_pathways: - # Multiple hosts with Pathways - split devices for trainer and sampler + # Multiple hosts with Pathways - potentially split devices for trainer and sampler + # based on trainer_devices_fraction and sampler_devices_fraction print(f"{num_vms} VMs detected, allocating trainer and sampler devices, and using Pathways.") num_devices = len(devices) - num_trainer_devices = int(num_devices * trainer_devices_fraction) - num_sampler_devices = int(num_devices * sampler_devices_fraction) + num_trainer_devices = int(num_devices * mt_config.trainer_devices_fraction) + num_sampler_devices = int(num_devices * mt_config.sampler_devices_fraction) trainer_devices = devices[:num_trainer_devices] sampler_devices = devices[num_devices - num_sampler_devices :] - - print("Creating reference model and also meshes for reference and rollout") - llama3_1_70b, reference_mesh = get_maxtext_model(config_ref, trainer_devices) - devices_array = maxtext_utils.create_device_mesh(config_ref, sampler_devices) - rollout_mesh = Mesh(devices_array, config_ref.mesh_axes) - mesh = reference_mesh + if mt_config.trainer_devices_fraction!=1.0: + print(f"Using first {len(trainer_devices)} devices as Trainer devices") + if mt_config.sampler_devices_fraction != 1.0: + print(f"Using last {len(sampler_devices)} devices as Sampler devices") return trainer_devices, sampler_devices, num_vms # Reward Functions -def match_format_exactly(prompts, completions, **kwargs): +def match_format_exactly(prompts, completions, mt_config, **kwargs): """Reward exact format matching.""" scores = [] match_format = re.compile( - rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?" rf"{SOLUTION_START}(.+?){SOLUTION_END}" rf"[\s]{{0,}}$", + ( + r"^[\s]{0,}}" + rf"{mt_config.reasoning_start_token}.+?{mt_config.reasoning_end_token}.*?" + rf"{mt_config.solution_start_token}(.+?){mt_config.solution_end_token}" + r"[\s]{0,}$" + ), flags=re.MULTILINE | re.DOTALL, ) for completion in completions: score = 0 if match_format.search(completion) is not None: - score += REWARD_EXACT_FORMAT_MATCH + score += mt_config.reward_exact_format_match scores.append(score) return scores -def match_format_approximately(prompts, completions, **kwargs): +def match_format_approximately(prompts, completions, mt_config, **kwargs): """Reward approximate format matching.""" scores = [] for completion in completions: score = 0 - score += REWARD_PARTIAL_FORMAT_MATCH if completion.count(REASONING_START) == 1 else PENALTY_INCORRECT_FORMAT - score += REWARD_PARTIAL_FORMAT_MATCH if completion.count(REASONING_END) == 1 else PENALTY_INCORRECT_FORMAT - score += REWARD_PARTIAL_FORMAT_MATCH if completion.count(SOLUTION_START) == 1 else PENALTY_INCORRECT_FORMAT - score += REWARD_PARTIAL_FORMAT_MATCH if completion.count(SOLUTION_END) == 1 else PENALTY_INCORRECT_FORMAT + score += ( + mt_config.reward_partial_format_match + if completion.count(mt_config.reasoning_start_token) == 1 + else mt_config.penalty_incorrect_format + ) + score += ( + mt_config.reward_partial_format_match + if completion.count(mt_config.reasoning_end_token) == 1 + else mt_config.penalty_incorrect_format + ) + score += ( + mt_config.reward_partial_format_match + if completion.count(mt_config.solution_start_token) == 1 + else mt_config.penalty_incorrect_format + ) + score += ( + mt_config.reward_partial_format_match + if completion.count(mt_config.solution_end_token) == 1 + else mt_config.penalty_incorrect_format + ) scores.append(score) return scores -def check_answer(prompts, completions, answer, **kwargs): +def check_answer(prompts, completions, answer, mt_config, **kwargs): """Reward correct answers.""" match_format = re.compile( - rf"^[\s]{{0,}}" rf"{REASONING_START}.+?{REASONING_END}.*?" rf"{SOLUTION_START}(.+?){SOLUTION_END}" rf"[\s]{{0,}}$", + ( + r"^[\s]{0,}}" + rf"{mt_config.reasoning_start_token}.+?{mt_config.reasoning_end_token}.*?" + rf"{mt_config.solution_start_token}(.+?){mt_config.solution_end_token}" + r"[\s]{0,}$" + ), flags=re.MULTILINE | re.DOTALL, ) - extracted_responses = [guess.group(1) if (guess := match_format.search(r)) is not None else None for r in completions] + extracted_responses = [ + guess.group(1) if (guess := match_format.search(r)) is not None else None for r in completions + ] scores = [] for guess, true_answer in zip(extracted_responses, answer): @@ -255,28 +249,33 @@ def check_answer(prompts, completions, answer, **kwargs): continue if guess == true_answer: - score += REWARD_EXACT_FORMAT_MATCH + score += mt_config.reward_exact_format_match elif guess.strip() == true_answer.strip(): - score += REWARD_WHITE_SPACE_FORMAT_MATCH + score += mt_config.reward_white_space_format_match else: try: ratio = float(guess) / float(true_answer) if 0.9 <= ratio <= 1.1: - score += REWARD_RATIO_GUESS_TO_ANSWER_HIGH + score += mt_config.reward_ratio_guess_to_answer_high elif 0.8 <= ratio <= 1.2: - score += REWARD_RATIO_GUESS_TO_ANSWER_LOW + score += mt_config.reward_ratio_guess_to_answer_low else: - score += PENALTY_INCORRECT_ANSWER + score += mt_config.penalty_incorrect_answer except (ValueError, ZeroDivisionError): - score += PENALTY_INCORRECT_FORMAT + score += mt_config.penalty_incorrect_format scores.append(score) return scores -def check_numbers(prompts, completions, answer, **kwargs): +def check_numbers(prompts, completions, answer, mt_config, **kwargs): """Reward correct numerical answers.""" - match_numbers = re.compile(rf"{SOLUTION_START}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL) - extracted_responses = [guess.group(1) if (guess := match_numbers.search(r)) is not None else None for r in completions] + match_numbers = re.compile( + rf"{mt_config.solution_start_token}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL + ) + extracted_responses = [ + guess.group(1) if (guess := match_numbers.search(r)) is not None else None + for r in completions + ] scores = [] for guess, true_answer in zip(extracted_responses, answer): @@ -300,22 +299,13 @@ def rl_train(mt_config): Args: mt_config: MaxText configuration object """ - print("Starting GRPO Training") + # ====== Debug flag for verbose logs ====== + DEBUG = mt_config.debug + print("Starting GRPO Training") # Number of training steps. - max_steps = int(mt_config.num_batches * mt_config.num_iterations * TRAIN_FRACTION * NUM_EPOCHS) - # Setup device allocation - if jax.extend.backend.get_backend().platform_version == "Pathways": - max_logging.log("Pathways backend detected. Disabling setting profile options.") - use_pathways = True - else: - use_pathways = False - - trainer_devices, sampler_devices, num_vms = setup_device_allocation(mt_config, use_pathways) - - print(f"Device allocation: {len(trainer_devices)} trainer, {len(sampler_devices)} sampler") - print(f"Use Pathways: {use_pathways}") + max_train_steps = int(mt_config.num_batches * mt_config.num_iterations * mt_config.train_fraction * mt_config.num_epochs) # ====== Data ====== # Setup data directories @@ -326,7 +316,7 @@ def rl_train(mt_config): os.makedirs(train_data_dir) if not os.path.exists(test_data_dir): os.makedirs(test_data_dir) - TRAIN_FRACTION = 1.0 + # Load datasets print("Loading GSM8K dataset...") @@ -345,65 +335,91 @@ def rl_train(mt_config): num_batches=getattr(mt_config, "num_test_batches", 5), ) + + # Setup device allocation + if jax.extend.backend.get_backend().platform_version == "Pathways": + max_logging.log("Pathways backend detected. Disabling setting profile options.") + use_pathways = True + else: + use_pathways = False + print(f"Use Pathways: {use_pathways}") + trainer_devices, sampler_devices, num_vms = setup_device_allocation(mt_config, use_pathways) + # Load reference model - print("Loading reference model...") - reference_model, reference_mesh = get_maxtext_model(mt_config, trainer_devices) - reference_model.config = None + print("Creating reference model and also meshes for reference and rollout") + reference_model, reference_mesh = get_ref_maxtext_model(mt_config, trainer_devices) + devices_array = maxtext_utils.create_device_mesh(mt_config, sampler_devices) + # if trainer_devices=sampler_devices, then rollout_mesh=reference_mesh + # else rollout_mesh uses sampler_devices + rollout_mesh = Mesh(devices_array, mt_config.mesh_axes) + if mt_config.debug: + print("Reference Model initialized successfully") + nnx.display(reference_model) + print(f"Reference mesh shape: {reference_mesh.shape}") + + # Sanity check that weights are loaded correctly + _maxtext_state_flatten = nnx.state(llama3_1_70b).flat_state() + maxtext_state_flatten = {".".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten} + print( + f"maxtext_state_flatten[base.token_embedder.embedding].value={maxtext_state_flatten['base.token_embedder.embedding'].value}" + ) + # TODO: @mazumdera: change this to use lora # Load policy model - print("Loading policy model...") + print("Creating policy model with same config as reference model on trainer mesh") policy_model, policy_mesh = get_maxtext_model(mt_config, trainer_devices) - policy_model.config = None + actor_mesh = policy_mesh - - # Setup meshes - if num_vms >= 2 and not use_pathways: - actor_mesh = policy_mesh - rollout_mesh = Mesh(maxtext_utils.create_device_mesh(mt_config, sampler_devices), mt_config.mesh_axes) - else: - actor_mesh = policy_mesh - rollout_mesh = policy_mesh + if mt_config.debug: + print("Policy Model initialized successfully") + nnx.display(policy_model) + print(f"Policy mesh shape: {policy_mesh.shape}") # Setup optimizer - learning_rate = getattr(mt_config, "learning_rate", 3e-6) - max_steps = getattr(mt_config, "steps", 100) - warmup_steps = int(0.1 * max_steps) - optimizer = optax.adamw( learning_rate=optax.schedules.warmup_cosine_decay_schedule( init_value=0.0, - peak_value=learning_rate, - warmup_steps=warmup_steps, - decay_steps=max_steps, - end_value=0.0, + peak_value=mt_config.learning_rate, + # Linearly increase learning rate from 0. to learning_rate in the first + # warmup_steps_fraction training steps, and then gradually decrease the + # learning rate to 0 using cosine scheduler. + warmup_steps=int(mt_config.warmup_steps_fraction*mt_config.max_train_steps), + decay_steps=max_train_steps, ), - b1=0.9, - b2=0.99, - weight_decay=0.1, + b1=mt_config.adam_b1, + b2=mt_config.adam_b2, + weight_decay=mt_config.adam_weight_decay, ) + # TODO: @mazumdera: try optimizer offloading with adamw # Add gradient clipping if specified - max_grad_norm = getattr(mt_config, "max_grad_norm", 0.1) - if max_grad_norm is not None: + # Grad clipping to prevent large gradients. We find this + # important to keep KL divergence in check. + if mt_config.gradient_clipping_threshold > 0: optimizer = optax.chain( - optax.clip_by_global_norm(max_norm=max_grad_norm), + optax.clip_by_global_norm(max_norm=mt_config.gradient_clipping_threshold), optimizer, ) # Setup checkpointing - ckpt_dir = mt_config.base_output_directory - + ckpt_dir = mt_config.checkpoint_dir checkpointing_options = ocp.CheckpointManagerOptions( - save_interval_steps=getattr(mt_config, "checkpoint_period", 50), max_to_keep=4 + save_interval_steps=mt_config.checkpoint_period, mt_config.max_num_checkpoints_to_keep ) # Setup metrics logging - log_dir = mt_config.base_output_directory - max_logging.log(f"Logging to {log_dir}") - + log_dir=mt_config.tensorboard_dir + print(f"TensorBoard logs directory: {LOG_DIR}") + print(f"tensorboard --logdir {LOG_DIR} --port=8086") metrics_logging_options = metrics_logger.MetricsLoggerOptions(log_dir=log_dir, flush_every_n_steps=20) - # Setup RL cluster config + # Profiler configurations + # TODO: xfgu@: add profiling + profiler_options = None + + # RL Cluster config + # Note that we use vLLM as the rollout engine. + # and we are using Tensor Parallelism for rollout cluster_config = rl_cluster_lib.ClusterConfig( role_to_mesh={ rl_cluster_lib.Role.ACTOR: actor_mesh, @@ -414,34 +430,38 @@ def rl_train(mt_config): offload_to_cpu=False, training_config=rl_cluster_lib.RLTrainingConfig( actor_optimizer=optimizer, - eval_every_n_steps=getattr(mt_config, "eval_interval", 10), - max_steps=max_steps, + eval_every_n_steps=mt_config.eval_interval, + max_steps=max_train_steps, + # metrics logging metrics_logging_options=metrics_logging_options, - profiler_options=None, + # profiling + profiler_options=profiler_options, + # checkpoint saving checkpoint_root_directory=ckpt_dir, checkpointing_options=checkpointing_options, ), rollout_config=base_rollout.RolloutConfig( - max_tokens_to_generate=getattr(mt_config, "max_target_length", 768), - max_prompt_length=getattr(mt_config, "max_prefill_predict_length", 256), - kv_cache_size=getattr(mt_config, "max_prefill_predict_length", 256) - + getattr(mt_config, "max_target_length", 768) - + 256, - temperature=getattr(mt_config, "decode_sampling_temperature", 0.9), - top_p=getattr(mt_config, "decode_sampling_top_p", 1.0), - top_k=getattr(mt_config, "decode_sampling_top_k", 50), + max_tokens_to_generate=mt_config.max_target_length, + max_prompt_length=mt_config.max_prefill_predict_length, + kv_cache_size=mt_config.max_prefill_predict_length + + mt_config.max_target_length + + mt_config.kv_cache_buffer, + temperature=mt_config.decode_sampling_temperature, + top_p=mt_config.decode_sampling_nucleus_p, + top_k=mt_config.decode_sampling_top_k, ), - rollout_vllm_model_version="meta-llama/Meta-Llama-3.1-8B-Instruct", - rollout_vllm_hbm_utilization=0.2, + rollout_vllm_model_version=mt_config.hf_model_name, + rollout_vllm_hbm_utilization=mt_config.hbm_utilization_vllm, rollout_vllm_tpu_backend_type="jax", + rollout_vllm_swap_space_size_gb=mt_config.swap_space_vllm_gb, ) # Setup GRPO config grpo_config = GrpoConfig( - num_generations=getattr(mt_config, "num_generations", 2), - num_iterations=1, - beta=getattr(mt_config, "grpo_beta", 0.08), - epsilon=getattr(mt_config, "grpo_epsilon", 0.2), + num_generations=mt_config.num_generations, + num_iterations=mt_config.num_iterations, + beta=mt_config.grpo_beta, + epsilon=mt_config.grpo_epsilon, loss_algo=mt_config.loss_algo, ) @@ -460,10 +480,10 @@ def rl_train(mt_config): rl_trainer = GrpoLearner( rl_cluster=rl_cluster, reward_fns=[ - match_format_exactly, - match_format_approximately, - check_answer, - check_numbers, + lambda **kwargs: match_format_exactly(mt_config=mt_config, **kwargs), + lambda **kwargs: match_format_approximately(mt_config=mt_config, **kwargs), + lambda **kwargs: check_answer(mt_config=mt_config, **kwargs), + lambda **kwargs: check_numbers(mt_config=mt_config, **kwargs), ], grpo_config=grpo_config, ) From 22a2623bd4ce4d411f8a1906edbb2db288fc91fd Mon Sep 17 00:00:00 2001 From: A9isha Date: Thu, 30 Oct 2025 03:14:14 +0000 Subject: [PATCH 16/31] create rl_utils --- src/MaxText/evaluate_rl.py | 275 ++++++++++++++++++++++++++++++ src/MaxText/rl_utils.py | 219 ++++++++++++++++++++++++ src/MaxText/{rl => }/train_rl.py | 277 +++++++++---------------------- 3 files changed, 577 insertions(+), 194 deletions(-) create mode 100644 src/MaxText/evaluate_rl.py create mode 100644 src/MaxText/rl_utils.py rename src/MaxText/{rl => }/train_rl.py (69%) diff --git a/src/MaxText/evaluate_rl.py b/src/MaxText/evaluate_rl.py new file mode 100644 index 000000000..99657c346 --- /dev/null +++ b/src/MaxText/evaluate_rl.py @@ -0,0 +1,275 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=bare-except, consider-using-generator + +import functools +import os +from pprint import pprint +import re +import sys + +from datetime import datetime +from flax import nnx +from flax.linen import partitioning as nn_partitioning +import grain +import humanize + + +import jax +from jax.sharding import Mesh +import optax +from orbax import checkpoint as ocp +import tensorflow_datasets as tfds +from tqdm.auto import tqdm +from tunix.rl import rl_cluster as rl_cluster_lib +from tunix.rl.rollout import base_rollout +from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner +from tunix.sft import metrics_logger + + +from transformers import AutoTokenizer + +from flax import linen as nn +import numpy as np +from etils import epath + +from tunix.rl.rollout.base_rollout import RolloutConfig + +from MaxText.globals import MAXTEXT_ASSETS_ROOT +from MaxText import rl_utils + +# ## Evaluate +# +# +# Before we train the model, let's evaluate the model on the test set so we can +# see the improvement post training. +# +# We evaluate it in two ways: +# +# **Quantitative** +# +# * **Answer Accuracy**: percentage of samples for which the model predicts the +# correct final numerical answer +# * **Answer (Partial) Accuracy**: percentage of samples for which the model +# predicts a final numerical answer such that the \`model answer / answer\` +# ratio lies between 0.9 and 1.1. +# * **Format Accuracy**: percentage of samples for which the model outputs the +# correct format, i.e., reasoning between the reasoning special tokens, and the +# final answer between the \`\\`, \`\\` tokens. +# +# **Qualitative** +# +# We'll also print outputs for a few given questions so that we can compare the generated output later. +# + + +def generate_responses( + mt_config, + prompts, + rl_cluster, + num_passes=1, + temperature=0.7, + top_k=50, + top_p=0.95, +): + """ + Generate responses for a batch of prompts across multiple passes. + + Args: + prompts: List of prompts to generate responses for + rl_cluster: Model cluster for generation + num_passes: Number of generation passes + temperature: Sampling temperature + top_k: Top-k sampling parameter + top_p: Top-p sampling parameter + + Returns: + List of lists containing responses for each prompt across passes + """ + multiple_call_responses = [[] for _ in range(len(prompts))] + + for p in range(num_passes): + responses = rl_cluster.rollout.generate( + prompts, + rollout_config=RolloutConfig( + max_tokens_to_generate=mt_config.max_target_length, + temperature=mt_config.eval_temperature, + top_k=mt_config.eval_top_k, + top_p=mt_config.eval_top_p, + ), + ) + responses = responses.text + + if mt_config.debug: + print(f"Pass {p+1}/{num_passes}, responses: {responses}") + + for idx, response in enumerate(responses): + multiple_call_responses[idx].append(response) + + return multiple_call_responses + + +def score_responses(mt_config, question, responses, answer): + """ + Score a set of responses for a single question. + + Args: + question: The evaluation question + responses: List of generated responses for this question + answer: The correct answer + + Returns: + Tuple of (is_correct, is_partially_correct, has_correct_format) + """ + match_format = rl_utils.get_match_format_regex(mt_config) + match_numbers = rl_utils.get_match_numbers_regex(mt_config) + + if DEBUG: + print("========================================") + print(f"Evaluation Question: {question}") + print(f"Evaluation Answer: {answer}") + print(f"Evaluation Responses: {responses}") + print("========================================") + + is_correct = False + is_partially_correct = False + has_correct_format = False + + for response in responses: + # Extract numerical response + extracted_response = guess.group(1) if (guess := match_numbers.search(response)) is not None else "-1000000" + + if DEBUG: + print(f"Evaluation extracted_response: {extracted_response}") + + # Check exact correctness + try: + if float(extracted_response.strip()) == float(answer.strip()): + is_correct = True + + # Check partial correctness (within 10%) + ratio = float(extracted_response.strip()) / float(answer.strip()) + if 0.9 <= ratio <= 1.1: + is_partially_correct = True + except Exception as e: + if DEBUG: + print(f"Evaluation Exception: {e}") + print("SKIPPED") + + # Check format correctness + if match_format.search(response) is not None: + has_correct_format = True + + # Early exit if all criteria are met + if is_correct and is_partially_correct and has_correct_format: + break + + return is_correct, is_partially_correct, has_correct_format + + +def evaluate( + mt_config, + dataset, + rl_cluster, + temperature=0.7, + top_k=50, + top_p=0.95, + num_passes=1, + corr_lst=False, + make_lst=False, +): + """ + Computes accuracy and percentage of outputs matching the format. + + Args: + dataset: The evaluation dataset + rl_cluster: Model cluster for generation + temperature: Sampling temperature + top_k: Top-k sampling parameter + top_p: Top-p sampling parameter + num_passes: Number of generation passes + corr_lst: If True, only include correct responses in the list + make_lst: If True, return a list of (question, answer, responses) + + Returns: + Tuple of statistics and optionally the response list + """ + response_lst = [] + corr = 0 + partially_corr = 0 + corr_format = 0 + total = 0 + + for batch in tqdm(dataset): + answers = batch["answer"] + questions = batch["question"] + prompts = batch["prompts"] + + # Generate responses for all prompts in the batch + multiple_call_responses = generate_responses( + mt_config=mt_config, + prompts=prompts, + rl_cluster=rl_cluster, + num_passes=num_passes, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + + # Score each question-answer pair + for question, responses, answer in zip(questions, multiple_call_responses, answers): + is_correct, is_partially_correct, has_correct_format = score_responses( + mt_config=mt_config, + question=question, + responses=responses, + answer=answer, + ) + + # Update counters + if is_correct: + corr += 1 + if corr_lst and make_lst: + response_lst.append((question, answer, responses)) + else: + if not corr_lst and make_lst: + response_lst.append((question, answer, responses)) + + if is_partially_correct: + partially_corr += 1 + + if has_correct_format: + corr_format += 1 + + total += 1 + + # Print progress every 10 items + if total % 10 == 0: + print( + f"===> {corr=}, {total=}, {corr / total * 100=}, " + f"{partially_corr / total * 100=}, {corr_format / total * 100=}" + ) + + # Prepare return values + to_return = ( + corr, + total, + corr / total * 100, + partially_corr / total * 100, + corr_format / total * 100, + ) + + if make_lst: + return to_return, response_lst + return to_return diff --git a/src/MaxText/rl_utils.py b/src/MaxText/rl_utils.py new file mode 100644 index 000000000..48aa58a31 --- /dev/null +++ b/src/MaxText/rl_utils.py @@ -0,0 +1,219 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=bare-except, consider-using-generator + +import functools +import os +from pprint import pprint +import re +import sys + +from datetime import datetime +from flax import nnx +from flax.linen import partitioning as nn_partitioning +import grain +import humanize + + +import jax +from jax.sharding import Mesh +import optax +from orbax import checkpoint as ocp +import tensorflow_datasets as tfds +from tqdm.auto import tqdm +from tunix.rl import rl_cluster as rl_cluster_lib +from tunix.rl.rollout import base_rollout +from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner +from tunix.sft import metrics_logger + + +from transformers import AutoTokenizer + +from flax import linen as nn +import numpy as np +from etils import epath + +from tunix.rl.rollout.base_rollout import RolloutConfig + +from MaxText.globals import MAXTEXT_ASSETS_ROOT +import pathwaysutils + +# Let's define a RegEx for checking whether the format matches. +# +def get_match_format_regex(mt_config): + """Returns a compiled regex to extract the answer from a completion.""" + match_format = re.compile( + ( + r"^[\s]{0,}}" + rf"{mt_config.reasoning_start_token}.+?{mt_config.reasoning_end_token}.*?" + rf"{mt_config.solution_start_token}(.+?){mt_config.solution_end_token}" + r"[\s]{0,}$" + ), + flags=re.MULTILINE | re.DOTALL, + ) + if mt_config.debug: + match_format.search( + f"{mt_config.reasoning_start_token}Let me" f" think!{mt_config.reasoning_end_token}{mt_config.solution_start_token}2{mt_config.solution_end_token}", + ) + return match_format + + +def match_format_exactly(prompts, completions, mt_config, **kargs): + """ + Give the model a reward of mt_config.reward_exact_format_match points if the format matches exactly. + """ + scores = [] + match_format = get_match_format_regex(mt_config) + for completion in completions: + score = 0 + response = completion + # Match if format is seen exactly! + if match_format.search(response) is not None: + score += mt_config.reward_exact_format_match + scores.append(score) + return scores + + +def match_format_approximately(prompts, completions, mt_config, **kargs): + """ + We also reward the model if the format of the output matches partially. + """ + scores = [] + + for completion in completions: + score = 0 + # Count how many keywords are seen - we penalize if too many! + # If we see 1, then plus some points! + score += ( + mt_config.reward_partial_format_match + if completion.count(mt_config.reasoning_start_token) == 1 + else mt_config.penalty_incorrect_format + ) + score += ( + mt_config.reward_partial_format_match + if completion.count(mt_config.reasoning_end_token) == 1 + else mt_config.penalty_incorrect_format + ) + score += ( + mt_config.reward_partial_format_match + if completion.count(mt_config.solution_start_token) == 1 + else mt_config.penalty_incorrect_format + ) + score += ( + mt_config.reward_partial_format_match + if completion.count(mt_config.solution_end_token) == 1 + else mt_config.penalty_incorrect_format + ) + scores.append(score) + return scores + + +def check_answer(prompts, completions, answer, mt_config, **kargs): + """ + Reward the model if the answer is correct. A reward is also given if the answer + does not match exactly, i.e., based on how close the answer is to the correct + value. + """ + match_format = get_match_format_regex(mt_config) + extracted_responses = [ + guess.group(1) if (guess := match_format.search(c)) is not None else None + for c in completions + ] + + extracted_responses = [guess.group(1) if (guess := match_format.search(c)) is not None else None for c in completions] + + scores = [] + for guess, true_answer in zip(extracted_responses, answer): + score = 0 + if guess is None: + scores.append(0) + continue + # Correct answer gets 3 points! + if guess == true_answer: + score += mt_config.reward_exact_format_match + # Match if spaces are seen + elif guess.strip() == true_answer.strip(): + score += mt_config.reward_white_space_format_match + else: + # We also reward it if the answer is close via ratios! + # Ie if the answer is within some range, reward it! + try: + ratio = float(guess) / float(true_answer) + if ratio >= 0.9 and ratio <= 1.1: + score += mt_config.reward_ratio_guess_to_answer_high + elif ratio >= 0.8 and ratio <= 1.2: + score += mt_config.reward_ratio_guess_to_answer_low + else: + score += mt_config.penalty_incorrect_answer # Penalize wrong answers + except: + score += mt_config.penalty_incorrect_format # Penalize + scores.append(score) + return scores + + +# Sometimes, the text between `` and `` might not be one +# number; it can be a sentence. So, we extract the number and compare the answer. + +def get_match_numbers_regex(mt_config): + """Returns a compiled regex to extract the answer from a completion.""" + match_numbers = re.compile( + rf"{mt_config.solution_start_token}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL + ) + if mt_config.debug: + match_numbers.findall(f"{mt_config.solution_start_token} 0.34 {mt_config.solution_end_token}") + return match_numbers + +def check_numbers(prompts, completions, answer, mt_config, **kargs): + """ + Reward the model if the answer is correct. + """ + question = kargs["question"] + + match_numbers = get_match_numbers_regex(mt_config) + extracted_responses = [ + guess.group(1) if (guess := match_numbers.search(c)) is not None else None + for c in completions + ] + + scores = [] + if mt_config.debug: + print("START ============================") + print(f"Question: {question[0]}") + print(f"Answer: {answer[0]}") + print(f"Response: {completions[0]}") + print(f"Extracted: {extracted_responses[0]}") + print("END ==============================") + for guess, true_answer in zip(extracted_responses, answer): + if guess is None: + scores.append(0) + continue + # Convert to numbers + try: + true_answer = float(true_answer.strip()) + guess = float(guess.strip()) + scores.append(1.5 if guess == true_answer else 0.0) + except: + scores.append(0) + continue + return scores + +def extract_hash_answer(text: str, debug: bool = False) -> str | None: + """Function to extract only the answer hash from the text.""" + if debug: + print(f"Extracting answer from: {text}") + if "####" not in text: + return None + return text.split("####")[1].strip() + diff --git a/src/MaxText/rl/train_rl.py b/src/MaxText/train_rl.py similarity index 69% rename from src/MaxText/rl/train_rl.py rename to src/MaxText/train_rl.py index 037395d00..59e1a53c2 100644 --- a/src/MaxText/rl/train_rl.py +++ b/src/MaxText/train_rl.py @@ -54,6 +54,8 @@ --steps=100 """ +from pprint import pprint +from typing import Sequence from absl import app import os import re @@ -80,62 +82,12 @@ from MaxText import model_creation_utils from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter +from MaxText import rl_utils -# ====== Reproducibility ====== -SEED = 42 # We use OpenAI's GSM8K dataset. GSM8K comprises grade school math word problems. -def extract_hash_answer(text: str, debug: bool = False) -> str | None: - if debug: - print(f"Extracting answer from: {text}") - if "####" not in text: - return None - return text.split("####")[1].strip() - - -def get_dataset(data_dir, split="train") -> grain.MapDataset: - # Download data - if not os.path.exists(data_dir): - os.makedirs(data_dir) - - data = tfds.data_source( - "gsm8k", - split=split, - data_dir=data_dir, - builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD}, - download=True, - ) - - loaded_dataset = ( - grain.MapDataset.source(data) - .shuffle(seed=SEED) - .map( - lambda x: { - # passed to model forward pass - "prompts": model_tokenizer.apply_chat_template( - [ - { - "role": "user", - "content": TEMPLATE.format( - system_prompt=SYSTEM_PROMPT, - question=x["question"].decode("utf-8"), - ), - }, - ], - tokenize=False, - add_generation_prompt=True, - ), - # passed to reward functions - "question": x["question"].decode("utf-8"), - # passed to reward functions - "answer": extract_hash_answer(x["answer"].decode("utf-8"), debug=mt_config.debug), - } - ) - ) - return loaded_dataset - def get_maxtext_model(config, devices=None): """ @@ -173,124 +125,46 @@ def setup_device_allocation(mt_config, use_pathways: bool = False): return trainer_devices, sampler_devices, num_vms +def get_dataset(model_tokenizer, mt_config, data_dir, split="train") -> grain.MapDataset: + # Download data + if not os.path.exists(data_dir): + os.makedirs(data_dir) -# Reward Functions -def match_format_exactly(prompts, completions, mt_config, **kwargs): - """Reward exact format matching.""" - scores = [] - match_format = re.compile( - ( - r"^[\s]{0,}}" - rf"{mt_config.reasoning_start_token}.+?{mt_config.reasoning_end_token}.*?" - rf"{mt_config.solution_start_token}(.+?){mt_config.solution_end_token}" - r"[\s]{0,}$" - ), - flags=re.MULTILINE | re.DOTALL, - ) - - for completion in completions: - score = 0 - if match_format.search(completion) is not None: - score += mt_config.reward_exact_format_match - scores.append(score) - return scores - - -def match_format_approximately(prompts, completions, mt_config, **kwargs): - """Reward approximate format matching.""" - scores = [] - for completion in completions: - score = 0 - score += ( - mt_config.reward_partial_format_match - if completion.count(mt_config.reasoning_start_token) == 1 - else mt_config.penalty_incorrect_format - ) - score += ( - mt_config.reward_partial_format_match - if completion.count(mt_config.reasoning_end_token) == 1 - else mt_config.penalty_incorrect_format - ) - score += ( - mt_config.reward_partial_format_match - if completion.count(mt_config.solution_start_token) == 1 - else mt_config.penalty_incorrect_format - ) - score += ( - mt_config.reward_partial_format_match - if completion.count(mt_config.solution_end_token) == 1 - else mt_config.penalty_incorrect_format - ) - scores.append(score) - return scores - - -def check_answer(prompts, completions, answer, mt_config, **kwargs): - """Reward correct answers.""" - match_format = re.compile( - ( - r"^[\s]{0,}}" - rf"{mt_config.reasoning_start_token}.+?{mt_config.reasoning_end_token}.*?" - rf"{mt_config.solution_start_token}(.+?){mt_config.solution_end_token}" - r"[\s]{0,}$" - ), - flags=re.MULTILINE | re.DOTALL, + data = tfds.data_source( + "gsm8k", + split=split, + data_dir=data_dir, + builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD}, + download=True, ) - extracted_responses = [ - guess.group(1) if (guess := match_format.search(r)) is not None else None for r in completions - ] - - scores = [] - for guess, true_answer in zip(extracted_responses, answer): - score = 0 - if guess is None: - scores.append(0) - continue - - if guess == true_answer: - score += mt_config.reward_exact_format_match - elif guess.strip() == true_answer.strip(): - score += mt_config.reward_white_space_format_match - else: - try: - ratio = float(guess) / float(true_answer) - if 0.9 <= ratio <= 1.1: - score += mt_config.reward_ratio_guess_to_answer_high - elif 0.8 <= ratio <= 1.2: - score += mt_config.reward_ratio_guess_to_answer_low - else: - score += mt_config.penalty_incorrect_answer - except (ValueError, ZeroDivisionError): - score += mt_config.penalty_incorrect_format - scores.append(score) - return scores - - -def check_numbers(prompts, completions, answer, mt_config, **kwargs): - """Reward correct numerical answers.""" - match_numbers = re.compile( - rf"{mt_config.solution_start_token}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL + loaded_dataset = ( + grain.MapDataset.source(data) + .shuffle(seed=mt_config.data_shuffle_seed) + .map( + lambda x: { + # passed to model forward pass + "prompts": model_tokenizer.apply_chat_template( + [ + { + "role": "user", + "content": mt_config.template.format( + system_prompt=mt_config.system_prompt, + question=x["question"].decode("utf-8"), + ), + }, + ], + tokenize=False, + add_generation_prompt=True, + ), + # passed to reward functions + "question": x["question"].decode("utf-8"), + # passed to reward functions + "answer": rl_utils.extract_hash_answer(x["answer"].decode("utf-8")), + } + ) ) - extracted_responses = [ - guess.group(1) if (guess := match_numbers.search(r)) is not None else None - for r in completions - ] - - scores = [] - for guess, true_answer in zip(extracted_responses, answer): - if guess is None: - scores.append(0) - continue - try: - true_answer = float(true_answer.strip()) - guess = float(guess.strip()) - scores.append(1.5 if guess == true_answer else 0.0) - except (ValueError, TypeError): - scores.append(0) - continue - return scores - + return loaded_dataset def rl_train(mt_config): """ @@ -317,23 +191,29 @@ def rl_train(mt_config): if not os.path.exists(test_data_dir): os.makedirs(test_data_dir) + # Create model tokenizer + model_tokenizer = AutoTokenizer.from_pretrained(mt_config.hf_model_name) # Load datasets - print("Loading GSM8K dataset...") - train_dataset, tokenizer = get_gsm8k_dataset( - train_data_dir, - split="train", - batch_size=mt_config.per_device_batch_size, - num_batches=getattr(mt_config, "num_batches", 4), - ) + dataset = get_dataset(model_tokenizer, mt_config, train_data_dir, "train").batch(mt_config.batch_size)[:mt_config.num_batches] + + if mt_config.train_fraction == 1.0: + train_dataset = dataset.repeat(mt_config.num_epochs) + val_dataset = None + else: + train_dataset = dataset[: int(len(dataset) * mt_config.train_fraction)] + train_dataset = train_dataset.repeat(mt_config.num_epochs) + + val_dataset = dataset[int(len(dataset) * mt_config.train_fraction) :].repeat(mt_config.num_epochs) + + test_dataset = get_dataset(model_tokenizer, mt_config, test_data_dir, "test").batch(mt_config.batch_size)[:mt_config.num_test_batches] + + + # Let's see how one batch of the dataset looks like! + if mt_config.debug: + for ele in train_dataset[:1]: + pprint(ele) - # Load test dataset for evaluation (currently not used in training loop) - get_gsm8k_dataset( - test_data_dir, - split="test", - batch_size=mt_config.per_device_batch_size, - num_batches=getattr(mt_config, "num_test_batches", 5), - ) # Setup device allocation @@ -347,7 +227,7 @@ def rl_train(mt_config): # Load reference model print("Creating reference model and also meshes for reference and rollout") - reference_model, reference_mesh = get_ref_maxtext_model(mt_config, trainer_devices) + reference_model, reference_mesh = get_maxtext_model(mt_config, trainer_devices) devices_array = maxtext_utils.create_device_mesh(mt_config, sampler_devices) # if trainer_devices=sampler_devices, then rollout_mesh=reference_mesh # else rollout_mesh uses sampler_devices @@ -357,8 +237,8 @@ def rl_train(mt_config): nnx.display(reference_model) print(f"Reference mesh shape: {reference_mesh.shape}") - # Sanity check that weights are loaded correctly - _maxtext_state_flatten = nnx.state(llama3_1_70b).flat_state() + # Sanity check that weights are loaded correctly. + _maxtext_state_flatten = nnx.state(reference_model).flat_state() maxtext_state_flatten = {".".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten} print( f"maxtext_state_flatten[base.token_embedder.embedding].value={maxtext_state_flatten['base.token_embedder.embedding'].value}" @@ -409,8 +289,8 @@ def rl_train(mt_config): # Setup metrics logging log_dir=mt_config.tensorboard_dir - print(f"TensorBoard logs directory: {LOG_DIR}") - print(f"tensorboard --logdir {LOG_DIR} --port=8086") + print(f"TensorBoard logs directory: {log_dir}") + print(f"tensorboard --logdir {log_dir} --port=8086") metrics_logging_options = metrics_logger.MetricsLoggerOptions(log_dir=log_dir, flush_every_n_steps=20) # Profiler configurations @@ -471,7 +351,7 @@ def rl_train(mt_config): rl_cluster = rl_cluster_lib.RLCluster( actor=policy_model, reference=reference_model, - tokenizer=tokenizer, + tokenizer=model_tokenizer, cluster_config=cluster_config, ) @@ -479,15 +359,26 @@ def rl_train(mt_config): print("Setting up GRPO trainer...") rl_trainer = GrpoLearner( rl_cluster=rl_cluster, - reward_fns=[ - lambda **kwargs: match_format_exactly(mt_config=mt_config, **kwargs), - lambda **kwargs: match_format_approximately(mt_config=mt_config, **kwargs), - lambda **kwargs: check_answer(mt_config=mt_config, **kwargs), - lambda **kwargs: check_numbers(mt_config=mt_config, **kwargs), + reward_fns=[ # type: ignore + lambda **kwargs: rl_utils.match_format_exactly(mt_config=mt_config, **kwargs), + lambda **kwargs: rl_utils.match_format_approximately(mt_config=mt_config, **kwargs), + lambda **kwargs: rl_utils.check_answer(mt_config=mt_config, **kwargs), + lambda **kwargs: rl_utils.check_numbers(mt_config=mt_config, **kwargs), ], grpo_config=grpo_config, ) + + + if mt_config.debug: + # verify if vllm sampler works + output = rl_cluster.rollout.generate( + ["The capital of France is"], + rollout_config=base_rollout.RolloutConfig(max_tokens_to_generate=64, temperature=0.1), + ) + + print(f"Output: {output}") + # Start training print("Starting GRPO training...") with policy_mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules): @@ -497,13 +388,11 @@ def rl_train(mt_config): max_logging.log(f"Saving profiles to {profile_dir}") jax.profiler.start_trace(profile_dir) - with mesh, nn_partitioning.axis_rules(config_policy.logical_axis_rules): - rl_trainer.train(dataset) + with reference_mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules): + rl_trainer.train(train_dataset) jax.profiler.stop_trace() - print("=" * 80) print("GRPO Training Completed Successfully!") - print("=" * 80) return rl_trainer, rl_cluster From 27305f879c7c191618f463e1ed60b95b42c941d2 Mon Sep 17 00:00:00 2001 From: A9isha Date: Fri, 31 Oct 2025 09:32:16 +0000 Subject: [PATCH 17/31] refactored --- src/MaxText/configs/rl.yml | 88 +++++++---------- src/MaxText/evaluate_rl.py | 51 ++++------ src/MaxText/rl_utils.py | 79 ++++++++------- src/MaxText/train_rl.py | 191 +++++++++++++++++++++---------------- 4 files changed, 199 insertions(+), 210 deletions(-) diff --git a/src/MaxText/configs/rl.yml b/src/MaxText/configs/rl.yml index 79081e644..a0c2626e2 100644 --- a/src/MaxText/configs/rl.yml +++ b/src/MaxText/configs/rl.yml @@ -61,15 +61,13 @@ loss_algo: 'grpo' # grpo or gspo-token # Model-Specific Overrides (examples) # For Llama3.1-8B: # model_name: llama3.1-8b -# tokenizer_path: meta-llama/Llama-3.1-8B-Instruct +# HF tokenizer_path: meta-llama/Llama-3.1-8B-Instruct # # For Llama3.1-70B with Pathways: # model_name: llama3.1-70b -# tokenizer_path: meta-llama/Llama-3.1-70B-Instruct +# HF tokenizer_path: meta-llama/Llama-3.1-70B-Instruct -async_checkpointing: 'false' -checkpoint_period: 5 -skip_jax_distributed_system: True +# ====== MaxText configs ====== weight_dtype: 'bfloat16' attention: 'dot_product' remat_policy: 'custom' @@ -77,17 +75,16 @@ decoder_layer_input: 'offload' query_proj: 'offload' key_proj: 'offload' value_proj: 'offload' -# for vLLM -hf_model_name: 'meta-llama/Llama-3.1-70B-Instruct' + # ====== Training ====== batch_size: 1 # Increase `batch_size` and `MAX_STEPS` for better results. -# num_batches = 3738 -num_batches = 4 # 200 +# num_batches: 3738 +num_batches: 4 # 200 # Keep `num_test_batches` low so that evaluation runs quickly. It can be # increased to a max. of 330 (if batch size is 4). -num_test_batches = 5 # 200 +num_test_batches: 5 # 200 train_fraction: 1.0 eval_interval: 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`. @@ -96,6 +93,20 @@ num_epochs: 1 # can potentially train for more epochs gradient_clipping_threshold: 0.1 +# ====== Evaluation ====== +generation_configs: + greedy: + temperature: 0.01 + top_k: 1 + top_p: 1.0 + standard: + temperature: 0.7 + top_k: 50 + top_p: 0.95 + liberal: + temperature: 0.85 + top_k: 2000 + top_p: 1.0 # greedy eval_temperature: 0.01 @@ -108,13 +119,9 @@ eval_top_p: 1.0 # ====== Inference ====== -# Important to keep a high-ish temperature for varied, diverse responses during -# training. -decode_sampling_temperature: 0.9 -decode_sampling_top_k: 50 -decode_sampling_nucleus_p: 1.0 - # for vLLM +hf_model_name: None + # === Generation during GRPO training === # max Lengths for prompt and completion max_prefill_predict_length: 256 @@ -123,13 +130,15 @@ kv_cache_buffer: 256 hbm_utilization_vllm: 0.72 swap_space_vllm_gb: 2 # Generation Configuration During Training +# Important to keep a high-ish temperature for varied, diverse responses during +# training. decode_sampling_temperature: 0.9 -decode_sampling_top_p: 1.0 decode_sampling_top_k: 50 +decode_sampling_nucleus_p: 1.0 # ====== Checkpoint Configuration ====== enable_checkpointing: True -async_checkpointing: True +async_checkpointing: False checkpoint_period: 50 max_num_checkpoints_to_keep: 10 @@ -162,40 +171,9 @@ template: | model -# TODO: fix this -# Dataset Configuration -dataset_type: hf # Huggingface input pipeline -hf_path: 'gsm8k' -hf_data_split: 'main' -hf_data_files: 'train' - - -# Pathways Inference Configuration -# For multi-host/multi-slice setups -use_pathways_reshard: False -inference_devices_per_replica: 4 -inference_replicas: 1 - -# Tokenizer Settings -add_bos: False -add_eos: False -return_log_prob: True - -# Performance and Memory -weight_dtype: bfloat16 -dtype: bfloat16 - -# Splash Attention Block Sizes -# Tuned for GRPO workloads -sa_block_q: 128 -sa_block_kv: 128 -sa_block_kv_compute: 128 -sa_block_q_dkv: 128 -sa_block_kv_dkv: 128 -sa_block_kv_dkv_compute: 128 -sa_block_q_dq: 128 -sa_block_kv_dq: 128 -sa_use_fused_bwd_kernel: False -sa_q_layout: "HEAD_DIM_MINOR" -sa_k_layout: "HEAD_DIM_MINOR" -sa_v_layout: "HEAD_DIM_MINOR" \ No newline at end of file +# # TODO(@mazumdera): fix this +# # Dataset Configuration +# dataset_type: hf # Huggingface input pipeline +# hf_path: 'gsm8k' +# hf_data_split: 'main' +# hf_data_files: 'train' diff --git a/src/MaxText/evaluate_rl.py b/src/MaxText/evaluate_rl.py index 99657c346..efc6dd447 100644 --- a/src/MaxText/evaluate_rl.py +++ b/src/MaxText/evaluate_rl.py @@ -18,6 +18,7 @@ import os from pprint import pprint import re + import sys from datetime import datetime @@ -51,11 +52,6 @@ from MaxText import rl_utils # ## Evaluate -# -# -# Before we train the model, let's evaluate the model on the test set so we can -# see the improvement post training. -# # We evaluate it in two ways: # # **Quantitative** @@ -76,13 +72,10 @@ def generate_responses( - mt_config, + tmvp_config, prompts, rl_cluster, num_passes=1, - temperature=0.7, - top_k=50, - top_p=0.95, ): """ Generate responses for a batch of prompts across multiple passes. @@ -104,15 +97,15 @@ def generate_responses( responses = rl_cluster.rollout.generate( prompts, rollout_config=RolloutConfig( - max_tokens_to_generate=mt_config.max_target_length, - temperature=mt_config.eval_temperature, - top_k=mt_config.eval_top_k, - top_p=mt_config.eval_top_p, + max_tokens_to_generate=tmvp_config.max_target_length, + temperature=tmvp_config.eval_temperature, + top_k=tmvp_config.eval_top_k, + top_p=tmvp_config.eval_top_p, ), ) responses = responses.text - if mt_config.debug: + if tmvp_config.debug: print(f"Pass {p+1}/{num_passes}, responses: {responses}") for idx, response in enumerate(responses): @@ -121,7 +114,7 @@ def generate_responses( return multiple_call_responses -def score_responses(mt_config, question, responses, answer): +def score_responses(tmvp_config, question, responses, answer): """ Score a set of responses for a single question. @@ -133,10 +126,10 @@ def score_responses(mt_config, question, responses, answer): Returns: Tuple of (is_correct, is_partially_correct, has_correct_format) """ - match_format = rl_utils.get_match_format_regex(mt_config) - match_numbers = rl_utils.get_match_numbers_regex(mt_config) + match_format = rl_utils.get_match_format_regex(tmvp_config) + match_numbers = rl_utils.get_match_numbers_regex(tmvp_config) - if DEBUG: + if tmvp_config.debug: print("========================================") print(f"Evaluation Question: {question}") print(f"Evaluation Answer: {answer}") @@ -151,7 +144,7 @@ def score_responses(mt_config, question, responses, answer): # Extract numerical response extracted_response = guess.group(1) if (guess := match_numbers.search(response)) is not None else "-1000000" - if DEBUG: + if tmvp_config.debug: print(f"Evaluation extracted_response: {extracted_response}") # Check exact correctness @@ -164,7 +157,7 @@ def score_responses(mt_config, question, responses, answer): if 0.9 <= ratio <= 1.1: is_partially_correct = True except Exception as e: - if DEBUG: + if tmvp_config.debug: print(f"Evaluation Exception: {e}") print("SKIPPED") @@ -180,12 +173,9 @@ def score_responses(mt_config, question, responses, answer): def evaluate( - mt_config, + tmvp_config, dataset, rl_cluster, - temperature=0.7, - top_k=50, - top_p=0.95, num_passes=1, corr_lst=False, make_lst=False, @@ -194,11 +184,9 @@ def evaluate( Computes accuracy and percentage of outputs matching the format. Args: + tmvp_config: Configuration object dataset: The evaluation dataset - rl_cluster: Model cluster for generation - temperature: Sampling temperature - top_k: Top-k sampling parameter - top_p: Top-p sampling parameter + rl_cluster: Model cluster for generation. num_passes: Number of generation passes corr_lst: If True, only include correct responses in the list make_lst: If True, return a list of (question, answer, responses) @@ -219,19 +207,16 @@ def evaluate( # Generate responses for all prompts in the batch multiple_call_responses = generate_responses( - mt_config=mt_config, + tmvp_config=tmvp_config, prompts=prompts, rl_cluster=rl_cluster, num_passes=num_passes, - temperature=temperature, - top_k=top_k, - top_p=top_p, ) # Score each question-answer pair for question, responses, answer in zip(questions, multiple_call_responses, answers): is_correct, is_partially_correct, has_correct_format = score_responses( - mt_config=mt_config, + tmvp_config=tmvp_config, question=question, responses=responses, answer=answer, diff --git a/src/MaxText/rl_utils.py b/src/MaxText/rl_utils.py index 48aa58a31..dd6c59bc6 100644 --- a/src/MaxText/rl_utils.py +++ b/src/MaxText/rl_utils.py @@ -52,41 +52,41 @@ # Let's define a RegEx for checking whether the format matches. # -def get_match_format_regex(mt_config): +def get_match_format_regex(tmvp_config): """Returns a compiled regex to extract the answer from a completion.""" match_format = re.compile( ( r"^[\s]{0,}}" - rf"{mt_config.reasoning_start_token}.+?{mt_config.reasoning_end_token}.*?" - rf"{mt_config.solution_start_token}(.+?){mt_config.solution_end_token}" + rf"{tmvp_config.reasoning_start_token}.+?{tmvp_config.reasoning_end_token}.*?" + rf"{tmvp_config.solution_start_token}(.+?){tmvp_config.solution_end_token}" r"[\s]{0,}$" ), flags=re.MULTILINE | re.DOTALL, ) - if mt_config.debug: + if tmvp_config.debug: match_format.search( - f"{mt_config.reasoning_start_token}Let me" f" think!{mt_config.reasoning_end_token}{mt_config.solution_start_token}2{mt_config.solution_end_token}", + f"{tmvp_config.reasoning_start_token}Let me" f" think!{tmvp_config.reasoning_end_token}{tmvp_config.solution_start_token}2{tmvp_config.solution_end_token}", ) return match_format -def match_format_exactly(prompts, completions, mt_config, **kargs): +def match_format_exactly(prompts, completions, tmvp_config, **kargs): """ - Give the model a reward of mt_config.reward_exact_format_match points if the format matches exactly. + Give the model a reward of tmvp_config.reward_exact_format_match points if the format matches exactly. """ scores = [] - match_format = get_match_format_regex(mt_config) + match_format = get_match_format_regex(tmvp_config) for completion in completions: score = 0 response = completion # Match if format is seen exactly! if match_format.search(response) is not None: - score += mt_config.reward_exact_format_match + score += tmvp_config.reward_exact_format_match scores.append(score) return scores -def match_format_approximately(prompts, completions, mt_config, **kargs): +def match_format_approximately(prompts, completions, tmvp_config, **kargs): """ We also reward the model if the format of the output matches partially. """ @@ -97,68 +97,66 @@ def match_format_approximately(prompts, completions, mt_config, **kargs): # Count how many keywords are seen - we penalize if too many! # If we see 1, then plus some points! score += ( - mt_config.reward_partial_format_match - if completion.count(mt_config.reasoning_start_token) == 1 - else mt_config.penalty_incorrect_format + tmvp_config.reward_partial_format_match + if completion.count(tmvp_config.reasoning_start_token) == 1 + else tmvp_config.penalty_incorrect_format ) score += ( - mt_config.reward_partial_format_match - if completion.count(mt_config.reasoning_end_token) == 1 - else mt_config.penalty_incorrect_format + tmvp_config.reward_partial_format_match + if completion.count(tmvp_config.reasoning_end_token) == 1 + else tmvp_config.penalty_incorrect_format ) score += ( - mt_config.reward_partial_format_match - if completion.count(mt_config.solution_start_token) == 1 - else mt_config.penalty_incorrect_format + tmvp_config.reward_partial_format_match + if completion.count(tmvp_config.solution_start_token) == 1 + else tmvp_config.penalty_incorrect_format ) score += ( - mt_config.reward_partial_format_match - if completion.count(mt_config.solution_end_token) == 1 - else mt_config.penalty_incorrect_format + tmvp_config.reward_partial_format_match + if completion.count(tmvp_config.solution_end_token) == 1 + else tmvp_config.penalty_incorrect_format ) scores.append(score) return scores -def check_answer(prompts, completions, answer, mt_config, **kargs): +def check_answer(prompts, completions, answer, tmvp_config, **kargs): """ Reward the model if the answer is correct. A reward is also given if the answer does not match exactly, i.e., based on how close the answer is to the correct value. """ - match_format = get_match_format_regex(mt_config) + match_format = get_match_format_regex(tmvp_config) extracted_responses = [ guess.group(1) if (guess := match_format.search(c)) is not None else None for c in completions ] - extracted_responses = [guess.group(1) if (guess := match_format.search(c)) is not None else None for c in completions] - scores = [] for guess, true_answer in zip(extracted_responses, answer): score = 0 if guess is None: scores.append(0) continue - # Correct answer gets 3 points! + # Correct answer gets points! if guess == true_answer: - score += mt_config.reward_exact_format_match + score += tmvp_config.reward_exact_format_match # Match if spaces are seen elif guess.strip() == true_answer.strip(): - score += mt_config.reward_white_space_format_match + score += tmvp_config.reward_white_space_format_match else: # We also reward it if the answer is close via ratios! # Ie if the answer is within some range, reward it! try: ratio = float(guess) / float(true_answer) if ratio >= 0.9 and ratio <= 1.1: - score += mt_config.reward_ratio_guess_to_answer_high + score += tmvp_config.reward_ratio_guess_to_answer_high elif ratio >= 0.8 and ratio <= 1.2: - score += mt_config.reward_ratio_guess_to_answer_low + score += tmvp_config.reward_ratio_guess_to_answer_low else: - score += mt_config.penalty_incorrect_answer # Penalize wrong answers + score += tmvp_config.penalty_incorrect_answer # Penalize wrong answers except: - score += mt_config.penalty_incorrect_format # Penalize + score += tmvp_config.penalty_incorrect_format # Penalize scores.append(score) return scores @@ -166,29 +164,29 @@ def check_answer(prompts, completions, answer, mt_config, **kargs): # Sometimes, the text between `` and `` might not be one # number; it can be a sentence. So, we extract the number and compare the answer. -def get_match_numbers_regex(mt_config): +def get_match_numbers_regex(tmvp_config): """Returns a compiled regex to extract the answer from a completion.""" match_numbers = re.compile( - rf"{mt_config.solution_start_token}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL + rf"{tmvp_config.solution_start_token}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL ) - if mt_config.debug: - match_numbers.findall(f"{mt_config.solution_start_token} 0.34 {mt_config.solution_end_token}") + if tmvp_config.debug: + match_numbers.findall(f"{tmvp_config.solution_start_token} 0.34 {tmvp_config.solution_end_token}") return match_numbers -def check_numbers(prompts, completions, answer, mt_config, **kargs): +def check_numbers(prompts, completions, answer, tmvp_config, **kargs): """ Reward the model if the answer is correct. """ question = kargs["question"] - match_numbers = get_match_numbers_regex(mt_config) + match_numbers = get_match_numbers_regex(tmvp_config) extracted_responses = [ guess.group(1) if (guess := match_numbers.search(c)) is not None else None for c in completions ] scores = [] - if mt_config.debug: + if tmvp_config.debug: print("START ============================") print(f"Question: {question[0]}") print(f"Answer: {answer[0]}") @@ -216,4 +214,3 @@ def extract_hash_answer(text: str, debug: bool = False) -> str | None: if "####" not in text: return None return text.split("####")[1].strip() - diff --git a/src/MaxText/train_rl.py b/src/MaxText/train_rl.py index 59e1a53c2..764ee7224 100644 --- a/src/MaxText/train_rl.py +++ b/src/MaxText/train_rl.py @@ -25,7 +25,7 @@ Usage Examples: # Llama3.1-8B (single host) - python3 src/MaxText/examples/rl_trainer.py \\ + python3 src/MaxText/examples/train_rl \\ --model_name=llama3.1-8b \\ --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \\ --load_parameters_path=gs://path/to/checkpoint \\ @@ -34,7 +34,7 @@ --steps=100 # Llama3.1-70B with Pathways (multi-host) - python3 src/MaxText/examples/rl_trainer.py \\ + python3 src/MaxText/examples/train_rl \\ --model_name=llama3.1-70b \\ --tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \\ --load_parameters_path=gs://path/to/checkpoint \\ @@ -44,7 +44,7 @@ --steps=100 # Custom dataset - python3 src/MaxText/examples/rl_trainer.py \\ + python3 src/MaxText/examples/train_rl \\ --model_name=llama3.1-8b \\ --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \\ --load_parameters_path=gs://path/to/checkpoint \\ @@ -55,14 +55,15 @@ """ from pprint import pprint -from typing import Sequence from absl import app import os -import re +import sys +from typing import Sequence import jax from jax.sharding import Mesh from flax.linen import partitioning as nn_partitioning +from flax import nnx import optax from orbax import checkpoint as ocp import tensorflow_datasets as tfds @@ -72,16 +73,27 @@ import pathwaysutils +# add the parent directory (two levels up to say ~/HOME/maxtext) to sys.path if currenlt runnig from +# ~/HOME/maxtext/MaxText/examples + +# Get the directory of the current script +script_dir = os.path.dirname(os.path.abspath(__file__)) + +# Go up two levels to get the project root +project_root = os.path.abspath(os.path.join(script_dir, "..", "..", "..")) + +# Add the project root to the Python path +sys.path.insert(0, project_root) + from tunix.rl import rl_cluster as rl_cluster_lib from tunix.rl.rollout import base_rollout from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner from tunix.sft import metrics_logger -from tunix.models.llama3 import model as llama3_lib from MaxText import max_logging, max_utils, maxtext_utils, pyconfig from MaxText import model_creation_utils from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter - +from MaxText.evaluate_rl import evaluate from MaxText import rl_utils @@ -102,11 +114,11 @@ def get_maxtext_model(config, devices=None): return tunix_model, mesh -def setup_device_allocation(mt_config, use_pathways: bool = False): +def setup_device_allocation(tmvp_config, use_pathways: bool = False): """Setup device allocation for training and inference.""" devices = jax.devices() - num_vms = len(devices) // mt_config.chips_per_vm + num_vms = len(devices) // tmvp_config.chips_per_vm trainer_devices = devices sampler_devices = devices if num_vms >= 2 and use_pathways: @@ -114,18 +126,18 @@ def setup_device_allocation(mt_config, use_pathways: bool = False): # based on trainer_devices_fraction and sampler_devices_fraction print(f"{num_vms} VMs detected, allocating trainer and sampler devices, and using Pathways.") num_devices = len(devices) - num_trainer_devices = int(num_devices * mt_config.trainer_devices_fraction) - num_sampler_devices = int(num_devices * mt_config.sampler_devices_fraction) + num_trainer_devices = int(num_devices * tmvp_config.trainer_devices_fraction) + num_sampler_devices = int(num_devices * tmvp_config.sampler_devices_fraction) trainer_devices = devices[:num_trainer_devices] sampler_devices = devices[num_devices - num_sampler_devices :] - if mt_config.trainer_devices_fraction!=1.0: + if tmvp_config.trainer_devices_fraction!=1.0: print(f"Using first {len(trainer_devices)} devices as Trainer devices") - if mt_config.sampler_devices_fraction != 1.0: + if tmvp_config.sampler_devices_fraction != 1.0: print(f"Using last {len(sampler_devices)} devices as Sampler devices") return trainer_devices, sampler_devices, num_vms -def get_dataset(model_tokenizer, mt_config, data_dir, split="train") -> grain.MapDataset: +def get_dataset(model_tokenizer, tmvp_config, data_dir, split="train") -> grain.MapDataset: # Download data if not os.path.exists(data_dir): os.makedirs(data_dir) @@ -140,7 +152,7 @@ def get_dataset(model_tokenizer, mt_config, data_dir, split="train") -> grain.Ma loaded_dataset = ( grain.MapDataset.source(data) - .shuffle(seed=mt_config.data_shuffle_seed) + .shuffle(seed=tmvp_config.data_shuffle_seed) .map( lambda x: { # passed to model forward pass @@ -148,8 +160,8 @@ def get_dataset(model_tokenizer, mt_config, data_dir, split="train") -> grain.Ma [ { "role": "user", - "content": mt_config.template.format( - system_prompt=mt_config.system_prompt, + "content": tmvp_config.template.format( + system_prompt=tmvp_config.system_prompt, question=x["question"].decode("utf-8"), ), }, @@ -166,20 +178,20 @@ def get_dataset(model_tokenizer, mt_config, data_dir, split="train") -> grain.Ma ) return loaded_dataset -def rl_train(mt_config): +def rl_train(tmvp_config): """ Run RL training with the provided configuration. Args: - mt_config: MaxText configuration object + tmvp_config: MaxText configuration object """ # ====== Debug flag for verbose logs ====== - DEBUG = mt_config.debug + DEBUG = tmvp_config.debug print("Starting GRPO Training") # Number of training steps. - max_train_steps = int(mt_config.num_batches * mt_config.num_iterations * mt_config.train_fraction * mt_config.num_epochs) + max_train_steps = int(tmvp_config.num_batches * tmvp_config.num_iterations * tmvp_config.train_fraction * tmvp_config.num_epochs) # ====== Data ====== # Setup data directories @@ -192,25 +204,25 @@ def rl_train(mt_config): os.makedirs(test_data_dir) # Create model tokenizer - model_tokenizer = AutoTokenizer.from_pretrained(mt_config.hf_model_name) + model_tokenizer = AutoTokenizer.from_pretrained(tmvp_config.hf_model_name) # Load datasets - dataset = get_dataset(model_tokenizer, mt_config, train_data_dir, "train").batch(mt_config.batch_size)[:mt_config.num_batches] + dataset = get_dataset(model_tokenizer, tmvp_config, train_data_dir, "train").batch(tmvp_config.batch_size)[:tmvp_config.num_batches] - if mt_config.train_fraction == 1.0: - train_dataset = dataset.repeat(mt_config.num_epochs) + if tmvp_config.train_fraction == 1.0: + train_dataset = dataset.repeat(tmvp_config.num_epochs) val_dataset = None else: - train_dataset = dataset[: int(len(dataset) * mt_config.train_fraction)] - train_dataset = train_dataset.repeat(mt_config.num_epochs) + train_dataset = dataset[: int(len(dataset) * tmvp_config.train_fraction)] + train_dataset = train_dataset.repeat(tmvp_config.num_epochs) - val_dataset = dataset[int(len(dataset) * mt_config.train_fraction) :].repeat(mt_config.num_epochs) + val_dataset = dataset[int(len(dataset) * tmvp_config.train_fraction) :].repeat(tmvp_config.num_epochs) - test_dataset = get_dataset(model_tokenizer, mt_config, test_data_dir, "test").batch(mt_config.batch_size)[:mt_config.num_test_batches] + test_dataset = get_dataset(model_tokenizer, tmvp_config, test_data_dir, "test").batch(tmvp_config.batch_size)[:tmvp_config.num_test_batches] # Let's see how one batch of the dataset looks like! - if mt_config.debug: + if tmvp_config.debug: for ele in train_dataset[:1]: pprint(ele) @@ -223,16 +235,16 @@ def rl_train(mt_config): else: use_pathways = False print(f"Use Pathways: {use_pathways}") - trainer_devices, sampler_devices, num_vms = setup_device_allocation(mt_config, use_pathways) + trainer_devices, sampler_devices, num_vms = setup_device_allocation(tmvp_config, use_pathways) # Load reference model print("Creating reference model and also meshes for reference and rollout") - reference_model, reference_mesh = get_maxtext_model(mt_config, trainer_devices) - devices_array = maxtext_utils.create_device_mesh(mt_config, sampler_devices) + reference_model, reference_mesh = get_maxtext_model(tmvp_config, trainer_devices) + devices_array = maxtext_utils.create_device_mesh(tmvp_config, sampler_devices) # if trainer_devices=sampler_devices, then rollout_mesh=reference_mesh # else rollout_mesh uses sampler_devices - rollout_mesh = Mesh(devices_array, mt_config.mesh_axes) - if mt_config.debug: + rollout_mesh = Mesh(devices_array, tmvp_config.mesh_axes) + if tmvp_config.debug: print("Reference Model initialized successfully") nnx.display(reference_model) print(f"Reference mesh shape: {reference_mesh.shape}") @@ -247,10 +259,10 @@ def rl_train(mt_config): # TODO: @mazumdera: change this to use lora # Load policy model print("Creating policy model with same config as reference model on trainer mesh") - policy_model, policy_mesh = get_maxtext_model(mt_config, trainer_devices) + policy_model, policy_mesh = get_maxtext_model(tmvp_config, trainer_devices) actor_mesh = policy_mesh - if mt_config.debug: + if tmvp_config.debug: print("Policy Model initialized successfully") nnx.display(policy_model) print(f"Policy mesh shape: {policy_mesh.shape}") @@ -259,36 +271,36 @@ def rl_train(mt_config): optimizer = optax.adamw( learning_rate=optax.schedules.warmup_cosine_decay_schedule( init_value=0.0, - peak_value=mt_config.learning_rate, + peak_value=tmvp_config.learning_rate, # Linearly increase learning rate from 0. to learning_rate in the first # warmup_steps_fraction training steps, and then gradually decrease the # learning rate to 0 using cosine scheduler. - warmup_steps=int(mt_config.warmup_steps_fraction*mt_config.max_train_steps), + warmup_steps=int(tmvp_config.warmup_steps_fraction*max_train_steps), decay_steps=max_train_steps, ), - b1=mt_config.adam_b1, - b2=mt_config.adam_b2, - weight_decay=mt_config.adam_weight_decay, + b1=tmvp_config.adam_b1, + b2=tmvp_config.adam_b2, + weight_decay=tmvp_config.adam_weight_decay, ) # TODO: @mazumdera: try optimizer offloading with adamw # Add gradient clipping if specified # Grad clipping to prevent large gradients. We find this # important to keep KL divergence in check. - if mt_config.gradient_clipping_threshold > 0: + if tmvp_config.gradient_clipping_threshold > 0: optimizer = optax.chain( - optax.clip_by_global_norm(max_norm=mt_config.gradient_clipping_threshold), + optax.clip_by_global_norm(max_norm=tmvp_config.gradient_clipping_threshold), optimizer, ) # Setup checkpointing - ckpt_dir = mt_config.checkpoint_dir + ckpt_dir = tmvp_config.checkpoint_dir checkpointing_options = ocp.CheckpointManagerOptions( - save_interval_steps=mt_config.checkpoint_period, mt_config.max_num_checkpoints_to_keep + save_interval_steps=tmvp_config.checkpoint_period, max_to_keep=tmvp_config.max_num_checkpoints_to_keep ) # Setup metrics logging - log_dir=mt_config.tensorboard_dir + log_dir=tmvp_config.tensorboard_dir print(f"TensorBoard logs directory: {log_dir}") print(f"tensorboard --logdir {log_dir} --port=8086") metrics_logging_options = metrics_logger.MetricsLoggerOptions(log_dir=log_dir, flush_every_n_steps=20) @@ -310,7 +322,7 @@ def rl_train(mt_config): offload_to_cpu=False, training_config=rl_cluster_lib.RLTrainingConfig( actor_optimizer=optimizer, - eval_every_n_steps=mt_config.eval_interval, + eval_every_n_steps=tmvp_config.eval_interval, max_steps=max_train_steps, # metrics logging metrics_logging_options=metrics_logging_options, @@ -321,33 +333,33 @@ def rl_train(mt_config): checkpointing_options=checkpointing_options, ), rollout_config=base_rollout.RolloutConfig( - max_tokens_to_generate=mt_config.max_target_length, - max_prompt_length=mt_config.max_prefill_predict_length, - kv_cache_size=mt_config.max_prefill_predict_length - + mt_config.max_target_length - + mt_config.kv_cache_buffer, - temperature=mt_config.decode_sampling_temperature, - top_p=mt_config.decode_sampling_nucleus_p, - top_k=mt_config.decode_sampling_top_k, + max_tokens_to_generate=tmvp_config.max_target_length, + max_prompt_length=tmvp_config.max_prefill_predict_length, + kv_cache_size=tmvp_config.max_prefill_predict_length + + tmvp_config.max_target_length + + tmvp_config.kv_cache_buffer, + temperature=tmvp_config.decode_sampling_temperature, + top_p=tmvp_config.decode_sampling_nucleus_p, + top_k=tmvp_config.decode_sampling_top_k, ), - rollout_vllm_model_version=mt_config.hf_model_name, - rollout_vllm_hbm_utilization=mt_config.hbm_utilization_vllm, + rollout_vllm_model_version=tmvp_config.hf_model_name, + rollout_vllm_hbm_utilization=tmvp_config.hbm_utilization_vllm, rollout_vllm_tpu_backend_type="jax", - rollout_vllm_swap_space_size_gb=mt_config.swap_space_vllm_gb, + rollout_vllm_swap_space_size_gb=tmvp_config.swap_space_vllm_gb, ) # Setup GRPO config grpo_config = GrpoConfig( - num_generations=mt_config.num_generations, - num_iterations=mt_config.num_iterations, - beta=mt_config.grpo_beta, - epsilon=mt_config.grpo_epsilon, - loss_algo=mt_config.loss_algo, + num_generations=tmvp_config.num_generations, + num_iterations=tmvp_config.num_iterations, + beta=tmvp_config.grpo_beta, + epsilon=tmvp_config.grpo_epsilon, + loss_algo=tmvp_config.loss_algo, ) # Create RL cluster print("Creating RL cluster...") - with nn_partitioning.axis_rules(mt_config.logical_axis_rules): + with nn_partitioning.axis_rules(tmvp_config.logical_axis_rules): rl_cluster = rl_cluster_lib.RLCluster( actor=policy_model, reference=reference_model, @@ -360,17 +372,17 @@ def rl_train(mt_config): rl_trainer = GrpoLearner( rl_cluster=rl_cluster, reward_fns=[ # type: ignore - lambda **kwargs: rl_utils.match_format_exactly(mt_config=mt_config, **kwargs), - lambda **kwargs: rl_utils.match_format_approximately(mt_config=mt_config, **kwargs), - lambda **kwargs: rl_utils.check_answer(mt_config=mt_config, **kwargs), - lambda **kwargs: rl_utils.check_numbers(mt_config=mt_config, **kwargs), + lambda **kwargs: rl_utils.match_format_exactly(tmvp_config=tmvp_config, **kwargs), + lambda **kwargs: rl_utils.match_format_approximately(tmvp_config=tmvp_config, **kwargs), + lambda **kwargs: rl_utils.check_answer(tmvp_config=tmvp_config, **kwargs), + lambda **kwargs: rl_utils.check_numbers(tmvp_config=tmvp_config, **kwargs), ], grpo_config=grpo_config, ) - if mt_config.debug: + if tmvp_config.debug: # verify if vllm sampler works output = rl_cluster.rollout.generate( ["The capital of France is"], @@ -378,26 +390,43 @@ def rl_train(mt_config): ) print(f"Output: {output}") + # + # + # Before we train the model, let's evaluate the model on the test set so we can + # see the improvement post training. + # + (corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate( + test_dataset, + rl_cluster=rl_cluster, + **tmvp_config.generation_configs["greedy"], + ) + print(f"Pre GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") + # Start training - print("Starting GRPO training...") - with policy_mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules): - rl_trainer.train(train_dataset) - profile_dir = mt_config.tensorboard_dir + profile_dir = tmvp_config.tensorboard_dir max_logging.log(f"Saving profiles to {profile_dir}") + print("Starting GRPO training...") - jax.profiler.start_trace(profile_dir) - with reference_mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules): + # jax.profiler.start_trace(profile_dir) + with reference_mesh, nn_partitioning.axis_rules(tmvp_config.logical_axis_rules): rl_trainer.train(train_dataset) - jax.profiler.stop_trace() + # jax.profiler.stop_trace() print("GRPO Training Completed Successfully!") - return rl_trainer, rl_cluster + # Let's evaluate our model! + (corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate( + test_dataset, + rl_cluster=rl_cluster, + **tmvp_config.generation_configs["greedy"], + ) + print(f"Post GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") + def main(argv: Sequence[str]) -> None: - """Main function to run SFT training. + """Main function to run RL training. Args: argv: Command-line arguments. @@ -410,10 +439,10 @@ def main(argv: Sequence[str]) -> None: os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" ) - mt_config = pyconfig.initialize(argv) + tmvp_config = pyconfig.initialize(argv) max_utils.print_system_information() - rl_train(mt_config) + rl_train(tmvp_config) if __name__ == "__main__": From dc6d0a12e9ef90a2dca39fd99a385fab868338e0 Mon Sep 17 00:00:00 2001 From: A9isha Date: Fri, 31 Oct 2025 19:48:41 +0000 Subject: [PATCH 18/31] update to flag=post-training --- docker_build_dependency_image.sh | 42 ++++---- maxtext_grpo_dependencies.Dockerfile | 147 +++++++++++++++++++++++++++ 2 files changed, 167 insertions(+), 22 deletions(-) create mode 100644 maxtext_grpo_dependencies.Dockerfile diff --git a/docker_build_dependency_image.sh b/docker_build_dependency_image.sh index 0ef64ad55..a732ec454 100644 --- a/docker_build_dependency_image.sh +++ b/docker_build_dependency_image.sh @@ -27,7 +27,7 @@ # works with any custom wheels. # bash docker_build_dependency_image.sh MODE=custom_wheels -# bash docker_build_dependency_image.sh MODE=post-training +# bash docker_build_dependency_image.sh MODE=grpo # Enable "exit immediately if any command fails" option set -e @@ -68,17 +68,17 @@ if [[ -z ${MODE} ]]; then export MODE=stable echo "Default MODE=${MODE}" export CUSTOM_JAX=0 - export INSTALL_POST_TRAINING=0 + export INSTALL_GRPO=0 elif [[ ${MODE} == "custom_wheels" ]] ; then export MODE=nightly export CUSTOM_JAX=1 - export INSTALL_POST_TRAINING=0 -elif [[ ${MODE} == "post-training" || ${MODE} == "post-training-experimental" ]] ; then - export INSTALL_POST_TRAINING=1 + export INSTALL_GRPO=0 +elif [[ ${MODE} == "grpo" || ${MODE} == "grpo-experimental" ]] ; then + export INSTALL_GRPO=1 export CUSTOM_JAX=0 else export CUSTOM_JAX=0 - export INSTALL_POST_TRAINING=0 + export INSTALL_GRPO=0 fi if [[ -z ${DEVICE} ]]; then @@ -124,9 +124,9 @@ if [[ -z ${LIBTPU_GCS_PATH+x} ]] ; then elif [[ ${MANTARAY} == "true" ]]; then echo "Building with benchmark-db" docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_db_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . - elif [[ ${INSTALL_POST_TRAINING} -eq 1 && ${DEVICE} == "tpu" ]]; then - echo "Installing MaxText stable mode dependencies for Post-Training" - docker build --network host --build-arg MODE=stable --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . + elif [[ ${INSTALL_GRPO} -eq 1 && ${DEVICE} == "tpu" ]]; then + echo "Installing MaxText stable mode dependencies for GRPO BASEIMAGE=$BASEIMAGE" + docker build --network host --build-arg MODE=stable --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . else docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . fi @@ -136,29 +136,27 @@ else docker build --network host --build-arg CUSTOM_LIBTPU=true -f ./maxtext_libtpu_path.Dockerfile -t ${LOCAL_IMAGE_NAME} . fi -if [[ ${INSTALL_POST_TRAINING} -eq 1 ]] ; then +if [[ ${INSTALL_GRPO} -eq 1 ]] ; then if [[ ${DEVICE} != "tpu" ]] ; then - echo "Error: MODE=post-training is only supported for DEVICE=tpu" + echo "Error: MODE=grpo is only supported for DEVICE=tpu" exit 1 fi - # # To install tpu_commons from a local path, we copy it into the build context, excluding __pycache__. - # # This assumes vllm, tunix, tpu_commons is a sibling directory to the current one (maxtext). - # rsync -a --exclude='__pycache__' ../tpu_commons . - # # To install vllm from a local path, we copy it into the build context, excluding __pycache__. - # # This assumes vllm is a sibling directory to the current one (maxtext). - # rsync -a --exclude='__pycache__' ../vllm . + # To install from local paths, we copy vllm and tpu-inference into the build context. + # This assumes vllm and tpu-inference are sibling directories to the current one (maxtext). + echo "Copying local vllm and tpu-inference directories into the build context..." + rsync -a --exclude='__pycache__' ../tunix . + rsync -a --exclude='__pycache__' ../tpu-inference . + rsync -a --exclude='__pycache__' ../vllm . - # rsync -a --exclude='__pycache__' ../tunix . - - # # The cleanup is set to run even if the build fails to remove the copied directory. - # trap "rm -rf ./tpu_commons ./vllm ./tunix" EXIT INT TERM + # The cleanup is set to run even if the build fails to remove the copied directories. + trap "echo 'Cleaning up copied directories...' && rm -rf ./tpu-inference ./vllm" EXIT INT TERM docker build \ --network host \ --build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \ --build-arg MODE=${MODE} \ - -f ./maxtext_post_training_dependencies.Dockerfile \ + -f ./maxtext_grpo_dependencies.Dockerfile \ -t ${LOCAL_IMAGE_NAME} . fi diff --git a/maxtext_grpo_dependencies.Dockerfile b/maxtext_grpo_dependencies.Dockerfile new file mode 100644 index 000000000..306a90405 --- /dev/null +++ b/maxtext_grpo_dependencies.Dockerfile @@ -0,0 +1,147 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ARG BASEIMAGE +FROM ${BASEIMAGE} +ARG MODE +ENV MODE=$MODE + +RUN echo "Installing GRPO dependencies (vLLM, tpu-inference) with MODE=${MODE}" +RUN pip uninstall -y jax jaxlib libtpu + +RUN pip install aiohttp==3.12.15 + +# Install Python packages that enable pip to authenticate with Google Artifact Registry automatically. +RUN pip install keyring keyrings.google-artifactregistry-auth + +RUN pip install numba==0.61.2 + +COPY tunix /tunix +RUN pip install -e /tunix --no-cache-dir + + +COPY vllm /vllm +RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir --pre \ + --extra-index-url https://pypi.org/simple/ \ + --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ + --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ + --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ + --find-links https://storage.googleapis.com/libtpu-releases/index.html \ + --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + + +COPY tpu-inference /tpu-inference +RUN pip install -e /tpu-inference --no-cache-dir --pre \ + --extra-index-url https://pypi.org/simple/ \ + --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ + --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html + +# # Install vLLM for Jax and TPUs from the artifact registry +# RUN VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \ +# --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \ +# --extra-index-url https://pypi.org/simple/ \ +# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ +# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ +# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ +# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ +# --find-links https://storage.googleapis.com/libtpu-releases/index.html \ +# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ +# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \ +# vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu + +# # Install tpu-commons from the artifact registry +# RUN pip install --no-cache-dir --pre \ +# --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \ +# --extra-index-url https://pypi.org/simple/ \ +# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ +# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ +# tpu-commons==0.1.2 + +# # Uninstall existing jax to avoid conflicts +# # RUN pip uninstall -y jax jaxlib libtpu + +# # --- STAGE 1: Install Static Dependencies --- +# # Install any packages *not* defined in your project dependency files +# RUN --mount=type=cache,target=/root/.cache/pip pip install \ +# aiohttp==3.12.15\ +# keyring \ +# keyrings.google-artifactregistry-auth + +# RUN --mount=type=cache,target=/root/.cache/pip pip install \ +# numba==0.61.2 + +# # RUN VLLM_TARGET_DEVICE="tpu" pip install vllm +# # --- STAGE 2: Install Project Dependencies (The Main Cached Layer) --- + +# # Copy *only* the dependency definition files. +# # This assumes vllm and tpu-inference are in the build context, copied from the parent directory. +# COPY vllm/requirements/tpu.txt /tmp/ +# COPY vllm/requirements/build.txt /tmp/ +# COPY vllm/requirements/common.txt /tmp/ +# COPY tpu-inference/requirements.txt /tmp/ + +# # Run the full dependency installation. +# # This entire layer is cached and will *only* be rebuilt if +# # these .txt files change. +# RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \ +# # Set the target device so pip installs the right JAX/libtpu +# # Install tpu-inference dependencies +# export VLLM_TARGET_DEVICE="tpu" && \ +# pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \ +# --extra-index-url https://pypi.org/simple/ \ +# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ +# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ +# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ +# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ +# --find-links https://storage.googleapis.com/libtpu-releases/index.html \ +# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ +# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html' + +# # Install tpu-inference dependencies +# RUN --mount=type=cache,target=/root/.cache/pip bash -c ' \ +# pip install -r /tmp/requirements.txt --no-cache-dir --pre \ +# --extra-index-url https://pypi.org/simple/ \ +# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ +# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ +# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ +# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ +# --find-links https://storage.googleapis.com/libtpu-releases/index.html \ +# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ +# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html' + +# # --- STAGE 3: Install Project Source Code --- + +# # Now, copy the full source code. This invalidates cache frequently, +# # but the next step is fast. +# COPY vllm /vllm/ +# COPY tpu-inference /tpu-inference/ +# COPY tunix /tunix + + +# # Install in editable mode. This is lightning-fast because all +# # dependencies were installed and cached in STAGE 2. +# RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/ +# RUN --mount=type=cache,target=/root/.cache/pip pip install -e /tpu-inference/ + +# RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps /tunix/ +# # RUN --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /tpu-inference/ + +RUN if [ "$MODE" = "grpo-experimental" ]; then \ + echo "MODE=grpo-experimental: Re-installing JAX/libtpu"; \ + pip uninstall -y jax jaxlib libtpu && \ + pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \ + pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \ + fi From e640c7ed3f5ed477a85c9025b09b2b573c98982a Mon Sep 17 00:00:00 2001 From: A9isha Date: Fri, 31 Oct 2025 22:01:37 +0000 Subject: [PATCH 19/31] address PR feedeback --- docker_build_dependency_image.sh | 42 ++++--- src/MaxText/configs/rl.yml | 26 ++-- src/MaxText/evaluate_rl.py | 19 ++- src/MaxText/examples/local_installation.sh | 138 +++++++++++++++++++++ src/MaxText/train_rl.py | 45 ++++--- src/MaxText/{rl_utils.py => utils_rl.py} | 0 6 files changed, 208 insertions(+), 62 deletions(-) create mode 100644 src/MaxText/examples/local_installation.sh rename src/MaxText/{rl_utils.py => utils_rl.py} (100%) diff --git a/docker_build_dependency_image.sh b/docker_build_dependency_image.sh index a732ec454..0ef64ad55 100644 --- a/docker_build_dependency_image.sh +++ b/docker_build_dependency_image.sh @@ -27,7 +27,7 @@ # works with any custom wheels. # bash docker_build_dependency_image.sh MODE=custom_wheels -# bash docker_build_dependency_image.sh MODE=grpo +# bash docker_build_dependency_image.sh MODE=post-training # Enable "exit immediately if any command fails" option set -e @@ -68,17 +68,17 @@ if [[ -z ${MODE} ]]; then export MODE=stable echo "Default MODE=${MODE}" export CUSTOM_JAX=0 - export INSTALL_GRPO=0 + export INSTALL_POST_TRAINING=0 elif [[ ${MODE} == "custom_wheels" ]] ; then export MODE=nightly export CUSTOM_JAX=1 - export INSTALL_GRPO=0 -elif [[ ${MODE} == "grpo" || ${MODE} == "grpo-experimental" ]] ; then - export INSTALL_GRPO=1 + export INSTALL_POST_TRAINING=0 +elif [[ ${MODE} == "post-training" || ${MODE} == "post-training-experimental" ]] ; then + export INSTALL_POST_TRAINING=1 export CUSTOM_JAX=0 else export CUSTOM_JAX=0 - export INSTALL_GRPO=0 + export INSTALL_POST_TRAINING=0 fi if [[ -z ${DEVICE} ]]; then @@ -124,9 +124,9 @@ if [[ -z ${LIBTPU_GCS_PATH+x} ]] ; then elif [[ ${MANTARAY} == "true" ]]; then echo "Building with benchmark-db" docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_db_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . - elif [[ ${INSTALL_GRPO} -eq 1 && ${DEVICE} == "tpu" ]]; then - echo "Installing MaxText stable mode dependencies for GRPO BASEIMAGE=$BASEIMAGE" - docker build --network host --build-arg MODE=stable --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . + elif [[ ${INSTALL_POST_TRAINING} -eq 1 && ${DEVICE} == "tpu" ]]; then + echo "Installing MaxText stable mode dependencies for Post-Training" + docker build --network host --build-arg MODE=stable --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . else docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE -f ./maxtext_dependencies.Dockerfile -t ${LOCAL_IMAGE_NAME} . fi @@ -136,27 +136,29 @@ else docker build --network host --build-arg CUSTOM_LIBTPU=true -f ./maxtext_libtpu_path.Dockerfile -t ${LOCAL_IMAGE_NAME} . fi -if [[ ${INSTALL_GRPO} -eq 1 ]] ; then +if [[ ${INSTALL_POST_TRAINING} -eq 1 ]] ; then if [[ ${DEVICE} != "tpu" ]] ; then - echo "Error: MODE=grpo is only supported for DEVICE=tpu" + echo "Error: MODE=post-training is only supported for DEVICE=tpu" exit 1 fi - # To install from local paths, we copy vllm and tpu-inference into the build context. - # This assumes vllm and tpu-inference are sibling directories to the current one (maxtext). - echo "Copying local vllm and tpu-inference directories into the build context..." - rsync -a --exclude='__pycache__' ../tunix . - rsync -a --exclude='__pycache__' ../tpu-inference . - rsync -a --exclude='__pycache__' ../vllm . + # # To install tpu_commons from a local path, we copy it into the build context, excluding __pycache__. + # # This assumes vllm, tunix, tpu_commons is a sibling directory to the current one (maxtext). + # rsync -a --exclude='__pycache__' ../tpu_commons . + # # To install vllm from a local path, we copy it into the build context, excluding __pycache__. + # # This assumes vllm is a sibling directory to the current one (maxtext). + # rsync -a --exclude='__pycache__' ../vllm . - # The cleanup is set to run even if the build fails to remove the copied directories. - trap "echo 'Cleaning up copied directories...' && rm -rf ./tpu-inference ./vllm" EXIT INT TERM + # rsync -a --exclude='__pycache__' ../tunix . + + # # The cleanup is set to run even if the build fails to remove the copied directory. + # trap "rm -rf ./tpu_commons ./vllm ./tunix" EXIT INT TERM docker build \ --network host \ --build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \ --build-arg MODE=${MODE} \ - -f ./maxtext_grpo_dependencies.Dockerfile \ + -f ./maxtext_post_training_dependencies.Dockerfile \ -t ${LOCAL_IMAGE_NAME} . fi diff --git a/src/MaxText/configs/rl.yml b/src/MaxText/configs/rl.yml index a0c2626e2..6df1359f0 100644 --- a/src/MaxText/configs/rl.yml +++ b/src/MaxText/configs/rl.yml @@ -28,10 +28,6 @@ debug: True # ====== Reproducibility ====== data_shuffle_seed: 42 -# ====== Checkpoint saving ====== -save_interval_steps: 500 -max_to_keep: 4 - # ====== GRPO ====== # The number of times the policy generates multiple responses for a given prompt @@ -94,19 +90,20 @@ num_epochs: 1 # can potentially train for more epochs gradient_clipping_threshold: 0.1 # ====== Evaluation ====== +eval_sampling_strategy: "greedy" # can be "greedy", "standard", or "liberal" generation_configs: greedy: - temperature: 0.01 - top_k: 1 - top_p: 1.0 + eval_temperature: 0.01 + eval_top_k: 1 + eval_top_p: 1.0 standard: - temperature: 0.7 - top_k: 50 - top_p: 0.95 + eval_temperature: 0.7 + eval_top_k: 50 + eval_top_p: 0.95 liberal: - temperature: 0.85 - top_k: 2000 - top_p: 1.0 + eval_temperature: 0.85 + eval_top_k: 2000 + eval_top_p: 1.0 # greedy eval_temperature: 0.01 @@ -171,7 +168,8 @@ template: | model -# # TODO(@mazumdera): fix this +# # TODO(@mazumdera +): fix this # # Dataset Configuration # dataset_type: hf # Huggingface input pipeline # hf_path: 'gsm8k' diff --git a/src/MaxText/evaluate_rl.py b/src/MaxText/evaluate_rl.py index efc6dd447..2f99c0073 100644 --- a/src/MaxText/evaluate_rl.py +++ b/src/MaxText/evaluate_rl.py @@ -49,7 +49,7 @@ from tunix.rl.rollout.base_rollout import RolloutConfig from MaxText.globals import MAXTEXT_ASSETS_ROOT -from MaxText import rl_utils +from maxtext.src.MaxText import utils_rl # ## Evaluate # We evaluate it in two ways: @@ -78,29 +78,28 @@ def generate_responses( num_passes=1, ): """ - Generate responses for a batch of prompts across multiple passes. + Generate responses for a batch of prompts across potentially multiple passes. Args: + tmvp_config: Configuration object prompts: List of prompts to generate responses for rl_cluster: Model cluster for generation num_passes: Number of generation passes - temperature: Sampling temperature - top_k: Top-k sampling parameter - top_p: Top-p sampling parameter Returns: List of lists containing responses for each prompt across passes """ multiple_call_responses = [[] for _ in range(len(prompts))] + eval_strategy = tmvp_config.generation_configs[tmvp_config.eval_sampling_strategy] for p in range(num_passes): responses = rl_cluster.rollout.generate( prompts, rollout_config=RolloutConfig( max_tokens_to_generate=tmvp_config.max_target_length, - temperature=tmvp_config.eval_temperature, - top_k=tmvp_config.eval_top_k, - top_p=tmvp_config.eval_top_p, + temperature=eval_strategy["eval_temperature"], + top_k=eval_strategy["eval_top_k"], + top_p=eval_strategy["eval_top_p"], ), ) responses = responses.text @@ -126,8 +125,8 @@ def score_responses(tmvp_config, question, responses, answer): Returns: Tuple of (is_correct, is_partially_correct, has_correct_format) """ - match_format = rl_utils.get_match_format_regex(tmvp_config) - match_numbers = rl_utils.get_match_numbers_regex(tmvp_config) + match_format = utils_rl.get_match_format_regex(tmvp_config) + match_numbers = utils_rl.get_match_numbers_regex(tmvp_config) if tmvp_config.debug: print("========================================") diff --git a/src/MaxText/examples/local_installation.sh b/src/MaxText/examples/local_installation.sh new file mode 100644 index 000000000..0b7d81edb --- /dev/null +++ b/src/MaxText/examples/local_installation.sh @@ -0,0 +1,138 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +echo "Installing GRPO dependencies (vLLM, tpu-inference) with MODE=${MODE}" +uv pip uninstall -y jax jaxlib libtpu + +uv pip install aiohttp==3.12.15 + +# Install Python packages that enable uv pip to authenticate with Google Artifact Registry automatically. +uv pip install keyring keyrings.google-artifactregistry-auth + +uv pip install numba==0.61.2 + +uv pip install -e ../tunix --no-cache-dir + + +VLLM_TARGET_DEVICE="tpu" uv pip install -e ../vllm --no-cache-dir --pre \ + --extra-index-url https://pypi.org/simple/ \ + --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ + --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ + --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ + --find-links https://storage.googleapis.com/libtpu-releases/index.html \ + --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + + +uv pip install -e ../tpu-inference --no-cache-dir --pre \ + --extra-index-url https://pypi.org/simple/ \ + --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ + --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html + +# # Install vLLM for Jax and TPUs from the artifact registry +# VLLM_TARGET_DEVICE="tpu" uv pip install --no-cache-dir --pre \ +# --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \ +# --extra-index-url https://pypi.org/simple/ \ +# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ +# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ +# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ +# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ +# --find-links https://storage.googleapis.com/libtpu-releases/index.html \ +# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ +# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \ +# vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu + +# # Install tpu-commons from the artifact registry +# uv pip install --no-cache-dir --pre \ +# --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \ +# --extra-index-url https://pypi.org/simple/ \ +# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ +# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ +# tpu-commons==0.1.2 + +# # Uninstall existing jax to avoid conflicts +# # uv pip uninstall -y jax jaxlib libtpu + +# # --- STAGE 1: Install Static Dependencies --- +# # Install any packages *not* defined in your project dependency files +# --mount=type=cache,target=/root/.cache/uv pip uv pip install \ +# aiohttp==3.12.15\ +# keyring \ +# keyrings.google-artifactregistry-auth + +# --mount=type=cache,target=/root/.cache/uv pip uv pip install \ +# numba==0.61.2 + +# # VLLM_TARGET_DEVICE="tpu" uv pip install vllm +# # --- STAGE 2: Install Project Dependencies (The Main Cached Layer) --- + +# # Copy *only* the dependency definition files. +# # This assumes vllm and tpu-inference are in the build context, copied from the parent directory. +# COPY vllm/requirements/tpu.txt /tmp/ +# COPY vllm/requirements/build.txt /tmp/ +# COPY vllm/requirements/common.txt /tmp/ +# COPY tpu-inference/requirements.txt /tmp/ + +# # the full dependency installation. +# # This entire layer is cached and will *only* be rebuilt if +# # these .txt files change. +# --mount=type=cache,target=/root/.cache/uv pip bash -c ' \ +# # Set the target device so uv pip installs the right JAX/libtpu +# # Install tpu-inference dependencies +# export VLLM_TARGET_DEVICE="tpu" && \ +# uv pip install -r /tmp/tpu.txt -r /tmp/build.txt -r /tmp/common.txt -r /tmp/requirements.txt --no-cache-dir --pre \ +# --extra-index-url https://pypi.org/simple/ \ +# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ +# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ +# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ +# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ +# --find-links https://storage.googleapis.com/libtpu-releases/index.html \ +# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ +# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html' + +# # Install tpu-inference dependencies +# --mount=type=cache,target=/root/.cache/pip bash -c ' \ +# pip install -r /tmp/requirements.txt --no-cache-dir --pre \ +# --extra-index-url https://pypi.org/simple/ \ +# --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \ +# --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ +# --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \ +# --find-links https://storage.googleapis.com/libtpu-wheels/index.html \ +# --find-links https://storage.googleapis.com/libtpu-releases/index.html \ +# --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ +# --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html' + +# # --- STAGE 3: Install Project Source Code --- + +# # Now, copy the full source code. This invalidates cache frequently, +# # but the next step is fast. +# COPY vllm /vllm/ +# COPY tpu-inference /tpu-inference/ +# COPY tunix /tunix + + +# # Install in editable mode. This is lightning-fast because all +# # dependencies were installed and cached in STAGE 2. +# --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /vllm/ +# --mount=type=cache,target=/root/.cache/pip pip install -e /tpu-inference/ + +# --mount=type=cache,target=/root/.cache/pip pip install --no-deps /tunix/ +# # --mount=type=cache,target=/root/.cache/pip VLLM_TARGET_DEVICE="tpu" pip install -e /tpu-inference/ + +if [ "$MODE" = "grpo-experimental" ]; then \ + echo "MODE=grpo-experimental: Re-installing JAX/libtpu"; \ + pip uninstall -y jax jaxlib libtpu && \ + pip install --pre -U jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ && \ + pip install -U --pre libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \ + fi diff --git a/src/MaxText/train_rl.py b/src/MaxText/train_rl.py index 764ee7224..4c9345509 100644 --- a/src/MaxText/train_rl.py +++ b/src/MaxText/train_rl.py @@ -40,7 +40,6 @@ --load_parameters_path=gs://path/to/checkpoint \\ --base_output_directory=gs://path/to/output \\ --hf_access_token=$HF_TOKEN \\ - --use_pathways=true \\ --steps=100 # Custom dataset @@ -57,6 +56,10 @@ from pprint import pprint from absl import app import os +# The following import is added to address a circular import issue in vLLM. +# By pre-importing PoolingRequestOutput, we ensure it's loaded before the +# problematic import chain is triggered, resolving the ImportError. +from vllm import PoolingRequestOutput import sys from typing import Sequence @@ -73,6 +76,9 @@ import pathwaysutils +# for vLLM we can skip JAX precompilation with this flag, it makes startup faster +os.environ["SKIP_JAX_PRECOMPILE"] = "1" + # add the parent directory (two levels up to say ~/HOME/maxtext) to sys.path if currenlt runnig from # ~/HOME/maxtext/MaxText/examples @@ -94,7 +100,7 @@ from MaxText import model_creation_utils from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter from MaxText.evaluate_rl import evaluate -from MaxText import rl_utils +from maxtext.src.MaxText import utils_rl # We use OpenAI's GSM8K dataset. GSM8K comprises grade school math word problems. @@ -135,7 +141,7 @@ def setup_device_allocation(tmvp_config, use_pathways: bool = False): if tmvp_config.sampler_devices_fraction != 1.0: print(f"Using last {len(sampler_devices)} devices as Sampler devices") - return trainer_devices, sampler_devices, num_vms + return trainer_devices, sampler_devices def get_dataset(model_tokenizer, tmvp_config, data_dir, split="train") -> grain.MapDataset: # Download data @@ -172,7 +178,7 @@ def get_dataset(model_tokenizer, tmvp_config, data_dir, split="train") -> grain. # passed to reward functions "question": x["question"].decode("utf-8"), # passed to reward functions - "answer": rl_utils.extract_hash_answer(x["answer"].decode("utf-8")), + "answer": utils_rl.extract_hash_answer(x["answer"].decode("utf-8")), } ) ) @@ -235,7 +241,7 @@ def rl_train(tmvp_config): else: use_pathways = False print(f"Use Pathways: {use_pathways}") - trainer_devices, sampler_devices, num_vms = setup_device_allocation(tmvp_config, use_pathways) + trainer_devices, sampler_devices = setup_device_allocation(tmvp_config, use_pathways) # Load reference model print("Creating reference model and also meshes for reference and rollout") @@ -302,8 +308,9 @@ def rl_train(tmvp_config): # Setup metrics logging log_dir=tmvp_config.tensorboard_dir print(f"TensorBoard logs directory: {log_dir}") - print(f"tensorboard --logdir {log_dir} --port=8086") - metrics_logging_options = metrics_logger.MetricsLoggerOptions(log_dir=log_dir, flush_every_n_steps=20) + # Metrics logger + metrics_logging_options = metrics_logger.MetricsLoggerOptions(log_dir=log_dir, flush_every_n_steps=tmvp_config.log_period) + # Profiler configurations # TODO: xfgu@: add profiling @@ -342,13 +349,17 @@ def rl_train(tmvp_config): top_p=tmvp_config.decode_sampling_nucleus_p, top_k=tmvp_config.decode_sampling_top_k, ), + # TODO: @mazumdera: move these to rollout_config when updating to use latest Tunix rollout_vllm_model_version=tmvp_config.hf_model_name, rollout_vllm_hbm_utilization=tmvp_config.hbm_utilization_vllm, rollout_vllm_tpu_backend_type="jax", rollout_vllm_swap_space_size_gb=tmvp_config.swap_space_vllm_gb, ) - # Setup GRPO config + # The metrics logger is now managed by the GrpoLearner, so we don't need to + # register jax.monitoring listeners separately. + cluster_config.training_config.metrics_logging_options = metrics_logging_options + grpo_config = GrpoConfig( num_generations=tmvp_config.num_generations, num_iterations=tmvp_config.num_iterations, @@ -372,10 +383,10 @@ def rl_train(tmvp_config): rl_trainer = GrpoLearner( rl_cluster=rl_cluster, reward_fns=[ # type: ignore - lambda **kwargs: rl_utils.match_format_exactly(tmvp_config=tmvp_config, **kwargs), - lambda **kwargs: rl_utils.match_format_approximately(tmvp_config=tmvp_config, **kwargs), - lambda **kwargs: rl_utils.check_answer(tmvp_config=tmvp_config, **kwargs), - lambda **kwargs: rl_utils.check_numbers(tmvp_config=tmvp_config, **kwargs), + lambda **kwargs: utils_rl.match_format_exactly(tmvp_config=tmvp_config, **kwargs), + lambda **kwargs: utils_rl.match_format_approximately(tmvp_config=tmvp_config, **kwargs), + lambda **kwargs: utils_rl.check_answer(tmvp_config=tmvp_config, **kwargs), + lambda **kwargs: utils_rl.check_numbers(tmvp_config=tmvp_config, **kwargs), ], grpo_config=grpo_config, ) @@ -396,31 +407,29 @@ def rl_train(tmvp_config): # see the improvement post training. # (corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate( + tmvp_config, test_dataset, rl_cluster=rl_cluster, - **tmvp_config.generation_configs["greedy"], ) print(f"Pre GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") # Start training - profile_dir = tmvp_config.tensorboard_dir - max_logging.log(f"Saving profiles to {profile_dir}") print("Starting GRPO training...") - # jax.profiler.start_trace(profile_dir) with reference_mesh, nn_partitioning.axis_rules(tmvp_config.logical_axis_rules): rl_trainer.train(train_dataset) - # jax.profiler.stop_trace() + if rl_trainer.metrics_logger and jax.process_index() == 0: + rl_trainer.metrics_logger.close() print("GRPO Training Completed Successfully!") # Let's evaluate our model! (corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate( + tmvp_config, test_dataset, rl_cluster=rl_cluster, - **tmvp_config.generation_configs["greedy"], ) print(f"Post GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") diff --git a/src/MaxText/rl_utils.py b/src/MaxText/utils_rl.py similarity index 100% rename from src/MaxText/rl_utils.py rename to src/MaxText/utils_rl.py From 3249753e1a13f702c3aecc0efda28c9d0034d12e Mon Sep 17 00:00:00 2001 From: A9isha Date: Fri, 31 Oct 2025 23:30:06 +0000 Subject: [PATCH 20/31] add chat template, metrics issue and PoolingRequestOutput issue still pending --- src/MaxText/configs/base.yml | 1 + src/MaxText/configs/rl.yml | 27 ++------ .../examples/chat_templates/gsm8k_rl.json | 4 ++ src/MaxText/examples/local_installation.sh | 1 + src/MaxText/{ => rl}/evaluate_rl.py | 2 +- src/MaxText/{ => rl}/train_rl.py | 69 ++++++++++--------- src/MaxText/{ => rl}/utils_rl.py | 1 - 7 files changed, 49 insertions(+), 56 deletions(-) create mode 100644 src/MaxText/examples/chat_templates/gsm8k_rl.json rename src/MaxText/{ => rl}/evaluate_rl.py (99%) rename src/MaxText/{ => rl}/train_rl.py (86%) rename src/MaxText/{ => rl}/utils_rl.py (99%) diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index a51181f6a..588972453 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -18,6 +18,7 @@ run_name: "" model_name: "default" # override config settings to match a specific model. other than the override, nothing should use this! override_model_config: False # When set to true allows overriding model parameters via CLI for the purpose of debugging/testing. +debug: False # Various trainers can set this to True for custom debugging normalization_layer_epsilon: 1.e-05 # epsilon value for rmsnorm, layernorm. ################################## CHECKPOINTING ################################## diff --git a/src/MaxText/configs/rl.yml b/src/MaxText/configs/rl.yml index 6df1359f0..93ef537ee 100644 --- a/src/MaxText/configs/rl.yml +++ b/src/MaxText/configs/rl.yml @@ -22,9 +22,6 @@ trainer_devices_fraction: 0.5 sampler_devices_fraction: 0.5 chips_per_vm: 4 # depends on hardware, for v5p this is 4 -# ====== Debug ====== -debug: True - # ====== Reproducibility ====== data_shuffle_seed: 42 @@ -155,23 +152,9 @@ reasoning_end_token: '' solution_start_token: '' solution_end_token: '' -# ====== System prompt and Templates ====== - -system_prompt: | - You are given a problem. Think about the problem and provide your reasoning. Place it between {reasoning_start_token} and {reasoning_end_token}. Then, provide the final answer (i.e., just one numerical value) between {solution_start_token} and {solution_end_token}. - -template: | - user - {system_prompt} - - {question} - model - -# # TODO(@mazumdera -): fix this -# # Dataset Configuration -# dataset_type: hf # Huggingface input pipeline -# hf_path: 'gsm8k' -# hf_data_split: 'main' -# hf_data_files: 'train' +# # TODO(@mazumdera): fix this +# Dataset Configuration +dataset_name: 'gsm8k' +train_split: 'train' +eval_split: 'test' diff --git a/src/MaxText/examples/chat_templates/gsm8k_rl.json b/src/MaxText/examples/chat_templates/gsm8k_rl.json new file mode 100644 index 000000000..ff4cb35ba --- /dev/null +++ b/src/MaxText/examples/chat_templates/gsm8k_rl.json @@ -0,0 +1,4 @@ +{ + "SYSTEM_PROMPT": "You are given a problem. Think about the problem and provide your reasoning. Place it between {reasoning_start_token} and {reasoning_end_token}. Then, provide the final answer (i.e., just one numerical value) between {solution_start_token} and {solution_end_token}.", + "TEMPLATE": "user\n{system_prompt}\n\n{question}\nmodel" +} diff --git a/src/MaxText/examples/local_installation.sh b/src/MaxText/examples/local_installation.sh index 0b7d81edb..8e4dc1adc 100644 --- a/src/MaxText/examples/local_installation.sh +++ b/src/MaxText/examples/local_installation.sh @@ -21,6 +21,7 @@ uv pip install keyring keyrings.google-artifactregistry-auth uv pip install numba==0.61.2 +uv pip uninstall tunix uv pip install -e ../tunix --no-cache-dir diff --git a/src/MaxText/evaluate_rl.py b/src/MaxText/rl/evaluate_rl.py similarity index 99% rename from src/MaxText/evaluate_rl.py rename to src/MaxText/rl/evaluate_rl.py index 2f99c0073..7b0eea406 100644 --- a/src/MaxText/evaluate_rl.py +++ b/src/MaxText/rl/evaluate_rl.py @@ -49,7 +49,7 @@ from tunix.rl.rollout.base_rollout import RolloutConfig from MaxText.globals import MAXTEXT_ASSETS_ROOT -from maxtext.src.MaxText import utils_rl +from MaxText.rl import utils_rl # ## Evaluate # We evaluate it in two ways: diff --git a/src/MaxText/train_rl.py b/src/MaxText/rl/train_rl.py similarity index 86% rename from src/MaxText/train_rl.py rename to src/MaxText/rl/train_rl.py index 4c9345509..e5bd37ae7 100644 --- a/src/MaxText/train_rl.py +++ b/src/MaxText/rl/train_rl.py @@ -56,10 +56,15 @@ from pprint import pprint from absl import app import os +import vllm +print("vLLM module location:", vllm.__file__) +# Also check if there's a local 'vllm' directory that could be interfering +print("Is there a local vllm directory?:", os.path.isdir("vllm")) # The following import is added to address a circular import issue in vLLM. # By pre-importing PoolingRequestOutput, we ensure it's loaded before the # problematic import chain is triggered, resolving the ImportError. from vllm import PoolingRequestOutput + import sys from typing import Sequence @@ -99,8 +104,9 @@ from MaxText import max_logging, max_utils, maxtext_utils, pyconfig from MaxText import model_creation_utils from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter -from MaxText.evaluate_rl import evaluate -from maxtext.src.MaxText import utils_rl +from MaxText.rl.evaluate_rl import evaluate +from MaxText.rl import utils_rl +from MaxText.input_pipeline.instruction_data_processing import load_template_from_file # We use OpenAI's GSM8K dataset. GSM8K comprises grade school math word problems. @@ -149,13 +155,14 @@ def get_dataset(model_tokenizer, tmvp_config, data_dir, split="train") -> grain. os.makedirs(data_dir) data = tfds.data_source( - "gsm8k", + tmvp_config.dataset_name, split=split, data_dir=data_dir, builder_kwargs={"file_format": tfds.core.FileFormat.ARRAY_RECORD}, download=True, ) + template_config = load_template_from_file(tmvp_config.chat_template_path) loaded_dataset = ( grain.MapDataset.source(data) .shuffle(seed=tmvp_config.data_shuffle_seed) @@ -166,8 +173,8 @@ def get_dataset(model_tokenizer, tmvp_config, data_dir, split="train") -> grain. [ { "role": "user", - "content": tmvp_config.template.format( - system_prompt=tmvp_config.system_prompt, + "content": template_config["TEMPLATE"].format( + system_prompt=template_config["SYSTEM_PROMPT"], question=x["question"].decode("utf-8"), ), }, @@ -213,7 +220,7 @@ def rl_train(tmvp_config): model_tokenizer = AutoTokenizer.from_pretrained(tmvp_config.hf_model_name) # Load datasets - dataset = get_dataset(model_tokenizer, tmvp_config, train_data_dir, "train").batch(tmvp_config.batch_size)[:tmvp_config.num_batches] + dataset = get_dataset(model_tokenizer, tmvp_config, train_data_dir, tmvp_config.train_split).batch(tmvp_config.batch_size)[:tmvp_config.num_batches] if tmvp_config.train_fraction == 1.0: train_dataset = dataset.repeat(tmvp_config.num_epochs) @@ -224,7 +231,7 @@ def rl_train(tmvp_config): val_dataset = dataset[int(len(dataset) * tmvp_config.train_fraction) :].repeat(tmvp_config.num_epochs) - test_dataset = get_dataset(model_tokenizer, tmvp_config, test_data_dir, "test").batch(tmvp_config.batch_size)[:tmvp_config.num_test_batches] + test_dataset = get_dataset(model_tokenizer, tmvp_config, test_data_dir, tmvp_config.eval_split).batch(tmvp_config.batch_size)[:tmvp_config.num_test_batches] # Let's see how one batch of the dataset looks like! @@ -240,38 +247,38 @@ def rl_train(tmvp_config): use_pathways = True else: use_pathways = False - print(f"Use Pathways: {use_pathways}") + max_logging.log(f"Use Pathways: {use_pathways}") trainer_devices, sampler_devices = setup_device_allocation(tmvp_config, use_pathways) # Load reference model - print("Creating reference model and also meshes for reference and rollout") + max_logging.log("Creating reference model and also meshes for reference and rollout") reference_model, reference_mesh = get_maxtext_model(tmvp_config, trainer_devices) devices_array = maxtext_utils.create_device_mesh(tmvp_config, sampler_devices) # if trainer_devices=sampler_devices, then rollout_mesh=reference_mesh # else rollout_mesh uses sampler_devices rollout_mesh = Mesh(devices_array, tmvp_config.mesh_axes) if tmvp_config.debug: - print("Reference Model initialized successfully") + max_logging.log("Reference Model initialized successfully") nnx.display(reference_model) - print(f"Reference mesh shape: {reference_mesh.shape}") + max_logging.log(f"Reference mesh shape: {reference_mesh.shape}") # Sanity check that weights are loaded correctly. _maxtext_state_flatten = nnx.state(reference_model).flat_state() maxtext_state_flatten = {".".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten} - print( + max_logging.log( f"maxtext_state_flatten[base.token_embedder.embedding].value={maxtext_state_flatten['base.token_embedder.embedding'].value}" ) # TODO: @mazumdera: change this to use lora # Load policy model - print("Creating policy model with same config as reference model on trainer mesh") + max_logging.log("Creating policy model with same config as reference model on trainer mesh") policy_model, policy_mesh = get_maxtext_model(tmvp_config, trainer_devices) actor_mesh = policy_mesh if tmvp_config.debug: - print("Policy Model initialized successfully") + max_logging.log("Policy Model initialized successfully") nnx.display(policy_model) - print(f"Policy mesh shape: {policy_mesh.shape}") + max_logging.log(f"Policy mesh shape: {policy_mesh.shape}") # Setup optimizer optimizer = optax.adamw( @@ -306,8 +313,9 @@ def rl_train(tmvp_config): ) # Setup metrics logging - log_dir=tmvp_config.tensorboard_dir - print(f"TensorBoard logs directory: {log_dir}") + log_dir = os.path.join(tmvp_config.tensorboard_dir, f"worker_{jax.process_index()}") + + max_logging.log(f"TensorBoard logs directory: {log_dir}") # Metrics logger metrics_logging_options = metrics_logger.MetricsLoggerOptions(log_dir=log_dir, flush_every_n_steps=tmvp_config.log_period) @@ -347,13 +355,12 @@ def rl_train(tmvp_config): + tmvp_config.kv_cache_buffer, temperature=tmvp_config.decode_sampling_temperature, top_p=tmvp_config.decode_sampling_nucleus_p, - top_k=tmvp_config.decode_sampling_top_k, + top_k=tmvp_config.decode_sampling_top_k, + rollout_vllm_model_version=tmvp_config.hf_model_name, + rollout_vllm_hbm_utilization=tmvp_config.hbm_utilization_vllm, + rollout_vllm_tpu_backend_type="jax", + rollout_vllm_swap_space_size_gb=tmvp_config.swap_space_vllm_gb, ), - # TODO: @mazumdera: move these to rollout_config when updating to use latest Tunix - rollout_vllm_model_version=tmvp_config.hf_model_name, - rollout_vllm_hbm_utilization=tmvp_config.hbm_utilization_vllm, - rollout_vllm_tpu_backend_type="jax", - rollout_vllm_swap_space_size_gb=tmvp_config.swap_space_vllm_gb, ) # The metrics logger is now managed by the GrpoLearner, so we don't need to @@ -369,7 +376,7 @@ def rl_train(tmvp_config): ) # Create RL cluster - print("Creating RL cluster...") + max_logging.log("Creating RL cluster...") with nn_partitioning.axis_rules(tmvp_config.logical_axis_rules): rl_cluster = rl_cluster_lib.RLCluster( actor=policy_model, @@ -379,7 +386,7 @@ def rl_train(tmvp_config): ) # Create GRPO trainer - print("Setting up GRPO trainer...") + max_logging.log("Setting up GRPO trainer...") rl_trainer = GrpoLearner( rl_cluster=rl_cluster, reward_fns=[ # type: ignore @@ -400,7 +407,7 @@ def rl_train(tmvp_config): rollout_config=base_rollout.RolloutConfig(max_tokens_to_generate=64, temperature=0.1), ) - print(f"Output: {output}") + max_logging.log(f"Output: {output}") # # # Before we train the model, let's evaluate the model on the test set so we can @@ -411,19 +418,17 @@ def rl_train(tmvp_config): test_dataset, rl_cluster=rl_cluster, ) - print(f"Pre GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") + max_logging.log(f"Pre GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") # Start training - print("Starting GRPO training...") + max_logging.log("Starting GRPO training...") with reference_mesh, nn_partitioning.axis_rules(tmvp_config.logical_axis_rules): rl_trainer.train(train_dataset) - if rl_trainer.metrics_logger and jax.process_index() == 0: - rl_trainer.metrics_logger.close() - print("GRPO Training Completed Successfully!") + max_logging.log("GRPO Training Completed Successfully!") # Let's evaluate our model! (corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate( @@ -431,7 +436,7 @@ def rl_train(tmvp_config): test_dataset, rl_cluster=rl_cluster, ) - print(f"Post GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") + max_logging.log(f"Post GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") def main(argv: Sequence[str]) -> None: diff --git a/src/MaxText/utils_rl.py b/src/MaxText/rl/utils_rl.py similarity index 99% rename from src/MaxText/utils_rl.py rename to src/MaxText/rl/utils_rl.py index dd6c59bc6..cb1f64baa 100644 --- a/src/MaxText/utils_rl.py +++ b/src/MaxText/rl/utils_rl.py @@ -48,7 +48,6 @@ from tunix.rl.rollout.base_rollout import RolloutConfig from MaxText.globals import MAXTEXT_ASSETS_ROOT -import pathwaysutils # Let's define a RegEx for checking whether the format matches. # From fc1cbf612d4a9cf3f9f39ceb99a3aac04a229037 Mon Sep 17 00:00:00 2001 From: A9isha Date: Sat, 1 Nov 2025 23:51:30 +0000 Subject: [PATCH 21/31] permute import orderings --- src/MaxText/rl/train_rl.py | 43 ++++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/src/MaxText/rl/train_rl.py b/src/MaxText/rl/train_rl.py index e5bd37ae7..e1e938df2 100644 --- a/src/MaxText/rl/train_rl.py +++ b/src/MaxText/rl/train_rl.py @@ -53,34 +53,42 @@ --steps=100 """ -from pprint import pprint -from absl import app +import functools import os -import vllm -print("vLLM module location:", vllm.__file__) -# Also check if there's a local 'vllm' directory that could be interfering -print("Is there a local vllm directory?:", os.path.isdir("vllm")) -# The following import is added to address a circular import issue in vLLM. -# By pre-importing PoolingRequestOutput, we ensure it's loaded before the -# problematic import chain is triggered, resolving the ImportError. -from vllm import PoolingRequestOutput - +from pprint import pprint +import re import sys -from typing import Sequence + +from datetime import datetime +from flax import nnx +from flax.linen import partitioning as nn_partitioning +import grain +import humanize + import jax from jax.sharding import Mesh -from flax.linen import partitioning as nn_partitioning -from flax import nnx import optax from orbax import checkpoint as ocp import tensorflow_datasets as tfds +from tqdm.auto import tqdm +from tunix.rl import rl_cluster as rl_cluster_lib +from tunix.rl.rollout import base_rollout +from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner +from tunix.sft import metrics_logger + + from transformers import AutoTokenizer -import grain +from flax import linen as nn +import numpy as np +from etils import epath +from MaxText.globals import MAXTEXT_ASSETS_ROOT import pathwaysutils +pathwaysutils.initialize() + # for vLLM we can skip JAX precompilation with this flag, it makes startup faster os.environ["SKIP_JAX_PRECOMPILE"] = "1" @@ -96,11 +104,6 @@ # Add the project root to the Python path sys.path.insert(0, project_root) -from tunix.rl import rl_cluster as rl_cluster_lib -from tunix.rl.rollout import base_rollout -from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner -from tunix.sft import metrics_logger - from MaxText import max_logging, max_utils, maxtext_utils, pyconfig from MaxText import model_creation_utils from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter From 4caffb1ac30afeea23e094fe4e67dca8194b1ceb Mon Sep 17 00:00:00 2001 From: A9isha Date: Mon, 3 Nov 2025 23:52:27 +0000 Subject: [PATCH 22/31] fix PoolingRequest issue --- src/MaxText/configs/rl.yml | 2 ++ src/MaxText/rl/__init__.py | 0 src/MaxText/rl/train_rl.py | 43 +++++++++++++++++++++----------------- 3 files changed, 26 insertions(+), 19 deletions(-) create mode 100644 src/MaxText/rl/__init__.py diff --git a/src/MaxText/configs/rl.yml b/src/MaxText/configs/rl.yml index 93ef537ee..7f5cf8045 100644 --- a/src/MaxText/configs/rl.yml +++ b/src/MaxText/configs/rl.yml @@ -68,6 +68,8 @@ decoder_layer_input: 'offload' query_proj: 'offload' key_proj: 'offload' value_proj: 'offload' +checkpoint_storage_use_ocdbt: False # For Pathways +checkpoint_storage_use_zarr3: False # For Pathways # ====== Training ====== diff --git a/src/MaxText/rl/__init__.py b/src/MaxText/rl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/MaxText/rl/train_rl.py b/src/MaxText/rl/train_rl.py index e1e938df2..695265f88 100644 --- a/src/MaxText/rl/train_rl.py +++ b/src/MaxText/rl/train_rl.py @@ -53,6 +53,7 @@ --steps=100 """ +from typing import Any, Sequence import functools import os from pprint import pprint @@ -60,12 +61,13 @@ import sys from datetime import datetime +from absl import app from flax import nnx from flax.linen import partitioning as nn_partitioning import grain import humanize - +import vllm import jax from jax.sharding import Mesh import optax @@ -245,43 +247,46 @@ def rl_train(tmvp_config): # Setup device allocation - if jax.extend.backend.get_backend().platform_version == "Pathways": - max_logging.log("Pathways backend detected. Disabling setting profile options.") + + if jax.extend.backend.get_backend().platform_version.strip() == "Pathways": + print("Pathways backend detected. Disabling setting profile options.") use_pathways = True else: use_pathways = False - max_logging.log(f"Use Pathways: {use_pathways}") + print(f"jax.extend.backend.get_backend().platform_version={jax.extend.backend.get_backend().platform_version}") + use_pathways = True + print(f"Use Pathways: {use_pathways}") trainer_devices, sampler_devices = setup_device_allocation(tmvp_config, use_pathways) # Load reference model - max_logging.log("Creating reference model and also meshes for reference and rollout") + print("Creating reference model and also meshes for reference and rollout") reference_model, reference_mesh = get_maxtext_model(tmvp_config, trainer_devices) devices_array = maxtext_utils.create_device_mesh(tmvp_config, sampler_devices) # if trainer_devices=sampler_devices, then rollout_mesh=reference_mesh # else rollout_mesh uses sampler_devices rollout_mesh = Mesh(devices_array, tmvp_config.mesh_axes) if tmvp_config.debug: - max_logging.log("Reference Model initialized successfully") + print("Reference Model initialized successfully") nnx.display(reference_model) - max_logging.log(f"Reference mesh shape: {reference_mesh.shape}") + print(f"Reference mesh shape: {reference_mesh.shape}") # Sanity check that weights are loaded correctly. _maxtext_state_flatten = nnx.state(reference_model).flat_state() maxtext_state_flatten = {".".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten} - max_logging.log( + print( f"maxtext_state_flatten[base.token_embedder.embedding].value={maxtext_state_flatten['base.token_embedder.embedding'].value}" ) # TODO: @mazumdera: change this to use lora # Load policy model - max_logging.log("Creating policy model with same config as reference model on trainer mesh") + print("Creating policy model with same config as reference model on trainer mesh") policy_model, policy_mesh = get_maxtext_model(tmvp_config, trainer_devices) actor_mesh = policy_mesh if tmvp_config.debug: - max_logging.log("Policy Model initialized successfully") + print("Policy Model initialized successfully") nnx.display(policy_model) - max_logging.log(f"Policy mesh shape: {policy_mesh.shape}") + print(f"Policy mesh shape: {policy_mesh.shape}") # Setup optimizer optimizer = optax.adamw( @@ -318,7 +323,7 @@ def rl_train(tmvp_config): # Setup metrics logging log_dir = os.path.join(tmvp_config.tensorboard_dir, f"worker_{jax.process_index()}") - max_logging.log(f"TensorBoard logs directory: {log_dir}") + print(f"TensorBoard logs directory: {log_dir}") # Metrics logger metrics_logging_options = metrics_logger.MetricsLoggerOptions(log_dir=log_dir, flush_every_n_steps=tmvp_config.log_period) @@ -379,7 +384,7 @@ def rl_train(tmvp_config): ) # Create RL cluster - max_logging.log("Creating RL cluster...") + print("Creating RL cluster...") with nn_partitioning.axis_rules(tmvp_config.logical_axis_rules): rl_cluster = rl_cluster_lib.RLCluster( actor=policy_model, @@ -389,7 +394,7 @@ def rl_train(tmvp_config): ) # Create GRPO trainer - max_logging.log("Setting up GRPO trainer...") + print("Setting up GRPO trainer...") rl_trainer = GrpoLearner( rl_cluster=rl_cluster, reward_fns=[ # type: ignore @@ -410,7 +415,7 @@ def rl_train(tmvp_config): rollout_config=base_rollout.RolloutConfig(max_tokens_to_generate=64, temperature=0.1), ) - max_logging.log(f"Output: {output}") + print(f"Output: {output}") # # # Before we train the model, let's evaluate the model on the test set so we can @@ -421,17 +426,17 @@ def rl_train(tmvp_config): test_dataset, rl_cluster=rl_cluster, ) - max_logging.log(f"Pre GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") + print(f"Pre GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") # Start training - max_logging.log("Starting GRPO training...") + print("Starting GRPO training...") with reference_mesh, nn_partitioning.axis_rules(tmvp_config.logical_axis_rules): rl_trainer.train(train_dataset) - max_logging.log("GRPO Training Completed Successfully!") + print("GRPO Training Completed Successfully!") # Let's evaluate our model! (corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate( @@ -439,7 +444,7 @@ def rl_train(tmvp_config): test_dataset, rl_cluster=rl_cluster, ) - max_logging.log(f"Post GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") + print(f"Post GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") def main(argv: Sequence[str]) -> None: From f963ef80fd874682f4093f8b17132577059fef5d Mon Sep 17 00:00:00 2001 From: A9isha Date: Tue, 4 Nov 2025 00:20:02 +0000 Subject: [PATCH 23/31] debug=True --- src/MaxText/configs/rl.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/MaxText/configs/rl.yml b/src/MaxText/configs/rl.yml index 7f5cf8045..ea592dff2 100644 --- a/src/MaxText/configs/rl.yml +++ b/src/MaxText/configs/rl.yml @@ -71,6 +71,8 @@ value_proj: 'offload' checkpoint_storage_use_ocdbt: False # For Pathways checkpoint_storage_use_zarr3: False # For Pathways +# ====== Debugging ====== +debug: True # ====== Training ====== batch_size: 1 From 5d2633dad277559b15e995644ea721cb3fe598b0 Mon Sep 17 00:00:00 2001 From: A9isha Date: Tue, 4 Nov 2025 05:22:28 +0000 Subject: [PATCH 24/31] fix template issue --- src/MaxText/configs/rl.yml | 9 --------- src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb | 2 +- src/MaxText/rl/evaluate_rl.py | 5 ++++- src/MaxText/rl/train_rl.py | 10 +++++++--- 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/MaxText/configs/rl.yml b/src/MaxText/configs/rl.yml index ea592dff2..8af0e7575 100644 --- a/src/MaxText/configs/rl.yml +++ b/src/MaxText/configs/rl.yml @@ -106,15 +106,6 @@ generation_configs: eval_top_k: 2000 eval_top_p: 1.0 -# greedy -eval_temperature: 0.01 -eval_top_k: 1 -eval_top_p: 1.0 -# # some randomness -# 'standard': {'eval_temperature': 0.7, 'eval_top_k': 50, 'eval_top_p': 0.95}, -# # liberal -# # 'liberal': {'eval_temperature': 0.85, 'eval_top_k': 2000, 'eval_top_p': 1.0}, - # ====== Inference ====== # for vLLM diff --git a/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb b/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb index 41fbeccfc..e24cbba99 100644 --- a/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb +++ b/src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb @@ -115,7 +115,7 @@ "\n", "# Import required modules\n", "from MaxText import pyconfig\n", - "from MaxText.rl.rl_trainer import rl_train\n", + "from MaxText.train_rl import rl_train\n", "\n", "print(\"✅ Successfully imported GRPO training function\")\n", "print(f\"📁 MaxText path: {maxtext_path}\")\n", diff --git a/src/MaxText/rl/evaluate_rl.py b/src/MaxText/rl/evaluate_rl.py index 7b0eea406..db2d5bf9f 100644 --- a/src/MaxText/rl/evaluate_rl.py +++ b/src/MaxText/rl/evaluate_rl.py @@ -90,13 +90,16 @@ def generate_responses( List of lists containing responses for each prompt across passes """ multiple_call_responses = [[] for _ in range(len(prompts))] - eval_strategy = tmvp_config.generation_configs[tmvp_config.eval_sampling_strategy] + for p in range(num_passes): responses = rl_cluster.rollout.generate( prompts, rollout_config=RolloutConfig( max_tokens_to_generate=tmvp_config.max_target_length, + # temperature=eval_strategy.eval_temperature, + # top_k=eval_strategy.eval_top_k, + # top_p=eval_strategy.eval_top_p, temperature=eval_strategy["eval_temperature"], top_k=eval_strategy["eval_top_k"], top_p=eval_strategy["eval_top_p"], diff --git a/src/MaxText/rl/train_rl.py b/src/MaxText/rl/train_rl.py index 695265f88..5f26f6692 100644 --- a/src/MaxText/rl/train_rl.py +++ b/src/MaxText/rl/train_rl.py @@ -179,7 +179,12 @@ def get_dataset(model_tokenizer, tmvp_config, data_dir, split="train") -> grain. { "role": "user", "content": template_config["TEMPLATE"].format( - system_prompt=template_config["SYSTEM_PROMPT"], + system_prompt=template_config["SYSTEM_PROMPT"].format( + reasoning_start_token=tmvp_config.reasoning_start_token, + reasoning_end_token=tmvp_config.reasoning_end_token, + solution_start_token=tmvp_config.solution_start_token, + solution_end_token=tmvp_config.solution_end_token, + ), question=x["question"].decode("utf-8"), ), }, @@ -244,7 +249,6 @@ def rl_train(tmvp_config): for ele in train_dataset[:1]: pprint(ele) - # Setup device allocation @@ -321,7 +325,7 @@ def rl_train(tmvp_config): ) # Setup metrics logging - log_dir = os.path.join(tmvp_config.tensorboard_dir, f"worker_{jax.process_index()}") + log_dir = os.path.join(tmvp_config.tensorboard_dir, f"worker_{jax.process_index()}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}") print(f"TensorBoard logs directory: {log_dir}") # Metrics logger From 8d422998778dab6b6446bca430795b3aa34d1ec6 Mon Sep 17 00:00:00 2001 From: A9isha Date: Tue, 4 Nov 2025 23:11:05 +0000 Subject: [PATCH 25/31] fix get_match_format_regex --- src/MaxText/configs/rl.yml | 2 +- .../examples/grpo_llama3_1_70b_demo_pw.py | 24 +++++++++++-------- src/MaxText/rl/train_rl.py | 9 ++----- src/MaxText/rl/utils_rl.py | 2 +- 4 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/MaxText/configs/rl.yml b/src/MaxText/configs/rl.yml index 8af0e7575..1ba826e3a 100644 --- a/src/MaxText/configs/rl.yml +++ b/src/MaxText/configs/rl.yml @@ -114,7 +114,7 @@ hf_model_name: None # === Generation during GRPO training === # max Lengths for prompt and completion max_prefill_predict_length: 256 -max_target_length: 768 +max_target_length: 1024 kv_cache_buffer: 256 hbm_utilization_vllm: 0.72 swap_space_vllm_gb: 2 diff --git a/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py b/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py index 33d7fcd33..c2446f601 100644 --- a/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py +++ b/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py @@ -147,20 +147,24 @@ # ====== Input Checkpoint directory ===== -MODEL_CHECKPOINT_PATH = "/path/to/scanned/model/ckpt_load_dir/" +# MODEL_CHECKPOINT_PATH = "/path/to/scanned/model/ckpt_load_dir/" +MODEL_CHECKPOINT_PATH = "gs://maxtext-model-checkpoints/llama3.1_70b_instruct/2025-10-15/pathways/scanned/0/items" # ====== Checkpoint directory ===== -LOG_DIR = f"{HOME}/content/tensorboard/grpo/logs_llama3/" -if not os.path.exists(LOG_DIR): - os.makedirs(LOG_DIR) +# LOG_DIR = f"{HOME}/content/tensorboard/grpo/logs_llama3/" +LOG_DIR = f"gs://mazumdera-test-bucket-us-central2/grpo/rl-tuning/grpo/anisha-{run_id}/tensorboard/grpo/logs_llama3/" +# if not os.path.exists(LOG_DIR): + # os.makedirs(LOG_DIR) # ===== Profiling ===== -PROFILE_DIR = f"/path/to/profile_dir/{run_id}/profiles_llama3/" +# PROFILE_DIR = f"/path/to/profile_dir/{run_id}/profiles_llama3/" +PROFILE_DIR = f"gs://mazumdera-test-bucket-us-central2/grpo/rl-tuning/grpo/anisha-{run_id}/profiles_llama3/" if not epath.Path(PROFILE_DIR).exists(): epath.Path(PROFILE_DIR).mkdir(parents=True) # ====== Checkpoint saving ====== -CKPT_DIR = f"/path/to/ckpt_save_dir/{run_id}/ckpts_llama3/" +# CKPT_DIR = f"/path/to/ckpt_save_dir/{run_id}/ckpts_llama3/" +CKPT_DIR = f"gs://mazumdera-test-bucket-us-central2/grpo/rl-tuning/grpo/anisha-{run_id}/ckpts_llama3/" if not epath.Path(CKPT_DIR).exists(): epath.Path(CKPT_DIR).mkdir(parents=True) @@ -952,11 +956,11 @@ def evaluate( temperature=TEMPERATURE, top_p=TOP_P, top_k=TOP_K, + rollout_vllm_model_version="meta-llama/Llama-3.1-70B-Instruct", + rollout_vllm_hbm_utilization=HBM_UTILIZATION_VLLM, + rollout_vllm_tpu_backend_type="jax", + rollout_vllm_swap_space_size_gb=SWAP_SPACE_VLLM_GB, ), - rollout_vllm_model_version="meta-llama/Llama-3.1-70B-Instruct", - rollout_vllm_hbm_utilization=HBM_UTILIZATION_VLLM, - rollout_vllm_tpu_backend_type="jax", - rollout_vllm_swap_space_size_gb=SWAP_SPACE_VLLM_GB, ) grpo_config = GrpoConfig( diff --git a/src/MaxText/rl/train_rl.py b/src/MaxText/rl/train_rl.py index 5f26f6692..1fdfa41ad 100644 --- a/src/MaxText/rl/train_rl.py +++ b/src/MaxText/rl/train_rl.py @@ -360,10 +360,9 @@ def rl_train(tmvp_config): checkpointing_options=checkpointing_options, ), rollout_config=base_rollout.RolloutConfig( - max_tokens_to_generate=tmvp_config.max_target_length, + max_tokens_to_generate=tmvp_config.max_target_length-tmvp_config.max_prefill_predict_length, max_prompt_length=tmvp_config.max_prefill_predict_length, - kv_cache_size=tmvp_config.max_prefill_predict_length - + tmvp_config.max_target_length + kv_cache_size=tmvp_config.max_target_length + tmvp_config.kv_cache_buffer, temperature=tmvp_config.decode_sampling_temperature, top_p=tmvp_config.decode_sampling_nucleus_p, @@ -375,10 +374,6 @@ def rl_train(tmvp_config): ), ) - # The metrics logger is now managed by the GrpoLearner, so we don't need to - # register jax.monitoring listeners separately. - cluster_config.training_config.metrics_logging_options = metrics_logging_options - grpo_config = GrpoConfig( num_generations=tmvp_config.num_generations, num_iterations=tmvp_config.num_iterations, diff --git a/src/MaxText/rl/utils_rl.py b/src/MaxText/rl/utils_rl.py index cb1f64baa..caa9334a3 100644 --- a/src/MaxText/rl/utils_rl.py +++ b/src/MaxText/rl/utils_rl.py @@ -55,7 +55,7 @@ def get_match_format_regex(tmvp_config): """Returns a compiled regex to extract the answer from a completion.""" match_format = re.compile( ( - r"^[\s]{0,}}" + r"^[\s]{0,}" rf"{tmvp_config.reasoning_start_token}.+?{tmvp_config.reasoning_end_token}.*?" rf"{tmvp_config.solution_start_token}(.+?){tmvp_config.solution_end_token}" r"[\s]{0,}$" From b75bb3e1a12e5eedee4ce773ad5f620c0517ddd0 Mon Sep 17 00:00:00 2001 From: A9isha Date: Wed, 5 Nov 2025 01:47:01 +0000 Subject: [PATCH 26/31] import PoolingRequestOutput --- src/MaxText/rl/train_rl.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/MaxText/rl/train_rl.py b/src/MaxText/rl/train_rl.py index 1fdfa41ad..204f324f1 100644 --- a/src/MaxText/rl/train_rl.py +++ b/src/MaxText/rl/train_rl.py @@ -67,7 +67,7 @@ import grain import humanize -import vllm +from vllm.outputs import PoolingRequestOutput import jax from jax.sharding import Mesh import optax @@ -373,7 +373,6 @@ def rl_train(tmvp_config): rollout_vllm_swap_space_size_gb=tmvp_config.swap_space_vllm_gb, ), ) - grpo_config = GrpoConfig( num_generations=tmvp_config.num_generations, num_iterations=tmvp_config.num_iterations, From 377abfa5728a5ef60f7917644603381066837823 Mon Sep 17 00:00:00 2001 From: A9isha Date: Wed, 5 Nov 2025 07:47:21 +0000 Subject: [PATCH 27/31] update learning rate --- src/MaxText/configs/rl.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/MaxText/configs/rl.yml b/src/MaxText/configs/rl.yml index 1ba826e3a..2dae0395a 100644 --- a/src/MaxText/configs/rl.yml +++ b/src/MaxText/configs/rl.yml @@ -88,6 +88,9 @@ eval_interval: 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`. num_epochs: 1 # can potentially train for more epochs +learning_rate: 3e-6 +adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. +adam_b2: 0.99 # Exponential decay rate to track the second moment of past gradients. gradient_clipping_threshold: 0.1 # ====== Evaluation ====== From 592bca90cb0d852c55547e783243347f453352e9 Mon Sep 17 00:00:00 2001 From: A9isha Date: Wed, 5 Nov 2025 08:18:24 +0000 Subject: [PATCH 28/31] nit updates --- src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py | 3 ++- src/MaxText/rl/train_rl.py | 1 + src/MaxText/rl/utils_rl.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py b/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py index c2446f601..4572e2237 100644 --- a/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py +++ b/src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py @@ -956,11 +956,12 @@ def evaluate( temperature=TEMPERATURE, top_p=TOP_P, top_k=TOP_K, + ), rollout_vllm_model_version="meta-llama/Llama-3.1-70B-Instruct", rollout_vllm_hbm_utilization=HBM_UTILIZATION_VLLM, rollout_vllm_tpu_backend_type="jax", rollout_vllm_swap_space_size_gb=SWAP_SPACE_VLLM_GB, - ), + ) grpo_config = GrpoConfig( diff --git a/src/MaxText/rl/train_rl.py b/src/MaxText/rl/train_rl.py index 204f324f1..b3c3df875 100644 --- a/src/MaxText/rl/train_rl.py +++ b/src/MaxText/rl/train_rl.py @@ -302,6 +302,7 @@ def rl_train(tmvp_config): # learning rate to 0 using cosine scheduler. warmup_steps=int(tmvp_config.warmup_steps_fraction*max_train_steps), decay_steps=max_train_steps, + end_value=0.0, ), b1=tmvp_config.adam_b1, b2=tmvp_config.adam_b2, diff --git a/src/MaxText/rl/utils_rl.py b/src/MaxText/rl/utils_rl.py index caa9334a3..8db39caf8 100644 --- a/src/MaxText/rl/utils_rl.py +++ b/src/MaxText/rl/utils_rl.py @@ -137,7 +137,7 @@ def check_answer(prompts, completions, answer, tmvp_config, **kargs): if guess is None: scores.append(0) continue - # Correct answer gets points! + # Correct answer gets tmvp_config.reward_exact_format_match points! if guess == true_answer: score += tmvp_config.reward_exact_format_match # Match if spaces are seen From 8784eb1f86684771e9c9b3e3be546cb2285acbc1 Mon Sep 17 00:00:00 2001 From: A9isha Date: Wed, 5 Nov 2025 08:29:04 +0000 Subject: [PATCH 29/31] use use_pathways flag --- src/MaxText/configs/rl.yml | 1 + src/MaxText/rl/train_rl.py | 14 ++++---------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/MaxText/configs/rl.yml b/src/MaxText/configs/rl.yml index 2dae0395a..603a7005c 100644 --- a/src/MaxText/configs/rl.yml +++ b/src/MaxText/configs/rl.yml @@ -70,6 +70,7 @@ key_proj: 'offload' value_proj: 'offload' checkpoint_storage_use_ocdbt: False # For Pathways checkpoint_storage_use_zarr3: False # For Pathways +use_pathways: True # ====== Debugging ====== debug: True diff --git a/src/MaxText/rl/train_rl.py b/src/MaxText/rl/train_rl.py index b3c3df875..1f49e3f49 100644 --- a/src/MaxText/rl/train_rl.py +++ b/src/MaxText/rl/train_rl.py @@ -131,14 +131,14 @@ def get_maxtext_model(config, devices=None): return tunix_model, mesh -def setup_device_allocation(tmvp_config, use_pathways: bool = False): +def setup_device_allocation(tmvp_config): """Setup device allocation for training and inference.""" devices = jax.devices() num_vms = len(devices) // tmvp_config.chips_per_vm trainer_devices = devices sampler_devices = devices - if num_vms >= 2 and use_pathways: + if num_vms >= 2 and tmvp_config.use_pathways: # Multiple hosts with Pathways - potentially split devices for trainer and sampler # based on trainer_devices_fraction and sampler_devices_fraction print(f"{num_vms} VMs detected, allocating trainer and sampler devices, and using Pathways.") @@ -252,15 +252,9 @@ def rl_train(tmvp_config): # Setup device allocation - if jax.extend.backend.get_backend().platform_version.strip() == "Pathways": + if tmvp_config.use_pathways: print("Pathways backend detected. Disabling setting profile options.") - use_pathways = True - else: - use_pathways = False - print(f"jax.extend.backend.get_backend().platform_version={jax.extend.backend.get_backend().platform_version}") - use_pathways = True - print(f"Use Pathways: {use_pathways}") - trainer_devices, sampler_devices = setup_device_allocation(tmvp_config, use_pathways) + trainer_devices, sampler_devices = setup_device_allocation(tmvp_config) # Load reference model print("Creating reference model and also meshes for reference and rollout") From ef88996baedc12154f7165c48899794a9d8ec9f5 Mon Sep 17 00:00:00 2001 From: A9isha Date: Wed, 5 Nov 2025 09:18:12 +0000 Subject: [PATCH 30/31] use max_logging.log --- src/MaxText/configs/base.yml | 4 +- src/MaxText/configs/rl.yml | 3 +- src/MaxText/rl/evaluate_rl.py | 29 ++++++----- src/MaxText/rl/train_rl.py | 96 +++++++++++++++-------------------- src/MaxText/rl/utils_rl.py | 27 ++++------ 5 files changed, 72 insertions(+), 87 deletions(-) diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 588972453..c334e2586 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -18,7 +18,9 @@ run_name: "" model_name: "default" # override config settings to match a specific model. other than the override, nothing should use this! override_model_config: False # When set to true allows overriding model parameters via CLI for the purpose of debugging/testing. -debug: False # Various trainers can set this to True for custom debugging +debug: + rl: False # RL-specific debugging + normalization_layer_epsilon: 1.e-05 # epsilon value for rmsnorm, layernorm. ################################## CHECKPOINTING ################################## diff --git a/src/MaxText/configs/rl.yml b/src/MaxText/configs/rl.yml index 603a7005c..56c7e75f3 100644 --- a/src/MaxText/configs/rl.yml +++ b/src/MaxText/configs/rl.yml @@ -73,7 +73,8 @@ checkpoint_storage_use_zarr3: False # For Pathways use_pathways: True # ====== Debugging ====== -debug: True +debug: + rl: True # ====== Training ====== batch_size: 1 diff --git a/src/MaxText/rl/evaluate_rl.py b/src/MaxText/rl/evaluate_rl.py index db2d5bf9f..ce1ba5c19 100644 --- a/src/MaxText/rl/evaluate_rl.py +++ b/src/MaxText/rl/evaluate_rl.py @@ -50,6 +50,7 @@ from MaxText.globals import MAXTEXT_ASSETS_ROOT from MaxText.rl import utils_rl +from MaxText import max_logging # ## Evaluate # We evaluate it in two ways: @@ -107,8 +108,8 @@ def generate_responses( ) responses = responses.text - if tmvp_config.debug: - print(f"Pass {p+1}/{num_passes}, responses: {responses}") + if tmvp_config.debug["rl"]: + max_logging.log(f"Pass {p+1}/{num_passes}, responses: {responses}") for idx, response in enumerate(responses): multiple_call_responses[idx].append(response) @@ -131,12 +132,12 @@ def score_responses(tmvp_config, question, responses, answer): match_format = utils_rl.get_match_format_regex(tmvp_config) match_numbers = utils_rl.get_match_numbers_regex(tmvp_config) - if tmvp_config.debug: - print("========================================") - print(f"Evaluation Question: {question}") - print(f"Evaluation Answer: {answer}") - print(f"Evaluation Responses: {responses}") - print("========================================") + if tmvp_config.debug["rl"]: + max_logging.log("========================================") + max_logging.log(f"Evaluation Question: {question}") + max_logging.log(f"Evaluation Answer: {answer}") + max_logging.log(f"Evaluation Responses: {responses}") + max_logging.log("========================================") is_correct = False is_partially_correct = False @@ -146,8 +147,8 @@ def score_responses(tmvp_config, question, responses, answer): # Extract numerical response extracted_response = guess.group(1) if (guess := match_numbers.search(response)) is not None else "-1000000" - if tmvp_config.debug: - print(f"Evaluation extracted_response: {extracted_response}") + if tmvp_config.debug["rl"]: + max_logging.log(f"Evaluation extracted_response: {extracted_response}") # Check exact correctness try: @@ -159,9 +160,9 @@ def score_responses(tmvp_config, question, responses, answer): if 0.9 <= ratio <= 1.1: is_partially_correct = True except Exception as e: - if tmvp_config.debug: - print(f"Evaluation Exception: {e}") - print("SKIPPED") + if tmvp_config.debug["rl"]: + max_logging.log(f"Evaluation Exception: {e}") + max_logging.log("SKIPPED") # Check format correctness if match_format.search(response) is not None: @@ -243,7 +244,7 @@ def evaluate( # Print progress every 10 items if total % 10 == 0: - print( + max_logging.log( f"===> {corr=}, {total=}, {corr / total * 100=}, " f"{partially_corr / total * 100=}, {corr_format / total * 100=}" ) diff --git a/src/MaxText/rl/train_rl.py b/src/MaxText/rl/train_rl.py index 1f49e3f49..276b5fe49 100644 --- a/src/MaxText/rl/train_rl.py +++ b/src/MaxText/rl/train_rl.py @@ -25,7 +25,7 @@ Usage Examples: # Llama3.1-8B (single host) - python3 src/MaxText/examples/train_rl \\ + python3 -m src.MaxText.rl.train_rl \\ --model_name=llama3.1-8b \\ --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \\ --load_parameters_path=gs://path/to/checkpoint \\ @@ -34,7 +34,7 @@ --steps=100 # Llama3.1-70B with Pathways (multi-host) - python3 src/MaxText/examples/train_rl \\ + python3 -m src.MaxText.rl.train_rl \\ --model_name=llama3.1-70b \\ --tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \\ --load_parameters_path=gs://path/to/checkpoint \\ @@ -43,7 +43,7 @@ --steps=100 # Custom dataset - python3 src/MaxText/examples/train_rl \\ + python3 -m src.MaxText.rl.train_rl \\ --model_name=llama3.1-8b \\ --tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \\ --load_parameters_path=gs://path/to/checkpoint \\ @@ -94,17 +94,6 @@ # for vLLM we can skip JAX precompilation with this flag, it makes startup faster os.environ["SKIP_JAX_PRECOMPILE"] = "1" -# add the parent directory (two levels up to say ~/HOME/maxtext) to sys.path if currenlt runnig from -# ~/HOME/maxtext/MaxText/examples - -# Get the directory of the current script -script_dir = os.path.dirname(os.path.abspath(__file__)) - -# Go up two levels to get the project root -project_root = os.path.abspath(os.path.join(script_dir, "..", "..", "..")) - -# Add the project root to the Python path -sys.path.insert(0, project_root) from MaxText import max_logging, max_utils, maxtext_utils, pyconfig from MaxText import model_creation_utils @@ -114,15 +103,20 @@ from MaxText.input_pipeline.instruction_data_processing import load_template_from_file -# We use OpenAI's GSM8K dataset. GSM8K comprises grade school math word problems. - - - def get_maxtext_model(config, devices=None): """ Load MaxText model with Tunix adapter. - # Note: pass the path to your scanned checkpoint for "load_parameters_path". To generate a scanned checkpoint, you can use the `scanned_checkpoint.py` script in MaxText. - # To create a scanned checkpoint, you can use /maxtext/MaxText/utils/ckpt_conversion/to_maxtext.py + # Note: pass the path to your scanned checkpoint for 'load_parameters_path'. + # To create a scanned checkpoint, you can use /maxtext/src/MaxText/utils/ckpt_conversion/to_maxtext.py and if + # using Pathways, please set `checkpoint_storage_use_ocdbt=False checkpoint_storage_use_zarr3=False` + # python src/MaxText/utils/ckpt_conversion/to_maxtext.py \ + # --model_name="gemma2-2b" \ + # --base_output_directory="/path/to/your/output/directory" \ + # --scan_layers=True \ + # --checkpoint_storage_use_ocdbt=False\ + # checkpoint_storage_use_zarr3=False + # Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e., + # load_parameters_path=/path/to/your/output/directory/0/items """ model, mesh = model_creation_utils.create_nnx_model(config, devices) with mesh: @@ -141,16 +135,16 @@ def setup_device_allocation(tmvp_config): if num_vms >= 2 and tmvp_config.use_pathways: # Multiple hosts with Pathways - potentially split devices for trainer and sampler # based on trainer_devices_fraction and sampler_devices_fraction - print(f"{num_vms} VMs detected, allocating trainer and sampler devices, and using Pathways.") + max_logging.log(f"{num_vms} VMs detected, allocating trainer and sampler devices, and using Pathways.") num_devices = len(devices) num_trainer_devices = int(num_devices * tmvp_config.trainer_devices_fraction) num_sampler_devices = int(num_devices * tmvp_config.sampler_devices_fraction) trainer_devices = devices[:num_trainer_devices] sampler_devices = devices[num_devices - num_sampler_devices :] if tmvp_config.trainer_devices_fraction!=1.0: - print(f"Using first {len(trainer_devices)} devices as Trainer devices") + max_logging.log(f"Using first {len(trainer_devices)} devices as Trainer devices") if tmvp_config.sampler_devices_fraction != 1.0: - print(f"Using last {len(sampler_devices)} devices as Sampler devices") + max_logging.log(f"Using last {len(sampler_devices)} devices as Sampler devices") return trainer_devices, sampler_devices @@ -208,10 +202,8 @@ def rl_train(tmvp_config): Args: tmvp_config: MaxText configuration object """ - # ====== Debug flag for verbose logs ====== - DEBUG = tmvp_config.debug - print("Starting GRPO Training") + max_logging.log("Starting GRPO Training") # Number of training steps. max_train_steps = int(tmvp_config.num_batches * tmvp_config.num_iterations * tmvp_config.train_fraction * tmvp_config.num_epochs) @@ -245,7 +237,7 @@ def rl_train(tmvp_config): # Let's see how one batch of the dataset looks like! - if tmvp_config.debug: + if tmvp_config.debug["rl"]: for ele in train_dataset[:1]: pprint(ele) @@ -253,38 +245,37 @@ def rl_train(tmvp_config): # Setup device allocation if tmvp_config.use_pathways: - print("Pathways backend detected. Disabling setting profile options.") + max_logging.log("Pathways backend detected. Disabling setting profile options.") trainer_devices, sampler_devices = setup_device_allocation(tmvp_config) # Load reference model - print("Creating reference model and also meshes for reference and rollout") + max_logging.log("Creating reference model and also meshes for reference and rollout") reference_model, reference_mesh = get_maxtext_model(tmvp_config, trainer_devices) devices_array = maxtext_utils.create_device_mesh(tmvp_config, sampler_devices) # if trainer_devices=sampler_devices, then rollout_mesh=reference_mesh # else rollout_mesh uses sampler_devices rollout_mesh = Mesh(devices_array, tmvp_config.mesh_axes) - if tmvp_config.debug: - print("Reference Model initialized successfully") + if tmvp_config.debug["rl"]: + max_logging.log("Reference Model initialized successfully") nnx.display(reference_model) - print(f"Reference mesh shape: {reference_mesh.shape}") + max_logging.log(f"Reference mesh shape: {reference_mesh.shape}") # Sanity check that weights are loaded correctly. _maxtext_state_flatten = nnx.state(reference_model).flat_state() maxtext_state_flatten = {".".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten} - print( + max_logging.log( f"maxtext_state_flatten[base.token_embedder.embedding].value={maxtext_state_flatten['base.token_embedder.embedding'].value}" ) # TODO: @mazumdera: change this to use lora # Load policy model - print("Creating policy model with same config as reference model on trainer mesh") - policy_model, policy_mesh = get_maxtext_model(tmvp_config, trainer_devices) - actor_mesh = policy_mesh + max_logging.log("Creating policy model with same config as reference model on trainer mesh") + actor_model, actor_mesh = get_maxtext_model(tmvp_config, trainer_devices) - if tmvp_config.debug: - print("Policy Model initialized successfully") - nnx.display(policy_model) - print(f"Policy mesh shape: {policy_mesh.shape}") + if tmvp_config.debug["rl"]: + max_logging.log("Policy Model initialized successfully") + nnx.display(actor_model) + max_logging.log(f"Policy mesh shape: {actor_mesh.shape}") # Setup optimizer optimizer = optax.adamw( @@ -322,7 +313,7 @@ def rl_train(tmvp_config): # Setup metrics logging log_dir = os.path.join(tmvp_config.tensorboard_dir, f"worker_{jax.process_index()}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}") - print(f"TensorBoard logs directory: {log_dir}") + max_logging.log(f"TensorBoard logs directory: {log_dir}") # Metrics logger metrics_logging_options = metrics_logger.MetricsLoggerOptions(log_dir=log_dir, flush_every_n_steps=tmvp_config.log_period) @@ -377,17 +368,17 @@ def rl_train(tmvp_config): ) # Create RL cluster - print("Creating RL cluster...") + max_logging.log("Creating RL cluster...") with nn_partitioning.axis_rules(tmvp_config.logical_axis_rules): rl_cluster = rl_cluster_lib.RLCluster( - actor=policy_model, + actor=actor_model, reference=reference_model, tokenizer=model_tokenizer, cluster_config=cluster_config, ) # Create GRPO trainer - print("Setting up GRPO trainer...") + max_logging.log("Setting up GRPO trainer...") rl_trainer = GrpoLearner( rl_cluster=rl_cluster, reward_fns=[ # type: ignore @@ -399,18 +390,15 @@ def rl_train(tmvp_config): grpo_config=grpo_config, ) - - - if tmvp_config.debug: + if tmvp_config.debug["rl"]: # verify if vllm sampler works output = rl_cluster.rollout.generate( ["The capital of France is"], rollout_config=base_rollout.RolloutConfig(max_tokens_to_generate=64, temperature=0.1), ) - print(f"Output: {output}") - # - # + max_logging.log(f"Output: {output}") + # Before we train the model, let's evaluate the model on the test set so we can # see the improvement post training. # @@ -419,17 +407,17 @@ def rl_train(tmvp_config): test_dataset, rl_cluster=rl_cluster, ) - print(f"Pre GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") + max_logging.log(f"Pre GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") # Start training - print("Starting GRPO training...") + max_logging.log("Starting GRPO training...") with reference_mesh, nn_partitioning.axis_rules(tmvp_config.logical_axis_rules): rl_trainer.train(train_dataset) - print("GRPO Training Completed Successfully!") + max_logging.log("GRPO Training Completed Successfully!") # Let's evaluate our model! (corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate( @@ -437,7 +425,7 @@ def rl_train(tmvp_config): test_dataset, rl_cluster=rl_cluster, ) - print(f"Post GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") + max_logging.log(f"Post GRPO Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") def main(argv: Sequence[str]) -> None: diff --git a/src/MaxText/rl/utils_rl.py b/src/MaxText/rl/utils_rl.py index 8db39caf8..2d0b71a44 100644 --- a/src/MaxText/rl/utils_rl.py +++ b/src/MaxText/rl/utils_rl.py @@ -41,13 +41,8 @@ from transformers import AutoTokenizer -from flax import linen as nn -import numpy as np -from etils import epath - -from tunix.rl.rollout.base_rollout import RolloutConfig - from MaxText.globals import MAXTEXT_ASSETS_ROOT +from MaxText import max_logging # Let's define a RegEx for checking whether the format matches. # @@ -62,7 +57,7 @@ def get_match_format_regex(tmvp_config): ), flags=re.MULTILINE | re.DOTALL, ) - if tmvp_config.debug: + if tmvp_config.debug["rl"]: match_format.search( f"{tmvp_config.reasoning_start_token}Let me" f" think!{tmvp_config.reasoning_end_token}{tmvp_config.solution_start_token}2{tmvp_config.solution_end_token}", ) @@ -168,7 +163,7 @@ def get_match_numbers_regex(tmvp_config): match_numbers = re.compile( rf"{tmvp_config.solution_start_token}.*?([\d\.]{{1,}})", flags=re.MULTILINE | re.DOTALL ) - if tmvp_config.debug: + if tmvp_config.debug["rl"]: match_numbers.findall(f"{tmvp_config.solution_start_token} 0.34 {tmvp_config.solution_end_token}") return match_numbers @@ -185,13 +180,13 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs): ] scores = [] - if tmvp_config.debug: - print("START ============================") - print(f"Question: {question[0]}") - print(f"Answer: {answer[0]}") - print(f"Response: {completions[0]}") - print(f"Extracted: {extracted_responses[0]}") - print("END ==============================") + if tmvp_config.debug["rl"]: + max_logging.log("START ============================") + max_logging.log(f"Question: {question[0]}") + max_logging.log(f"Answer: {answer[0]}") + max_logging.log(f"Response: {completions[0]}") + max_logging.log(f"Extracted: {extracted_responses[0]}") + max_logging.log("END ==============================") for guess, true_answer in zip(extracted_responses, answer): if guess is None: scores.append(0) @@ -208,8 +203,6 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs): def extract_hash_answer(text: str, debug: bool = False) -> str | None: """Function to extract only the answer hash from the text.""" - if debug: - print(f"Extracting answer from: {text}") if "####" not in text: return None return text.split("####")[1].strip() From 0b83e8b1c046a79db1eab99f181697424db3d92e Mon Sep 17 00:00:00 2001 From: A9isha Date: Wed, 5 Nov 2025 19:32:03 +0000 Subject: [PATCH 31/31] fix lint --- src/MaxText/rl/evaluate_rl.py | 40 ++++------------------------------- src/MaxText/rl/train_rl.py | 13 ++---------- src/MaxText/rl/utils_rl.py | 27 ----------------------- 3 files changed, 6 insertions(+), 74 deletions(-) diff --git a/src/MaxText/rl/evaluate_rl.py b/src/MaxText/rl/evaluate_rl.py index ce1ba5c19..de15e0610 100644 --- a/src/MaxText/rl/evaluate_rl.py +++ b/src/MaxText/rl/evaluate_rl.py @@ -13,42 +13,12 @@ # limitations under the License. # pylint: disable=bare-except, consider-using-generator - -import functools -import os -from pprint import pprint -import re - -import sys - -from datetime import datetime -from flax import nnx -from flax.linen import partitioning as nn_partitioning -import grain -import humanize - - -import jax -from jax.sharding import Mesh -import optax -from orbax import checkpoint as ocp -import tensorflow_datasets as tfds +""" +RL Evaluation Module. +""" from tqdm.auto import tqdm -from tunix.rl import rl_cluster as rl_cluster_lib -from tunix.rl.rollout import base_rollout -from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner -from tunix.sft import metrics_logger - - -from transformers import AutoTokenizer - -from flax import linen as nn -import numpy as np -from etils import epath - from tunix.rl.rollout.base_rollout import RolloutConfig -from MaxText.globals import MAXTEXT_ASSETS_ROOT from MaxText.rl import utils_rl from MaxText import max_logging @@ -258,6 +228,4 @@ def evaluate( corr_format / total * 100, ) - if make_lst: - return to_return, response_lst - return to_return + return to_return, response_lst diff --git a/src/MaxText/rl/train_rl.py b/src/MaxText/rl/train_rl.py index 276b5fe49..88e2c39e3 100644 --- a/src/MaxText/rl/train_rl.py +++ b/src/MaxText/rl/train_rl.py @@ -53,19 +53,15 @@ --steps=100 """ -from typing import Any, Sequence -import functools +from typing import Sequence import os from pprint import pprint -import re -import sys from datetime import datetime from absl import app from flax import nnx from flax.linen import partitioning as nn_partitioning import grain -import humanize from vllm.outputs import PoolingRequestOutput import jax @@ -73,7 +69,6 @@ import optax from orbax import checkpoint as ocp import tensorflow_datasets as tfds -from tqdm.auto import tqdm from tunix.rl import rl_cluster as rl_cluster_lib from tunix.rl.rollout import base_rollout from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner @@ -82,11 +77,7 @@ from transformers import AutoTokenizer -from flax import linen as nn -import numpy as np -from etils import epath -from MaxText.globals import MAXTEXT_ASSETS_ROOT import pathwaysutils pathwaysutils.initialize() @@ -402,7 +393,7 @@ def rl_train(tmvp_config): # Before we train the model, let's evaluate the model on the test set so we can # see the improvement post training. # - (corr, total, accuracy, partial_accuracy, format_accuracy) = evaluate( + (corr, total, accuracy, partial_accuracy, format_accuracy), _ = evaluate( tmvp_config, test_dataset, rl_cluster=rl_cluster, diff --git a/src/MaxText/rl/utils_rl.py b/src/MaxText/rl/utils_rl.py index 2d0b71a44..303147bf1 100644 --- a/src/MaxText/rl/utils_rl.py +++ b/src/MaxText/rl/utils_rl.py @@ -14,34 +14,7 @@ # pylint: disable=bare-except, consider-using-generator -import functools -import os -from pprint import pprint import re -import sys - -from datetime import datetime -from flax import nnx -from flax.linen import partitioning as nn_partitioning -import grain -import humanize - - -import jax -from jax.sharding import Mesh -import optax -from orbax import checkpoint as ocp -import tensorflow_datasets as tfds -from tqdm.auto import tqdm -from tunix.rl import rl_cluster as rl_cluster_lib -from tunix.rl.rollout import base_rollout -from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner -from tunix.sft import metrics_logger - - -from transformers import AutoTokenizer - -from MaxText.globals import MAXTEXT_ASSETS_ROOT from MaxText import max_logging # Let's define a RegEx for checking whether the format matches.