@@ -887,23 +887,11 @@ def gpt2_unicode_to_bytes():
887
887
return {v : k for k , v in gpt2_bytes_to_unicode ().items ()}
888
888
889
889
890
- # TODO: Cannot cache typed collections to disk, yet. See
891
- # https://github.com/numba/numba/issues/4698
892
- @lru_cache
893
- def reduced_vocabulary (
894
- tokenizer : "Tokenizer" ,
895
- ) -> Tuple [List [Tuple [str , Sequence [int ]]], Set [int ]]:
896
- """Create a map from decoded vocabulary tokens to lists of equivalent token ids."""
890
+ def get_normalized_vocab (tokenizer : "Tokenizer" ) -> Tuple [Dict [int , str ], Set [int ]]:
891
+ norm_vocab = {}
897
892
empty_token_ids = set ()
898
- vocabulary : Dict [Union [str , Tuple [str , ...]], List [int ]] = {}
899
893
for token , token_idx in tokenizer .vocabulary .items ():
900
- if token in tokenizer .special_tokens :
901
- continue
902
-
903
- token_str : Union [str , Tuple [str , ...]] = tokenizer .convert_token_to_string (
904
- token
905
- )
906
-
894
+ token_str = tokenizer .convert_token_to_string (token )
907
895
if token_str :
908
896
# invalid utf-8 sequences are replaced with � (\ufffd), but there
909
897
# might also be tokens specifically for �, ��, ���, etc.
@@ -927,22 +915,88 @@ def reduced_vocabulary(
927
915
)
928
916
token_str = "" .join (byte_symbol (b ) for b in token_bytes )
929
917
930
- vocabulary . setdefault ( token_str , []). append ( token_idx )
918
+ norm_vocab [ token_idx ] = token_str
931
919
else :
932
920
empty_token_ids .add (numba .int64 (token_idx ))
933
921
934
- vocabulary_nb = numba .typed .List .empty_list (
935
- numba .types .Tuple (
936
- (
937
- nb_unicode_type ,
938
- numba .int64 [:],
939
- )
940
- )
922
+ return norm_vocab , empty_token_ids
923
+
924
+
925
+ @numba .njit (cache = True , nogil = True )
926
+ def to_numba_dict (keys : List [int ], values : List [str ]):
927
+ """
928
+ Pure-python numba dict construction is extremely slow.
929
+ This helper accepts equal length key and value arrays, and constructs a numba dict
930
+ """
931
+ # Define the key and value types for the Numba dictionary
932
+ numba_dict = numba .typed .Dict .empty (
933
+ key_type = numba .types .int64 ,
934
+ value_type = numba .types .unicode_type ,
941
935
)
942
- for token_str , token_ids in vocabulary .items ():
943
- token_ids_np = np .fromiter (token_ids , dtype = np .dtype ("int64" ))
944
- vocabulary_nb .append ((token_str , token_ids_np ))
945
936
937
+ # Fill the Numba dictionary with values from the input lists
938
+ for i in range (len (keys )):
939
+ numba_dict [keys [i ]] = values [i ]
940
+
941
+ return numba_dict
942
+
943
+
944
+ token_id_str_pair = numba .types .Tuple ((nb_unicode_type , numba .int64 [:]))
945
+
946
+
947
+ @numba .njit (
948
+ numba .types .ListType (token_id_str_pair )(
949
+ numba .types .DictType (numba .int64 , nb_unicode_type )
950
+ ),
951
+ cache = True ,
952
+ nogil = True ,
953
+ )
954
+ def vocab_dict_to_inverted_vocab_list (
955
+ vocab_dict_nb : Dict [int , str ]
956
+ ) -> List [Tuple [str , Sequence [int ]]]:
957
+ """
958
+ Helper for `reduced_vocabulary`
959
+
960
+ Convert
961
+ - from `vocab_dict_nb`: Dict[token_id, token_str]
962
+ - to `vocab_nb`: List[token_str, token_id[:]]
963
+ """
964
+ inverse_vocab_dict = numba .typed .Dict .empty (
965
+ key_type = numba .types .unicode_type , value_type = numba .types .int64 [:]
966
+ )
967
+
968
+ # Fill the temporary dictionary
969
+ for key in vocab_dict_nb :
970
+ value = vocab_dict_nb [key ]
971
+ if value not in inverse_vocab_dict :
972
+ inverse_vocab_dict [value ] = np .zeros (0 , dtype = np .int64 )
973
+ inverse_vocab_dict [value ] = np .append (inverse_vocab_dict [value ], key )
974
+
975
+ # Transfer data from the temporary dictionary to the final dictionary
976
+ vocab_nb = numba .typed .List .empty_list (token_id_str_pair )
977
+
978
+ for value in inverse_vocab_dict :
979
+ vocab_nb .append ((value , inverse_vocab_dict [value ]))
980
+
981
+ return vocab_nb
982
+
983
+
984
+ # TODO: Cannot cache typed collections to disk, yet. See
985
+ # https://github.com/numba/numba/issues/4698
986
+ @lru_cache
987
+ def reduced_vocabulary (
988
+ tokenizer : "Tokenizer" ,
989
+ ) -> Tuple [List [Tuple [str , Sequence [int ]]], Set [int ]]:
990
+ """
991
+ Provided the tokenizer, calculate the
992
+ - vocabulary_nb: mapping of (normalized token str -> token_ids[:])
993
+ - empty token ids
994
+ """
995
+ norm_vocab , empty_token_ids = get_normalized_vocab (tokenizer )
996
+ norm_vocab_dict_nb = to_numba_dict (
997
+ np .fromiter (norm_vocab .keys (), dtype = np .int64 ), list (norm_vocab .values ())
998
+ )
999
+ vocabulary_nb = vocab_dict_to_inverted_vocab_list (norm_vocab_dict_nb )
946
1000
return vocabulary_nb , empty_token_ids
947
1001
948
1002
0 commit comments