Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡ vLLM for fast generation in GRPO #2600

Merged
merged 72 commits into from
Jan 29, 2025
Merged

⚡ vLLM for fast generation in GRPO #2600

merged 72 commits into from
Jan 29, 2025

Conversation

qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Jan 21, 2025

from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
import random

def random_reward(completions, **kwargs):
    return [random.random() for _ in completions]


def main():
    # Load the dataset
    dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train[:5%]")

    training_args = GRPOConfig(
        output_dir="Qwen2-0.5B-GRPO",
        logging_steps=2,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=16,
        max_prompt_length=64,
        max_completion_length=32,
        num_generations=4,
        num_train_epochs=1,
        use_vllm=True,
        vllm_device=2,
    )
    trainer = GRPOTrainer(
        model="Qwen/Qwen2-0.5B-Instruct",
        reward_funcs=random_reward,
        args=training_args,
        train_dataset=dataset,
    )
    trainer.train()

if __name__ == "__main__":
    main()
accelerate launch --num_processes 2 train_grpo.py

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great with just some minor nits and a question again about where we really need to load the model in float32 (double the VRAM otherwise)

@DreamGenX
Copy link

Could we consider a more flexible solution that is not tied to vLLM? For most models, there are faster inference engines, like SGLang.

And if you want to stick to one inference engine, SGLang already has API to update model weights: https://docs.sglang.ai/backend/native_api.html#Update-Weights-From-Disk

@sfc-gh-zhyao
Copy link

@qgallouedec does it work for multi-nodes (say 2nodes with 8GPUs)?

@qgallouedec
Copy link
Member Author

@qgallouedec does it work for multi-nodes (say 2nodes with 8GPUs)?

No idea. Have you tried?

@kashif
Copy link
Collaborator

kashif commented Jan 28, 2025

@sfc-gh-zhyao this will not work for the multi-node case...

@lewtun
Copy link
Member

lewtun commented Jan 29, 2025

@qgallouedec I am getting this error with ZeRO-3 on a node:

