Skip to content

Commit 2be1743

Browse files
committed
ensure byte fsm unicode_type compatibility by prefixing hex-bytes with \x00
1 parent da8ecf7 commit 2be1743

File tree

2 files changed

+79
-28
lines changed

2 files changed

+79
-28
lines changed

outlines/fsm/regex.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def transition_trie_setdefault(
196196

197197

198198
def byte_symbol(byte: int) -> str:
199-
return f"{byte:02X}" if byte >= 0x80 else chr(byte)
199+
return f"\x00{byte:02X}" if byte >= 0x80 else chr(byte)
200200

201201

202202
def make_byte_level_fsm(fsm: FSM, keep_utf8=False) -> FSM:
@@ -416,15 +416,29 @@ def _walk_fsm(
416416
alphabet_anything_value: int,
417417
fsm_initial: int,
418418
fsm_finals: Set[int],
419-
input_string: Sequence[str],
419+
input_string: str,
420420
start_state: int,
421421
full_match: bool = True,
422422
) -> List[int]:
423423
state = start_state
424424
accepted_states: List[int] = numba.typed.List.empty_list(numba.int64)
425425
last_final_idx: int = numba.uint64(0)
426426

427-
for i, symbol in enumerate(input_string):
427+
# Iterate over symbols (characters and null-prefixed two-hex-character bytes)
428+
# By default, each symbol is a unicode character
429+
# Except, if the character, input_string[i] == '\x00', then the next two
430+
# in input_string characters are a hex representation of the byte
431+
i = 0
432+
while i < len(input_string):
433+
# if null-byte prefixed its a hex representation
434+
# unless its the last character, then its a trailing null byte symbol
435+
if input_string[i] == "\x00" and i != len(input_string) - 1:
436+
symbol = input_string[i : i + 3]
437+
i += 3
438+
else:
439+
symbol = input_string[i]
440+
i += 1
441+
428442
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
429443

430444
new_state = fsm_transitions.get((state, trans_key))
@@ -438,19 +452,19 @@ def _walk_fsm(
438452
state = new_state
439453

440454
if state in fsm_finals:
441-
last_final_idx = numba.uint64(i + 1)
455+
last_final_idx = numba.uint64(i)
442456

443457
accepted_states.append(_nonoptional(state))
444458

445-
if full_match and last_final_idx - 1 != i:
459+
if full_match and last_final_idx != i:
446460
return numba.typed.List.empty_list(numba.int64)
447461

448462
return accepted_states
449463

450464

451465
def walk_fsm(
452466
fsm: BetterFSM,
453-
input_string: Sequence[str],
467+
input_string: str,
454468
start_state: int,
455469
full_match: bool = True,
456470
) -> List[int]:
@@ -464,7 +478,17 @@ def walk_fsm(
464478
alphabet_anything_value = fsm.alphabet.anything_value
465479
fsm_transitions = fsm.flat_transition_map
466480

467-
for i, symbol in enumerate(input_string):
481+
# See _walk_fsm() explanation of symbol iteration
482+
i = 0
483+
while i < len(input_string):
484+
# if null-byte prefixed its a hex representation
485+
# unless the input string itself is a null byte, then symbol is a lone null-byte
486+
if input_string[i] == "\x00" and input_string != "\x00":
487+
symbol = input_string[i : i + 3]
488+
i += 3
489+
else:
490+
symbol = input_string[i]
491+
i += 1
468492
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
469493

470494
new_state = fsm_transitions.get((state, trans_key))
@@ -478,11 +502,11 @@ def walk_fsm(
478502
state = new_state
479503

480504
if state in fsm_finals:
481-
last_final_idx = i + 1
505+
last_final_idx = i
482506

483507
accepted_states.append(state)
484508

485-
if full_match and last_final_idx - 1 != i:
509+
if full_match and last_final_idx != i:
486510
return []
487511

488512
return accepted_states
@@ -652,7 +676,7 @@ def state_scan_tokens(
652676
alphabet_anything_value: int,
653677
fsm_initial: int,
654678
fsm_finals: Set[int],
655-
vocabulary: List[Tuple[Sequence[str], Sequence[int]]],
679+
vocabulary: List[Tuple[str, Sequence[int]]],
656680
start_state: int,
657681
) -> Set[Tuple[int, int]]:
658682
res = set()
@@ -669,7 +693,11 @@ def state_scan_tokens(
669693
False,
670694
)
671695

672-
if state_seq is not None and len(state_seq) < len(token):
696+
if token == "\x00":
697+
token_length = 1
698+
else:
699+
token_length = len(token) - 2 * token.count("\x00")
700+
if state_seq is not None and len(state_seq) < token_length:
673701
continue
674702

675703
for token_id in token_ids:
@@ -680,7 +708,7 @@ def state_scan_tokens(
680708

681709
def create_fsm_index_end_to_end(
682710
fsm_info: FSMInfo,
683-
vocabulary: List[Tuple[Sequence[str], Sequence[int]]],
711+
vocabulary: List[Tuple[str, Sequence[int]]],
684712
) -> Dict[int, Set[Tuple[int, int]]]:
685713
"""Create an FSM state-to-vocabulary map/index through end-to-end token parsing."""
686714

@@ -768,7 +796,7 @@ def gpt2_unicode_to_bytes():
768796
@lru_cache
769797
def reduced_vocabulary(
770798
tokenizer: "Tokenizer",
771-
) -> Tuple[List[Tuple[Sequence[str], Sequence[int]]], Set[int]]:
799+
) -> Tuple[List[Tuple[str, Sequence[int]]], Set[int]]:
772800
"""Create a map from decoded vocabulary tokens to lists of equivalent token ids."""
773801
empty_token_ids = set()
774802
vocabulary: Dict[Union[str, Tuple[str, ...]], List[int]] = {}

tests/fsm/test_regex.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def identity(s):
2525

2626

2727
def to_bytes(s):
28-
return [chr(b) if b < 0x80 else f"{b:02X}" for b in s.encode("utf-8")]
28+
return [chr(b) if b < 0x80 else f"\x00{b:02X}" for b in s.encode("utf-8")]
2929

3030

3131
def walk_fsm_numba(
@@ -115,19 +115,27 @@ def test_walk_fsm_multi_bytes(function, transform):
115115
str_regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
116116
regex_fsm = make_byte_level_better_fsm(str_regex_fsm, keep_utf8=True)
117117

118-
res = tuple(function(regex_fsm, transform("😂"), regex_fsm.initial, full_match=True))
118+
res = tuple(
119+
function(regex_fsm, "".join(transform("😂")), regex_fsm.initial, full_match=True)
120+
)
119121
assert res[-1:] == (1,)
120122

121123
res = tuple(
122-
function(regex_fsm, transform("😂😂"), regex_fsm.initial, full_match=False)
124+
function(
125+
regex_fsm, "".join(transform("😂😂")), regex_fsm.initial, full_match=False
126+
)
123127
)
124128
assert res[-1:] == (1,)
125129

126-
res = tuple(function(regex_fsm, transform("!"), regex_fsm.initial, full_match=True))
130+
res = tuple(
131+
function(regex_fsm, "".join(transform("!")), regex_fsm.initial, full_match=True)
132+
)
127133
assert res == tuple()
128134

129135
res = tuple(
130-
function(regex_fsm, transform("😂😂"), regex_fsm.initial, full_match=True)
136+
function(
137+
regex_fsm, "".join(transform("😂😂")), regex_fsm.initial, full_match=True
138+
)
131139
)
132140
assert res == tuple()
133141

@@ -304,15 +312,15 @@ def test_create_fsm_index_end_to_end():
304312
vocabulary_nb = numba.typed.List.empty_list(
305313
numba.types.Tuple(
306314
(
307-
numba.types.UnicodeCharSeq(2)[:],
315+
numba.types.unicode_type,
308316
numba.int64[:],
309317
)
310318
)
311319
)
312320
for token_tuple, token_ids in vocabulary.items():
313-
token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2"))
321+
token = "".join(token_tuple)
314322
token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64"))
315-
vocabulary_nb.append((token_tuple_np, token_ids_np))
323+
vocabulary_nb.append((token, token_ids_np))
316324

317325
res = create_fsm_index_end_to_end(regex_fsm.fsm_info, vocabulary_nb)
318326

@@ -326,28 +334,34 @@ def test_create_fsm_index_end_to_end_multi_byte():
326334
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
327335
byte_fsm = make_byte_level_better_fsm(regex_fsm, keep_utf8=True)
328336

337+
merge_symbols = lambda byte_hexs: "".join(
338+
["" + b if len(b) == 2 else b for b in byte_hexs]
339+
)
340+
329341
vocabulary = {
330342
"blah": numba.typed.List([0]),
331343
"😈a": numba.typed.List([1]),
332344
"😇": numba.typed.List([2]),
333345
"😍": numba.typed.List([3]),
334-
("F0", "9F", "98", "8D"): numba.typed.List([4]), # '😍'
346+
merge_symbols(("F0", "9F", "98", "8D")): numba.typed.List([4]), # '😍'
335347
" 😍": numba.typed.List([5]),
336-
(" ", "F0", "9F", "98", "8D"): numba.typed.List([6]), # ' 😍'
337-
(" ", "F0", "9F", "98"): numba.typed.List([7]), # ' 😍' incomplete
348+
merge_symbols((" ", "F0", "9F", "98", "8D")): numba.typed.List([6]), # ' 😍'
349+
merge_symbols((" ", "F0", "9F", "98")): numba.typed.List(
350+
[7]
351+
), # ' 😍' incomplete
338352
"<EOS>": numba.typed.List([8]),
339353
}
340354

341355
vocabulary_nb = numba.typed.List.empty_list(
342356
numba.types.Tuple(
343357
(
344-
numba.types.UnicodeCharSeq(2)[:],
358+
numba.types.unicode_type,
345359
numba.int64[:],
346360
)
347361
)
348362
)
349363
for token_tuple, token_ids in vocabulary.items():
350-
token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2"))
364+
token_tuple_np = merge_symbols(token_tuple)
351365
token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64"))
352366
vocabulary_nb.append((token_tuple_np, token_ids_np))
353367

@@ -356,7 +370,16 @@ def test_create_fsm_index_end_to_end_multi_byte():
356370
assert res == {0: {(5, 3), (6, 3), (7, 7), (2, 2)}, 3: {(2, 3), (3, 3), (4, 3)}}
357371

358372

359-
def test_create_fsm_index_tokenizer():
373+
@pytest.mark.parametrize(
374+
"hf_tokenizer_uri",
375+
[
376+
"gpt2",
377+
"microsoft/phi-2",
378+
"Qwen/Qwen1.5-0.5B-Chat",
379+
"NousResearch/Hermes-2-Pro-Llama-3-8B",
380+
],
381+
)
382+
def test_create_fsm_index_tokenizer(hf_tokenizer_uri):
360383
# The combined regular expressions of a lexer state in a Python grammar
361384
regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~"
362385

@@ -371,7 +394,7 @@ def test_create_fsm_index_tokenizer():
371394
num_bytes_fsm_states = len(bytes_fsm.states)
372395
assert num_bytes_fsm_states == 235
373396

374-
tokenizer = AutoTokenizer.from_pretrained("gpt2")
397+
tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_uri)
375398
tokenizer = TransformerTokenizer(tokenizer)
376399

377400
states_to_token_subsets, empty_token_ids = create_fsm_index_tokenizer(

0 commit comments

Comments
 (0)