Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions mlx_lm/models/deepseek_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -2093,14 +2093,21 @@ def dequant_fp4(weight: mx.array, scale: mx.array, block_size: int = 32):
weights = new_weights

top_remap = {
"embed.weight": "model.embed_tokens.weight",
"norm.weight": "model.norm.weight",
"head.weight": "lm_head.weight",
"hc_head_fn": "model.hc_head.fn",
"hc_head_base": "model.hc_head.base",
"hc_head_scale": "model.hc_head.scale",
"embed": "model.embed_tokens",
"norm": "model.norm",
"head": "lm_head",
}
for old, new in top_remap.items():
for param in ("weight", "scales", "biases"):
old_key = f"{old}.{param}"
if old_key in weights:
weights[f"{new}.{param}"] = weights.pop(old_key)

for old, new in (
("hc_head_fn", "model.hc_head.fn"),
("hc_head_base", "model.hc_head.base"),
("hc_head_scale", "model.hc_head.scale"),
):
if old in weights:
weights[new] = weights.pop(old)

Expand All @@ -2124,14 +2131,31 @@ def dequant_fp4(weight: mx.array, scale: mx.array, block_size: int = 32):
("w2", "down_proj"),
("w3", "up_proj"),
):
dst_prefix = f"model.layers.{layer_idx}.ffn.switch_mlp.{dst}"
src_prefix = f"{prefix}.{src}"
for param in ("weight", "scales", "biases"):
key = f"{src_prefix}.{param}"
if key in weights:
weights[f"{dst_prefix}.{param}"] = weights.pop(key)

key0 = f"{prefix}.0.{src}.weight"
if key0 in weights:
stacked = [
weights.pop(f"{prefix}.{e}.{src}.weight")
for e in range(self.args.n_routed_experts)
]
weights[f"model.layers.{layer_idx}.ffn.switch_mlp.{dst}.weight"] = (
mx.stack(stacked)
weights[f"{dst_prefix}.weight"] = mx.stack(stacked)

prefix = f"model.layers.{layer_idx}.attn.wo_a"
for param in ("weight", "scales", "biases"):
key0 = f"{prefix}.0.{param}"
if key0 in weights:
weights[f"{prefix}.{param}"] = mx.concatenate(
[
weights.pop(f"{prefix}.{group}.{param}")
for group in range(self.args.o_groups)
],
axis=0,
)

return weights
Expand Down
3 changes: 2 additions & 1 deletion mlx_lm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def _load(self, model_path, adapter_path=None, draft_model_path=None):
model_path,
adapter_path=adapter_path,
tokenizer_config=self._tokenizer_config,
lazy=True,
)

# Use the default chat template if needed
Expand All @@ -360,7 +361,7 @@ def _load(self, model_path, adapter_path=None, draft_model_path=None):
# Load the draft model for speculative decoding
draft_model = None
if draft_model_path is not None:
draft_model, draft_tokenizer = load(draft_model_path)
draft_model, draft_tokenizer = load(draft_model_path, lazy=True)
if draft_tokenizer.vocab_size != tokenizer.vocab_size:
logging.warning(
"Draft model tokenizer does not match model tokenizer. "
Expand Down
54 changes: 49 additions & 5 deletions mlx_lm/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,17 +610,61 @@ def load(

tokenizer_config_file = model_path / "tokenizer_config.json"
chat_template = None
tokenizer_config_content = {}
if tokenizer_config_file.exists():
with open(tokenizer_config_file, "r", encoding="utf-8") as fid:
tokenizer_config_content = json.load(fid)

if (
tokenizer_file.exists()
and tokenizer_config_content.get("tokenizer_class") == "TokenizersBackend"
):
tokenizer_kwargs = {
key: tokenizer_config_content[key]
for key in (
"bos_token",
"eos_token",
"pad_token",
"unk_token",
"chat_template",
"clean_up_tokenization_spaces",
"model_max_length",
)
if key in tokenizer_config_content
}
tokenizer_kwargs.update(tokenizer_config_extra or {})
tokenizer = PreTrainedTokenizerFast(
tokenizer_file=str(tokenizer_file), **tokenizer_kwargs
)
else:
tokenizer = AutoTokenizer.from_pretrained(
model_path, **(tokenizer_config_extra or {})
)

tokenizer = AutoTokenizer.from_pretrained(
model_path, **(tokenizer_config_extra or {})
)

tokenizer_config = tokenizer.init_kwargs
tokenizer_config = {**tokenizer_config_content, **tokenizer.init_kwargs}

if chat_template_type := tokenizer_config.get("chat_template_type", False):
chat_template = importlib.import_module(
f"mlx_lm.chat_templates.{chat_template_type}"
).apply_chat_template
elif tokenizer.chat_template is None:
config_file = model_path / "config.json"
if config_file.exists():
with open(config_file, "r", encoding="utf-8") as fid:
model_config = json.load(fid)
if model_config.get("model_type") == "deepseek_v4":
tokenizer_config["tool_parser_type"] = "deepseek_v32"
deepseek_v4_template = importlib.import_module(
"mlx_lm.chat_templates.deepseek_v32"
).apply_chat_template

def chat_template(*args, **kwargs):
enable_thinking = kwargs.pop("enable_thinking", None)
if enable_thinking is not None and "thinking_mode" not in kwargs:
kwargs["thinking_mode"] = (
"thinking" if enable_thinking else "chat"
)
return deepseek_v4_template(*args, **kwargs)

tool_parser_type = tokenizer_config.get(
"tool_parser_type", _infer_tool_parser(tokenizer.chat_template)
Expand Down
35 changes: 35 additions & 0 deletions mlx_lm/tool_parsers/deepseek_v32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright © 2026 Apple Inc.

import json
import re
from typing import Any

tool_call_start = "<|DSML|function_calls>"
tool_call_end = "</|DSML|function_calls>"

_invoke_regex = re.compile(
r'<|DSML|invoke\s+name="([^"]+)">(.*?)</|DSML|invoke>',
re.DOTALL,
)
_parameter_regex = re.compile(
r'<|DSML|parameter\s+name="([^"]+)"\s+string="(true|false)">(.*?)</|DSML|parameter>',
re.DOTALL,
)


def _parse_invoke(match: re.Match):
name, body = match.groups()
arguments = {}
for parameter in _parameter_regex.finditer(body):
param_name, is_string, value = parameter.groups()
if is_string != "true":
value = json.loads(value)
arguments[param_name] = value
return {"name": name, "arguments": arguments}


def parse_tool_call(text: str, _: Any | None = None):
calls = [_parse_invoke(invoke) for invoke in _invoke_regex.finditer(text)]
if not calls:
raise ValueError("No function provided.")
return calls[0] if len(calls) == 1 else calls