Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Minimal Tokenizer Implementation #513

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
14 changes: 5 additions & 9 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import asyncio
import json
from pathlib import Path
from transformers import AutoTokenizer
from typing import List, Literal, Union, Dict
from aiohttp import web
import aiohttp_cors
Expand All @@ -14,9 +13,8 @@
from exo import DEBUG, VERSION
from exo.download.download_progress import RepoProgressEvent
from exo.helpers import PrefixDict, shutdown
from exo.inference.tokenizers import resolve_tokenizer
from exo.orchestration import Node
from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
from exo.models import build_base_shard, model_cards, pretty_name, get_supported_models
from typing import Callable, Optional

class Message:
Expand Down Expand Up @@ -228,7 +226,7 @@ async def handle_post_chat_token_encode(self, request):
data = await request.json()
shard = build_base_shard(self.default_model, self.inference_engine_classname)
messages = [parse_message(msg) for msg in data.get("messages", [])]
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
tokenizer = await self.node.inference_engine.get_tokenizer(shard)
return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})

async def handle_get_download_progress(self, request):
Expand Down Expand Up @@ -257,8 +255,7 @@ async def handle_post_chat_completions(self, request):
{"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
status=400,
)

tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
tokenizer = await self.node.inference_engine.get_tokenizer(shard)
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")

prompt = build_prompt(tokenizer, chat_request.messages)
Expand Down Expand Up @@ -307,8 +304,7 @@ async def stream_result(_request_id: str, tokens: List[int], is_finished: bool):
self.prev_token_lens[_request_id] = max(prev_last_tokens_len, len(tokens))
new_tokens = tokens[prev_last_tokens_len:]
finish_reason = None
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
eos_token_id = tokenizer.eos_token_id
if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
new_tokens = new_tokens[:-1]
if is_finished:
Expand Down Expand Up @@ -354,7 +350,7 @@ def on_result(_request_id: str, tokens: List[int], is_finished: bool):
)

finish_reason = "length"
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
eos_token_id = tokenizer.eos_token_id
if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
if tokens[-1] == eos_token_id:
tokens = tokens[:-1]
Expand Down
5 changes: 5 additions & 0 deletions exo/inference/mlx/sharded_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from exo.download.shard_download import ShardDownloader
import asyncio
from concurrent.futures import ThreadPoolExecutor
from exo.tokenizer.tokenizer import Tokenizer

def sample_logits(
logits: mx.array,
Expand Down Expand Up @@ -58,6 +59,10 @@ async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarr
await self.ensure_shard(shard)
output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.model, mx.array(input_data), request_id))
return output_data

async def get_tokenizer(self, shard: Shard) -> Tokenizer:
await self.ensure_shard(shard)
return self.tokenizer

async def ensure_shard(self, shard: Shard):
if self.shard == shard:
Expand Down
8 changes: 4 additions & 4 deletions exo/inference/mlx/sharded_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
import mlx.nn as nn
from transformers import AutoProcessor

from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
from exo.tokenizer.tokenizer import Tokenizer

from exo import DEBUG
from exo.inference.tokenizers import resolve_tokenizer
from exo.tokenizer.tokenizer import resolve_tokenizer
from ..shard import Shard


Expand Down Expand Up @@ -174,7 +174,7 @@ async def load_shard(
model_config={},
adapter_path: Optional[str] = None,
lazy: bool = False,
) -> Tuple[nn.Module, TokenizerWrapper]:
) -> Tuple[nn.Module, Tokenizer]:
model = load_model_shard(model_path, shard, lazy, model_config)

