Skip to content

VLLM tensor-parallel and RegexLogitsProcessor #524

Closed
@BenoitHardier

Description

@BenoitHardier

Describe the issue as clearly as possible:

Hi,
I recently tried to use the RegexLogitsProcessor with VLLM introduced by #481.

When using it with a "small" model like 7B one on a unique GPU it works fine but when I try with a big one, namely Mixtral, on multiple GPUs with the vllm engine argument tensor-parallel, I ran into several problems (monkey patching not working and fsm_state in the Processor not initialize). I suspect the multiple workers of Ray to be the cause (monkey patching may not be propagated to all the workers same for the fsm_states).

I could have missed some relevant information but It seems that #481 only checks without tensor-parallel.

Steps/code to reproduce the bug:

import vllm
import vllm.model_executor.layers.sampler as sampler
from pydantic import BaseModel

from outlines.serve.vllm import JSONLogitsProcessor, _patched_apply_logits_processors

# Patch the _apply_logits_processors so it is compatible with `JSONLogitsProcessor`
sampler._apply_logits_processors = _patched_apply_logits_processors


class User(BaseModel):
    id: int
    name: str

model = "mistralai/Mixtral-8X7B-Instruct-v0.1"
#model = "mistralai/Mistral-7B-Instruct-v0.2"
llm = vllm.LLM(model=model, dtype='float16', max_model_len=1024,  tensor_parallel_size=4, max_num_seqs=512, enforce_eager=True)
logits_processor = JSONLogitsProcessor(User, llm)
result = llm.generate(
    ["A prompt", "Another prompt"],
    sampling_params=vllm.SamplingParams(
        max_tokens=100, logits_processors=[logits_processor]
    ),
)
print(result)

Expected result:

result from vllm

Error message:

Traceback (most recent call last):
  File "/workspace/./python_scripts/vllm_integration.py", line 19, in <module>
    result = llm.generate(
  File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py", line 165, in generate
    return self._run_engine(use_tqdm)
  File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py", line 185, in _run_engine
    step_outputs = self.llm_engine.step()
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 581, in step
    output = self._run_workers(
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 755, in _run_workers
    self._run_workers_in_batch(workers, method, *args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 732, in _run_workers_in_batch
    all_outputs = ray.get(all_outputs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 2624, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(TypeError): ray::RayWorkerVllm.execute_method() (pid=250471, ip=172.17.0.12, actor_id=6003ef308c0e22f0e05d73c901000000, repr=<vllm.engine.ray_utils.RayWorkerVllm object at 0x7f8c9c37a770>)
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/ray_utils.py", line 31, in execute_method
    return executor(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py", line 159, in execute_model
    output = self.model_runner.execute_model(seq_group_metadata_list,
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 354, in execute_model
    output = self.model.sample(
  File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/mixtral.py", line 390, in sample
    next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/sampler.py", line 52, in forward
    logits = _apply_logits_processors(logits, sampling_metadata)
  File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/sampler.py", line 172, in _apply_logits_processors
    logits_row = logits_processor(token_ids, logits_row)
TypeError: RegexLogitsProcessor.__call__() missing 1 required positional argument: 'scores'

Outlines/Python version information:

Outlines 0.0.22
Python 3.10.12

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugvLLMThings involving vLLM support

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions