Skip to content

Commit 22cbed6

Browse files
committed
index token -> transition key sequence for efficient fsm walk
1 parent 2be1743 commit 22cbed6

File tree

1 file changed

+46
-21
lines changed

1 file changed

+46
-21
lines changed

outlines/fsm/regex.py

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def _walk_fsm(
416416
alphabet_anything_value: int,
417417
fsm_initial: int,
418418
fsm_finals: Set[int],
419-
input_string: str,
419+
token_trans_key_seq: Sequence[int],
420420
start_state: int,
421421
full_match: bool = True,
422422
) -> List[int]:
@@ -428,19 +428,7 @@ def _walk_fsm(
428428
# By default, each symbol is a unicode character
429429
# Except, if the character, input_string[i] == '\x00', then the next two
430430
# in input_string characters are a hex representation of the byte
431-
i = 0
432-
while i < len(input_string):
433-
# if null-byte prefixed its a hex representation
434-
# unless its the last character, then its a trailing null byte symbol
435-
if input_string[i] == "\x00" and i != len(input_string) - 1:
436-
symbol = input_string[i : i + 3]
437-
i += 3
438-
else:
439-
symbol = input_string[i]
440-
i += 1
441-
442-
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
443-
431+
for i, trans_key in enumerate(token_trans_key_seq):
444432
new_state = fsm_transitions.get((state, trans_key))
445433

446434
if new_state is None:
@@ -677,27 +665,26 @@ def state_scan_tokens(
677665
fsm_initial: int,
678666
fsm_finals: Set[int],
679667
vocabulary: List[Tuple[str, Sequence[int]]],
668+
token_trans_key_seqs: List[Sequence[int]],
680669
start_state: int,
681670
) -> Set[Tuple[int, int]]:
682671
res = set()
683672

684-
for token, token_ids in vocabulary:
673+
for (token, token_ids), token_trans_key_seq in zip(
674+
vocabulary, token_trans_key_seqs
675+
):
685676
state_seq = _walk_fsm(
686677
fsm_transitions,
687678
alphabet_symbol_mapping,
688679
alphabet_anything_value,
689680
fsm_initial,
690681
fsm_finals,
691-
token,
682+
token_trans_key_seq,
692683
start_state,
693684
False,
694685
)
695686

696-
if token == "\x00":
697-
token_length = 1
698-
else:
699-
token_length = len(token) - 2 * token.count("\x00")
700-
if state_seq is not None and len(state_seq) < token_length:
687+
if state_seq is not None and len(state_seq) < len(token_trans_key_seq):
701688
continue
702689

703690
for token_id in token_ids:
@@ -706,6 +693,37 @@ def state_scan_tokens(
706693
return res
707694

708695

696+
@numba.njit(cache=True, nogil=True)
697+
def get_tokens_trans_keys(
698+
alphabet_symbol_mapping: Dict[str, int],
699+
alphabet_anything_value: int,
700+
vocabulary: List[Tuple[str, Sequence[int]]],
701+
) -> List[Tuple[str, Sequence[int], Sequence[int]]]:
702+
tokens_trans_keys = numba.typed.List.empty_list(numba.int64[:])
703+
for token_str, _ in vocabulary:
704+
trans_key_seq = []
705+
i = 0
706+
while i < len(token_str):
707+
if token_str[i] == "\x00" and i != len(token_str) - 1:
708+
symbol = token_str[i : i + 3]
709+
i += 3
710+
else:
711+
symbol = token_str[i]
712+
i += 1
713+
714+
trans_key_seq.append(
715+
alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
716+
)
717+
718+
trans_key_seq_array = np.empty(len(trans_key_seq), dtype=np.int64)
719+
for j in range(len(trans_key_seq)):
720+
trans_key_seq_array[j] = trans_key_seq[j]
721+
722+
tokens_trans_keys.append(trans_key_seq_array)
723+
724+
return tokens_trans_keys
725+
726+
709727
def create_fsm_index_end_to_end(
710728
fsm_info: FSMInfo,
711729
vocabulary: List[Tuple[str, Sequence[int]]],
@@ -724,6 +742,12 @@ def create_fsm_index_end_to_end(
724742
desc="Compiling FSM index for all state transitions",
725743
)
726744

745+
tokens_trans_key_seqs = get_tokens_trans_keys(
746+
fsm_info.alphabet_symbol_mapping,
747+
fsm_info.alphabet_anything_value,
748+
vocabulary,
749+
)
750+
727751
while next_states:
728752
start_state = next_states.pop()
729753

@@ -734,6 +758,7 @@ def create_fsm_index_end_to_end(
734758
fsm_info.initial,
735759
fsm_info.finals,
736760
vocabulary,
761+
tokens_trans_key_seqs,
737762
start_state,
738763
)
739764

0 commit comments

Comments
 (0)