Skip to content

Commit fde61a8

Browse files
mory91rlouf
authored andcommitted
Add CFG to vllm serving
1 parent 04bbb96 commit fde61a8

File tree

3 files changed

+81
-36
lines changed

3 files changed

+81
-36
lines changed

docs/reference/vllm.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ You can then query the model in shell by passing a prompt and either
2424

2525
1. a [JSON Schema][jsonschema]{:target="_blank"} specification or
2626
2. a [Regex][regex]{:target="_blank"} pattern
27+
2. an EBNF grammar
2728

28-
with the `schema` or `regex` parameters, respectively, to the `/generate` endpoint. If both are specified, the schema will be used. If neither is specified, the generated text will be unconstrained.
29+
with the `schema`, `regex` of `cfg` parameters, respectively, to the `/generate` endpoint. If both are specified, the schema will be used. If neither is specified, the generated text will be unconstrained.
2930

3031
For example, to generate a string that matches the schema `{"type": "string"}` (any string):
3132

@@ -47,6 +48,16 @@ curl http://127.0.0.1:8000/generate \
4748
}'
4849
```
4950

51+
To generate a string that matches the grammar `<grammar>`:
52+
53+
```bash
54+
curl http://127.0.0.1:8000/generate \
55+
-d '{
56+
"prompt": "What is Pi? Give me the first 15 digits: ",
57+
"cfg": <grammar>
58+
}'
59+
```
60+
5061
Instead of `curl`, you can also use the [requests][requests]{:target="_blank"} library from another python program.
5162

5263
Please consult the [vLLM documentation][vllm]{:target="_blank"} for details on additional request parameters. You can also [read the code](https://github.com/outlines-dev/outlines/blob/main/outlines/serve/serve.py) in case you need to customize the solution to your needs.

outlines/serve/serve.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from vllm.utils import random_uuid
2626

2727
from .vllm import (
28+
CFGLogitsProcessor,
2829
JSONLogitsProcessor,
2930
RegexLogitsProcessor,
3031
_patched_apply_logits_processors,
@@ -65,10 +66,13 @@ async def generate(request: Request) -> Response:
6566

6667
json_schema = request_dict.pop("schema", None)
6768
regex_string = request_dict.pop("regex", None)
69+
cfg_string = request_dict.pop("cfg", None)
6870
if json_schema is not None:
6971
logits_processors = [JSONLogitsProcessor(json_schema, engine.engine)]
7072
elif regex_string is not None:
7173
logits_processors = [RegexLogitsProcessor(regex_string, engine.engine)]
74+
elif cfg_string is not None:
75+
logits_processors = [CFGLogitsProcessor(cfg_string, engine.engine)]
7276
else:
7377
logits_processors = []
7478

outlines/serve/vllm.py

Lines changed: 65 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,50 @@
22
import json
33
import math
44
from collections import defaultdict
5-
from typing import DefaultDict, List
5+
from typing import Callable, DefaultDict, List
66

77
import torch
88

9-
from outlines.fsm.fsm import RegexFSM
9+
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
1010
from outlines.fsm.json_schema import build_regex_from_object
1111

1212

13+
def _adapt_tokenizer(tokenizer):
14+
"""Adapt vLLM's tokenizer to use to compile the FSM.
15+
16+
The API of Outlines tokenizers is slightly different to that of
17+
`transformers`. In addition we need to handle the missing spaces to
18+
Llama's tokenizer to be able to compile FSMs for this model.
19+
20+
"""
21+
tokenizer.vocabulary = tokenizer.get_vocab()
22+
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
23+
24+
def convert_token_to_string(token: str) -> str:
25+
from transformers.file_utils import SPIECE_UNDERLINE
26+
27+
string = tokenizer.convert_tokens_to_string([token])
28+
29+
# A hack to handle missing spaces to HF's Llama tokenizers
30+
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
31+
return " " + string
32+
33+
return string
34+
35+
def change_decoder(
36+
decoder: Callable[[List[int]], str]
37+
) -> Callable[[List[int]], List[str]]:
38+
def new_decoder(inp_tokens: List[int]) -> List[str]:
39+
return [decoder(inp_tokens)]
40+
41+
return new_decoder
42+
43+
tokenizer.convert_token_to_string = convert_token_to_string
44+
tokenizer.decode = change_decoder(tokenizer.decode)
45+
46+
return tokenizer
47+
48+
1349
def _patched_apply_logits_processors(
1450
logits,
1551
sampling_metadata,
@@ -39,21 +75,9 @@ def _patched_apply_logits_processors(
3975
return logits
4076

4177

42-
class RegexLogitsProcessor:
43-
def __init__(self, regex_string, llm):
44-
"""Compile the FSM that drives the regex-guided generation.
45-
46-
Parameters
47-
----------
48-
regex_string
49-
A string that represents a regular expression
50-
llm
51-
An instance of `vllm.LLM`
52-
53-
"""
54-
tokenizer = self.adapt_tokenizer(llm.tokenizer)
55-
56-
fsm = RegexFSM(regex_string, tokenizer)
78+
class FSMLogitsProcessor:
79+
def __init__(self):
80+
fsm = FSM()
5781
self.fsm = fsm
5882

5983
def __call__(
@@ -77,31 +101,37 @@ def __call__(
77101

78102
return biased_scores
79103

80-
def adapt_tokenizer(self, tokenizer):
81-
"""Adapt vLLM's tokenizer to use to compile the FSM.
82-
83-
The API of Outlines tokenizers is slightly different to that of
84-
`transformers`. In addition we need to handle the missing spaces to
85-
Llama's tokenizer to be able to compile FSMs for this model.
86104

87-
"""
88-
tokenizer.vocabulary = tokenizer.get_vocab()
89-
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
105+
class RegexLogitsProcessor(FSMLogitsProcessor):
106+
def __init__(self, regex_string, llm):
107+
"""Compile the FSM that drives the regex-guided generation.
90108
91-
def convert_token_to_string(token: str) -> str:
92-
from transformers.file_utils import SPIECE_UNDERLINE
109+
Parameters
110+
----------
111+
regex_string
112+
A string that represents a regular expression
113+
llm
114+
An instance of `vllm.LLM`
93115
94-
string = tokenizer.convert_tokens_to_string([token])
116+
"""
117+
fsm = RegexFSM(regex_string, llm.tokenizer)
118+
self.fsm = fsm
95119

96-
# A hack to handle missing spaces to HF's Llama tokenizers
97-
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
98-
return " " + string
99120

100-
return string
121+
class CFGLogitsProcessor(FSMLogitsProcessor):
122+
def __init__(self, cfg_string, llm):
123+
"""Compile the FSM that drives the cfg-guided generation.
101124
102-
tokenizer.convert_token_to_string = convert_token_to_string
125+
Parameters
126+
----------
127+
regex_string
128+
A string that represents a regular expression
129+
llm
130+
An instance of `vllm.LLM`
103131
104-
return tokenizer
132+
"""
133+
fsm = CFGFSM(cfg_string, llm.tokenizer)
134+
self.fsm = fsm
105135

106136

107137
class JSONLogitsProcessor(RegexLogitsProcessor):

0 commit comments

Comments
 (0)