Skip to content

Commit 10a8842

Browse files
committed
Fix JSON inference example
1 parent b2c7cf2 commit 10a8842

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

examples/inference_json.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer
66
from exllamav2.generator import ExLlamaV2DynamicGenerator
77
from exllamav2.generator.filters import ExLlamaV2PrefixFilter
8-
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter
8+
from inference_json_lmfe_wrapper import ExLlamaV2TokenEnforcerFilter
99
from lmformatenforcer import JsonSchemaParser
1010
from pydantic import BaseModel, conlist
1111
from typing import Literal
@@ -61,7 +61,7 @@ class Superhero(BaseModel):
6161
filters.append(None)
6262
prompts.append(p)
6363
filters.append([
64-
ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer),
64+
ExLlamaV2TokenEnforcerFilter(model, tokenizer, schema_parser),
6565
ExLlamaV2PrefixFilter(model, tokenizer, ["{", " {"])
6666
])
6767

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
2+
import sys, os
3+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4+
5+
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
6+
from exllamav2.generator.filters import ExLlamaV2Filter
7+
from functools import lru_cache
8+
from lmformatenforcer.integrations.exllamav2 import build_token_enforcer_tokenizer_data
9+
from lmformatenforcer import TokenEnforcer, CharacterLevelParser
10+
from typing import List
11+
12+
13+
# Temporary wrapper for lm-format-enforcer, until the integration in LMFE itself is updated
14+
15+
16+
@lru_cache(10)
17+
def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer):
18+
return build_token_enforcer_tokenizer_data(tokenizer)
19+
20+
21+
class ExLlamaV2TokenEnforcerFilter(ExLlamaV2Filter):
22+
23+
token_sequence: List[int]
24+
25+
def __init__(
26+
self,
27+
model: ExLlamaV2,
28+
tokenizer: ExLlamaV2Tokenizer,
29+
character_level_parser: CharacterLevelParser,
30+
):
31+
super().__init__(model, tokenizer)
32+
tokenizer_data = _get_lmfe_tokenizer_data(tokenizer)
33+
self.token_enforcer = TokenEnforcer(tokenizer_data, character_level_parser)
34+
self.token_sequence = []
35+
36+
def begin(self, prefix_str: str) -> None:
37+
self.token_sequence = []
38+
39+
def feed(self, token) -> None:
40+
self.token_sequence.append(int(token[0][0]))
41+
42+
def next(self):
43+
allowed_tokens = self.token_enforcer.get_allowed_tokens(self.token_sequence)
44+
return sorted(allowed_tokens), []
45+
46+
def use_background_worker(self):
47+
return True

0 commit comments

Comments
 (0)