-
Notifications
You must be signed in to change notification settings - Fork 418
GRPO refactoring #2530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
GRPO refactoring #2530
Conversation
src/MaxText/examples/GRPO_README.md
Outdated
| @@ -0,0 +1,226 @@ | |||
| # GRPO Demo - Unified Training Interface | |||
|
|
|||
| This directory contains a unified interface for running GRPO (Group Relative Policy Optimization) training demos across different model sizes and configurations. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we be using the word "demo"?
Don't we anticipate users to use these scripts directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah let's call it grpo_runner.py which calls in grpo_tunix_trainer.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed
src/MaxText/examples/GRPO_README.md
Outdated
| - `grpo_llama3_1_8b_demo_pw.py` - Pathways-based 8B model | ||
| - `grpo_llama3_1_70b_demo_pw.py` - Pathways-based 70B model | ||
|
|
||
| These have been consolidated into a single **unified CLI script** (`grpo_demo.py`) that works with the new **grpo.yml** configuration file. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again - should be "demo"?
to me, demo indicates it may not be suitable for production workloads
Signed-off-by: Vladimir Suvorov <[email protected]>
Signed-off-by: Vladimir Suvorov <[email protected]>
Signed-off-by: Vladimir Suvorov <[email protected]>
Signed-off-by: Vladimir Suvorov <[email protected]>
149a3dc to
e72906d
Compare
Signed-off-by: Vladimir Suvorov <[email protected]>
src/MaxText/examples/README.md
Outdated
| - **`grpo_llama3_demo.ipynb`** → GRPO training on math dataset | ||
| - **`grpo_demo.py`** → Unified CLI for GRPO training (any model) | ||
|
|
||
| #### GRPO Usage |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's call it
GRPO python script usage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
src/MaxText/examples/README.md
Outdated
|
|
||
| ### GRPO Training | ||
|
|
||
| - **`grpo_llama3_demo.ipynb`** → GRPO training on math dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you are using #### for the python script, maybe put #### GRPO colab usage here, and can we call it grpo_llama3_1_8b_demo.ipynb
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes put there
| if num_vms >= 2: | ||
| # Multi-VM single host setup | ||
| num_devices = len(devices) | ||
| num_trainer_devices = int(num_devices * 0.5) # 50% for training |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not correct
for pathways we use the following and split out the mesh for trainer and inference if there are multiple hosts present
TRAINER_DEVICES_FRACTION = 0.5
SAMPLER_DEVICES_FRACTION = 0.5
if not using pathways, or if once one host
trainer_devices = devices
sampler_devices = devices
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed and made it as params to grpo.yml
Signed-off-by: Vladimir Suvorov <[email protected]>
Signed-off-by: Vladimir Suvorov <[email protected]>
Signed-off-by: Vladimir Suvorov <[email protected]>
Signed-off-by: Vladimir Suvorov <[email protected]>
Signed-off-by: Vladimir Suvorov <[email protected]>
Signed-off-by: Vladimir Suvorov <[email protected]>
Signed-off-by: Vladimir Suvorov <[email protected]>
Description
Refactoring of grpo. Adding new unified functionality allowing to add models easily
Tests
Please describe how you tested this change, and include any instructions and/or
commands to reproduce.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.