# TODO: figure out a generic solution
Expand All @@ -184,7 +184,7 @@ async def load_shard(
processor.encode = processor.tokenizer.encode
return model, processor
else:
tokenizer = await resolve_tokenizer(model_path)
tokenizer = resolve_tokenizer(shard.model_id, model_path)
return model, tokenizer


Expand Down
43 changes: 24 additions & 19 deletions exo/inference/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from exo.download.hf.hf_helpers import get_local_snapshot_dir
from exo.helpers import DEBUG

from exo.tokenizer.tokenizer import Tokenizer

class DummyTokenizer:
def __init__(self):
Expand Down Expand Up @@ -40,25 +41,29 @@ async def resolve_tokenizer(model_id: str):


async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
try:
if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")
processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in f"{model_id_or_local_path}" else False, trust_remote_code=True)
if not hasattr(processor, 'eos_token_id'):
processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
if not hasattr(processor, 'encode'):
processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode
if not hasattr(processor, 'decode'):
processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
return processor
except Exception as e:
if DEBUG >= 4: print(f"Failed to load processor for {model_id_or_local_path}. Error: {e}")
if DEBUG >= 4: print(traceback.format_exc())
# try:
# if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")
# processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in f"{model_id_or_local_path}" else False, trust_remote_code=True)
# if not hasattr(processor, 'eos_token_id'):
# processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
# if not hasattr(processor, 'encode'):
# processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode
# if not hasattr(processor, 'decode'):
# processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
# return processor
# except Exception as e:
# if DEBUG >= 4: print(f"Failed to load processor for {model_id_or_local_path}. Error: {e}")
# if DEBUG >= 4: print(traceback.format_exc())

# try:
# if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}")
# return AutoTokenizer.from_pretrained(model_id_or_local_path, trust_remote_code=True)
# except Exception as e:
# if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
# if DEBUG >= 4: print(traceback.format_exc())

# raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}")
try:
if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}")
return AutoTokenizer.from_pretrained(model_id_or_local_path, trust_remote_code=True)
return Tokenizer(model_id_or_local_path)
except Exception as e:
if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
if DEBUG >= 4: print(traceback.format_exc())

raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}")
raise ValueError(f"Failed to load tokenizer for {model_id_or_local_path}. Error: {e}")
2 changes: 1 addition & 1 deletion exo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
if not shard:
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
return
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
tokenizer = await node.inference_engine.get_tokenizer(shard)
request_id = str(uuid.uuid4())
callback_id = f"cli-wait-response-{request_id}"
callback = node.on_token.register(callback_id)
Expand Down
4 changes: 4 additions & 0 deletions exo/tokenizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .llama import LlamaTokenizer
from .tokenizer import Tokenizer

__all__ = ['LlamaTokenizer', 'Tokenizer']
106 changes: 106 additions & 0 deletions exo/tokenizer/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os
import json
import re
import tiktoken
from typing import List, Dict, Any, Tuple
from jinja2 import Template
from datetime import datetime
from exo.tokenizer.tokenizer import Tokenizer

class LlamaTokenizer(Tokenizer):
def __init__(self, model_path: str):
with open(os.path.join(model_path, 'tokenizer.json'), 'r', encoding="utf-8") as f:
tokenizer_data = json.load(f)
vocab = tokenizer_data["model"]["vocab"]
self.pattern = tokenizer_data["pre_tokenizer"]["pretokenizers"][0]["pattern"]["Regex"]
self.special_tokens = {token["content"]: int(token["id"]) for token in tokenizer_data["added_tokens"]}

with open(os.path.join(model_path, 'tokenizer_config.json'), 'r', encoding="utf-8") as f:
tokenizer_config = json.load(f)
self.chat_template = tokenizer_config["chat_template"]
self.bos_token = tokenizer_config["bos_token"]
self.eos_token = tokenizer_config["eos_token"]
self.add_bos_token = bool(tokenizer_config.get("add_bos_token", False))
self.add_eos_token = bool(tokenizer_config.get("add_eos_token", False))

self._bos_token_id, self._eos_token_id = self.get_bos_and_eos_ids()

self.vocab = {bytes(k, "utf-8"): v for k, v in vocab.items()} # convert str keys to bytes

self.encoding = tiktoken.Encoding(
name="custom_encoding",
pat_str=self.pattern,
mergeable_ranks=self.vocab,
special_tokens=self.special_tokens
)

def decode_chars(self, text: str) -> str:
decoding_map = {'Ġ': ' ', 'ĉ': '\t', 'Ċ': '\n'}
result = ''
for char in text:
result += decoding_map.get(char, char)
return result

def encode_chars(self, text: str) -> str:
encoding_map = {' ': 'Ġ', '\t': 'ĉ', '\n': 'Ċ'}
result = ''
for char in text:
result += encoding_map.get(char, char)
return result

def decode(self, tokens: List[int]) -> str:
return self.decode_chars(self.encoding.decode(tokens))

