Skip to content

Commit 609fc5d

Browse files
committed
Add test for logit processor
1 parent 9b092e5 commit 609fc5d

File tree

3 files changed

+46
-5
lines changed

3 files changed

+46
-5
lines changed

outlines/serve/serve.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
CFGLogitsProcessor,
2929
JSONLogitsProcessor,
3030
RegexLogitsProcessor,
31-
CFGLogitsProcessor,
3231
_patched_apply_logits_processors,
3332
)
3433

outlines/serve/vllm.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Callable, DefaultDict, List
66

77
import torch
8-
from vllm import LLMEngine
98

109
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
1110
from outlines.fsm.json_schema import build_regex_from_object
@@ -113,7 +112,7 @@ def __call__(
113112

114113

115114
class RegexLogitsProcessor(FSMLogitsProcessor):
116-
def __init__(self, regex_string, llm: LLMEngine):
115+
def __init__(self, regex_string, llm):
117116
"""Compile the FSM that drives the regex-guided generation.
118117
119118
Parameters
@@ -130,7 +129,7 @@ def __init__(self, regex_string, llm: LLMEngine):
130129

131130

132131
class CFGLogitsProcessor(FSMLogitsProcessor):
133-
def __init__(self, cfg_string, llm: LLMEngine):
132+
def __init__(self, cfg_string, llm):
134133
"""Compile the FSM that drives the cfg-guided generation.
135134
136135
Parameters
@@ -147,7 +146,7 @@ def __init__(self, cfg_string, llm: LLMEngine):
147146

148147

149148
class JSONLogitsProcessor(RegexLogitsProcessor):
150-
def __init__(self, schema, llm: LLMEngine):
149+
def __init__(self, schema, llm):
151150
"""Compile the FSM that drives the JSON-guided generation.
152151
153152
Parameters

tests/test_vllm.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import pytest
2+
import torch
3+
from transformers import AutoTokenizer
4+
5+
from outlines.serve.vllm import (
6+
CFGLogitsProcessor,
7+
JSONLogitsProcessor,
8+
RegexLogitsProcessor,
9+
)
10+
11+
TEST_REGEX = r"(-)?(0|[1-9][0-9]*)(.[0-9]+)?([eE][+-][0-9]+)?"
12+
TEST_CFG = """
13+
start: DECIMAL
14+
DIGIT: "0".."9"
15+
INT: DIGIT+
16+
DECIMAL: INT "." INT? | "." INT
17+
"""
18+
TEST_SCHEMA = '{"type": "string", "maxLength": 5}'
19+
20+
LOGIT_PROCESSORS = (
21+
(CFGLogitsProcessor, TEST_CFG),
22+
(RegexLogitsProcessor, TEST_REGEX),
23+
(JSONLogitsProcessor, TEST_SCHEMA),
24+
)
25+
26+
TEST_MODEL = "hf-internal-testing/tiny-random-GPTJForCausalLM"
27+
28+
29+
@pytest.mark.parametrize("logit_processor, fsm_str", LOGIT_PROCESSORS)
30+
def test_logit_processor(logit_processor, fsm_str: str):
31+
class MockvLLMEngine:
32+
def __init__(self, tokenizer):
33+
self.tokenizer = tokenizer
34+
35+
def __call__(*_):
36+
return torch.tensor([[0, 1, 2, 3, 4]], dtype=torch.float), None
37+
38+
tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL)
39+
engine = MockvLLMEngine(tokenizer)
40+
logit_processor(fsm_str, engine)
41+
assert isinstance(engine.tokenizer.decode([0, 1, 2, 3]), list)
42+
logit_processor(fsm_str, engine)
43+
assert isinstance(engine.tokenizer.decode([0, 1, 2, 3]), list)

0 commit comments

Comments
 (0)