Skip to content

Commit 88d8eaa

Browse files
committed
make LlamaCppTokenizer an outlines Tokenizer
1 parent cb16b16 commit 88d8eaa

File tree

3 files changed

+123
-42
lines changed

3 files changed

+123
-42
lines changed

outlines/integrations/llamacpp.py

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"""
2727

2828
import math
29-
from typing import TYPE_CHECKING, Dict, Optional, Set, Type, Union
29+
from typing import TYPE_CHECKING, Optional, Type, Union
3030

3131
import numpy as np
3232
import torch
@@ -36,47 +36,12 @@
3636
from outlines.fsm.guide import CFGGuide, Guide, RegexGuide
3737
from outlines.fsm.json_schema import build_regex_from_schema
3838
from outlines.integrations.utils import convert_json_schema_to_str
39+
from outlines.models.llamacpp import LlamaCppTokenizer
3940

4041
if TYPE_CHECKING:
4142
from llama_cpp import Llama
4243

4344

44-
class LlamaCppTokenizer:
45-
def __init__(self, model: "Llama"):
46-
self.eos_token_id = model.token_eos()
47-
self.eos_token = model.tokenizer().decode([self.eos_token_id])
48-
self.pad_token_id = self.eos_token_id
49-
self.special_tokens: Set[int] = set()
50-
51-
self.vocabulary: Dict[str, int] = dict()
52-
53-
tokenizer = model.tokenizer()
54-
55-
self.decode = tokenizer.decode
56-
57-
# TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved
58-
try:
59-
self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab()
60-
except AttributeError:
61-
# ###
62-
for t in range(model.n_vocab()):
63-
token_piece = model.tokenizer().decode([t])
64-
self.vocabulary[token_piece] = t
65-
66-
def convert_token_to_string(self, token: str) -> str:
67-
return token
68-
69-
def __getstate__(self):
70-
"""Allow tokenizer to be used as hash key by excluding self.decode"""
71-
return (
72-
self.vocabulary.items(),
73-
self.eos_token_id,
74-
self.eos_token,
75-
self.pad_token_id,
76-
sorted(self.special_tokens),
77-
)
78-
79-
8045
class LogitsProcessor:
8146
"""Bias LlamaCpp generation using a finite state machine.
8247

outlines/models/llamacpp.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,95 @@
11
import dataclasses
2+
import pickle
23
import warnings
3-
from typing import TYPE_CHECKING, Iterator, List, Optional, TypedDict, Union
4+
from typing import (
5+
TYPE_CHECKING,
6+
Dict,
7+
Iterator,
8+
List,
9+
Optional,
10+
Set,
11+
Tuple,
12+
TypedDict,
13+
Union,
14+
)
415

516
from typing_extensions import Unpack
617

718
from outlines.generate.api import GenerationParameters, SamplingParameters
19+
from outlines.models.tokenizer import Tokenizer
820

921
if TYPE_CHECKING:
1022
from llama_cpp import Llama, LogitsProcessorList
1123

1224

