@@ -416,7 +416,7 @@ def _walk_fsm(
416
416
alphabet_anything_value : int ,
417
417
fsm_initial : int ,
418
418
fsm_finals : Set [int ],
419
- input_string : str ,
419
+ token_trans_key_seq : Sequence [ int ] ,
420
420
start_state : int ,
421
421
full_match : bool = True ,
422
422
) -> List [int ]:
@@ -428,19 +428,7 @@ def _walk_fsm(
428
428
# By default, each symbol is a unicode character
429
429
# Except, if the character, input_string[i] == '\x00', then the next two
430
430
# in input_string characters are a hex representation of the byte
431
- i = 0
432
- while i < len (input_string ):
433
- # if null-byte prefixed its a hex representation
434
- # unless its the last character, then its a trailing null byte symbol
435
- if input_string [i ] == "\x00 " and i != len (input_string ) - 1 :
436
- symbol = input_string [i : i + 3 ]
437
- i += 3
438
- else :
439
- symbol = input_string [i ]
440
- i += 1
441
-
442
- trans_key = alphabet_symbol_mapping .get (symbol , alphabet_anything_value )
443
-
431
+ for i , trans_key in enumerate (token_trans_key_seq ):
444
432
new_state = fsm_transitions .get ((state , trans_key ))
445
433
446
434
if new_state is None :
@@ -677,27 +665,26 @@ def state_scan_tokens(
677
665
fsm_initial : int ,
678
666
fsm_finals : Set [int ],
679
667
vocabulary : List [Tuple [str , Sequence [int ]]],
668
+ token_trans_key_seqs : List [Sequence [int ]],
680
669
start_state : int ,
681
670
) -> Set [Tuple [int , int ]]:
682
671
res = set ()
683
672
684
- for token , token_ids in vocabulary :
673
+ for (token , token_ids ), token_trans_key_seq in zip (
674
+ vocabulary , token_trans_key_seqs
675
+ ):
685
676
state_seq = _walk_fsm (
686
677
fsm_transitions ,
687
678
alphabet_symbol_mapping ,
688
679
alphabet_anything_value ,
689
680
fsm_initial ,
690
681
fsm_finals ,
691
- token ,
682
+ token_trans_key_seq ,
692
683
start_state ,
693
684
False ,
694
685
)
695
686
696
- if token == "\x00 " :
697
- token_length = 1
698
- else :
699
- token_length = len (token ) - 2 * token .count ("\x00 " )
700
- if state_seq is not None and len (state_seq ) < token_length :
687
+ if state_seq is not None and len (state_seq ) < len (token_trans_key_seq ):
701
688
continue
702
689
703
690
for token_id in token_ids :
@@ -706,6 +693,37 @@ def state_scan_tokens(
706
693
return res
707
694
708
695
696
+ @numba .njit (cache = True , nogil = True )
697
+ def get_tokens_trans_keys (
698
+ alphabet_symbol_mapping : Dict [str , int ],
699
+ alphabet_anything_value : int ,
700
+ vocabulary : List [Tuple [str , Sequence [int ]]],
701
+ ) -> List [Tuple [str , Sequence [int ], Sequence [int ]]]:
702
+ tokens_trans_keys = numba .typed .List .empty_list (numba .int64 [:])
703
+ for token_str , _ in vocabulary :
704
+ trans_key_seq = []
705
+ i = 0
706
+ while i < len (token_str ):
707
+ if token_str [i ] == "\x00 " and i != len (token_str ) - 1 :
708
+ symbol = token_str [i : i + 3 ]
709
+ i += 3
710
+ else :
711
+ symbol = token_str [i ]
712
+ i += 1
713
+
714
+ trans_key_seq .append (
715
+ alphabet_symbol_mapping .get (symbol , alphabet_anything_value )
716
+ )
717
+
718
+ trans_key_seq_array = np .empty (len (trans_key_seq ), dtype = np .int64 )
719
+ for j in range (len (trans_key_seq )):
720
+ trans_key_seq_array [j ] = trans_key_seq [j ]
721
+
722
+ tokens_trans_keys .append (trans_key_seq_array )
723
+
724
+ return tokens_trans_keys
725
+
726
+
709
727
def create_fsm_index_end_to_end (
710
728
fsm_info : FSMInfo ,
711
729
vocabulary : List [Tuple [str , Sequence [int ]]],
@@ -724,6 +742,12 @@ def create_fsm_index_end_to_end(
724
742
desc = "Compiling FSM index for all state transitions" ,
725
743
)
726
744
745
+ tokens_trans_key_seqs = get_tokens_trans_keys (
746
+ fsm_info .alphabet_symbol_mapping ,
747
+ fsm_info .alphabet_anything_value ,
748
+ vocabulary ,
749
+ )
750
+
727
751
while next_states :
728
752
start_state = next_states .pop ()
729
753
@@ -734,6 +758,7 @@ def create_fsm_index_end_to_end(
734
758
fsm_info .initial ,
735
759
fsm_info .finals ,
736
760
vocabulary ,
761
+ tokens_trans_key_seqs ,
737
762
start_state ,
738
763
)
739
764
0 commit comments