diff --git a/mlx_lm/models/deepseek_v4.py b/mlx_lm/models/deepseek_v4.py index c41731bf7..f56abe63b 100644 --- a/mlx_lm/models/deepseek_v4.py +++ b/mlx_lm/models/deepseek_v4.py @@ -2093,14 +2093,21 @@ def dequant_fp4(weight: mx.array, scale: mx.array, block_size: int = 32): weights = new_weights top_remap = { - "embed.weight": "model.embed_tokens.weight", - "norm.weight": "model.norm.weight", - "head.weight": "lm_head.weight", - "hc_head_fn": "model.hc_head.fn", - "hc_head_base": "model.hc_head.base", - "hc_head_scale": "model.hc_head.scale", + "embed": "model.embed_tokens", + "norm": "model.norm", + "head": "lm_head", } for old, new in top_remap.items(): + for param in ("weight", "scales", "biases"): + old_key = f"{old}.{param}" + if old_key in weights: + weights[f"{new}.{param}"] = weights.pop(old_key) + + for old, new in ( + ("hc_head_fn", "model.hc_head.fn"), + ("hc_head_base", "model.hc_head.base"), + ("hc_head_scale", "model.hc_head.scale"), + ): if old in weights: weights[new] = weights.pop(old) @@ -2124,14 +2131,31 @@ def dequant_fp4(weight: mx.array, scale: mx.array, block_size: int = 32): ("w2", "down_proj"), ("w3", "up_proj"), ): + dst_prefix = f"model.layers.{layer_idx}.ffn.switch_mlp.{dst}" + src_prefix = f"{prefix}.{src}" + for param in ("weight", "scales", "biases"): + key = f"{src_prefix}.{param}" + if key in weights: + weights[f"{dst_prefix}.{param}"] = weights.pop(key) + key0 = f"{prefix}.0.{src}.weight" if key0 in weights: stacked = [ weights.pop(f"{prefix}.{e}.{src}.weight") for e in range(self.args.n_routed_experts) ] - weights[f"model.layers.{layer_idx}.ffn.switch_mlp.{dst}.weight"] = ( - mx.stack(stacked) + weights[f"{dst_prefix}.weight"] = mx.stack(stacked) + + prefix = f"model.layers.{layer_idx}.attn.wo_a" + for param in ("weight", "scales", "biases"): + key0 = f"{prefix}.0.{param}" + if key0 in weights: + weights[f"{prefix}.{param}"] = mx.concatenate( + [ + weights.pop(f"{prefix}.{group}.{param}") + for group in range(self.args.o_groups) + ], + axis=0, ) return weights diff --git a/mlx_lm/server.py b/mlx_lm/server.py index ce8d95817..a2070629b 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -350,6 +350,7 @@ def _load(self, model_path, adapter_path=None, draft_model_path=None): model_path, adapter_path=adapter_path, tokenizer_config=self._tokenizer_config, + lazy=True, ) # Use the default chat template if needed @@ -360,7 +361,7 @@ def _load(self, model_path, adapter_path=None, draft_model_path=None): # Load the draft model for speculative decoding draft_model = None if draft_model_path is not None: - draft_model, draft_tokenizer = load(draft_model_path) + draft_model, draft_tokenizer = load(draft_model_path, lazy=True) if draft_tokenizer.vocab_size != tokenizer.vocab_size: logging.warning( "Draft model tokenizer does not match model tokenizer. " diff --git a/mlx_lm/tokenizer_utils.py b/mlx_lm/tokenizer_utils.py index c7e50fbe7..142777016 100644 --- a/mlx_lm/tokenizer_utils.py +++ b/mlx_lm/tokenizer_utils.py @@ -610,17 +610,61 @@ def load( tokenizer_config_file = model_path / "tokenizer_config.json" chat_template = None + tokenizer_config_content = {} + if tokenizer_config_file.exists(): + with open(tokenizer_config_file, "r", encoding="utf-8") as fid: + tokenizer_config_content = json.load(fid) + + if ( + tokenizer_file.exists() + and tokenizer_config_content.get("tokenizer_class") == "TokenizersBackend" + ): + tokenizer_kwargs = { + key: tokenizer_config_content[key] + for key in ( + "bos_token", + "eos_token", + "pad_token", + "unk_token", + "chat_template", + "clean_up_tokenization_spaces", + "model_max_length", + ) + if key in tokenizer_config_content + } + tokenizer_kwargs.update(tokenizer_config_extra or {}) + tokenizer = PreTrainedTokenizerFast( + tokenizer_file=str(tokenizer_file), **tokenizer_kwargs + ) + else: + tokenizer = AutoTokenizer.from_pretrained( + model_path, **(tokenizer_config_extra or {}) + ) - tokenizer = AutoTokenizer.from_pretrained( - model_path, **(tokenizer_config_extra or {}) - ) - - tokenizer_config = tokenizer.init_kwargs + tokenizer_config = {**tokenizer_config_content, **tokenizer.init_kwargs} if chat_template_type := tokenizer_config.get("chat_template_type", False): chat_template = importlib.import_module( f"mlx_lm.chat_templates.{chat_template_type}" ).apply_chat_template + elif tokenizer.chat_template is None: + config_file = model_path / "config.json" + if config_file.exists(): + with open(config_file, "r", encoding="utf-8") as fid: + model_config = json.load(fid) + if model_config.get("model_type") == "deepseek_v4": + tokenizer_config["tool_parser_type"] = "deepseek_v32" + deepseek_v4_template = importlib.import_module( + "mlx_lm.chat_templates.deepseek_v32" + ).apply_chat_template + + def chat_template(*args, **kwargs): + enable_thinking = kwargs.pop("enable_thinking", None) + if enable_thinking is not None and "thinking_mode" not in kwargs: + kwargs["thinking_mode"] = ( + "thinking" if enable_thinking else "chat" + ) + return deepseek_v4_template(*args, **kwargs) tool_parser_type = tokenizer_config.get( "tool_parser_type", _infer_tool_parser(tokenizer.chat_template) diff --git a/mlx_lm/tool_parsers/deepseek_v32.py b/mlx_lm/tool_parsers/deepseek_v32.py new file mode 100644 index 000000000..f7003c960 --- /dev/null +++ b/mlx_lm/tool_parsers/deepseek_v32.py @@ -0,0 +1,35 @@ +# Copyright © 2026 Apple Inc. + +import json +import re +from typing import Any + +tool_call_start = "<|DSML|function_calls>" +tool_call_end = "" + +_invoke_regex = re.compile( + r'<|DSML|invoke\s+name="([^"]+)">(.*?)', + re.DOTALL, +) +_parameter_regex = re.compile( + r'<|DSML|parameter\s+name="([^"]+)"\s+string="(true|false)">(.*?)', + re.DOTALL, +) + + +def _parse_invoke(match: re.Match): + name, body = match.groups() + arguments = {} + for parameter in _parameter_regex.finditer(body): + param_name, is_string, value = parameter.groups() + if is_string != "true": + value = json.loads(value) + arguments[param_name] = value + return {"name": name, "arguments": arguments} + + +def parse_tool_call(text: str, _: Any | None = None): + calls = [_parse_invoke(invoke) for invoke in _invoke_regex.finditer(text)] + if not calls: + raise ValueError("No function provided.") + return calls[0] if len(calls) == 1 else calls