diff --git a/.github/workflows/asv_benchmark_pr.yml b/.github/workflows/asv_benchmark_pr.yml new file mode 100644 index 000000000..3522772e9 --- /dev/null +++ b/.github/workflows/asv_benchmark_pr.yml @@ -0,0 +1,72 @@ +name: Benchmark PR + +on: + pull_request: + branches: [main] + +permissions: + contents: read # Read access for repository contents + pull-requests: write # Write access for pull requests + +env: + PYTHON_VERSION: "3.10" + WORKING_DIR: ${{ github.workspace }}/benchmarks + +jobs: + benchmark-pr: + runs-on: ubuntu-latest + + defaults: + run: + working-directory: ${{ env.WORKING_DIR }} + + steps: + + - name: Checkout repository + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ env.PYTHON_VERSION }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install asv virtualenv lf-asv-formatter + + - name: Create ASV machine config file + run: asv machine --machine gh-runner --yes + + - name: Save comparison of PR against main branch + run: | + # prepare main branch for comparison + git remote add upstream https://github.com/${{ github.repository }}.git + git fetch upstream main + + # Run benchmarks, writing comment contents to ./output + asv continuous upstream/main HEAD \ + --factor 1.1 --sort ratio --split --interleave-rounds -a repeat=3 + asv compare upstream/main HEAD --factor 1.1 --sort ratio --split | tee output + python -m lf_asv_formatter --asv_version "$(echo asv --version)" + printf "Benchmark Suite Results:\n\n" >> comment_body + cat output >> comment_body + + # from https://github.com/hombit/load_ztfdr_for_tape/blob/9acf7c83/.github/workflows/asv-pr.yml + - name: Find benchmarks comment + uses: peter-evans/find-comment@v2 + id: find-comment + with: + issue-number: ${{ github.event.pull_request.number }} + comment-author: 'github-actions[bot]' + body-includes: Benchmark Suite Results + + - name: Create or update benchmarks comment + uses: peter-evans/create-or-update-comment@v3 + with: + comment-id: ${{ steps.find-comment.outputs.comment-id }} + issue-number: ${{ github.event.pull_request.number }} + body-path: ${{ env.WORKING_DIR }}/comment_body + edit-mode: replace diff --git a/.gitignore b/.gitignore index 9e95a8732..9add6d8c4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ docs/build .idea/ *.gguf .venv +benchmarks/results diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json new file mode 100644 index 000000000..d1aca8a86 --- /dev/null +++ b/benchmarks/asv.conf.json @@ -0,0 +1,20 @@ +{ + "version": 1, + "project": "Outlines", + "project_url": "https://outlines-dev.github.io/outlines/", + "repo": "..", + "branches": [ + "HEAD" + ], + "build_command": [ + "pip install .[test]", + "python -m build --wheel -o {build_cache_dir} {build_dir}" + ], + "environment_type": "virtualenv", + "show_commit_url": "https://github.com/lapp0/outlines/commit/", + "benchmark_dir": ".", + "env_dir": "env", + "results_dir": "results", + "html_dir": "html", + "build_cache_size": 8 +} diff --git a/tests/benchmark/test_benchmark_json_schema.py b/benchmarks/bench_json_schema.py similarity index 70% rename from tests/benchmark/test_benchmark_json_schema.py rename to benchmarks/bench_json_schema.py index 33f3f5b16..daa77510b 100644 --- a/tests/benchmark/test_benchmark_json_schema.py +++ b/benchmarks/bench_json_schema.py @@ -1,5 +1,3 @@ -import pytest - import outlines outlines.disable_cache() @@ -7,6 +5,12 @@ from outlines.fsm.guide import RegexGuide # noqa: E402 from outlines.fsm.json_schema import build_regex_from_schema # noqa: E402 +from .common import ( # noqa: E402 + clear_outlines_cache, + ensure_numba_compiled, + setup_tokenizer, +) + simple_schema = """{ "$defs": { "Armor": { @@ -63,30 +67,21 @@ "required": ["id", "work", "recording_artists"] }""" - schemas = dict(simple_schema=simple_schema, complex_schema=complex_schema) -@pytest.mark.parametrize("schema_name", schemas.keys()) -def test_benchmark_json_schema_to_regex(benchmark, ensure_numba_compiled, schema_name): - """Benchmark convert json schema to regex""" - schema = schemas[schema_name] - benchmark.pedantic( - build_regex_from_schema, - args=(schema,), - rounds=8, - ) +class JsonSchemaBenchmark: + params = schemas.keys() + + def setup(self, schema_name): + clear_outlines_cache() + self.tokenizer = setup_tokenizer() + self.schema = schemas[schema_name] + ensure_numba_compiled(self.tokenizer) + def time_json_schema_to_regex(self, schema_name): + build_regex_from_schema(self.schema) -@pytest.mark.parametrize("schema_name", schemas.keys()) -def test_benchmark_json_schema_to_fsm( - benchmark, tokenizer, ensure_numba_compiled, schema_name -): - """Benchmark compile json schema as FSM""" - schema = schemas[schema_name] - regex = build_regex_from_schema(schema) - benchmark.pedantic( - RegexGuide, - args=(regex, tokenizer), - rounds=8, - ) + def time_json_schema_to_fsm(self, schema_name): + regex = build_regex_from_schema(self.schema) + RegexGuide(regex, self.tokenizer) diff --git a/benchmarks/bench_numba_compile.py b/benchmarks/bench_numba_compile.py new file mode 100644 index 000000000..c0e9d87c4 --- /dev/null +++ b/benchmarks/bench_numba_compile.py @@ -0,0 +1,37 @@ +import importlib + +import interegular +import numba + +import outlines + +from .common import clear_outlines_cache, setup_tokenizer + +outlines.disable_cache() + + +class NumbaCompileBenchmark: + def setup(self): + clear_outlines_cache() + from outlines.fsm import regex + + self.tokenizer = setup_tokenizer() + self.regex = regex + original_njit = numba.njit + + def mock_njit(*args, **kwargs): + kwargs["cache"] = False + return original_njit(*args, **kwargs) + + self.original_njit = original_njit + numba.njit = mock_njit + importlib.reload(self.regex) + self.regex_pattern, _ = self.regex.make_deterministic_fsm( + interegular.parse_pattern("a").to_fsm().reduce() + ) + + def teardown(self): + numba.njit = self.original_njit + + def time_compile_numba(self): + self.regex.create_fsm_index_tokenizer(self.regex_pattern, self.tokenizer) diff --git a/tests/benchmark/test_benchmark_regex_fsm.py b/benchmarks/bench_regex_guide.py similarity index 68% rename from tests/benchmark/test_benchmark_regex_fsm.py rename to benchmarks/bench_regex_guide.py index e9e45052a..efaea9e1f 100644 --- a/tests/benchmark/test_benchmark_regex_fsm.py +++ b/benchmarks/bench_regex_guide.py @@ -1,7 +1,7 @@ -import pytest - import outlines +from .common import clear_outlines_cache, ensure_numba_compiled, setup_tokenizer + outlines.disable_cache() from outlines.fsm.guide import RegexGuide # noqa: E402 @@ -19,14 +19,27 @@ } -@pytest.mark.parametrize("regex_name", regex_samples.keys()) -def test_benchmark_regex_to_fsm( - benchmark, tokenizer, ensure_numba_compiled, regex_name -): - """Benchmark converting regex to FSM""" - regex_str = regex_samples[regex_name] - benchmark.pedantic( - RegexGuide, - args=(regex_str, tokenizer), - rounds=8, - ) +class RegexGuideBenchmark: + params = regex_samples.keys() + + def setup(self, pattern_name): + clear_outlines_cache() + self.tokenizer = setup_tokenizer() + ensure_numba_compiled(self.tokenizer) + self.pattern = regex_samples[pattern_name] + + def time_regex_to_guide(self, pattern_name): + RegexGuide(self.pattern, self.tokenizer) + + +class MemoryRegexGuideBenchmark: + params = ["simple_phone", "complex_span_constrained_relation_extraction"] + + def setup(self, pattern_name): + clear_outlines_cache() + self.tokenizer = setup_tokenizer() + ensure_numba_compiled(self.tokenizer) + self.pattern = regex_samples[pattern_name] + + def peakmem_regex_to_guide(self, pattern_name): + RegexGuide(self.pattern, self.tokenizer) diff --git a/tests/benchmark/conftest.py b/benchmarks/common.py similarity index 74% rename from tests/benchmark/conftest.py rename to benchmarks/common.py index 902d5d6eb..e0fe36f14 100644 --- a/tests/benchmark/conftest.py +++ b/benchmarks/common.py @@ -1,17 +1,19 @@ -import pytest from transformers import AutoTokenizer +import outlines.caching from outlines.fsm.guide import RegexGuide from outlines.models.transformers import TransformerTokenizer -@pytest.fixture -def tokenizer(): +def clear_outlines_cache(): + outlines.caching.clear_cache() + + +def setup_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("gpt2") return TransformerTokenizer(tokenizer) -@pytest.fixture def ensure_numba_compiled(tokenizer): RegexGuide("a", tokenizer) return True diff --git a/outlines/fsm/regex.py b/outlines/fsm/regex.py index b68e31897..29adc813b 100644 --- a/outlines/fsm/regex.py +++ b/outlines/fsm/regex.py @@ -87,14 +87,11 @@ def fsm_info(self): ((k, z) for k, v in self.trans_key_to_states.items() for z in v), dtype=np.dtype("int64, int64"), ) - alphabet_symbol_mapping_items = np.fromiter( - ( - it - for it in self.alphabet._symbol_mapping.items() - if it[0] != anything_else - ), - dtype=np.dtype("U2, int64"), - ) + alphabet_symbol_mapping_items = [ + (k, v) + for k, v in self.alphabet._symbol_mapping.items() + if k != anything_else + ] nb_finals = np.fromiter(self.finals, dtype=np.dtype("int64")) self.__dict__["_fsm_info"] = create_fsm_info( self.initial, @@ -110,7 +107,7 @@ def fsm_info(self): nb_int_list_type = numba.types.ListType(numba.int64) nb_int_pair_type = numba.types.UniTuple(numba.int64, 2) -nb_unichar_2_type = numba.types.UnicodeCharSeq(2) +nb_unicode_type = numba.types.unicode_type @numba.njit(cache=True) @@ -136,7 +133,7 @@ def create_fsm_info( # use 2-char strings so that we can represent incomplete utf-8 sequences # as 2-hex-digit pairs - alphabet_symbol_map = numba.typed.Dict.empty(nb_unichar_2_type, numba.int64) + alphabet_symbol_map = numba.typed.Dict.empty(nb_unicode_type, numba.int64) for symbol_and_trans_key in alphabet_symbol_mapping_items: alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1] @@ -199,7 +196,7 @@ def transition_trie_setdefault( def byte_symbol(byte: int) -> str: - return f"{byte:02X}" if byte >= 0x80 else chr(byte) + return f"\x00{byte:02X}" if byte >= 0x80 else chr(byte) def make_byte_level_fsm(fsm: FSM, keep_utf8=False) -> FSM: @@ -419,7 +416,7 @@ def _walk_fsm( alphabet_anything_value: int, fsm_initial: int, fsm_finals: Set[int], - input_string: Sequence[str], + token_trans_key_seq: Sequence[int], start_state: int, full_match: bool = True, ) -> List[int]: @@ -427,9 +424,9 @@ def _walk_fsm( accepted_states: List[int] = numba.typed.List.empty_list(numba.int64) last_final_idx: int = numba.uint64(0) - for i, symbol in enumerate(input_string): - trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value) - + # Iterate over token transition key sequence. The transition key + # sequence represents the FSM traversal rules of the tokens symbols. + for i, trans_key in enumerate(token_trans_key_seq): new_state = fsm_transitions.get((state, trans_key)) if new_state is None: @@ -453,7 +450,7 @@ def _walk_fsm( def walk_fsm( fsm: BetterFSM, - input_string: Sequence[str], + token_trans_key_seq: Sequence[int], start_state: int, full_match: bool = True, ) -> List[int]: @@ -463,13 +460,11 @@ def walk_fsm( accepted_states: List[int] = [] last_final_idx: int = 0 - alphabet_symbol_mapping = fsm.alphabet._symbol_mapping - alphabet_anything_value = fsm.alphabet.anything_value fsm_transitions = fsm.flat_transition_map - for i, symbol in enumerate(input_string): - trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value) - + # Iterate over token transition key sequence. The transition key + # sequence represents the FSM traversal rules of the tokens symbols. + for i, trans_key in enumerate(token_trans_key_seq): new_state = fsm_transitions.get((state, trans_key)) if new_state is None: @@ -655,24 +650,27 @@ def state_scan_tokens( alphabet_anything_value: int, fsm_initial: int, fsm_finals: Set[int], - vocabulary: List[Tuple[Sequence[str], Sequence[int]]], + vocabulary: List[Tuple[str, Sequence[int]]], + token_trans_key_seqs: List[Sequence[int]], start_state: int, ) -> Set[Tuple[int, int]]: res = set() - for token, token_ids in vocabulary: + for (token, token_ids), token_trans_key_seq in zip( + vocabulary, token_trans_key_seqs + ): state_seq = _walk_fsm( fsm_transitions, alphabet_symbol_mapping, alphabet_anything_value, fsm_initial, fsm_finals, - token, + token_trans_key_seq, start_state, False, ) - if state_seq is not None and len(state_seq) < len(token): + if state_seq is not None and len(state_seq) < len(token_trans_key_seq): continue for token_id in token_ids: @@ -681,9 +679,40 @@ def state_scan_tokens( return res +@numba.njit(cache=True, nogil=True) +def get_tokens_trans_keys( + alphabet_symbol_mapping: Dict[str, int], + alphabet_anything_value: int, + vocabulary: List[Tuple[str, Sequence[int]]], +) -> List[Sequence[int]]: + tokens_trans_keys = numba.typed.List.empty_list(numba.int64[:]) + for token_str, _ in vocabulary: + trans_key_seq = [] + i = 0 + while i < len(token_str): + if token_str[i] == "\x00" and i != len(token_str) - 1: + symbol = token_str[i : i + 3] + i += 3 + else: + symbol = token_str[i] + i += 1 + + trans_key_seq.append( + alphabet_symbol_mapping.get(symbol, alphabet_anything_value) + ) + + trans_key_seq_array = np.empty(len(trans_key_seq), dtype=np.int64) + for j in range(len(trans_key_seq)): + trans_key_seq_array[j] = trans_key_seq[j] + + tokens_trans_keys.append(trans_key_seq_array) + + return tokens_trans_keys + + def create_fsm_index_end_to_end( fsm_info: FSMInfo, - vocabulary: List[Tuple[Sequence[str], Sequence[int]]], + vocabulary: List[Tuple[str, Sequence[int]]], ) -> Dict[int, Set[Tuple[int, int]]]: """Create an FSM state-to-vocabulary map/index through end-to-end token parsing.""" @@ -699,6 +728,12 @@ def create_fsm_index_end_to_end( desc="Compiling FSM index for all state transitions", ) + tokens_trans_key_seqs = get_tokens_trans_keys( + fsm_info.alphabet_symbol_mapping, + fsm_info.alphabet_anything_value, + vocabulary, + ) + while next_states: start_state = next_states.pop() @@ -709,6 +744,7 @@ def create_fsm_index_end_to_end( fsm_info.initial, fsm_info.finals, vocabulary, + tokens_trans_key_seqs, start_state, ) @@ -771,7 +807,7 @@ def gpt2_unicode_to_bytes(): @lru_cache def reduced_vocabulary( tokenizer: "Tokenizer", -) -> Tuple[List[Tuple[Sequence[str], Sequence[int]]], Set[int]]: +) -> Tuple[List[Tuple[str, Sequence[int]]], Set[int]]: """Create a map from decoded vocabulary tokens to lists of equivalent token ids.""" empty_token_ids = set() vocabulary: Dict[Union[str, Tuple[str, ...]], List[int]] = {} @@ -804,7 +840,7 @@ def reduced_vocabulary( raise RuntimeError( f"Cannot convert token `{token}` ({token_idx}) to bytes: {token_str}" ) - token_str = tuple(byte_symbol(b) for b in token_bytes) + token_str = "".join(byte_symbol(b) for b in token_bytes) vocabulary.setdefault(token_str, []).append(token_idx) else: @@ -813,15 +849,14 @@ def reduced_vocabulary( vocabulary_nb = numba.typed.List.empty_list( numba.types.Tuple( ( - nb_unichar_2_type[:], + nb_unicode_type, numba.int64[:], ) ) ) - for token_tuple, token_ids in vocabulary.items(): - token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2")) + for token_str, token_ids in vocabulary.items(): token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) - vocabulary_nb.append((token_tuple_np, token_ids_np)) + vocabulary_nb.append((token_str, token_ids_np)) return vocabulary_nb, empty_token_ids diff --git a/tests/benchmark/test_benchmark_numba_compile.py b/tests/benchmark/test_benchmark_numba_compile.py deleted file mode 100644 index 827d561bd..000000000 --- a/tests/benchmark/test_benchmark_numba_compile.py +++ /dev/null @@ -1,33 +0,0 @@ -import importlib - -import interegular -import numba - -import outlines - -outlines.disable_cache() - - -def test_benchmark_compile_numba(benchmark, tokenizer, mocker): - """Compile a basic regex to benchmark the numba compilation time""" - - def setup(): - from outlines.fsm import regex - - original_njit = numba.njit - - def mock_njit(*args, **kwargs): - kwargs["cache"] = False - return original_njit(*args, **kwargs) - - mocker.patch("numba.njit", new=mock_njit) - importlib.reload(regex) - - regex_pattern, _ = regex.make_deterministic_fsm( - interegular.parse_pattern("a").to_fsm().reduce() - ) - return (regex, regex_pattern, tokenizer), {} - - benchmark.pedantic( - lambda r, *args: r.create_fsm_index_tokenizer(*args), rounds=2, setup=setup - ) diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py index 2fc8a5384..2dc429f78 100644 --- a/tests/fsm/test_regex.py +++ b/tests/fsm/test_regex.py @@ -1,5 +1,3 @@ -from typing import Sequence - import interegular import numba import numpy as np @@ -12,6 +10,7 @@ create_fsm_index_tokenizer, fsm_union, get_sub_fsms_from_seq, + get_tokens_trans_keys, make_byte_level_better_fsm, make_byte_level_fsm, make_deterministic_fsm, @@ -25,12 +24,42 @@ def identity(s): def to_bytes(s): - return [chr(b) if b < 0x80 else f"{b:02X}" for b in s.encode("utf-8")] + return [chr(b) if b < 0x80 else f"\x00{b:02X}" for b in s.encode("utf-8")] + + +def merge_symbols(byte_hexs): + return "".join(["\x00" + b if len(b) == 2 else b for b in byte_hexs]) + + +def token_str_to_trans_key(fsm, input_string): + vocabulary_nb = numba.typed.List.empty_list( + numba.types.Tuple((numba.types.unicode_type, numba.int64[:])) + ) + vocabulary_nb.append((input_string, np.fromiter([1], dtype=np.dtype("int64")))) + return get_tokens_trans_keys( + fsm.fsm_info.alphabet_symbol_mapping, + fsm.fsm_info.alphabet_anything_value, + vocabulary_nb, + )[0] -def walk_fsm_numba( +def walk_fsm_from_token_str( fsm, - input_string: Sequence[str], + input_string: str, + start_state: int, + full_match: bool = True, +): + return walk_fsm( + fsm, + token_str_to_trans_key(fsm, input_string), + start_state, + full_match, + ) + + +def walk_fsm_from_token_str_numba( + fsm, + input_string: str, start_state: int, full_match: bool = True, ): @@ -40,7 +69,7 @@ def walk_fsm_numba( fsm.fsm_info.alphabet_anything_value, fsm.fsm_info.initial, fsm.fsm_info.finals, - input_string, + token_str_to_trans_key(fsm, input_string), start_state, full_match=full_match, ) @@ -49,8 +78,8 @@ def walk_fsm_numba( @pytest.mark.parametrize( "function", [ - walk_fsm, - walk_fsm_numba, + walk_fsm_from_token_str, + walk_fsm_from_token_str_numba, ], ) def test_walk_fsm(function): @@ -99,8 +128,8 @@ def test_walk_fsm(function): @pytest.mark.parametrize( "function", [ - walk_fsm, - walk_fsm_numba, + walk_fsm_from_token_str, + walk_fsm_from_token_str_numba, ], ) @pytest.mark.parametrize( @@ -115,19 +144,37 @@ def test_walk_fsm_multi_bytes(function, transform): str_regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) regex_fsm = make_byte_level_better_fsm(str_regex_fsm, keep_utf8=True) - res = tuple(function(regex_fsm, transform("😂"), regex_fsm.initial, full_match=True)) + res = tuple( + function( + regex_fsm, merge_symbols(transform("😂")), regex_fsm.initial, full_match=True + ) + ) assert res[-1:] == (1,) res = tuple( - function(regex_fsm, transform("😂😂"), regex_fsm.initial, full_match=False) + function( + regex_fsm, + merge_symbols(transform("😂😂")), + regex_fsm.initial, + full_match=False, + ) ) assert res[-1:] == (1,) - res = tuple(function(regex_fsm, transform("!"), regex_fsm.initial, full_match=True)) + res = tuple( + function( + regex_fsm, merge_symbols(transform("!")), regex_fsm.initial, full_match=True + ) + ) assert res == tuple() res = tuple( - function(regex_fsm, transform("😂😂"), regex_fsm.initial, full_match=True) + function( + regex_fsm, + merge_symbols(transform("😂😂")), + regex_fsm.initial, + full_match=True, + ) ) assert res == tuple() @@ -194,14 +241,14 @@ def test_get_sub_fsms_from_seq(): assert fsm.accepts("+=") assert fsm.accepts("+") - state_seq = walk_fsm(fsm, "def", fsm.initial) + state_seq = walk_fsm_from_token_str(fsm, "def", fsm.initial) state_seq.insert(0, fsm.fsm_info.initial) res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) assert res == [(0, False, True), (2, True, True)] # Make sure the old-to-new state map is correct - def_state_seq = walk_fsm(def_fsm, "def", fsm.initial) + def_state_seq = walk_fsm_from_token_str(def_fsm, "def", fsm.initial) def_state_seq.insert(0, fsm.fsm_info.initial) def_old_to_new_states = fsms_to_trans_finals[0][2] @@ -210,13 +257,13 @@ def test_get_sub_fsms_from_seq(): for old_state, new_state in zip(def_state_seq, state_seq) ) - state_seq = walk_fsm(fsm, "ef", fsm.initial) + state_seq = walk_fsm_from_token_str(fsm, "ef", fsm.initial) state_seq.insert(0, fsm.initial) res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) assert res == [(2, True, True)] - name_state_seq = walk_fsm(name_fsm, "ef", fsm.initial) + name_state_seq = walk_fsm_from_token_str(name_fsm, "ef", fsm.initial) name_state_seq.insert(0, fsm.initial) name_old_to_new_states = fsms_to_trans_finals[2][2] @@ -225,13 +272,13 @@ def test_get_sub_fsms_from_seq(): for old_state, new_state in zip(name_state_seq, state_seq) ) - state_seq = walk_fsm(fsm, "match", fsm.initial) + state_seq = walk_fsm_from_token_str(fsm, "match", fsm.initial) state_seq.insert(0, fsm.initial) res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) assert res == [(1, False, True), (2, True, True)] - match_state_seq = walk_fsm(match_fsm, "match", fsm.initial) + match_state_seq = walk_fsm_from_token_str(match_fsm, "match", fsm.initial) match_state_seq.insert(0, fsm.initial) match_old_to_new_states = fsms_to_trans_finals[1][2] @@ -240,25 +287,25 @@ def test_get_sub_fsms_from_seq(): for old_state, new_state in zip(match_state_seq, state_seq) ) - state_seq = walk_fsm(fsm, "defa", fsm.initial) + state_seq = walk_fsm_from_token_str(fsm, "defa", fsm.initial) state_seq.insert(0, fsm.initial) res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) assert res == [(2, True, True)] - state_seq = walk_fsm(fsm, "de", fsm.initial) + state_seq = walk_fsm_from_token_str(fsm, "de", fsm.initial) state_seq.insert(0, fsm.initial) res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) assert res == [(0, True, False), (2, True, True)] - state_seq = walk_fsm(fsm, "+", fsm.initial, False) + state_seq = walk_fsm_from_token_str(fsm, "+", fsm.initial, False) state_seq.insert(0, fsm.initial) res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) assert res == [(3, True, False), (4, False, True)] - state_seq = walk_fsm(fsm, "+=", fsm.initial) + state_seq = walk_fsm_from_token_str(fsm, "+=", fsm.initial) state_seq.insert(0, fsm.initial) res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) @@ -304,15 +351,15 @@ def test_create_fsm_index_end_to_end(): vocabulary_nb = numba.typed.List.empty_list( numba.types.Tuple( ( - numba.types.UnicodeCharSeq(2)[:], + numba.types.unicode_type, numba.int64[:], ) ) ) for token_tuple, token_ids in vocabulary.items(): - token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2")) + token = merge_symbols(token_tuple) token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) - vocabulary_nb.append((token_tuple_np, token_ids_np)) + vocabulary_nb.append((token, token_ids_np)) res = create_fsm_index_end_to_end(regex_fsm.fsm_info, vocabulary_nb) @@ -331,23 +378,25 @@ def test_create_fsm_index_end_to_end_multi_byte(): "😈a": numba.typed.List([1]), "😇": numba.typed.List([2]), "😍": numba.typed.List([3]), - ("F0", "9F", "98", "8D"): numba.typed.List([4]), # '😍' + merge_symbols(("F0", "9F", "98", "8D")): numba.typed.List([4]), # '😍' " 😍": numba.typed.List([5]), - (" ", "F0", "9F", "98", "8D"): numba.typed.List([6]), # ' 😍' - (" ", "F0", "9F", "98"): numba.typed.List([7]), # ' 😍' incomplete + merge_symbols((" ", "F0", "9F", "98", "8D")): numba.typed.List([6]), # ' 😍' + merge_symbols((" ", "F0", "9F", "98")): numba.typed.List( + [7] + ), # ' 😍' incomplete "": numba.typed.List([8]), } vocabulary_nb = numba.typed.List.empty_list( numba.types.Tuple( ( - numba.types.UnicodeCharSeq(2)[:], + numba.types.unicode_type, numba.int64[:], ) ) ) for token_tuple, token_ids in vocabulary.items(): - token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2")) + token_tuple_np = merge_symbols(token_tuple) token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) vocabulary_nb.append((token_tuple_np, token_ids_np)) @@ -356,7 +405,16 @@ def test_create_fsm_index_end_to_end_multi_byte(): assert res == {0: {(5, 3), (6, 3), (7, 7), (2, 2)}, 3: {(2, 3), (3, 3), (4, 3)}} -def test_create_fsm_index_tokenizer(): +@pytest.mark.parametrize( + "hf_tokenizer_uri", + [ + "gpt2", + "microsoft/phi-2", + "Qwen/Qwen1.5-0.5B-Chat", + "NousResearch/Hermes-2-Pro-Llama-3-8B", + ], +) +def test_create_fsm_index_tokenizer(hf_tokenizer_uri): # The combined regular expressions of a lexer state in a Python grammar regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" @@ -371,7 +429,7 @@ def test_create_fsm_index_tokenizer(): num_bytes_fsm_states = len(bytes_fsm.states) assert num_bytes_fsm_states == 235 - tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_uri) tokenizer = TransformerTokenizer(tokenizer) states_to_token_subsets, empty_token_ids = create_fsm_index_tokenizer(