Skip to content

Commit 43ed030

Browse files
committed
Add test for logit processor
1 parent 053b80e commit 43ed030

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

outlines/serve/vllm.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Callable, DefaultDict, List
66

77
import torch
8-
from vllm import LLMEngine
98

109
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
1110
from outlines.fsm.json_schema import build_regex_from_object
@@ -113,7 +112,7 @@ def __call__(
113112

114113

115114
class RegexLogitsProcessor(FSMLogitsProcessor):
116-
def __init__(self, regex_string, llm: LLMEngine):
115+
def __init__(self, regex_string, llm):
117116
"""Compile the FSM that drives the regex-guided generation.
118117
119118
Parameters
@@ -130,7 +129,7 @@ def __init__(self, regex_string, llm: LLMEngine):
130129

131130

132131
class CFGLogitsProcessor(FSMLogitsProcessor):
133-
def __init__(self, cfg_string, llm: LLMEngine):
132+
def __init__(self, cfg_string, llm):
134133
"""Compile the FSM that drives the cfg-guided generation.
135134
136135
Parameters
@@ -147,7 +146,7 @@ def __init__(self, cfg_string, llm: LLMEngine):
147146

148147

149148
class JSONLogitsProcessor(RegexLogitsProcessor):
150-
def __init__(self, schema, llm: LLMEngine):
149+
def __init__(self, schema, llm):
151150
"""Compile the FSM that drives the JSON-guided generation.
152151
153152
Parameters

tests/test_vllm.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import pytest
2+
import torch
3+
from transformers import AutoTokenizer
4+
5+
from outlines.serve.vllm import (
6+
CFGLogitsProcessor,
7+
JSONLogitsProcessor,
8+
RegexLogitsProcessor,
9+
)
10+
11+
TEST_REGEX = r"(-)?(0|[1-9][0-9]*)(.[0-9]+)?([eE][+-][0-9]+)?"
12+
TEST_CFG = """
13+
start: DECIMAL
14+
DIGIT: "0".."9"
15+
INT: DIGIT+
16+
DECIMAL: INT "." INT? | "." INT
17+
"""
18+
TEST_SCHEMA = '{"type": "string", "maxLength": 5}'
19+
20+
LOGIT_PROCESSORS = (
21+
(CFGLogitsProcessor, TEST_CFG),
22+
(RegexLogitsProcessor, TEST_REGEX),
23+
(JSONLogitsProcessor, TEST_SCHEMA),
24+
)
25+
26+
TEST_MODEL = "hf-internal-testing/tiny-random-GPTJForCausalLM"
27+
28+
29+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda available")
30+
@pytest.mark.parametrize("logit_processor, fsm_str", LOGIT_PROCESSORS)
31+
def test_logit_processor(logit_processor, fsm_str: str):
32+
class MockvLLMEngine:
33+
def __init__(self, tokenizer):
34+
self.tokenizer = tokenizer
35+
36+
def __call__(*_):
37+
return torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.float), None
38+
39+
tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL)
40+
engine = MockvLLMEngine(tokenizer)
41+
logit_processor(fsm_str, engine)
42+
assert isinstance(engine.tokenizer.decode([0, 1, 2, 3]), list)
43+
logit_processor(fsm_str, engine)
44+
assert isinstance(engine.tokenizer.decode([0, 1, 2, 3]), list)

0 commit comments

Comments
 (0)