-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
⚡ vLLM for fast generation in GRPO #2600
Conversation
qgallouedec
commented
Jan 21, 2025
•
edited
Loading
edited
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great with just some minor nits and a question again about where we really need to load the model in float32
(double the VRAM otherwise)
Co-authored-by: lewtun <[email protected]>
Co-authored-by: lewtun <[email protected]>
Could we consider a more flexible solution that is not tied to vLLM? For most models, there are faster inference engines, like SGLang. And if you want to stick to one inference engine, SGLang already has API to update model weights: https://docs.sglang.ai/backend/native_api.html#Update-Weights-From-Disk |
@qgallouedec does it work for multi-nodes (say 2nodes with 8GPUs)? |
No idea. Have you tried? |
@sfc-gh-zhyao this will not work for the multi-node case... |
@qgallouedec I am getting this error with ZeRO-3 on a node:
Gist to repro: https://gist.github.com/lewtun/d3c1ac9dbe96514b8fd6fafcc657f1bc I'm also trying to isolate the cause (maybe gradient checkpointing?) Update: error persists even when gradient checkpointing is disabled. |
Ah, maybe it's the deepspeed version. I am currently using |
Co-authored-by: lewtun <[email protected]>
Co-authored-by: lewtun <[email protected]>
OK regarding ZeRO-3, the following script from @qgallouedec works: from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
dataset = load_dataset("trl-lib/tldr", split="train")
# Dummy reward function: the closer the completion is to 20 characters, the higher the reward
def reward_len(completions, **kwargs):
return [-abs(20 - len(completion)) for completion in completions]
training_args = GRPOConfig(
output_dir="Qwen2.5-0.5B-GRPO",
logging_steps=2,
use_vllm=True,
vllm_gpu_memory_utilization=0.7,
max_prompt_length=128,
bf16=True,
gradient_accumulation_steps=4,
)
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
reward_funcs=reward_len,
args=training_args,
train_dataset=dataset,
)
trainer.train() Run with: accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml --num_processes 7 grpo.py Once the tests pass, let's go! |
can this solve >14B model's problem of oom? |
Still running into the mentioned issue! |
Which one? |
[rank0]: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:7 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select) running the given scripts results in the above error. How to fix this? |
Please provide the full traceback |
@cjfcsjt this error is due to older vllm, kindly upgrade vllm to 0.7.0 and try |
@kashif Fixed. Thank you! May I ask how the updates in the new vllm version resolved this issue? |
One more quick question. It seems this patch is not compatible in a multi-node universe. Can this be ported to support multinode too ? (I tested it myself and |
training_args = GRPOConfig( I understand that this way I could redirect vllm to a GPU, but how could I redirect to more GPUs, for example "cuda:2", "cuda:4", "cuda:5" or other GPUs that I want? |
hey @qgallouedec accelerate launch --multi_gpu --num_processes 1 train_grpo.py for use_vllm = True, this doesn't work as the --multi_gpu only works when --num_processes > 1 I was trying to run this on a 2xA100 setup |