Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions src/MaxText/configs/grpo.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# RL Configuration
# This config consolidates common parameters for RL training across different model sizes

base_config: "base.yml"

# ====== Hardware =====
trainer_devices_fraction: 0.5
sampler_devices_fraction: 0.5
chips_per_vm: 4 # depends on hardware, for v5p this is 4

# ====== Debug ======
debug: True

# ====== Reproducibility ======
data_shuffle_seed: 42
loss_algo: 'grpo' # grpo or gspo-token

# ====== Checkpoint saving ======
save_interval_steps: 500
max_to_keep: 4

# ====== GRPO ======
# === Generation during GRPO training ===
max_prompt_length: 256
total_generation_steps: 768

# The number of times the policy generates multiple responses for a given prompt
# within a single training step. This corresponds to `G` in Algorithm 1 in the
# paper. The "group" in GRPO comes from here.
num_generations: 2

# === other GRPO configs ===
# The number of iterations per batch (𝜇 in GRPO algo 1).
num_iterations: 1

# The coefficient for the KL divergence penalty (𝛽) in the GRPO loss function.
# Important to keep a high enough value for this, otherwise, the KL divergence
# can increase unchecked.
beta: 0.08
# Epsilon value for clipping (𝜀 in GRPO loss in paper). Similar to PPO, for
# stable updates.
epsilon: 0.2

# ====== Training ======

batch_size: 1
# Increase `batch_size` and `MAX_STEPS` for better results.
# NUM_BATCHES = 3738
NUM_BATCHES = 4 # 200
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
# increased to a max. of 330 (if batch size is 4).
NUM_TEST_BATCHES = 5 # 200