def encode(self, text: str, allow_special: bool = True) -> List[int]:
allowed_special = set(self.special_tokens.keys()) if allow_special else set()
preprocessed_text = self.encode_chars(text)
if self.add_bos_token:
preprocessed_text = self.bos_token + preprocessed_text
if self.add_eos_token:
preprocessed_text = preprocessed_text + self.eos_token
return self.encoding.encode(
preprocessed_text,
allowed_special=allowed_special,
disallowed_special=set()
)

def get_bos_and_eos_ids(self) -> Tuple[int, int]:
bos_token_id = self.special_tokens.get(self.bos_token, None)
eos_token_id = self.special_tokens.get(self.eos_token, None)

return bos_token_id, eos_token_id

def apply_chat_template(
self,
messages: List[Dict[str, Any]],
add_generation_prompt: bool = True,
**kwargs
) -> str:
if 'strftime_now' not in kwargs:
kwargs['strftime_now'] = datetime.now().strftime

template = Template(self.chat_template)
return template.render(
messages=messages,
add_generation_prompt=add_generation_prompt,
bos_token=self.bos_token,
eos_token=self.eos_token,
**kwargs
)

@property
def eos_token_id(self) -> int:
return self._eos_token_id

class PostProcessor:
def __init__(self, tokenizer_config: Dict[str, Any]):
self.add_bos_token = bool(tokenizer_config["add_bos_token"])
self.add_eos_token = bool(tokenizer_config["add_eos_token"])


def post_process(self, tokens: List[int]) -> List[int]:
if self.add_bos_token:
tokens.insert(0, self.bos_token_id)
if self.add_eos_token:
tokens.append(self.eos_token_id)
return tokens
67 changes: 67 additions & 0 deletions exo/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Any
import importlib
class Tokenizer(ABC):
@abstractmethod
def encode(self, text: str, allow_special: bool = True) -> List[int]:
pass

@abstractmethod
def decode(self, tokens: List[int]) -> str:
pass

@abstractmethod
def apply_chat_template(self, messages: List[Dict[str, Any]], add_generation_prompt: bool = True, **kwargs) -> str:
pass

@property
@abstractmethod
def eos_token_id(self) -> int:
pass

TOKENIZER_CLASSES = {
### llama
"llama-3.2-1b": "LlamaTokenizer",
"llama-3.2-3b": "LlamaTokenizer",
"llama-3.1-8b": "LlamaTokenizer",
"llama-3.1-70b": "LlamaTokenizer",
"llama-3.1-70b-bf16": "LlamaTokenizer",
"llama-3-8b": "LlamaTokenizer",
"llama-3-70b": "LlamaTokenizer",
"llama-3.1-405b": "LlamaTokenizer",
"llama-3.1-405b-8bit": "LlamaTokenizer",
### mistral
"mistral-nemo": "MistralTokenizer",
"mistral-large": "MistralTokenizer",
### deepseek
"deepseek-coder-v2-lite": "DeepSeekTokenizer",
"deepseek-coder-v2.5": "DeepSeekTokenizer",
### llava
"llava-1.5-7b-hf": "LlavaTokenizer",
### qwen
"qwen-2.5-0.5b": "QwenTokenizer",
"qwen-2.5-coder-1.5b": "QwenTokenizer",
"qwen-2.5-coder-3b": "QwenTokenizer",
"qwen-2.5-coder-7b": "QwenTokenizer",
"qwen-2.5-coder-14b": "QwenTokenizer",
"qwen-2.5-coder-32b": "QwenTokenizer",
"qwen-2.5-7b": "QwenTokenizer",
"qwen-2.5-math-7b": "QwenTokenizer",
"qwen-2.5-14b": "QwenTokenizer",
"qwen-2.5-72b": "QwenTokenizer",
"qwen-2.5-math-72b": "QwenTokenizer",
### nemotron
"nemotron-70b": "NemotronTokenizer",
"nemotron-70b-bf16": "NemotronTokenizer",
### gemma
"gemma2-9b": "GemmaTokenizer",
"gemma2-27b": "GemmaTokenizer",
### dummy
"dummy": "DummyTokenizer",
}

def resolve_tokenizer(model_id: str, model_path: str) -> Tokenizer:
tokenizer_class = TOKENIZER_CLASSES[model_id]
tokenizer_module = importlib.import_module("exo.tokenizer")
tokenizer_class_obj = getattr(tokenizer_module, tokenizer_class)
return tokenizer_class_obj(model_path)