Skip to content

Commit 91c7b3d

Browse files
committed
Create TextLogitsProcessor class
1 parent 81678d0 commit 91c7b3d

File tree

5 files changed

+133
-22
lines changed

5 files changed

+133
-22
lines changed

outlines/generate/text.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,18 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator:
4040
@text.register(Transformers)
4141
@text.register(LlamaCpp)
4242
def text_unified(model, sampler: Sampler = multinomial()):
43-
return SequenceGeneratorAdapter(model, None, sampler)
43+
from outlines.processors import TextLogitsProcessor
44+
45+
logits_processor = TextLogitsProcessor(model.tokenizer)
46+
return SequenceGeneratorAdapter(model, logits_processor, sampler)
4447

4548

4649
@text.register(VLLM)
4750
def text_vllm(model: VLLM, sampler: Sampler = multinomial()):
48-
return SequenceGeneratorAdapter(model, None, sampler)
51+
from outlines.integrations.vllm import TextLogitsProcessor
52+
53+
logits_processor = TextLogitsProcessor(model)
54+
return SequenceGeneratorAdapter(model, logits_processor, sampler)
4955

5056

5157
@text.register(OpenAI)

outlines/integrations/llamacpp.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from numpy.typing import NDArray
3434
from pydantic import BaseModel
3535

36-
from outlines.fsm.guide import CFGGuide, Guide, RegexGuide
36+
from outlines.fsm.guide import CFGGuide, Guide, RegexGuide, StopAtEOSGuide
3737
from outlines.fsm.json_schema import build_regex_from_schema
3838
from outlines.integrations.utils import convert_json_schema_to_str
3939
from outlines.models.llamacpp import LlamaCppTokenizer
@@ -104,6 +104,30 @@ def copy(self) -> "LogitsProcessor":
104104
return LogitsProcessor(tokenizer=self.tokenizer, fsm=self.fsm.copy())
105105

106106

107+
class TextLogitsProcessor(LogitsProcessor):
108+
"""Bias vLLM generation for free text (required because of prompt alignment).
109+
110+
Attributes
111+
----------
112+
tokenizer
113+
The tokenizer used to convert tokens to ids.
114+
fsm
115+
The finite state machine which is used to bias the logits.
116+
"""
117+
118+
def __init__(self, llm: "Llama"):
119+
"""Compile the FSM that drives the regex-guided generation.
120+
121+
Parameters
122+
----------
123+
llm
124+
The Llama model.
125+
"""
126+
tokenizer = LlamaCppTokenizer(model=llm)
127+
fsm = StopAtEOSGuide(tokenizer)
128+
super().__init__(tokenizer=tokenizer, fsm=fsm)
129+
130+
107131
class RegexLogitsProcessor(LogitsProcessor):
108132
"""Bias LlamaCpp generation based on a regular expression.
109133

outlines/integrations/vllm.py

Lines changed: 75 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,39 +32,41 @@
3232
import torch
3333
from pydantic import BaseModel
3434

35-
from outlines.fsm.guide import RegexGuide
35+
from outlines.fsm.guide import Guide, RegexGuide, StopAtEOSGuide
3636
from outlines.fsm.json_schema import build_regex_from_schema
3737
from outlines.integrations.utils import adapt_tokenizer, convert_json_schema_to_str
3838

3939
if TYPE_CHECKING:
4040
from vllm import LLM
4141

42+
from outlines.models.tokenizer import Tokenizer
4243

43-
class RegexLogitsProcessor:
44-
"""Bias vLLM generation based on a regular expression.
44+
45+
class FSMLogitsProcessor:
46+
"""Bias vLLM generation based on a FSM.
4547
4648
Attributes
4749
----------
4850
fsm
4951
The finite state machine which is used to bias the logits.
5052
"""
5153

52-
def __init__(self, regex_string: str, llm: "LLM"):
54+
def __init__(self, fsm: Guide):
5355
"""Compile the FSM that drives the regex-structured generation.
5456
5557
Parameters
5658
----------
57-
regex_string
58-
A string that represents a regular expression.
59-
llm
60-
The vLLM model.
59+
fsm
60+
Guide.
6161
62-
Raises
63-
------
64-
ValueError
65-
If the provided LLM instance in `RegexLogitsProcessor` neither has a
66-
`tokenizer` attribute or a `get_tokenizer` method.
6762
"""
63+
self.fsm = fsm
64+
self.mask_cache: Dict[int, torch.Tensor] = {}
65+
self._fsm_state: DefaultDict[int, int] = defaultdict(int)
66+
67+
@staticmethod
68+
def get_llm_tokenizer(llm: "LLM") -> "Tokenizer":
69+
"""Give the tokenizer attached to the LLM provided"""
6870
if hasattr(llm, "get_tokenizer"):
6971
tokenizer = llm.get_tokenizer()
7072
elif hasattr(llm, "tokenizer"):
@@ -74,13 +76,10 @@ def __init__(self, regex_string: str, llm: "LLM"):
7476
tokenizer = llm.tokenizer
7577
else:
7678
raise ValueError(
77-
"The provided LLM instance in `RegexLogitsProcessor` neither has a "
79+
"The provided LLM instance in `FSMLogitsProcessor` neither has a "
7880
"`tokenizer` attribute or a `get_tokenizer` method."
7981
)
80-
tokenizer = adapt_tokenizer(tokenizer=tokenizer)
81-
self.mask_cache: Dict[int, torch.Tensor] = {}
82-
self.fsm = RegexGuide(regex_string, tokenizer)
83-
self._fsm_state: DefaultDict[int, int] = defaultdict(int)
82+
return adapt_tokenizer(tokenizer=tokenizer)
8483

