Skip to content

Commit 65ae158

Browse files
authored
Merge pull request #86 from lapp0/parallel-model-tokenizer-index-load
Parallel model tokenizer index load
2 parents d78041e + f2abefe commit 65ae158

File tree

5 files changed

+140
-30
lines changed

5 files changed

+140
-30
lines changed

benchmarks/bench_regex_fsm.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import random
2+
3+
from outlines.caching import cache_disabled
4+
from outlines.fsm.regex import reduced_vocabulary
5+
from outlines.models.tokenizer import Tokenizer
6+
7+
from .common import ensure_numba_compiled
8+
9+
10+
class MockTokenizer(Tokenizer):
11+
def __init__(self, token_strs):
12+
self.eos_token = "<eos>"
13+
self.eos_token_id = 0
14+
self.pad_token_id = 1
15+
self.special_tokens = {0, 1}
16+
17+
self.vocabulary = {"<eos>": 0, "<pad>": 1}
18+
19+
for i, tok in enumerate(token_strs):
20+
self.vocabulary[tok] = i + 2
21+
22+
@classmethod
23+
def from_random_tokens(cls, n_tokens, max_token_length=8, seed=42):
24+
random.seed(seed)
25+
tokens = [
26+
"".join(
27+
chr(random.randint(0, 4096))
28+
for __ in range(random.randint(0, max_token_length))
29+
)
30+
for _ in range(n_tokens)
31+
]
32+
return cls(tokens)
33+
34+
def convert_token_to_string(self, token):
35+
return token
36+
37+
def __hash__(self):
38+
return hash(tuple(sorted(self.vocabulary.items())))
39+
40+
41+
def reduced_vocabulary_uncached(*args, **kwargs):
42+
return reduced_vocabulary.__wrapped__(*args, **kwargs)
43+
44+
45+
class RegexReducedVocabularyBenchmark:
46+
params = [10000, 100000, 1000000]
47+
param_names = ["vocab_size"]
48+
49+
def setup(self, vocab_size):
50+
ensure_numba_compiled(MockTokenizer([chr(i) for i in range(128)]))
51+
52+
self.tokenizer = MockTokenizer.from_random_tokens(vocab_size)
53+
54+
@cache_disabled()
55+
def time_reduced_vocabulary(self, _):
56+
reduced_vocabulary_uncached(self.tokenizer)

mkdocs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ nav:
138138
- TGI: reference/models/tgi.md
139139
- ExllamaV2: reference/models/exllamav2.md
140140
- MLX: reference/models/mlxlm.md
141-
- Mamba: reference/models/mamba.md
141+
- Mamba: reference/models/transformers.md
142142
- API:
143143
- OpenAI: reference/models/openai.md
144144
- API Reference:

