Skip to content

Commit dd0492f

Browse files
committed
Expect vllm.LLMEngine as processor's argument
1 parent 32047ab commit dd0492f

File tree

3 files changed

+72
-35
lines changed

3 files changed

+72
-35
lines changed

docs/reference/vllm.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@ You can then query the model in shell by passing a prompt and either
2828
1. a [JSON Schema][jsonschema]{:target="_blank"} specification or
2929
2. a [Regex][regex]{:target="_blank"} pattern
3030

31+
<<<<<<< HEAD
3132
with the `schema` or `regex` parameters, respectively, to the `/generate` endpoint. If both are specified, the schema will be used. If neither is specified, the generated text will be unconstrained.
33+
=======
34+
with the `schema`, `regex` or `cfg` parameters, respectively, to the `/generate` endpoint. If both are specified, the schema will be used. If neither is specified, the generated text will be unconstrained.
35+
>>>>>>> 43ff5c5 (Expect vllm.LLMEngine as processor's argument)
3236
3337
For example, to generate a string that matches the schema `{"type": "string"}` (any string):
3438

examples/vllm_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class User(BaseModel):
1414

1515

1616
llm = vllm.LLM(model="gpt2")
17-
logits_processor = JSONLogitsProcessor(User, llm)
17+
logits_processor = JSONLogitsProcessor(User, llm.llm_engine)
1818
result = llm.generate(
1919
["A prompt", "Another prompt"],
2020
sampling_params=vllm.SamplingParams(

outlines/serve/vllm.py

Lines changed: 67 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
import json
33
import math
44
from collections import defaultdict
5-
from typing import DefaultDict, List
5+
from typing import DefaultDict, List, Callable
66

77
import torch
8+
from vllm import LLMEngine
89

9-
from outlines.fsm.fsm import RegexFSM
10+
from outlines.fsm.fsm import RegexFSM, CFGFSM, FSM
1011
from outlines.fsm.json_schema import build_regex_from_object
1112

1213

@@ -39,21 +40,45 @@ def _patched_apply_logits_processors(
3940
return logits
4041

4142

42-
class RegexLogitsProcessor:
43-
def __init__(self, regex_string, llm):
44-
"""Compile the FSM that drives the regex-guided generation.
43+
def _adapt_tokenizer(tokenizer):
44+
"""Adapt vLLM's tokenizer to use to compile the FSM.
4545
46-
Parameters
47-
----------
48-
regex_string
49-
A string that represents a regular expression
50-
llm
51-
An instance of `vllm.LLM`
46+
The API of Outlines tokenizers is slightly different to that of
47+
`transformers`. In addition we need to handle the missing spaces to
48+
Llama's tokenizer to be able to compile FSMs for this model.
49+
50+
"""
51+
tokenizer.vocabulary = tokenizer.get_vocab()
52+
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
53+
54+
def convert_token_to_string(token: str) -> str:
55+
from transformers.file_utils import SPIECE_UNDERLINE
56+
57+
string = tokenizer.convert_tokens_to_string([token])
58+
59+
# A hack to handle missing spaces to HF's Llama tokenizers
60+
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
61+
return " " + string
62+
63+
return string
64+
65+
def change_decoder(
66+
decoder: Callable[[List[int]], str]
67+
) -> Callable[[List[int]], List[str]]:
68+
def new_decoder(inp_tokens: List[int]) -> List[str]:
69+
return [decoder(inp_tokens)]
70+
71+
return new_decoder
72+
73+
tokenizer.convert_token_to_string = convert_token_to_string
74+
tokenizer.decode = change_decoder(tokenizer.decode)
75+
76+
return tokenizer
5277

53-
"""
54-
tokenizer = self.adapt_tokenizer(llm.tokenizer)
5578

56-
fsm = RegexFSM(regex_string, tokenizer)
79+
class FSMLogitsProcessor:
80+
def __init__(self):
81+
fsm = FSM()
5782
self.fsm = fsm
5883

5984
def __call__(
@@ -77,43 +102,51 @@ def __call__(
77102

78103
return biased_scores
79104

80-
def adapt_tokenizer(self, tokenizer):
81-
"""Adapt vLLM's tokenizer to use to compile the FSM.
82105

83-
The API of Outlines tokenizers is slightly different to that of
84-
`transformers`. In addition we need to handle the missing spaces to
85-
Llama's tokenizer to be able to compile FSMs for this model.
86-
87-
"""
88-
tokenizer.vocabulary = tokenizer.get_vocab()
89-
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
106+
class RegexLogitsProcessor(FSMLogitsProcessor):
107+
def __init__(self, regex_string, llm: LLMEngine):
108+
"""Compile the FSM that drives the regex-guided generation.
90109
91-
def convert_token_to_string(token: str) -> str:
92-
from transformers.file_utils import SPIECE_UNDERLINE
110+
Parameters
111+
----------
112+
regex_string
113+
A string that represents a regular expression
114+
llm
115+
An instance of `vllm.LLMEngine`
93116
94-
string = tokenizer.convert_tokens_to_string([token])
117+
"""
118+
adapted_tokenizer = _adapt_tokenizer(llm.tokenizer)
119+
fsm = RegexFSM(regex_string, adapted_tokenizer)
120+
self.fsm = fsm
95121

96-
# A hack to handle missing spaces to HF's Llama tokenizers
97-
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
98-
return " " + string
99122

100-
return string
123+
class CFGLogitsProcessor(FSMLogitsProcessor):
124+
def __init__(self, cfg_string, llm: LLMEngine):
125+
"""Compile the FSM that drives the cfg-guided generation.
101126
102-
tokenizer.convert_token_to_string = convert_token_to_string
127+
Parameters
128+
----------
129+
regex_string
130+
A string that represents a regular expression
131+
llm
132+
An instance of `vllm.LLMEngine`
103133
104-
return tokenizer
134+
"""
135+
adapted_tokenizer = _adapt_tokenizer(llm.tokenizer)
136+
fsm = CFGFSM(cfg_string, adapted_tokenizer)
137+
self.fsm = fsm
105138

106139

107140
class JSONLogitsProcessor(RegexLogitsProcessor):
108-
def __init__(self, schema, llm):
141+
def __init__(self, schema, llm: LLMEngine):
109142
"""Compile the FSM that drives the JSON-guided generation.
110143
111144
Parameters
112145
----------
113146
schema
114147
A JSON schema that encodes the structure we want the model to generate
115148
llm
116-
An instance of `vllm.LLM`
149+
An instance of `vllm.LLMEngine`
117150
118151
"""
119152
if isinstance(schema, dict):

0 commit comments

Comments
 (0)