Skip to content

Commit 703e42e

Browse files
felixzhu555br3nosimon-mo
authored
Add guided decoding for OpenAI API server (#2819)
Co-authored-by: br3no <[email protected]> Co-authored-by: simon-mo <[email protected]>
1 parent 29a8d6a commit 703e42e

File tree

9 files changed

+597
-1
lines changed

9 files changed

+597
-1
lines changed

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ pydantic >= 2.0 # Required for OpenAI server.
1212
prometheus_client >= 0.18.0
1313
pynvml == 11.5.0
1414
triton >= 2.1.0
15+
outlines >= 0.0.27
1516
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# This unit test should be moved to a new
2+
# tests/test_guided_decoding directory.
3+
4+
from transformers import AutoTokenizer
5+
import torch
6+
7+
from vllm.model_executor.guided_logits_processors import (RegexLogitsProcessor,
8+
JSONLogitsProcessor)
9+
10+
TEST_SCHEMA = {
11+
"type": "object",
12+
"properties": {
13+
"name": {
14+
"type": "string"
15+
},
16+
"age": {
17+
"type": "integer"
18+
},
19+
"skills": {
20+
"type": "array",
21+
"items": {
22+
"type": "string",
23+
"maxLength": 10
24+
},
25+
"minItems": 3
26+
},
27+
"work history": {
28+
"type": "array",
29+
"items": {
30+
"type": "object",
31+
"properties": {
32+
"company": {
33+
"type": "string"
34+
},
35+
"duration": {
36+
"type": "string"
37+
},
38+
"position": {
39+
"type": "string"
40+
}
41+
},
42+
"required": ["company", "position"]
43+
}
44+
}
45+
},
46+
"required": ["name", "age", "skills", "work history"]
47+
}
48+
49+
TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \
50+
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"
51+
52+
53+
def test_guided_logits_processors():
54+
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
55+
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
56+
regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer)
57+
json_LP = JSONLogitsProcessor(TEST_SCHEMA, tokenizer)
58+
59+
regex_LP.init_state()
60+
token_ids = tokenizer.encode(
61+
f"Give an example IPv4 address with this regex: {TEST_REGEX}")
62+
tensor = torch.rand(32000)
63+
original_tensor = torch.clone(tensor)
64+
regex_LP(token_ids, tensor)
65+
assert tensor.shape == original_tensor.shape
66+
assert not torch.allclose(tensor, original_tensor)
67+
68+
json_LP.init_state()
69+
token_ids = tokenizer.encode(
70+
f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
71+
tensor = torch.rand(32000)
72+
original_tensor = torch.clone(tensor)
73+
json_LP(token_ids, tensor)
74+
assert tensor.shape == original_tensor.shape
75+
assert not torch.allclose(tensor, original_tensor)

tests/entrypoints/test_openai_server.py

+237
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,64 @@
99
import openai # use the official client for correctness check
1010
from huggingface_hub import snapshot_download # downloading lora to test lora requests
1111

12+
# imports for guided decoding tests
13+
import json
14+
import jsonschema
15+
import re
16+
1217
from vllm.transformers_utils.tokenizer import get_tokenizer
1318

1419
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
1520
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here
1621
LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here
1722

23+
TEST_SCHEMA = {
24+
"type": "object",
25+
"properties": {
26+
"name": {
27+
"type": "string"
28+
},
29+
"age": {
30+
"type": "integer"
31+
},
32+
"skills": {
33+
"type": "array",
34+
"items": {
35+
"type": "string",
36+
"maxLength": 10
37+
},
38+
"minItems": 3
39+
},
40+
"work history": {
41+
"type": "array",
42+
"items": {
43+
"type": "object",
44+
"properties": {
45+
"company": {
46+
"type": "string"
47+
},
48+
"duration": {
49+
"type": "string"
50+
},
51+
"position": {
52+
"type": "string"
53+
}
54+
},
55+
"required": ["company", "position"]
56+
}
57+
}
58+
},
59+
"required": ["name", "age", "skills", "work history"]
60+
}
61+
62+
TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \
63+
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"
64+
65+
TEST_CHOICE = [
66+
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby",
67+
"Swift", "Kotlin"
68+
]
69+
1870
pytestmark = pytest.mark.asyncio
1971

2072

@@ -325,6 +377,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
325377
max_tokens=max_tokens,
326378
temperature=0.0,
327379
logit_bias={str(token_id): 100},
380+
seed=42,
328381
)
329382
assert completion.choices[0].text is not None and len(
330383
completion.choices[0].text) >= 5
@@ -358,5 +411,189 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
358411
assert first_response != completion.choices[0].text
359412

360413

