11from copy import deepcopy
2- from typing import TYPE_CHECKING , List , NewType , Protocol
2+ from typing import TYPE_CHECKING , Dict , List , NewType , Protocol , Tuple
33
4- import cloudpickle
54import interegular
65from lark import Lark
76
1514
1615
1716class FSM (Protocol ):
17+ def align_prompt_tokens (self , prompt : str ) -> str :
18+ ...
19+
1820 def allowed_token_ids (self , state : FSMState ) -> List [int ]:
1921 ...
2022
@@ -39,8 +41,23 @@ class StopAtTokenFSM(FSM):
3941
4042 def __init__ (self , tokenizer : "Tokenizer" , stop_token_id : int ):
4143 self .stop_token_id = stop_token_id
42- self .vocabulary = tokenizer .vocabulary .values ()
43- self .final_states = {1 }
44+ self .tokenizer = tokenizer
45+ self .vocabulary = tokenizer .vocabulary
46+ self .final_states = {2 }
47+ self .valid_alignment_tokens : List [int ] = []
48+
49+ def align_prompt_tokens (self , prompt : str ) -> str :
50+ """Remove the last token from the prompt and set the value of self.valid_alignment_tokens"""
51+ token_ids , _ = self .tokenizer .encode (prompt )
52+ last_token_id = int (token_ids [0 ][- 1 ])
53+ last_token_text = self .tokenizer .decode ([last_token_id ])[0 ]
54+ # select the tokens that start with the text removed from the prompt
55+ self .valid_alignment_tokens = [
56+ token
57+ for text , token in self .vocabulary .items ()
58+ if text .startswith (last_token_text )
59+ ]
60+ return prompt [: - len (last_token_text )]
4461
4562 def allowed_token_ids (self , state : FSMState ) -> List [int ]:
4663 """Generate a list of allowed tokens for the next step.
@@ -59,7 +76,9 @@ def allowed_token_ids(self, state: FSMState) -> List[int]:
5976
6077 """
6178 if state == 0 :
62- return list (self .vocabulary )
79+ return self .valid_alignment_tokens
80+ elif state == 1 :
81+ return list (self .vocabulary .values ())
6382 else :
6483 return [self .stop_token_id ]
6584
@@ -83,17 +102,17 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
83102
84103 """
85104 if token_id == self .stop_token_id :
86- return FSMState (1 )
105+ return FSMState (2 )
87106
88- return FSMState (0 )
107+ return FSMState (1 )
89108
90109 def is_final_state (self , state : FSMState ) -> bool :
91110 """Determine whether the current state of the FSM is a final state."""
92111 return state in self .final_states
93112
94113 def copy (self ) -> "StopAtTokenFSM" :
95114 """Create a copy of the FSM."""
96- return self
115+ return deepcopy ( self )
97116
98117
99118class RegexFSM (FSM ):
@@ -122,41 +141,61 @@ def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
122141 - 1
123142 } # Include the EOS token in final states
124143 self .tokenizer = tokenizer
125- self .vocabulary = tokenizer .vocabulary . values ()
144+ self .vocabulary = tokenizer .vocabulary
126145 self .end_token_id = tokenizer .eos_token_id
127146
128147 def align_prompt_tokens (self , prompt : str ) -> str :
129148 """Remove the last token from the prompt and update the states_to_token_maps accordingly"""
130149 token_ids , _ = self .tokenizer .encode (prompt )
131150 last_token_id = int (token_ids [0 ][- 1 ])
132151 last_token_text = self .tokenizer .decode ([last_token_id ])[0 ]
133- vocabulary = {
134- self .tokenizer .decode ([token_id ])[0 ]: token_id
135- for token_id in range (len (self .vocabulary ))
136- }
137- starting_state_tokens = {
138- self .tokenizer .decode ([token_id ])[0 ]: self .states_to_token_maps [0 ][token_id ]
139- for token_id in self .states_to_token_maps [0 ]
140- }
141- # select the tokens that start with the text removed from the prompt and whose text after the
142- # initial prompt corresponds to that of one of the allowed tokens of the starting state
143- possible_tokens = {
144- vocabulary [token_text ]: starting_state_tokens [token_text [len (last_token_text ):]]
145- for token_text in vocabulary
146- if (
147- token_text .startswith (last_token_text )
148- and starting_state_tokens .get (token_text [len (last_token_text ):])
149- )
152+ last_token_length = len (last_token_text )
153+ # select the tokens that start with the text removed from the prompt
154+ crossing_tokens = {
155+ token : text
156+ for text , token in self .vocabulary .items ()
157+ if text .startswith (last_token_text )
150158 }
159+ # keep only the tokens whose text after the boundary matches the fsm
160+ valid_tokens_states = self .find_valid_crossing_tokens (
161+ crossing_tokens , last_token_length
162+ )
151163 # update the states_to_token_maps in the following manner:
152164 # the value of the starting state is assigned to a new state, the starting state is now the
153- # possible_tokens found above + the last_token we removed (that leads to the new state)
154- additional_state_id = max (list (self .states_to_token_maps .keys ()) + list (self .final_states )) + 1
165+ # valid_tokens_states found above
166+ additional_state_id = (
167+ max (list (self .states_to_token_maps .keys ()) + list (self .final_states )) + 1
168+ )
155169 self .states_to_token_maps [additional_state_id ] = self .states_to_token_maps [0 ]
156- self .states_to_token_maps [0 ] = {** possible_tokens , last_token_id : additional_state_id }
157-
158- return prompt [:- len (last_token_text )]
159-
170+ self .states_to_token_maps [0 ] = {}
171+ for token , state in valid_tokens_states :
172+ if state == 0 :
173+ self .states_to_token_maps [0 ][token ] = additional_state_id
174+ else :
175+ self .states_to_token_maps [0 ][token ] = state
176+ return prompt [: - len (last_token_text )]
177+
178+ def find_valid_crossing_tokens (
179+ self , crossing_tokens : Dict [int , str ], last_token_length : int
180+ ) -> List [Tuple [int , int ]]:
181+ """For each crossing token, check that the characters after the boundary match the FSM
182+ and find the state it would lead to. Return the valid tokens with the associated state
183+ """
184+ valid_tokens = []
185+ for token , text in crossing_tokens .items ():
186+ is_valid = True
187+ crossing_text = text [last_token_length :]
188+ state = 0
189+ for char in crossing_text :
190+ char_token = self .vocabulary .get (char )
191+ try :
192+ state = self .states_to_token_maps [state ][char_token ] # type: ignore
193+ except KeyError :
194+ is_valid = False
195+ break
196+ if is_valid :
197+ valid_tokens .append ((token , state ))
198+ return valid_tokens
160199
161200 def allowed_token_ids (self , state : FSMState ) -> List [int ]:
162201 """Generate a list of allowed tokens for the next step.
@@ -222,12 +261,7 @@ def is_final_state(self, state: FSMState) -> bool:
222261
223262 def copy (self ) -> "RegexFSM" :
224263 """Create a copy of the FSM."""
225- # temporary solution to the problem of unpickleable dict_values
226- self .vocabulary = cloudpickle .dumps (self .vocabulary )
227- copy = deepcopy (self )
228- self .vocabulary = cloudpickle .loads (self .vocabulary )
229- copy .vocabulary = cloudpickle .loads (copy .vocabulary )
230- return copy
264+ return deepcopy (self )
231265
232266
233267class CFGFSM (FSM ):
@@ -257,6 +291,10 @@ def __init__(self, cfg_string: str, tokenizer: "Tokenizer"):
257291 self .done = False
258292 self .regex_fsm : RegexFSM
259293
294+ def align_prompt_tokens (self , prompt : str ) -> str :
295+ """Not implemented for CFGFSM"""
296+ return prompt
297+
260298 def _set_next_regex_fsm (self ) -> None :
261299 """Use the CFG incremental parser to set the next regex FSM.
262300
@@ -278,7 +316,6 @@ def _set_next_regex_fsm(self) -> None:
278316 self .allow_eos = True
279317 options .add ("" )
280318 assert len (options ) > 1
281-
282319 regex_string = r"(" + r"|" .join ([r"(" + x + r")" for x in options ]) + r")"
283320 self .regex_fsm = RegexFSM (regex_string , self .tokenizer )
284321 self .reset_state = True
0 commit comments