Skip to content

Commit d6e053a

Browse files
committed
WIP src/MaxText/rl/rl_trainer.py, delete grpo_runner
1 parent d1dfbb8 commit d6e053a

File tree

7 files changed

+258
-467
lines changed

7 files changed

+258
-467
lines changed

β€Žsrc/MaxText/configs/grpo.ymlβ€Ž

Lines changed: 99 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,109 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# GRPO Configuration
16-
# This config consolidates common parameters for GRPO training across different model sizes
15+
# RL Configuration
16+
# This config consolidates common parameters for RL training across different model sizes
1717

1818
base_config: "base.yml"
1919

20-
use_grpo: True
21-
train_data_columns: 'prompt'
20+
# ====== Hardware =====
21+
trainer_devices_fraction: 0.5
22+
sampler_devices_fraction: 0.5
23+
chips_per_vm: 4 # depends on hardware, for v5p this is 4
2224

25+
# ====== Debug ======
26+
debug: True
27+
28+
# ====== Reproducibility ======
29+
data_shuffle_seed: 42
30+
loss_algo: 'grpo' # grpo or gspo-token
31+
32+
# ====== Checkpoint saving ======
33+
save_interval_steps: 500
34+
max_to_keep: 4
35+
36+
# ====== GRPO ======
37+
# === Generation during GRPO training ===
38+
max_prompt_length: 256
39+
total_generation_steps: 768
40+
41+
# The number of times the policy generates multiple responses for a given prompt
42+
# within a single training step. This corresponds to `G` in Algorithm 1 in the
43+
# paper. The "group" in GRPO comes from here.
44+
num_generations: 2
45+
46+
# === other GRPO configs ===
47+
# The number of iterations per batch (πœ‡ in GRPO algo 1).
48+
num_iterations: 1
49+
50+
# The coefficient for the KL divergence penalty (𝛽) in the GRPO loss function.
51+
# Important to keep a high enough value for this, otherwise, the KL divergence
52+
# can increase unchecked.
53+
beta: 0.08
54+
# Epsilon value for clipping (πœ€ in GRPO loss in paper). Similar to PPO, for
55+
# stable updates.
56+
epsilon: 0.2
57+
58+
# ====== Training ======
59+
60+
batch_size: 1
61+
# Increase `batch_size` and `MAX_STEPS` for better results.
62+
# NUM_BATCHES = 3738
63+
NUM_BATCHES = 4 # 200
64+
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
65+
# increased to a max. of 330 (if batch size is 4).
66+
NUM_TEST_BATCHES = 5 # 200
67+
68+
EVAL_EVERY_N_STEPS = 10 # this doesn't matter if `TRAIN_FRACTION = 1.0`.
69+
NUM_EPOCHS = 1 # can potentially train for more epochs
70+
71+
72+
# === AdamW, warmup, cosine scheduler ===
73+
LEARNING_RATE = 3e-6
74+
B1 = 0.9
75+
B2 = 0.99
76+
WEIGHT_DECAY = 0.1
77+
# == Cosine decay with warmup scheduler ==
78+
# Linearly increase learning rate from 0. to 5e-6 in the first 10% training
79+
# steps, and then gradually decrease the learning rate to 0 using cosine
80+
# scheduler.
81+
WARMUP_STEPS = int(0.1 * MAX_STEPS)
82+
# == Grad clipping ==
83+
# Grad clipping to prevent large gradients. Found this
84+
# important to keep KL divergence in check.
85+
MAX_GRAD_NORM = 0.1
86+
87+
88+
# ====== Inference ======
89+
# Important to keep a high-ish temperature for varied, diverse responses during
90+
# training.
91+
# greedy search
92+
temperature: 0.01
93+
top_p: 1.0
94+
top_k: 1
95+
96+
# # some randomness
97+
# "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
98+
# # liberal
99+
# "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
100+
101+
TRAINER_DEVICES_FRACTION = 0.5
102+
SAMPLER_DEVICES_FRACTION = 0.5
103+
HBM_UTILIZATION_VLLM = 0.72
104+
SWAP_SPACE_VLLM_GB = 2
105+
106+
107+
# ====== Reward ======
108+
REWARD_EXACT_FORMAT_MATCH = 3.0
109+
REWARD_WHITE_SPACE_FORMAT_MATCH = 1.5
110+
REWARD_PARTIAL_FORMAT_MATCH = 0.5
111+
REWARD_RATIO_GUESS_TO_ANSWER_HIGH = 0.5
112+
REWARD_RATIO_GUESS_TO_ANSWER_LOW = 0.25
113+
PENALTY_INCORRECT_FORMAT = -0.5
114+
PENALTY_INCORRECT_ANSWER = -1.0
115+
116+
117+
# TODO: fix this
23118
# Dataset Configuration
24119
dataset_type: hf # Huggingface input pipeline
25120
hf_path: 'gsm8k'
@@ -56,8 +151,6 @@ decode_sampling_top_k: 50
56151
# Training Loop Configuration
57152
steps: 100
58153
per_device_batch_size: 1
59-
eval_interval: 10
60-
eval_steps: 5
61154

