Skip to content

Commit a04e8d4

Browse files
committed
Basic proposition implementation token alignment
1 parent 37f53ca commit a04e8d4

File tree

2 files changed

+57
-4
lines changed

2 files changed

+57
-4
lines changed

outlines/fsm/fsm.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from copy import deepcopy
12
from typing import TYPE_CHECKING, List, NewType, Protocol
23

4+
import cloudpickle
35
import interegular
46
from lark import Lark
57

@@ -119,9 +121,43 @@ def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
119121
self.final_states = regex_fsm.finals | {
120122
-1
121123
} # Include the EOS token in final states
124+
self.tokenizer = tokenizer
122125
self.vocabulary = tokenizer.vocabulary.values()
123126
self.end_token_id = tokenizer.eos_token_id
124127

128+
def align_prompt_tokens(self, prompt: str) -> str:
129+
"""Remove the last token from the prompt and update the states_to_token_maps accordingly"""
130+
token_ids, _ = self.tokenizer.encode(prompt)
131+
last_token_id = int(token_ids[0][-1])
132+
last_token_text = self.tokenizer.decode([last_token_id])[0]
133+
vocabulary = {
134+
self.tokenizer.decode([token_id])[0]: token_id
135+
for token_id in range(len(self.vocabulary))
136+
}
137+
starting_state_tokens = {
138+
self.tokenizer.decode([token_id])[0]: self.states_to_token_maps[0][token_id]
139+
for token_id in self.states_to_token_maps[0]
140+
}
141+
# select the tokens that start with the text removed from the prompt and whose text after the
142+
# initial prompt corresponds to that of one of the allowed tokens of the starting state
143+
possible_tokens = {
144+
vocabulary[token_text]: starting_state_tokens[token_text[len(last_token_text):]]
145+
for token_text in vocabulary
146+
if (
147+
token_text.startswith(last_token_text)
148+
and starting_state_tokens.get(token_text[len(last_token_text):])
149+
)
150+
}
151+
# update the states_to_token_maps in the following manner:
152+
# the value of the starting state is assigned to a new state, the starting state is now the
153+
# possible_tokens found above + the last_token we removed (that leads to the new state)
154+
additional_state_id = max(list(self.states_to_token_maps.keys()) + list(self.final_states)) + 1
155+
self.states_to_token_maps[additional_state_id] = self.states_to_token_maps[0]
156+
self.states_to_token_maps[0] = {**possible_tokens, last_token_id: additional_state_id}
157+
158+
return prompt[:-len(last_token_text)]
159+
160+
125161
def allowed_token_ids(self, state: FSMState) -> List[int]:
126162
"""Generate a list of allowed tokens for the next step.
127163
@@ -186,7 +222,12 @@ def is_final_state(self, state: FSMState) -> bool:
186222

187223
def copy(self) -> "RegexFSM":
188224
"""Create a copy of the FSM."""
189-
return self
225+
# temporary solution to the problem of unpickleable dict_values
226+
self.vocabulary = cloudpickle.dumps(self.vocabulary)
227+
copy = deepcopy(self)
228+
self.vocabulary = cloudpickle.loads(self.vocabulary)
229+
copy.vocabulary = cloudpickle.loads(copy.vocabulary)
230+
return copy
190231

191232

192233
class CFGFSM(FSM):

outlines/generate/api.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,16 @@ def get_generated_token_ids(
7676

7777
return token_ids
7878

79+
def get_generated_sequences(
80+
self, generated_token_ids: List[torch.Tensor], initial_prompts: List[str], prompts: List[str]
81+
) -> List[str]:
82+
"""Give the text sequences generated based on the tokens generated and the initial prompts"""
83+
generated_tokens_text = self.tokenizer.decode(generated_token_ids)
84+
return [
85+
generated_tokens_text[i][len(initial_prompts[i]) - len(prompts[i]):]
86+
for i in range(len(generated_tokens_text))
87+
]
88+
7989
def is_stop_sequence_found(
8090
self, generated_sequences: List[str], stop_sequences: List[str]
8191
) -> bool:
@@ -186,6 +196,7 @@ def __call__(
186196

187197
if isinstance(prompts, str):
188198
prompts = [prompts]
199+
initial_prompts = copy.deepcopy(prompts)
189200

190201
if isinstance(stop_at, str):
191202
stop_at = [stop_at]
@@ -194,6 +205,7 @@ def __call__(
194205
max_tokens = max_tokens or self.max_tokens
195206
num_sequences = len(prompts)
196207
fsms = [self.fsm.copy() for _ in prompts]
208+
prompts = [fsm.align_prompt_tokens(prompt) for fsm, prompt in zip(fsms, prompts)]
197209

198210
if rng is None:
199211
rng = torch.Generator(device=self.device)
@@ -213,7 +225,7 @@ def __call__(
213225
last_state = next(states)
214226
if max_tokens or stop_sequences:
215227
generated_token_ids = self.get_generated_token_ids(
216-
init_state, prompts, last_state
228+
init_state, initial_prompts, last_state
217229
)
218230
if max_tokens and len(generated_token_ids[0]) >= max_tokens:
219231
break
@@ -225,9 +237,9 @@ def __call__(
225237
break
226238

227239
generated_token_ids = self.get_generated_token_ids(
228-
init_state, prompts, last_state
240+
init_state, initial_prompts, last_state
229241
)
230-
generated = self.tokenizer.decode(generated_token_ids)
242+
generated = self.get_generated_sequences(generated_token_ids, initial_prompts, prompts)
231243
stripped = [
232244
self.strip_stop_sequences(sequence, stop_sequences)
233245
for sequence in generated

0 commit comments

Comments
 (0)