[rank2]: IndexError: pop from an empty deque
[rank0]: Traceback (most recent call last):
[rank0]:   File "/fsx/lewis/git/hf/trl/scratch/grpo_demo.py", line 55, in <module>
[rank0]:     main()
[rank0]:   File "/fsx/lewis/git/hf/trl/scratch/grpo_demo.py", line 52, in main
[rank0]:     trainer.train()
[rank0]:   File "/fsx/lewis/miniconda3/envs/trl/lib/python3.11/site-packages/transformers/trainer.py", line 2171, in train
[rank0]:     return inner_training_loop(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/lewis/miniconda3/envs/trl/lib/python3.11/site-packages/transformers/trainer.py", line 2531, in _inner_training_loop
[rank0]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/lewis/miniconda3/envs/trl/lib/python3.11/site-packages/transformers/trainer.py", line 3675, in training_step
[rank0]:     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/lewis/git/hf/trl/trl/trainer/grpo_trainer.py", line 442, in compute_loss
[rank0]:     per_token_logps = get_per_token_logps(model, prompt_completion_ids)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/lewis/git/hf/trl/trl/trainer/grpo_trainer.py", line 431, in get_per_token_logps
[rank0]:     logits = model(input_ids).logits  # (B, L, V)
[rank0]:              ^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/lewis/miniconda3/envs/trl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/lewis/miniconda3/envs/trl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/lewis/miniconda3/envs/trl/lib/python3.11/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/lewis/miniconda3/envs/trl/lib/python3.11/site-packages/deepspeed/runtime/engine.py", line 1914, in forward
[rank0]:     loss = self.module(*inputs, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/lewis/miniconda3/envs/trl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/lewis/miniconda3/envs/trl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/fsx/lewis/miniconda3/envs/trl/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in inner
[rank0]:     args_result = hook(self, args)
[rank0]:                   ^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/lewis/miniconda3/envs/trl/lib/python3.11/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/lewis/miniconda3/envs/trl/lib/python3.11/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 241, in _start_of_forward_hook
[rank0]:     self.get_param_coordinator().reset_step()
[rank0]:   File "/fsx/lewis/miniconda3/envs/trl/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/lewis/miniconda3/envs/trl/lib/python3.11/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 235, in reset_step
[rank0]:     self.construct_parameter_trace_from_module_trace()
[rank0]:   File "/fsx/lewis/miniconda3/envs/trl/lib/python3.11/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 219, in construct_parameter_trace_from_module_trace
[rank0]:     self.record_parameters(sub_module)
[rank0]:   File "/fsx/lewis/miniconda3/envs/trl/lib/python3.11/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 211, in record_parameters
[rank0]:     step_id = self.__step_id_module_fetched_for[sub_module.id].popleft()
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: IndexError: pop from an empty deque

Gist to repro: https://gist.github.com/lewtun/d3c1ac9dbe96514b8fd6fafcc657f1bc

I'm also trying to isolate the cause (maybe gradient checkpointing?)

Update: error persists even when gradient checkpointing is disabled.

@lewtun
Copy link
Member

lewtun commented Jan 29, 2025

Ah, maybe it's the deepspeed version. I am currently using deepspeed==0.16.3. Will roll back to 0.15.4 to check

@lewtun
Copy link
Member

lewtun commented Jan 29, 2025

OK regarding ZeRO-3, the following script from @qgallouedec works:

from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer

dataset = load_dataset("trl-lib/tldr", split="train")


# Dummy reward function: the closer the completion is to 20 characters, the higher the reward
def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]


training_args = GRPOConfig(
    output_dir="Qwen2.5-0.5B-GRPO",
    logging_steps=2,
    use_vllm=True,
    vllm_gpu_memory_utilization=0.7,
    max_prompt_length=128,
    bf16=True,
    gradient_accumulation_steps=4,
)
trainer = GRPOTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

Run with:

accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml --num_processes 7 grpo.py

Once the tests pass, let's go!

@qgallouedec qgallouedec merged commit ed14ed9 into main Jan 29, 2025
14 checks passed
@qgallouedec qgallouedec deleted the grpo_vllm branch January 29, 2025 12:01
@yiyepiaoling0715
Copy link

can this solve >14B model's problem of oom?

@LukasNel
Copy link

Still running into the mentioned issue!

@qgallouedec
Copy link
Member Author

Still running into the mentioned issue!

Which one?

@cjfcsjt
Copy link

cjfcsjt commented Jan 30, 2025

[rank0]: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:7 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

running the given scripts results in the above error. How to fix this?
vllm 0.6.6.post1
deepspeed 0.15.4
torch 2.5.1+cu121
@qgallouedec @lewtun

@qgallouedec
Copy link
Member Author

Please provide the full traceback

@kashif
Copy link
Collaborator

kashif commented Jan 30, 2025

@cjfcsjt this error is due to older vllm, kindly upgrade vllm to 0.7.0 and try

@cjfcsjt
Copy link

cjfcsjt commented Jan 30, 2025

@kashif Fixed. Thank you! May I ask how the updates in the new vllm version resolved this issue?

@valayDave
Copy link

One more quick question. It seems this patch is not compatible in a multi-node universe. Can this be ported to support multinode too ? (I tested it myself and --num_processes is only accounting for single node case)

@NickyDark1
Copy link

training_args = GRPOConfig(
...
vllm_device="cuda:1",
...
)

I understand that this way I could redirect vllm to a GPU, but how could I redirect to more GPUs, for example "cuda:2", "cuda:4", "cuda:5" or other GPUs that I want?

@thetushargoyal
Copy link

hey @qgallouedec

accelerate launch --multi_gpu --num_processes 1 train_grpo.py

for use_vllm = True, this doesn't work as the --multi_gpu only works when --num_processes > 1

I was trying to run this on a 2xA100 setup

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.