Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡ vLLM for fast generation in GRPO #2600

Merged
merged 72 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
933455b
doc
qgallouedec Jan 21, 2025
de40989
fsdp
qgallouedec Jan 21, 2025
b906183
use vllm config
qgallouedec Jan 21, 2025
b4a7aa5
vllm
qgallouedec Jan 21, 2025
8e9ceaf
Update trl/trainer/grpo_config.py
qgallouedec Jan 22, 2025
59eafd0
Update trl/trainer/grpo_config.py
qgallouedec Jan 22, 2025
c32e815
typo
qgallouedec Jan 22, 2025
e49f38d
top_k, top_p
qgallouedec Jan 22, 2025
d17039b
Link to vllm pr
qgallouedec Jan 22, 2025
3a280d6
Merge branch 'main' into grpo_vllm
qgallouedec Jan 22, 2025
9e5b8d0
Merge branch 'main' into grpo_vllm
kashif Jan 24, 2025
1ba1ecf
fix missing device
kashif Jan 24, 2025
f8a33e3
fix tests
kashif Jan 24, 2025
693bb4e
fix citation
kashif Jan 24, 2025
b0b203c
fix title and paper_id
kashif Jan 24, 2025
383b795
Merge branch 'main' into grpo_vllm
kashif Jan 25, 2025
4abe3ea
Merge branch 'main' into grpo_vllm
kashif Jan 25, 2025
2d956c7
formatting
kashif Jan 25, 2025
b151cc1
output the correct number of generations
kashif Jan 25, 2025
136dd89
initial async vllm
kashif Jan 25, 2025
ca4b818
fix missing args
kashif Jan 25, 2025
edbf2ed
fix promps
kashif Jan 25, 2025
3b7fd21
Pass prompt_token_ids directly
kashif Jan 25, 2025
eff4263
Repeat each prompt num_generations times
kashif Jan 25, 2025
a7483bc
get the slice of results per processor
kashif Jan 25, 2025
d882f09
Merge branch 'main' into grpo_vllm
kashif Jan 25, 2025
9a67705
undo citation
kashif Jan 26, 2025
f97c875
Merge branch 'main' into grpo_vllm
kashif Jan 27, 2025
22fd22d
OMG
qgallouedec Jan 27, 2025
7db6847
nothing can resist me!!!!
qgallouedec Jan 28, 2025
aa9309a
working
qgallouedec Jan 28, 2025
ae63a04
vllm_device to "auto"
kashif Jan 28, 2025
a5256cf
add vllm test
kashif Jan 28, 2025
298e09c
add initial vllm docs
kashif Jan 28, 2025
e56d653
add vllm link and pip instructions
kashif Jan 28, 2025
b9f3dfb
add multi-gpu strategy fot vllm
kashif Jan 28, 2025
07f311f
Update docs/source/grpo_trainer.md
kashif Jan 28, 2025
b41751c
Update docs/source/grpo_trainer.md
kashif Jan 28, 2025
43bbbd4
Update docs/source/grpo_trainer.md
kashif Jan 28, 2025
c789ffe
Merge branch 'main' into grpo_vllm
kashif Jan 28, 2025
59157a8
add doc strings
kashif Jan 28, 2025
864a03c
Update docs/source/grpo_trainer.md
qgallouedec Jan 28, 2025
5944163
Update trl/trainer/grpo_trainer.py
qgallouedec Jan 28, 2025
1faed5a
Update docs/source/grpo_trainer.md
qgallouedec Jan 28, 2025
500c362
add important tag
kashif Jan 28, 2025
6e3c074
fix typo
kashif Jan 28, 2025
9fced92
overrides default batch size and grad accum and better doc
qgallouedec Jan 28, 2025
24b8212
Under no circumstances should you examine the contents of this commit.
qgallouedec Jan 28, 2025
1e9fc01
auto device, warnings, errors
qgallouedec Jan 28, 2025
5811812
better error message
qgallouedec Jan 28, 2025
78a1a66
require_torch_accelerator test vllm
qgallouedec Jan 28, 2025
c20d30d
speeding up traing doc
qgallouedec Jan 28, 2025
2936ad7
device as str
qgallouedec Jan 28, 2025
5738a4f
does it prevent deepspeed init to hang?
qgallouedec Jan 28, 2025
8f2fbc0
update docs
kashif Jan 28, 2025
a0c959a
require torch accelertor for vllm test
qgallouedec Jan 28, 2025
7d4b589
unwrap compat with ds z3
qgallouedec Jan 28, 2025
059ac9c
simplify examble in doc
qgallouedec Jan 28, 2025
37b26a3
More comments, fix ds3 hanging
qgallouedec Jan 28, 2025
36baa40
faster, not sure why
qgallouedec Jan 28, 2025
6bc69bd
style
qgallouedec Jan 28, 2025
d68b261
move doc about speed
qgallouedec Jan 28, 2025
3ff008b
revert change in config files
qgallouedec Jan 28, 2025
b79992a
fix default value in doc [ci skip]
qgallouedec Jan 28, 2025
355726e
style [ci skip]
qgallouedec Jan 28, 2025
6c18b5d
better comment [ci skip]
qgallouedec Jan 28, 2025
249fe97
fix warning
qgallouedec Jan 28, 2025
2e3ae36
Update grpo_config.py
qgallouedec Jan 29, 2025
8d283d0
Update deepspeed_zero1.yaml
qgallouedec Jan 29, 2025
ccc4d43
Update trl/trainer/grpo_trainer.py
qgallouedec Jan 29, 2025
3ff5de4
Apply suggestions from code review
qgallouedec Jan 29, 2025
be10004
Update docs/source/grpo_trainer.md
qgallouedec Jan 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ This post-training method was contributed by [Quentin Gallouédec](https://huggi

## Quick start

This example demonstrates how to train a model using the GRPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) as the base model and the [RM-Gemma-2B model](https://huggingface.co/weqweasdas/RM-Gemma-2B) as the reward model. We use the prompts from the [TLDR dataset](https://huggingface.co/datasets/trl-lib/tldr) (completion column is ingored!). You can view the data in the dataset here:
This example demonstrates how to train a model using the GRPO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [TLDR dataset](https://huggingface.co/datasets/trl-lib/tldr) (completion column is ingored!). You can view the data in the dataset here:

<iframe
src="https://huggingface.co/datasets/trl-lib/tldr/embed/viewer/default/train?row=0"
Expand All @@ -23,32 +23,26 @@ This example demonstrates how to train a model using the GRPO method. We use the
height="560px"
></iframe>

Below is the script to train the model. We use PEFT to reduce the memory requirements.
Below is the script to train the model.

```python
# train_grpo.py
from datasets import load_dataset
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer

# Load the dataset
dataset = load_dataset("trl-lib/tldr", split="train")

training_args = GRPOConfig(
output_dir="Qwen2-0.5B-GRPO",
learning_rate=1e-5,
logging_steps=10,
gradient_accumulation_steps=16,
max_completion_length=128,
)
# Define the reward function, which rewards completions that are close to 20 characters
def reward_len(completions, **kwargs):
return [abs(20 - len(completion)) for completion in completions]

training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=10)
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs="weqweasdas/RM-Gemma-2B",
reward_funcs=reward_len,
args=training_args,
train_dataset=dataset,
peft_config=LoraConfig(task_type="CAUSAL_LM"),
)

trainer.train()
```

Expand Down Expand Up @@ -118,6 +112,18 @@ The GRPO Trainer logs the following metrics:

## Customization

## Speed up training with vLLM-powered generation

Generation is often the main bottleneck that makes training slow with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation. To enable it, pass `use_vllm=True` in the training arguments.

```python
from trl import GRPOConfig

training_args = GRPOConfig(..., use_vllm=True)
```

For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).

### Using a custom reward function

The [`GRPOTrainer`] supports using custom reward functions instead of dense reward models. To ensure compatibility, your reward function must satisfy the following requirements:
Expand Down
52 changes: 48 additions & 4 deletions docs/source/speeding_up_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,21 @@ Section under construction. Feel free to contribute!

## vLLM for fast generation in online methods

Online methods such as Online DPO or Nash-MD require the model to generate completions, which is often a slow process and can significantly impact training time.
To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation through PagedAttention. TRL's online trainers support vLLM, greatly improving training speed.
Online methods such as GRPO or Online DPO require the model to generate completions, which is often a slow process and can significantly impact training time.
To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation through, among other things, PagedAttention. TRL's online trainers support vLLM, greatly improving training speed.

To use [vLLM](https://github.com/vllm-project/vllm), first install it using:

To use vLLM, first install it using:
```bash
pip install vllm
```

or

```bash
pip install "trl[vllm]"
```

<hfoptions id="vllm examples">
<hfoption id="Online DPO">

Expand All @@ -24,7 +31,44 @@ Then, enable it by passing `use_vllm=True` in the training arguments.
```python
from trl import OnlineDPOConfig

training_args = DPOConfig(..., use_vllm=True)
training_args = OnlineDPOConfig(..., use_vllm=True)
```

</hfoption>
<hfoption id="GRPO">

Then, enable it by passing `use_vllm=True` in the training arguments.

```python
from trl import GRPOConfig

training_args = GRPOConfig(..., use_vllm=True)
```

The strategy here is to use a dedicated GPU for generation powered by vLLM, while using the remainder for training.

<Tip warning={true}>

When using vLLM, an additional GPU is required exclusively for generation. This means you need at least two available GPUs and must ensure that one remains unused by the trainer. To achieve this, run the training with `--num_processes <NUMBER_OF_GPUs - 1>`.

For example, if you have 4 GPUs, set `--num_processes 3` to allocate three GPUs for training while reserving one for generation.
```bash
accelerate launch --multi_gpu --num_processes 3 train_grpo.py
```

![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/1_gpu_for_generation.png)

</Tip>

You can further tune the vLLM configuration by setting a specific `vllm_device` and `vllm_gpu_memory_utilization` in the [`GRPOConfig`].

```python
training_args = GRPOConfig(
...,
use_vllm=True,
vllm_device="cuda:4",
vllm_gpu_memory_utilization=0.7,
)
```

</hfoption>
Expand Down
37 changes: 36 additions & 1 deletion tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from transformers.testing_utils import require_peft
from transformers.testing_utils import require_peft, require_torch_accelerator
from transformers.utils import is_peft_available

from trl import GRPOConfig, GRPOTrainer
from trl.import_utils import is_vllm_available


if is_peft_available():
Expand Down Expand Up @@ -330,3 +331,37 @@ def reward_func(completions, some_values, **kwargs):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

@unittest.skipIf(not is_vllm_available(), "vLLM is not available")
@require_torch_accelerator
def test_training_vllm(self):
"""Test that training works with vLLM for generation."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=2, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
use_vllm=True,
)
trainer = GRPOTrainer(
model="trl-internal-testing/small-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
64 changes: 64 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,31 @@ class GRPOConfig(TrainingArguments):
max_completion_length (`int` or `None`, *optional*, defaults to `256`):
Maximum length of the generated completion.

> Parameters that control generation acceleration powered by vLLM

use_vllm (`bool`, *optional*, defaults to `False`):
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
vllm_device (`str`, *optional*, defaults to `"auto"`):
Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
automatically select the next available GPU after the last one used for training. This assumes that
training has not already occupied all available GPUs.
vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
during initialization.

> Parameters that control the training

learning_rate (`float`, *optional*, defaults to `1e-6`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
per_device_train_batch_size (`int`, *optional*, defaults to `1`):
Number of prompts sampled per device for training. The actual batch passed into the model will be this
value multiplied by `num_generations`.
gradient_accumulation_steps (`int`, *optional*, defaults to `8`):
Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
beta (`float`, *optional*, defaults to `0.04`):
KL coefficient.
"""
Expand Down Expand Up @@ -98,6 +118,33 @@ class GRPOConfig(TrainingArguments):
metadata={"help": "Maximum length of the generated completion."},
)

# Parameters that control generation acceleration powered by vLLM
use_vllm: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept "
"unused for training, as vLLM will require one for generation. vLLM must be installed "
"(`pip install vllm`)."
},
)
vllm_device: Optional[str] = field(
default="auto",
metadata={
"help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system "
"will automatically select the next available GPU after the last one used for training. This assumes "
"that training has not already occupied all available GPUs."
},
)
vllm_gpu_memory_utilization: float = field(
default=0.9,
metadata={
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
"out-of-memory (OOM) errors during initialization."
},
)

# Parameters that control the training
learning_rate: float = field(
default=1e-6,
Expand All @@ -106,6 +153,23 @@ class GRPOConfig(TrainingArguments):
"`transformers.TrainingArguments`."
},
)
# GRPO generates multiple completions per prompt, increasing memory usage.
# To accommodate this, the per-device train batch size is decreased (overriden from the parent class),
# and the number gradient accumulation steps is increased to maintain the effective batch size.
per_device_train_batch_size: int = field(
default=1,
metadata={
"help": "Number of prompts sampled per device for training. The actual batch passed into the model will "
"be this value multiplied by `num_generations`."
},
)
gradient_accumulation_steps: int = field(
default=8,
metadata={
"help": "Number of updates steps to accumulate the gradients for, before performing a backward/update "
"pass."
},
)
beta: float = field(
default=0.04,
metadata={"help": "KL coefficient."},
Expand Down
Loading