@@ -25,7 +25,7 @@ def identity(s):
25
25
26
26
27
27
def to_bytes (s ):
28
- return [chr (b ) if b < 0x80 else f"{ b :02X} " for b in s .encode ("utf-8" )]
28
+ return [chr (b ) if b < 0x80 else f"\x00 { b :02X} " for b in s .encode ("utf-8" )]
29
29
30
30
31
31
def walk_fsm_numba (
@@ -115,19 +115,27 @@ def test_walk_fsm_multi_bytes(function, transform):
115
115
str_regex_fsm , _ = make_deterministic_fsm (regex_pattern .to_fsm ().reduce ())
116
116
regex_fsm = make_byte_level_better_fsm (str_regex_fsm , keep_utf8 = True )
117
117
118
- res = tuple (function (regex_fsm , transform ("😂" ), regex_fsm .initial , full_match = True ))
118
+ res = tuple (
119
+ function (regex_fsm , "" .join (transform ("😂" )), regex_fsm .initial , full_match = True )
120
+ )
119
121
assert res [- 1 :] == (1 ,)
120
122
121
123
res = tuple (
122
- function (regex_fsm , transform ("😂😂" ), regex_fsm .initial , full_match = False )
124
+ function (
125
+ regex_fsm , "" .join (transform ("😂😂" )), regex_fsm .initial , full_match = False
126
+ )
123
127
)
124
128
assert res [- 1 :] == (1 ,)
125
129
126
- res = tuple (function (regex_fsm , transform ("!" ), regex_fsm .initial , full_match = True ))
130
+ res = tuple (
131
+ function (regex_fsm , "" .join (transform ("!" )), regex_fsm .initial , full_match = True )
132
+ )
127
133
assert res == tuple ()
128
134
129
135
res = tuple (
130
- function (regex_fsm , transform ("😂😂" ), regex_fsm .initial , full_match = True )
136
+ function (
137
+ regex_fsm , "" .join (transform ("😂😂" )), regex_fsm .initial , full_match = True
138
+ )
131
139
)
132
140
assert res == tuple ()
133
141
@@ -304,15 +312,15 @@ def test_create_fsm_index_end_to_end():
304
312
vocabulary_nb = numba .typed .List .empty_list (
305
313
numba .types .Tuple (
306
314
(
307
- numba .types .UnicodeCharSeq ( 2 )[:] ,
315
+ numba .types .unicode_type ,
308
316
numba .int64 [:],
309
317
)
310
318
)
311
319
)
312
320
for token_tuple , token_ids in vocabulary .items ():
313
- token_tuple_np = np . fromiter (token_tuple , dtype = np . dtype ( "U2" ) )
321
+ token = "" . join (token_tuple )
314
322
token_ids_np = np .fromiter (token_ids , dtype = np .dtype ("int64" ))
315
- vocabulary_nb .append ((token_tuple_np , token_ids_np ))
323
+ vocabulary_nb .append ((token , token_ids_np ))
316
324
317
325
res = create_fsm_index_end_to_end (regex_fsm .fsm_info , vocabulary_nb )
318
326
@@ -326,28 +334,34 @@ def test_create_fsm_index_end_to_end_multi_byte():
326
334
regex_fsm , _ = make_deterministic_fsm (regex_pattern .to_fsm ().reduce ())
327
335
byte_fsm = make_byte_level_better_fsm (regex_fsm , keep_utf8 = True )
328
336
337
+ merge_symbols = lambda byte_hexs : "" .join (
338
+ ["" + b if len (b ) == 2 else b for b in byte_hexs ]
339
+ )
340
+
329
341
vocabulary = {
330
342
"blah" : numba .typed .List ([0 ]),
331
343
"😈a" : numba .typed .List ([1 ]),
332
344
"😇" : numba .typed .List ([2 ]),
333
345
"😍" : numba .typed .List ([3 ]),
334
- ( "F0" , "9F" , "98" , "8D" ): numba .typed .List ([4 ]), # '😍'
346
+ merge_symbols (( "F0" , "9F" , "98" , "8D" ) ): numba .typed .List ([4 ]), # '😍'
335
347
" 😍" : numba .typed .List ([5 ]),
336
- (" " , "F0" , "9F" , "98" , "8D" ): numba .typed .List ([6 ]), # ' 😍'
337
- (" " , "F0" , "9F" , "98" ): numba .typed .List ([7 ]), # ' 😍' incomplete
348
+ merge_symbols ((" " , "F0" , "9F" , "98" , "8D" )): numba .typed .List ([6 ]), # ' 😍'
349
+ merge_symbols ((" " , "F0" , "9F" , "98" )): numba .typed .List (
350
+ [7 ]
351
+ ), # ' 😍' incomplete
338
352
"<EOS>" : numba .typed .List ([8 ]),
339
353
}
340
354
341
355
vocabulary_nb = numba .typed .List .empty_list (
342
356
numba .types .Tuple (
343
357
(
344
- numba .types .UnicodeCharSeq ( 2 )[:] ,
358
+ numba .types .unicode_type ,
345
359
numba .int64 [:],
346
360
)
347
361
)
348
362
)
349
363
for token_tuple , token_ids in vocabulary .items ():
350
- token_tuple_np = np . fromiter (token_tuple , dtype = np . dtype ( "U2" ) )
364
+ token_tuple_np = merge_symbols (token_tuple )
351
365
token_ids_np = np .fromiter (token_ids , dtype = np .dtype ("int64" ))
352
366
vocabulary_nb .append ((token_tuple_np , token_ids_np ))
353
367
@@ -356,7 +370,16 @@ def test_create_fsm_index_end_to_end_multi_byte():
356
370
assert res == {0 : {(5 , 3 ), (6 , 3 ), (7 , 7 ), (2 , 2 )}, 3 : {(2 , 3 ), (3 , 3 ), (4 , 3 )}}
357
371
358
372
359
- def test_create_fsm_index_tokenizer ():
373
+ @pytest .mark .parametrize (
374
+ "hf_tokenizer_uri" ,
375
+ [
376
+ "gpt2" ,
377
+ "microsoft/phi-2" ,
378
+ "Qwen/Qwen1.5-0.5B-Chat" ,
379
+ "NousResearch/Hermes-2-Pro-Llama-3-8B" ,
380
+ ],
381
+ )
382
+ def test_create_fsm_index_tokenizer (hf_tokenizer_uri ):
360
383
# The combined regular expressions of a lexer state in a Python grammar
361
384
regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\ +|\\ -))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\ .(?:[0-9](?:(?:_)?[0-9])*)?|\\ .[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\ +|\\ -))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\ +|\\ -))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\ .(?:[0-9](?:(?:_)?[0-9])*)?|\\ .[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\ +|\\ -))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\ \\ ']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\" ([^\\ \" ]|.)*?\" )))|(?:(?:\r ?\n [\t ]*|#[^\n ]*))+|[1-9](?:(?:_)?[0-9])*|\\ \\ [\t \x0c ]*\r ?\n |continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\ W\\ d]\\ w*|#[^\n ]*|[\t \x0c ]+|\\ .\\ .\\ .|@|\\ {|\\ (|\\ [|\\ -|\\ +|\\ *|\\ ~"
362
385
@@ -371,7 +394,7 @@ def test_create_fsm_index_tokenizer():
371
394
num_bytes_fsm_states = len (bytes_fsm .states )
372
395
assert num_bytes_fsm_states == 235
373
396
374
- tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
397
+ tokenizer = AutoTokenizer .from_pretrained (hf_tokenizer_uri )
375
398
tokenizer = TransformerTokenizer (tokenizer )
376
399
377
400
states_to_token_subsets , empty_token_ids = create_fsm_index_tokenizer (
0 commit comments