EVAL_EVERY_N_STEPS = 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`.
NUM_EPOCHS = 1 # can potentially train for more epochs


# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 3e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1
# == Cosine decay with warmup scheduler ==
# Linearly increase learning rate from 0. to 5e-6 in the first 10% training
# steps, and then gradually decrease the learning rate to 0 using cosine
# scheduler.
WARMUP_STEPS = int(0.1 * MAX_STEPS)
# == Grad clipping ==
# Grad clipping to prevent large gradients. Found this
# important to keep KL divergence in check.
MAX_GRAD_NORM = 0.1


# ====== Inference ======
# Important to keep a high-ish temperature for varied, diverse responses during
# training.
# greedy search
temperature: 0.01
top_p: 1.0
top_k: 1

# # some randomness
# "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
# # liberal
# "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},

TRAINER_DEVICES_FRACTION = 0.5
SAMPLER_DEVICES_FRACTION = 0.5
HBM_UTILIZATION_VLLM = 0.72
SWAP_SPACE_VLLM_GB = 2


# ====== Reward ======
REWARD_EXACT_FORMAT_MATCH = 3.0
REWARD_WHITE_SPACE_FORMAT_MATCH = 1.5
REWARD_PARTIAL_FORMAT_MATCH = 0.5
REWARD_RATIO_GUESS_TO_ANSWER_HIGH = 0.5
REWARD_RATIO_GUESS_TO_ANSWER_LOW = 0.25
PENALTY_INCORRECT_FORMAT = -0.5
PENALTY_INCORRECT_ANSWER = -1.0


# TODO: fix this
# Dataset Configuration
dataset_type: hf # Huggingface input pipeline
hf_path: 'gsm8k'
hf_data_split: 'main'
hf_data_files: 'train'

# Model and Tokenizer Configuration
# Override these via CLI:
# model_name, tokenizer_path, load_parameters_path

# Sequence Lengths
max_prefill_predict_length: 256
max_target_length: 768

# Training Hyperparameters
learning_rate: 3.e-6
adam_b1: 0.9
adam_b2: 0.99
weight_decay: 0.1
max_grad_norm: 0.1

# Group Relative Policy Optimization (GRPO) Parameters
num_generations: 2
grpo_beta: 0.08 # KL divergence penalty coefficient
grpo_epsilon: 0.2 # Clipping value for stable updates
inference_rollouts: 1

# Generation Configuration During Training
decode_sampling_strategy: "weighted"
decode_sampling_temperature: 0.9
decode_sampling_top_p: 1.0
decode_sampling_top_k: 50

# Training Loop Configuration
steps: 100
per_device_batch_size: 1

# Checkpoint Configuration
enable_checkpointing: True
async_checkpointing: True
checkpoint_period: 50

# Pathways Inference Configuration
# For multi-host/multi-slice setups
use_pathways_reshard: False
inference_devices_per_replica: 4
inference_replicas: 1

# Tokenizer Settings
add_bos: False
add_eos: False
return_log_prob: True

# Performance and Memory
weight_dtype: bfloat16
dtype: bfloat16

# Profiling
profiler: xplane
skip_first_n_steps_for_profiler: 5
profiler_steps: 3

# Splash Attention Block Sizes
# Tuned for GRPO workloads
sa_block_q: 128
sa_block_kv: 128
sa_block_kv_compute: 128
sa_block_q_dkv: 128
sa_block_kv_dkv: 128
sa_block_kv_dkv_compute: 128
sa_block_q_dq: 128
sa_block_kv_dq: 128
sa_use_fused_bwd_kernel: False
sa_q_layout: "HEAD_DIM_MINOR"
sa_k_layout: "HEAD_DIM_MINOR"
sa_v_layout: "HEAD_DIM_MINOR"

# Model-Specific Overrides (examples)
# For Llama3.1-8B:
# model_name: llama3.1-8b
# tokenizer_path: meta-llama/Llama-3.1-8B-Instruct
# ici_fsdp_parallelism: 8
#
# For Llama3.1-70B with Pathways:
# model_name: llama3.1-70b
# tokenizer_path: meta-llama/Llama-3.1-70B-Instruct
# use_pathways_reshard: True
# ici_fsdp_parallelism: 16

Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,38 @@ Use the link for Jupyter Lab as a link for "Connect to a local runtime" in Colla

### GRPO Training

- **`grpo_llama3_demo.ipynb`** → GRPO training on math dataset
- **`grpo_llama3_1_8b_demo.ipynb`** → GRPO training on math dataset (Colab/notebook)
- **`grpo_runner.py`** → Unified CLI for GRPO training (any model)

#### GRPO Colab Usage

For interactive GRPO training in Google Colab or Jupyter:

1. **Open** `src/MaxText/examples/grpo_llama3_1_8b_demo.ipynb`
2. **Enable TPU runtime** (Runtime → Change runtime type → TPU)
3. **Run cells** to train Llama3.1-8B with GRPO on GSM8K dataset

#### GRPO Python Script Usage

```bash
# Llama3.1-8B
python3 src/MaxText/examples/grpo_runner.py \
--model_name=llama3.1-8b \
--tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \
--load_parameters_path=gs://path/to/checkpoint \
--base_output_directory=/tmp/grpo_output \
--hf_access_token=$HF_TOKEN \
--steps=100

# Qwen2.5-7B
python3 src/MaxText/examples/grpo_runner.py \
--model_name=qwen2.5-7b \
--tokenizer_path=Qwen/Qwen2.5-7B-Instruct \
--load_parameters_path=gs://path/to/checkpoint \
--base_output_directory=/tmp/grpo_output \
--hf_access_token=$HF_TOKEN \
--steps=100
```

## Common Pitfalls & Debugging

Expand Down
7 changes: 6 additions & 1 deletion src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@

# pylint: disable=bare-except, consider-using-generator
"""
DEPRECATED: This file is deprecated and kept for reference only.
Please use the new unified CLI interface: rl_trainer.py

See GRPO_README.md for migration guide and usage examples.

This tutorial demonstrates training the Llama3.1 70B-IT model on
the GSM8K math reasoning benchmark using Group Relative Policy Optimization (GRPO).
GRPO can enhance your model's problem-solving skills on mathematical word problems,
Expand All @@ -34,7 +39,7 @@
# We use Tunix as the library for GRPO.
# And we use vLLM as the library for efficient model inference and generation.
#
# In this tutorial we use a single host TPUVM such as `v6e-8/v5p-8`. Let's get started!
# In this tutorial we use `v5p-256` or `v5p-128`. Let's get started!


# ## Install necessary libraries
Expand Down
Loading
Loading