outlines/caching.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ def disable_cache():
153153
154154
`outlines.cache.disable` should be called right after importing outlines:
155155
156-
>>> import outlines.cache as cache
157-
>>> cache.disable()
156+
>>> import outlines.caching as cache
157+
>>> cache.disable_cache()
158158
159159
"""
160160
global _caching_enabled

outlines/fsm/regex.py

Lines changed: 80 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -887,23 +887,11 @@ def gpt2_unicode_to_bytes():
887887
return {v: k for k, v in gpt2_bytes_to_unicode().items()}
888888

889889

890-
# TODO: Cannot cache typed collections to disk, yet. See
891-
# https://github.com/numba/numba/issues/4698
892-
@lru_cache
893-
def reduced_vocabulary(
894-
tokenizer: "Tokenizer",
895-
) -> Tuple[List[Tuple[str, Sequence[int]]], Set[int]]:
896-
"""Create a map from decoded vocabulary tokens to lists of equivalent token ids."""
890+
def get_normalized_vocab(tokenizer: "Tokenizer") -> Tuple[Dict[int, str], Set[int]]:
891+
norm_vocab = {}
897892
empty_token_ids = set()
898-
vocabulary: Dict[Union[str, Tuple[str, ...]], List[int]] = {}
899893
for token, token_idx in tokenizer.vocabulary.items():
900-
if token in tokenizer.special_tokens:
901-
continue
902-
903-
token_str: Union[str, Tuple[str, ...]] = tokenizer.convert_token_to_string(
904-
token
905-
)
906-
894+
token_str = tokenizer.convert_token_to_string(token)
907895
if token_str:
908896
# invalid utf-8 sequences are replaced with � (\ufffd), but there
909897
# might also be tokens specifically for �, ��, ���, etc.
@@ -927,22 +915,88 @@ def reduced_vocabulary(
927915
)
928916
token_str = "".join(byte_symbol(b) for b in token_bytes)
929917

930-
vocabulary.setdefault(token_str, []).append(token_idx)
918+
norm_vocab[token_idx] = token_str
931919
else:
932920
empty_token_ids.add(numba.int64(token_idx))
933921

934-
vocabulary_nb = numba.typed.List.empty_list(
935-
numba.types.Tuple(
936-
(
937-
nb_unicode_type,
938-
numba.int64[:],
939-
)
940-
)
922+
return norm_vocab, empty_token_ids
923+
924+
925+
@numba.njit(cache=True, nogil=True)
926+
def to_numba_dict(keys: List[int], values: List[str]):
927+
"""
928+
Pure-python numba dict construction is extremely slow.
929+
This helper accepts equal length key and value arrays, and constructs a numba dict
930+
"""
931+
# Define the key and value types for the Numba dictionary
932+
numba_dict = numba.typed.Dict.empty(
933+
key_type=numba.types.int64,
934+
value_type=numba.types.unicode_type,
941935
)
942-
for token_str, token_ids in vocabulary.items():
943-
token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64"))
944-
vocabulary_nb.append((token_str, token_ids_np))
945936

937+
# Fill the Numba dictionary with values from the input lists
938+
for i in range(len(keys)):
939+
numba_dict[keys[i]] = values[i]
940+
941+
return numba_dict
942+
943+
944+
token_id_str_pair = numba.types.Tuple((nb_unicode_type, numba.int64[:]))
945+
946+
947+
@numba.njit(
948+
numba.types.ListType(token_id_str_pair)(
949+
numba.types.DictType(numba.int64, nb_unicode_type)
950+
),
951+
cache=True,
952+
nogil=True,
953+
)
954+
def vocab_dict_to_inverted_vocab_list(
955+
vocab_dict_nb: Dict[int, str]
956+
) -> List[Tuple[str, Sequence[int]]]:
957+
"""
958+
Helper for `reduced_vocabulary`
959+
960+
Convert
961+
- from `vocab_dict_nb`: Dict[token_id, token_str]
962+
- to `vocab_nb`: List[token_str, token_id[:]]
963+
"""
964+
inverse_vocab_dict = numba.typed.Dict.empty(
965+
key_type=numba.types.unicode_type, value_type=numba.types.int64[:]
966+
)
967+
968+
# Fill the temporary dictionary
969+
for key in vocab_dict_nb:
970+
value = vocab_dict_nb[key]
971+
if value not in inverse_vocab_dict:
972+
inverse_vocab_dict[value] = np.zeros(0, dtype=np.int64)
973+
inverse_vocab_dict[value] = np.append(inverse_vocab_dict[value], key)
974+
975+
# Transfer data from the temporary dictionary to the final dictionary
976+
vocab_nb = numba.typed.List.empty_list(token_id_str_pair)
977+
978+
for value in inverse_vocab_dict:
979+
vocab_nb.append((value, inverse_vocab_dict[value]))
980+
981+
return vocab_nb
982+
983+
984+
# TODO: Cannot cache typed collections to disk, yet. See
985+
# https://github.com/numba/numba/issues/4698
986+
@lru_cache
987+
def reduced_vocabulary(
988+
tokenizer: "Tokenizer",
989+
) -> Tuple[List[Tuple[str, Sequence[int]]], Set[int]]:
990+
"""
991+
Provided the tokenizer, calculate the
992+
- vocabulary_nb: mapping of (normalized token str -> token_ids[:])
993+
- empty token ids
994+
"""
995+
norm_vocab, empty_token_ids = get_normalized_vocab(tokenizer)
996+
norm_vocab_dict_nb = to_numba_dict(
997+
np.fromiter(norm_vocab.keys(), dtype=np.int64), list(norm_vocab.values())
998+
)
999+
vocabulary_nb = vocab_dict_to_inverted_vocab_list(norm_vocab_dict_nb)
9461000
return vocabulary_nb, empty_token_ids
9471001

9481002

outlines/models/transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(self, tokenizer: "PreTrainedTokenizer", **kwargs):
6969
self.eos_token_id = self.tokenizer.eos_token_id
7070
self.eos_token = self.tokenizer.eos_token
7171

72-
if not self.tokenizer.pad_token_id:
72+
if self.tokenizer.pad_token_id is None:
7373
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
7474
self.pad_token_id = self.eos_token_id
7575
else:

0 commit comments

Comments
 (0)