62155
# Checkpoint Configuration
63156
enable_checkpointing: True

β€Žsrc/MaxText/examples/grpo_llama3_1_70b_demo_pw.pyβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# pylint: disable=bare-except, consider-using-generator
1616
"""
1717
DEPRECATED: This file is deprecated and kept for reference only.
18-
Please use the new unified CLI interface: grpo_runner.py
18+
Please use the new unified CLI interface: rl_trainer.py
1919
2020
See GRPO_README.md for migration guide and usage examples.
2121

β€Žsrc/MaxText/examples/grpo_llama3_1_8b_demo.ipynbβ€Ž

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"source": [
77
"# GRPO Llama3.1-8B Demo: Direct Function Call\n",
88
"\n",
9-
"This notebook demonstrates GRPO training by directly calling the `grpo_train` function from `grpo_tunix_trainer.py`.\n",
9+
"This notebook demonstrates GRPO training by directly calling the `rl_train` function from `rl_trainer.py`.\n",
1010
"\n",
1111
"## What is GRPO?\n",
1212
"\n",
@@ -16,7 +16,7 @@
1616
"3. Calculating relative advantages to update the policy\n",
1717
"\n",
1818
"\n",
19-
"This notebook imports and calls the `grpo_train` function \n",
19+
"This notebook imports and calls the `rl_train` function \n",
2020
"\n",
2121
"## Hardware Requirements\n",
2222
"\n",
@@ -115,7 +115,7 @@
115115
"\n",
116116
"# Import required modules\n",
117117
"from MaxText import pyconfig\n",
118-
"from MaxText.experimental.rl.grpo_tunix_trainer import grpo_train\n",
118+
"from MaxText.rl.rl_trainer import rl_train\n",
119119
"\n",
120120
"print(\"βœ… Successfully imported GRPO training function\")\n",
121121
"print(f\"πŸ“ MaxText path: {maxtext_path}\")\n",
@@ -145,8 +145,6 @@
145145
" \"num_generations=2\",\n",
146146
" \"grpo_beta=0.08\",\n",
147147
" \"grpo_epsilon=0.2\",\n",
148-
" \"trainer_devices_fraction=0.5\",\n",
149-
" \"sampler_devices_fraction=0.5\",\n",
150148
" \"chips_per_vm=4\"\n",
151149
"]\n",
152150
"\n",
@@ -168,14 +166,13 @@
168166
"source": [
169167
"# Execute GRPO training directly\n",
170168
"try:\n",
171-
" # Call the grpo_train function\n",
172-
" grpo_trainer, rl_cluster = grpo_train(config)\n",
169+
" # Call the rl_train function\n",
170+
" grpo_trainer, rl_cluster = rl_train(config)\n",
173171
" \n",
174172
" print(\"\\n\" + \"=\"*80)\n",
175173
" print(\"βœ… GRPO Training Completed Successfully!\")\n",
176174
" print(\"=\"*80)\n",
177-
" print(f\"πŸ“ Checkpoints saved to: {config.base_output_directory}/checkpoints\")\n",
178-
" print(f\"πŸ“Š Logs available in: {config.base_output_directory}/logs\")\n",
175+
" print(f\"πŸ“ Checkpoints and logs saved to: {config.base_output_directory}\")\n",
179176
" print(f\"🎯 Final model ready for inference!\")\n",
180177
" \n",
181178
"except Exception as e:\n",

β€Žsrc/MaxText/examples/grpo_llama3_1_8b_demo.pyβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# pylint: disable=bare-except, consider-using-generator
1616
"""
1717
DEPRECATED: This file is deprecated and kept for reference only.
18-
Please use the new unified CLI interface: grpo_runner.py
18+
Please use the new unified CLI interface: rl_trainer.py
1919
2020
See GRPO_README.md for migration guide and usage examples.
2121

β€Žsrc/MaxText/examples/grpo_llama3_1_8b_demo_pw.pyβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# pylint: disable=bare-except, consider-using-generator
1616
"""
1717
DEPRECATED: This file is deprecated and kept for reference only.
18-
Please use the new unified CLI interface: grpo_runner.py
18+
Please use the new unified CLI interface: rl_trainer.py
1919
2020
See GRPO_README.md for migration guide and usage examples.
2121

0 commit comments

Comments
Β (0)