|
| 1 | +from copy import deepcopy |
1 | 2 | from typing import TYPE_CHECKING, List, NewType, Protocol
|
2 | 3 |
|
| 4 | +import cloudpickle |
3 | 5 | import interegular
|
4 | 6 | from lark import Lark
|
5 | 7 |
|
@@ -119,9 +121,43 @@ def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
|
119 | 121 | self.final_states = regex_fsm.finals | {
|
120 | 122 | -1
|
121 | 123 | } # Include the EOS token in final states
|
| 124 | + self.tokenizer = tokenizer |
122 | 125 | self.vocabulary = tokenizer.vocabulary.values()
|
123 | 126 | self.end_token_id = tokenizer.eos_token_id
|
124 | 127 |
|
| 128 | + def align_prompt_tokens(self, prompt: str) -> str: |
| 129 | + """Remove the last token from the prompt and update the states_to_token_maps accordingly""" |
| 130 | + token_ids, _ = self.tokenizer.encode(prompt) |
| 131 | + last_token_id = int(token_ids[0][-1]) |
| 132 | + last_token_text = self.tokenizer.decode([last_token_id])[0] |
| 133 | + vocabulary = { |
| 134 | + self.tokenizer.decode([token_id])[0]: token_id |
| 135 | + for token_id in range(len(self.vocabulary)) |
| 136 | + } |
| 137 | + starting_state_tokens = { |
| 138 | + self.tokenizer.decode([token_id])[0]: self.states_to_token_maps[0][token_id] |
| 139 | + for token_id in self.states_to_token_maps[0] |
| 140 | + } |
| 141 | + # select the tokens that start with the text removed from the prompt and whose text after the |
| 142 | + # initial prompt corresponds to that of one of the allowed tokens of the starting state |
| 143 | + possible_tokens = { |
| 144 | + vocabulary[token_text]: starting_state_tokens[token_text[len(last_token_text):]] |
| 145 | + for token_text in vocabulary |
| 146 | + if ( |
| 147 | + token_text.startswith(last_token_text) |
| 148 | + and starting_state_tokens.get(token_text[len(last_token_text):]) |
| 149 | + ) |
| 150 | + } |
| 151 | + # update the states_to_token_maps in the following manner: |
| 152 | + # the value of the starting state is assigned to a new state, the starting state is now the |
| 153 | + # possible_tokens found above + the last_token we removed (that leads to the new state) |
| 154 | + additional_state_id = max(list(self.states_to_token_maps.keys()) + list(self.final_states)) + 1 |
| 155 | + self.states_to_token_maps[additional_state_id] = self.states_to_token_maps[0] |
| 156 | + self.states_to_token_maps[0] = {**possible_tokens, last_token_id: additional_state_id} |
| 157 | + |
| 158 | + return prompt[:-len(last_token_text)] |
| 159 | + |
| 160 | + |
125 | 161 | def allowed_token_ids(self, state: FSMState) -> List[int]:
|
126 | 162 | """Generate a list of allowed tokens for the next step.
|
127 | 163 |
|
@@ -186,7 +222,12 @@ def is_final_state(self, state: FSMState) -> bool:
|
186 | 222 |
|
187 | 223 | def copy(self) -> "RegexFSM":
|
188 | 224 | """Create a copy of the FSM."""
|
189 |
| - return self |
| 225 | + # temporary solution to the problem of unpickleable dict_values |
| 226 | + self.vocabulary = cloudpickle.dumps(self.vocabulary) |
| 227 | + copy = deepcopy(self) |
| 228 | + self.vocabulary = cloudpickle.loads(self.vocabulary) |
| 229 | + copy.vocabulary = cloudpickle.loads(copy.vocabulary) |
| 230 | + return copy |
190 | 231 |
|
191 | 232 |
|
192 | 233 | class CFGFSM(FSM):
|
|
0 commit comments