8584
def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
8685
"""Use the FSM to bias the logits before sampling the next token.
@@ -125,6 +124,64 @@ def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
125124
return biased_scores
126125

127126

127+
class TextLogitsProcessor(FSMLogitsProcessor):
128+
"""Bias vLLM generation for free text (required because of prompt alignment).
129+
130+
Attributes
131+
----------
132+
fsm
133+
The finite state machine which is used to bias the logits.
134+
"""
135+
136+
def __init__(self, llm: "LLM"):
137+
"""Compile the FSM that drives the regex-structured generation.
138+
139+
Parameters
140+
----------
141+
llm
142+
The vLLM model.
143+
144+
Raises
145+
------
146+
ValueError
147+
If the provided LLM instance in `TextLogitsProcessor` neither has a
148+
`tokenizer` attribute or a `get_tokenizer` method.
149+
"""
150+
tokenizer = self.get_llm_tokenizer(llm)
151+
fsm = StopAtEOSGuide(tokenizer)
152+
super().__init__(fsm=fsm)
153+
154+
155+
class RegexLogitsProcessor(FSMLogitsProcessor):
156+
"""Bias vLLM generation based on a regular expression.
157+
158+
Attributes
159+
----------
160+
fsm
161+
The finite state machine which is used to bias the logits.
162+
"""
163+
164+
def __init__(self, regex_string: str, llm: "LLM"):
165+
"""Compile the FSM that drives the regex-structured generation.
166+
167+
Parameters
168+
----------
169+
regex_string
170+
A string that represents a regular expression.
171+
llm
172+
The vLLM model.
173+
174+
Raises
175+
------
176+
ValueError
177+
If the provided LLM instance in `RegexLogitsProcessor` neither has a
178+
`tokenizer` attribute or a `get_tokenizer` method.
179+
"""
180+
tokenizer = self.get_llm_tokenizer(llm)
181+
fsm = RegexGuide(regex_string, tokenizer)
182+
super().__init__(fsm=fsm)
183+
184+
128185
class JSONLogitsProcessor(RegexLogitsProcessor):
129186
"""Bias vLLM generation based on a JSON schema.
130187

outlines/processors/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@
44
JSONLogitsProcessor,
55
OutlinesLogitsProcessor,
66
RegexLogitsProcessor,
7+
TextLogitsProcessor,
78
)

outlines/processors/structured.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import torch
3030
from pydantic import BaseModel
3131

32-
from outlines.fsm.guide import CFGGuide, Guide, RegexGuide
32+
from outlines.fsm.guide import CFGGuide, Guide, RegexGuide, StopAtEOSGuide
3333
from outlines.fsm.json_schema import build_regex_from_schema
3434
from outlines.integrations.utils import convert_json_schema_to_str
3535

@@ -115,6 +115,29 @@ def copy(self) -> "FSMLogitsProcessor":
115115
return FSMLogitsProcessor(tokenizer=self.tokenizer, fsm=self.fsm.copy())
116116

117117

118+
class TextLogitsProcessor(FSMLogitsProcessor):
119+
"""Bias generation for free text (required because of prompt alignment).
120+
121+
Attributes
122+
----------
123+
tokenizer
124+
The tokenizer used to convert tokens to ids.
125+
fsm
126+
The finite state machine which is used to bias the logits.
127+
"""
128+
129+
def __init__(self, tokenizer: "Tokenizer"):
130+
"""Compile the FSM that drives the regex-guided generation.
131+
132+
Parameters
133+
----------
134+
tokenizer
135+
An Outlines tokenizer.
136+
"""
137+
fsm = StopAtEOSGuide(tokenizer)
138+
super().__init__(tokenizer=tokenizer, fsm=fsm)
139+
140+
118141
class RegexLogitsProcessor(FSMLogitsProcessor):
119142
"""Bias generation based on a regular expression.
120143

0 commit comments

Comments
 (0)