414+
async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
415+
completion = await client.completions.create(
416+
model=MODEL_NAME,
417+
prompt=
418+
f"Give an example JSON for an employee profile that fits this schema: {TEST_SCHEMA}",
419+
n=3,
420+
temperature=1.0,
421+
max_tokens=500,
422+
extra_body=dict(guided_json=TEST_SCHEMA))
423+
424+
assert completion.id is not None
425+
assert completion.choices is not None and len(completion.choices) == 3
426+
for i in range(3):
427+
assert completion.choices[i].text is not None
428+
output_json = json.loads(completion.choices[i].text)
429+
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)
430+
431+
432+
async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
433+
messages = [{
434+
"role": "system",
435+
"content": "you are a helpful assistant"
436+
}, {
437+
"role": "user",
438+
"content": "Give an example JSON for an employee profile that " + \
439+
f"fits this schema: {TEST_SCHEMA}"
440+
}]
441+
chat_completion = await client.chat.completions.create(
442+
model=MODEL_NAME,
443+
messages=messages,
444+
max_tokens=500,
445+
extra_body=dict(guided_json=TEST_SCHEMA))
446+
message = chat_completion.choices[0].message
447+
assert message.content is not None
448+
json1 = json.loads(message.content)
449+
jsonschema.validate(instance=json1, schema=TEST_SCHEMA)
450+
451+
messages.append({"role": "assistant", "content": message.content})
452+
messages.append({
453+
"role":
454+
"user",
455+
"content":
456+
"Give me another one with a different name and age"
457+
})
458+
chat_completion = await client.chat.completions.create(
459+
model=MODEL_NAME,
460+
messages=messages,
461+
max_tokens=500,
462+
extra_body=dict(guided_json=TEST_SCHEMA))
463+
message = chat_completion.choices[0].message
464+
assert message.content is not None
465+
json2 = json.loads(message.content)
466+
jsonschema.validate(instance=json2, schema=TEST_SCHEMA)
467+
assert json1["name"] != json2["name"]
468+
assert json1["age"] != json2["age"]
469+
470+
471+
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
472+
completion = await client.completions.create(
473+
model=MODEL_NAME,
474+
prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}",
475+
n=3,
476+
temperature=1.0,
477+
max_tokens=20,
478+
extra_body=dict(guided_regex=TEST_REGEX))
479+
480+
assert completion.id is not None
481+
assert completion.choices is not None and len(completion.choices) == 3
482+
for i in range(3):
483+
assert completion.choices[i].text is not None
484+
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None
485+
486+
487+
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
488+
messages = [{
489+
"role": "system",
490+
"content": "you are a helpful assistant"
491+
}, {
492+
"role":
493+
"user",
494+
"content":
495+
f"Give an example IP address with this regex: {TEST_REGEX}"
496+
}]
497+
chat_completion = await client.chat.completions.create(
498+
model=MODEL_NAME,
499+
messages=messages,
500+
max_tokens=20,
501+
extra_body=dict(guided_regex=TEST_REGEX))
502+
ip1 = chat_completion.choices[0].message.content
503+
assert ip1 is not None
504+
assert re.fullmatch(TEST_REGEX, ip1) is not None
505+
506+
messages.append({"role": "assistant", "content": ip1})
507+
messages.append({"role": "user", "content": "Give me a different one"})
508+
chat_completion = await client.chat.completions.create(
509+
model=MODEL_NAME,
510+
messages=messages,
511+
max_tokens=20,
512+
extra_body=dict(guided_regex=TEST_REGEX))
513+
ip2 = chat_completion.choices[0].message.content
514+
assert ip2 is not None
515+
assert re.fullmatch(TEST_REGEX, ip2) is not None
516+
assert ip1 != ip2
517+
518+
519+
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
520+
completion = await client.completions.create(
521+
model=MODEL_NAME,
522+
prompt="The best language for type-safe systems programming is ",
523+
n=2,
524+
temperature=1.0,
525+
max_tokens=10,
526+
extra_body=dict(guided_choice=TEST_CHOICE))
527+
528+
assert completion.id is not None
529+
assert completion.choices is not None and len(completion.choices) == 2
530+
for i in range(2):
531+
assert completion.choices[i].text in TEST_CHOICE
532+
533+
534+
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
535+
messages = [{
536+
"role": "system",
537+
"content": "you are a helpful assistant"
538+
}, {
539+
"role":
540+
"user",
541+
"content":
542+
"The best language for type-safe systems programming is "
543+
}]
544+
chat_completion = await client.chat.completions.create(
545+
model=MODEL_NAME,
546+
messages=messages,
547+
max_tokens=10,
548+
extra_body=dict(guided_choice=TEST_CHOICE))
549+
choice1 = chat_completion.choices[0].message.content
550+
assert choice1 in TEST_CHOICE
551+
552+
messages.append({"role": "assistant", "content": choice1})
553+
messages.append({
554+
"role": "user",
555+
"content": "I disagree, pick another one"
556+
})
557+
chat_completion = await client.chat.completions.create(
558+
model=MODEL_NAME,
559+
messages=messages,
560+
max_tokens=10,
561+
extra_body=dict(guided_choice=TEST_CHOICE))
562+
choice2 = chat_completion.choices[0].message.content
563+
assert choice2 in TEST_CHOICE
564+
assert choice1 != choice2
565+
566+
567+
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
568+
with pytest.raises(openai.BadRequestError):
569+
_ = await client.completions.create(
570+
model=MODEL_NAME,
571+
prompt="Give an example JSON that fits this schema: 42",
572+
extra_body=dict(guided_json=42))
573+
574+
messages = [{
575+
"role": "system",
576+
"content": "you are a helpful assistant"
577+
}, {
578+
"role":
579+
"user",
580+
"content":
581+
"The best language for type-safe systems programming is "
582+
}]
583+
with pytest.raises(openai.BadRequestError):
584+
_ = await client.chat.completions.create(model=MODEL_NAME,
585+
messages=messages,
586+
extra_body=dict(guided_regex={
587+
1: "Python",
588+
2: "C++"
589+
}))
590+
591+
with pytest.raises(openai.BadRequestError):
592+
_ = await client.completions.create(
593+
model=MODEL_NAME,
594+
prompt="Give an example string that fits this regex",
595+
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))
596+
597+
361598
if __name__ == "__main__":
362599
pytest.main([__file__])

vllm/engine/async_llm_engine.py

+3
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,9 @@ def is_running(self) -> bool:
333333
return (self.background_loop is not None
334334
and not self.background_loop.done())
335335

336+
def get_tokenizer(self):
337+
return self.engine.tokenizer.tokenizer
338+
336339
def start_background_loop(self) -> None:
337340
"""Start the background loop."""
338341
if self.is_running:

0 commit comments

Comments
 (0)