Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add exaone-3.5 LLM Model and apply unit test #580

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions exo/inference/mlx/models/exaone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from dataclasses import dataclass, field
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.base import create_attention_mask
from mlx_lm.models.exaone import TransformerBlock, ModelArgs
from ...shard import Shard
from .base import IdentityBlock


@dataclass
class ModelArgs(ModelArgs):
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))

def __post_init__(self):
# super().__post_init__() # Ensure parent initializations are respected

if isinstance(self.shard, Shard):
return
if not isinstance(self.shard, dict):
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")

self.shard = Shard(**self.shard)


class ExaoneModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
self.h = [TransformerBlock(args) for _ in range(args.num_layers)]
self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)

def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.wte(inputs)
mask = create_attention_mask(h, cache)

if cache is None:
cache = [None] * len(self.h)

for layer, c in zip(self.h, cache):
h = layer(h, mask, cache=c)

return self.ln_f(h)


class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.transformer = ExaoneModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.transformer(inputs, cache)
if self.args.tie_word_embeddings:
out = self.transformer.wte.as_linear(out)
else:
out = self.lm_head(out)
return out

@property
def layers(self):
return self.transformer.h

@property
def head_dim(self):
return self.args.head_dim

@property
def n_kv_heads(self):
return self.args.num_key_value_heads
4 changes: 4 additions & 0 deletions exo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@
"phi-4": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/phi-4-4bit", }, },
# dummy
"dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
"exaone-3.5-7.8b": {"layers": 32, "repo": {"MLXDynamicShardInferenceEngine": "mlx-community/EXAONE-3.5-7.8B-Instruct-4bit"}, },
"exaone-3.5-2.4b": {"layers": 30, "repo": {"MLXDynamicShardInferenceEngine": "mlx-community/EXAONE-3.5-2.4B-Instruct-4bit"}, },
}

pretty_name = {
Expand Down Expand Up @@ -158,6 +160,8 @@
"phi-4": "Phi-4",
"llama-3-8b": "Llama 3 8B",
"llama-3-70b": "Llama 3 70B",
"exaone-3.5-2.4b": "EXAONE-3.5 2.4B",
"exaone-3.5-7.8b": "EXAONE-3.5 7.8B",
"stable-diffusion-2-1-base": "Stable Diffusion 2.1",
}

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
],
"apple_silicon": [
"mlx==0.20.0",
"mlx-lm==0.19.3",
"mlx-lm==0.20.5",
],
}

Expand Down
7 changes: 5 additions & 2 deletions test/test_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def test_tokenizer(name, tokenizer, verbose=False):
strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
assert text == strip_tokens(decoded) == strip_tokens(reconstructed)

enable_trust_remote_code_models = ['mlx-community/EXAONE-3.5-2.4B-Instruct-4bit', 'mlx-community/EXAONE-3.5-7.8B-Instruct-4bit']
ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", "mlx-community/Phi-3.5-mini-instruct-4bit", "mlx-community/phi-4-4bit", "stabilityai/stable-diffusion-2-1-base"]
ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
models = []
Expand All @@ -34,8 +35,10 @@ def test_tokenizer(name, tokenizer, verbose=False):
models = list(set(models))

verbose = os.environ.get("VERBOSE", "0").lower() == "1"

for m in models:
enable_trust_remote_code = m in enable_trust_remote_code_models
# TODO: figure out why use_fast=False is giving inconsistent behaviour (no spaces decoding invididual tokens) for Mistral-Large-Instruct-2407-4bit
# test_tokenizer(m, AutoProcessor.from_pretrained(m, use_fast=False), verbose)
test_tokenizer(m, AutoProcessor.from_pretrained(m, use_fast=True), verbose)
test_tokenizer(m, AutoTokenizer.from_pretrained(m), verbose)
test_tokenizer(m, AutoProcessor.from_pretrained(m, use_fast=True, trust_remote_code=enable_trust_remote_code), verbose)
test_tokenizer(m, AutoTokenizer.from_pretrained(m, trust_remote_code=enable_trust_remote_code), verbose)