25+
class LlamaCppTokenizer(Tokenizer):
26+
def __init__(self, model: "Llama"):
27+
self.eos_token_id = model.token_eos()
28+
self.eos_token = model.tokenizer().decode([self.eos_token_id])
29+
self.pad_token_id = self.eos_token_id
30+
self.special_tokens: Set[int] = set()
31+
32+
self.vocabulary: Dict[str, int] = dict()
33+
34+
self.tokenizer = model.tokenizer()
35+
36+
# TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved
37+
try:
38+
self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab()
39+
except AttributeError:
40+
# ###
41+
for t in range(model.n_vocab()):
42+
token_piece = model.tokenizer().decode([t])
43+
self.vocabulary[token_piece] = t
44+
45+
self._hash = None
46+
47+
def decode(self, token_ids: List[int]) -> List[str]:
48+
decoded_bytes = self.tokenizer.detokenize(token_ids)
49+
return [decoded_bytes.decode("utf-8", errors="ignore")]
50+
51+
def encode(
52+
self, prompt: Union[str, List[str]], add_bos: bool = True, special: bool = True
53+
) -> Tuple[List[int], List[int]]:
54+
if isinstance(prompt, list):
55+
raise NotImplementedError(
56+
"llama-cpp-python tokenizer doesn't support batch tokenization"
57+
)
58+
token_ids = self.tokenizer.tokenize(
59+
prompt.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special
60+
)
61+
# generate attention mask, missing from llama-cpp-python
62+
attention_mask = [
63+
1 if token_id != self.pad_token_id else 0 for token_id in token_ids
64+
]
65+
return token_ids, attention_mask
66+
67+
def convert_token_to_string(self, token: str) -> str:
68+
return token
69+
70+
def __eq__(self, other):
71+
return hash(self) == hash(other)
72+
73+
def __hash__(self):
74+
# cache object hash
75+
if self._hash is None:
76+
self._hash = hash(pickle.dumps(self))
77+
return self._hash
78+
79+
def __getstate__(self):
80+
"""Create a stable representation for outlines.caching"""
81+
return (
82+
self.vocabulary.items(),
83+
self.eos_token_id,
84+
self.eos_token,
85+
self.pad_token_id,
86+
sorted(self.special_tokens),
87+
)
88+
89+
def __setstate__(self, state):
90+
raise NotImplementedError("Cannot load a pickled llamacpp tokenizer")
91+
92+
1393
class LlamaCppParams(TypedDict, total=False):
1494
suffix: Optional[str]
1595
temperature: float

tests/generate/test_integration_llamacpp.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,45 @@ def test_llama_cpp_pre_tokenizer_remains_broken():
281281
generate.choice(model, ["skirt", "dress", "pen", "jacket"])
282282

283283

284-
def test_create_states_mapping_llamacpp_tokenizer_regression(model):
285-
"""Minimal reproducer for #922, error passing llamacpp tokenizer to create_states_mapping"""
284+
def test_RegexGuide_caching(temp_cache_dir):
285+
import outlines.caching
286286
from outlines.fsm.guide import create_states_mapping
287-
from outlines.integrations.llamacpp import LlamaCppTokenizer
288287

289-
create_states_mapping("a", LlamaCppTokenizer(model.model))
288+
assert outlines.caching._caching_enabled
289+
290+
regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
291+
prompt = "What is the IP address of the Google DNS servers? "
292+
293+
cache = outlines.caching.get_cache()
294+
295+
# Returns (hits, misses)
296+
_ = cache.stats(enable=True)
297+
assert cache.statistics
298+
299+
assert create_states_mapping.__memory__ is cache
300+
301+
model = models.transformers(
302+
"hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM"
303+
)
304+
generator = generate.regex(model, regex, sampler=samplers.greedy())
305+
assert cache.stats() == (0, 1)
306+
307+
model_2 = models.transformers("hf-internal-testing/tiny-random-GPTJForCausalLM")
308+
generator_2 = generate.regex(model_2, regex, sampler=samplers.greedy())
309+
assert cache.stats() == (0, 2)
310+
311+
# These two different models and tokenizers should not have the same state
312+
# mapping results
313+
assert generator.fsm.states_to_token_maps != generator_2.fsm.states_to_token_maps
314+
315+
generator_3 = generate.regex(model_2, regex, sampler=samplers.greedy())
316+
assert cache.stats() == (1, 2)
317+
assert generator_2.fsm.states_to_token_maps == generator_3.fsm.states_to_token_maps
318+
319+
# Just for fun...
320+
structured = generator(prompt, max_tokens=30)
321+
structured_2 = generator_2(prompt, max_tokens=30)
322+
323+
assert re.fullmatch(regex, structured)
324+
assert re.fullmatch(regex, structured_2)
325+
assert structured != structured_2

0 commit comments

Comments
 (0)