24
24
limitations under the License.
25
25
"""
26
26
import math
27
- from typing import TYPE_CHECKING , Dict , List , Optional , Type , Union
27
+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Type , Union
28
28
29
29
import torch
30
30
from pydantic import BaseModel
@@ -50,6 +50,11 @@ class GuideLogitsProcessor(OutlinesLogitsProcessor):
50
50
The `outlines.fsm.Guide` which is used to bias the logits.
51
51
"""
52
52
53
+ tokenizer : "Tokenizer"
54
+ guide : Guide
55
+ _guide_states : Dict [int , Any ]
56
+ _seq_start_idx : Optional [int ]
57
+
53
58
def __init__ (self , tokenizer : "Tokenizer" , guide : Guide ):
54
59
"""A Guide-based logits processor.
55
60
@@ -61,9 +66,9 @@ def __init__(self, tokenizer: "Tokenizer", guide: Guide):
61
66
The `outlines.fsm.Guide. which is used to bias the logits.
62
67
"""
63
68
self .tokenizer = tokenizer
64
- self .guide : Guide = guide
65
- self ._guide_states : Dict [ int , int ] = {hash (tuple ([])): self .guide .initial_state }
66
- self ._seq_start_idx : Optional [ int ] = None
69
+ self .guide = guide
70
+ self ._guide_states = {hash (tuple ([])): self .guide .initial_state }
71
+ self ._seq_start_idx = None
67
72
68
73
def process_logits (
69
74
self , input_ids : List [List [int ]], logits : torch .Tensor
@@ -181,6 +186,8 @@ class CFGLogitsProcessor(GuideLogitsProcessor):
181
186
The `outlines.fsm.CFGGuide. which is used to bias the logits.
182
187
"""
183
188
189
+ guide : CFGGuide
190
+
184
191
def __init__ (self , cfg_str : str , tokenizer : "Tokenizer" ):
185
192
"""Compile the CFGGuide that drives the CFG-guided generation.
186
193
@@ -193,3 +200,34 @@ def __init__(self, cfg_str: str, tokenizer: "Tokenizer"):
193
200
"""
194
201
cfg_guide = CFGGuide (cfg_string = cfg_str , tokenizer = tokenizer )
195
202
super ().__init__ (tokenizer = tokenizer , guide = cfg_guide )
203
+
204
+ def process_logits (
205
+ self , input_ids : List [List [int ]], logits : torch .Tensor
206
+ ) -> torch .Tensor :
207
+ """Same behavior as GuideLogitsProcessor, but uses rejection sampling"""
208
+ if self ._seq_start_idx is None :
209
+ self ._seq_start_idx = len (input_ids [0 ])
210
+
211
+ sequence_states : List = [] # vector of states corresponding to `input_ids`
212
+
213
+ for seq_ids in input_ids :
214
+ gen_ids = seq_ids [self ._seq_start_idx :]
215
+ curr_state_key = hash (tuple (gen_ids ))
216
+
217
+ if curr_state_key not in self ._guide_states :
218
+ prev_state = self ._guide_states [hash (tuple (gen_ids [:- 1 ]))]
219
+ curr_state = self .guide .get_next_state (prev_state , gen_ids [- 1 ])
220
+ self ._guide_states [curr_state_key ] = curr_state
221
+
222
+ sequence_states .append (self ._guide_states [curr_state_key ])
223
+
224
+ mask = torch .full_like (logits , - math .inf )
225
+ for i , guide_state in enumerate (sequence_states ):
226
+ first_legal_token = next (
227
+ self .guide .iter_valid_token_ids (
228
+ guide_state , torch .argsort (logits [i ], descending = True )
229
+ )
230
+ )
231
+ mask [i , [first_legal_token ]] = logits [i , [first_legal_token ]]
232
+
233
+ return mask
0 commit comments