Closed
Description
Describe the issue as clearly as possible:
Hi,
I recently tried to use the RegexLogitsProcessor with VLLM introduced by #481.
When using it with a "small" model like 7B one on a unique GPU it works fine but when I try with a big one, namely Mixtral, on multiple GPUs with the vllm engine argument tensor-parallel, I ran into several problems (monkey patching not working and fsm_state in the Processor not initialize). I suspect the multiple workers of Ray to be the cause (monkey patching may not be propagated to all the workers same for the fsm_states).
I could have missed some relevant information but It seems that #481 only checks without tensor-parallel.
Steps/code to reproduce the bug:
import vllm
import vllm.model_executor.layers.sampler as sampler
from pydantic import BaseModel
from outlines.serve.vllm import JSONLogitsProcessor, _patched_apply_logits_processors
# Patch the _apply_logits_processors so it is compatible with `JSONLogitsProcessor`
sampler._apply_logits_processors = _patched_apply_logits_processors
class User(BaseModel):
id: int
name: str
model = "mistralai/Mixtral-8X7B-Instruct-v0.1"
#model = "mistralai/Mistral-7B-Instruct-v0.2"
llm = vllm.LLM(model=model, dtype='float16', max_model_len=1024, tensor_parallel_size=4, max_num_seqs=512, enforce_eager=True)
logits_processor = JSONLogitsProcessor(User, llm)
result = llm.generate(
["A prompt", "Another prompt"],
sampling_params=vllm.SamplingParams(
max_tokens=100, logits_processors=[logits_processor]
),
)
print(result)
Expected result:
result from vllm
Error message:
Traceback (most recent call last):
File "/workspace/./python_scripts/vllm_integration.py", line 19, in <module>
result = llm.generate(
File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py", line 165, in generate
return self._run_engine(use_tqdm)
File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py", line 185, in _run_engine
step_outputs = self.llm_engine.step()
File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 581, in step
output = self._run_workers(
File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 755, in _run_workers
self._run_workers_in_batch(workers, method, *args, **kwargs))
File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 732, in _run_workers_in_batch
all_outputs = ray.get(all_outputs)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 2624, in get
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(TypeError): ray::RayWorkerVllm.execute_method() (pid=250471, ip=172.17.0.12, actor_id=6003ef308c0e22f0e05d73c901000000, repr=<vllm.engine.ray_utils.RayWorkerVllm object at 0x7f8c9c37a770>)
File "/usr/local/lib/python3.10/dist-packages/vllm/engine/ray_utils.py", line 31, in execute_method
return executor(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py", line 159, in execute_model
output = self.model_runner.execute_model(seq_group_metadata_list,
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 354, in execute_model
output = self.model.sample(
File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/mixtral.py", line 390, in sample
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/sampler.py", line 52, in forward
logits = _apply_logits_processors(logits, sampling_metadata)
File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/sampler.py", line 172, in _apply_logits_processors
logits_row = logits_processor(token_ids, logits_row)
TypeError: RegexLogitsProcessor.__call__() missing 1 required positional argument: 'scores'
Outlines/Python version information:
Outlines 0.0.22
Python 3.10.12
Context for the issue:
No response