Skip to content

Add CFG-guided generation to the vLLM integration #541

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion docs/reference/vllm.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ You can then query the model in shell by passing a prompt and either
1. a [JSON Schema][jsonschema]{:target="_blank"} specification or
2. a [Regex][regex]{:target="_blank"} pattern

with the `schema` or `regex` parameters, respectively, to the `/generate` endpoint. If both are specified, the schema will be used. If neither is specified, the generated text will be unconstrained.
with the `schema`, `regex` or `grammar` parameters, respectively, to the `/generate` endpoint. If both are specified, the schema will be used. If neither is specified, the generated text will be unconstrained.

For example, to generate a string that matches the schema `{"type": "string"}` (any string):

Expand All @@ -45,6 +45,16 @@ curl http://127.0.0.1:8000/generate \
}'
```

To generate a string that matches a given grammar `<grammar>`:

```bash
curl http://127.0.0.1:8000/generate \
-d '{
"prompt": "What is Pi? Give me the first 15 digits: ",
"grammar": "start: DECIMAL \r\nDIGIT: \"0\"..\"9\"\r\nINT: DIGIT+\r\nDECIMAL: INT \".\" INT? | \".\" INT"
}'
```

Instead of `curl`, you can also use the [requests][requests]{:target="_blank"} library from another python program.

Please consult the [vLLM documentation][vllm]{:target="_blank"} for details on additional request parameters. You can also [read the code](https://github.com/outlines-dev/outlines/blob/main/outlines/serve/serve.py) in case you need to customize the solution to your needs.
Expand Down
10 changes: 7 additions & 3 deletions examples/vllm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@ class User(BaseModel):


llm = vllm.LLM(model="gpt2")
logits_processor = JSONLogitsProcessor(User, llm)
result = llm.generate(
logits_processor = JSONLogitsProcessor(User, llm.llm_engine)
outputs = llm.generate(
["A prompt", "Another prompt"],
sampling_params=vllm.SamplingParams(
max_tokens=100, logits_processors=[logits_processor]
),
)
print(result)

for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
5 changes: 5 additions & 0 deletions outlines/serve/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from vllm.utils import random_uuid

from .vllm import (
CFGLogitsProcessor,
JSONLogitsProcessor,
RegexLogitsProcessor,
_patched_apply_logits_processors,
Expand Down Expand Up @@ -65,10 +66,14 @@ async def generate(request: Request) -> Response:

json_schema = request_dict.pop("schema", None)
regex_string = request_dict.pop("regex", None)
cfg_string = request_dict.pop("grammar", None)

if json_schema is not None:
logits_processors = [JSONLogitsProcessor(json_schema, engine.engine)]
elif regex_string is not None:
logits_processors = [RegexLogitsProcessor(regex_string, engine.engine)]
elif cfg_string is not None:
logits_processors = [CFGLogitsProcessor(cfg_string, engine.engine)]
else:
logits_processors = []

Expand Down
107 changes: 74 additions & 33 deletions outlines/serve/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import json
import math
from collections import defaultdict
from typing import DefaultDict, List
from typing import Callable, DefaultDict, List

import torch

from outlines.fsm.fsm import RegexFSM
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
from outlines.fsm.json_schema import build_regex_from_object


Expand Down Expand Up @@ -39,21 +39,54 @@ def _patched_apply_logits_processors(
return logits


class RegexLogitsProcessor:
def __init__(self, regex_string, llm):
"""Compile the FSM that drives the regex-guided generation.
def _adapt_tokenizer(tokenizer):
"""Adapt vLLM's tokenizer to use to compile the FSM.

Parameters
----------
regex_string
A string that represents a regular expression
llm
An instance of `vllm.LLM`
The API of Outlines tokenizers is slightly different to that of
`transformers`. The decoder of outlines, returns a list whereas
the decode of vLLM returns an str. To sync the vLLM decoder with
outlines internal api, the decoder should be adapted. In addition
we need to handle the missing spaces to Llama's tokenizer to be
able to compile FSMs for this model.

"""
tokenizer = self.adapt_tokenizer(llm.tokenizer.tokenizer)
"""
if getattr(tokenizer, "_outlines_adapted", False):
return tokenizer

tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)

def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

string = tokenizer.convert_tokens_to_string([token])

# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string

def change_decoder(
decoder: Callable[[List[int]], str]
) -> Callable[[List[int]], List[str]]:
"""Sync vLLM's decoder with the outlines expectations by returning list"""

def new_decoder(inp_tokens: List[int]) -> List[str]:
return [decoder(inp_tokens)]

return new_decoder

tokenizer.convert_token_to_string = convert_token_to_string
tokenizer.decode = change_decoder(tokenizer.decode)
setattr(tokenizer, "_outlines_adapted", True)

return tokenizer

fsm = RegexFSM(regex_string, tokenizer)

class FSMLogitsProcessor:
def __init__(self):
fsm = FSM()
self.fsm = fsm

def __call__(
Expand All @@ -77,31 +110,39 @@ def __call__(

return biased_scores

def adapt_tokenizer(self, tokenizer):
"""Adapt vLLM's tokenizer to use to compile the FSM.

The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.

"""
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
class RegexLogitsProcessor(FSMLogitsProcessor):
def __init__(self, regex_string, llm):
"""Compile the FSM that drives the regex-guided generation.

def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
Parameters
----------
regex_string
A string that represents a regular expression
llm
An instance of `vllm.LLMEngine`

string = tokenizer.convert_tokens_to_string([token])
"""
adapted_tokenizer = _adapt_tokenizer(llm.tokenizer)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be llm.tokenizer.tokenizer

fsm = RegexFSM(regex_string, adapted_tokenizer)
self.fsm = fsm

# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string

return string
class CFGLogitsProcessor(FSMLogitsProcessor):
def __init__(self, cfg_string, llm):
"""Compile the FSM that drives the cfg-guided generation.

tokenizer.convert_token_to_string = convert_token_to_string
Parameters
----------
regex_string
A string that represents a regular expression
llm
An instance of `vllm.LLMEngine`

return tokenizer
"""
adapted_tokenizer = _adapt_tokenizer(llm.tokenizer)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above, llm.tokenizer.tokenizer

fsm = CFGFSM(cfg_string, adapted_tokenizer)
self.fsm = fsm


class JSONLogitsProcessor(RegexLogitsProcessor):
Expand All @@ -113,7 +154,7 @@ def __init__(self, schema, llm):
schema
A JSON schema that encodes the structure we want the model to generate
llm
An instance of `vllm.LLM`
An instance of `vllm.LLMEngine`

"""
if isinstance(schema, dict):
Expand Down
43 changes: 43 additions & 0 deletions tests/test_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
import torch
from transformers import AutoTokenizer

from outlines.serve.vllm import (
CFGLogitsProcessor,
JSONLogitsProcessor,
RegexLogitsProcessor,
)

TEST_REGEX = r"(-)?(0|[1-9][0-9]*)(.[0-9]+)?([eE][+-][0-9]+)?"
TEST_CFG = """
start: DECIMAL
DIGIT: "0".."9"
INT: DIGIT+
DECIMAL: INT "." INT? | "." INT
"""
TEST_SCHEMA = '{"type": "string", "maxLength": 5}'

LOGIT_PROCESSORS = (
(CFGLogitsProcessor, TEST_CFG),
(RegexLogitsProcessor, TEST_REGEX),
(JSONLogitsProcessor, TEST_SCHEMA),
)

TEST_MODEL = "hf-internal-testing/tiny-random-GPTJForCausalLM"


@pytest.mark.parametrize("logit_processor, fsm_str", LOGIT_PROCESSORS)
def test_logit_processor(logit_processor, fsm_str: str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this test doing?

class MockvLLMEngine:
def __init__(self, tokenizer):
self.tokenizer = tokenizer

def __call__(*_):
return torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.float), None

tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL)
engine = MockvLLMEngine(tokenizer)
logit_processor(fsm_str, engine)
assert isinstance(engine.tokenizer.decode([0, 1, 2, 3]), list)
logit_processor(fsm_str, engine)
assert isinstance(engine.tokenizer.decode([0, 1, 2, 3]), list)