diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 3573b2640..649063809 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -5,8 +5,11 @@ import copy import functools import json +import math +import random import sys import time +import warnings from collections import deque from dataclasses import dataclass from functools import partial @@ -219,6 +222,12 @@ def setup_arg_parser(): help="Number of tokens to draft when using speculative decoding.", default=3, ) + parser.add_argument( + "--mtp", + action="store_true", + help="Use native Multi-Token Prediction for speculative decoding " + "(requires a model with an MTP head, e.g. DeepSeek-V4).", + ) return parser @@ -654,12 +663,186 @@ def _draft_generate(y, num_draft): _rewind_cache(num_draft, n) +def mtp_generate_step( + prompt: mx.array, + model: nn.Module, + *, + max_tokens: int = 256, + sampler: Optional[Callable[[mx.array], mx.array]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + prompt_cache: Optional[Any] = None, + prefill_step_size: int = 2048, + kv_bits: Optional[int] = None, + kv_group_size: int = 64, + quantized_kv_start: int = 0, +) -> Generator[Tuple[mx.array, mx.array, bool], None, None]: + """A generator that uses the model's native MTP head for speculative decoding. + + Each iteration runs one backbone forward pass over the current token and its + pending draft, then one MTP forward pass to propose the next draft. Up to 2 + tokens are emitted per backbone step: one always-accepted backbone token and + one conditionally-accepted draft token. + + The model must implement ``mtp_forward(hidden, next_tok, mtp_cache)`` and + support ``return_hidden=True`` in its ``__call__``. + + Yields: + Tuple[mx.array, mx.array, bool]: (token, log-probabilities, from_draft). + ``from_draft`` is ``True`` when the token came from the MTP head. + """ + y = prompt.astype(mx.uint32) + prev_tokens = None + + if prompt_cache is None: + model_cache = cache.make_prompt_cache(model) + mtp_cache = model.make_mtp_cache() + else: + n_main = len(model.layers) + model_cache = prompt_cache[:n_main] + mtp_cache = prompt_cache[n_main:] or model.make_mtp_cache() + + _is_greedy = sampler is None + sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) + + quantize_cache_fn = functools.partial( + maybe_quantize_kv_cache, + quantized_kv_start=quantized_kv_start, + kv_group_size=kv_group_size, + kv_bits=kv_bits, + ) + + def _process_and_sample(tokens, logits): + if logits_processors: + for processor in logits_processors: + logits = processor(tokens, logits) + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + return sampler(logprobs), logprobs + + def _step_backbone(y, prev_tokens, n_predict=1): + with mx.stream(generation_stream): + logits, hidden = model( + y[None], cache=model_cache, return_hidden=True + ) + logits = logits[:, -n_predict:, :] + quantize_cache_fn(model_cache) + toks, lps = [], [] + for i in range(n_predict): + if logits_processors: + prev_tokens = ( + mx.concatenate([prev_tokens, y[i : i + 1]]) + if prev_tokens is not None + else y[i : i + 1] + ) + tok, lp = _process_and_sample(prev_tokens, logits[:, i, :].squeeze(0)) + toks.append(tok) + lps.append(lp) + return mx.stack(toks), mx.stack(lps), hidden, prev_tokens + + def _step_mtp(hidden_last, main_tok, prev_tokens): + next_ids = main_tok.reshape(1, 1) + with mx.stream(generation_stream): + mtp_logits = model.mtp_forward(hidden_last, next_ids, mtp_cache) + quantize_cache_fn(mtp_cache) + mtp_logits = mtp_logits[:, -1, :].squeeze(0) + if logits_processors: + tokens_for_proc = ( + mx.concatenate([prev_tokens, main_tok.reshape(-1)]) + if prev_tokens is not None + else main_tok.reshape(-1) + ) + else: + tokens_for_proc = prev_tokens + draft_tok, draft_lp = _process_and_sample(tokens_for_proc, mtp_logits) + return draft_tok, draft_lp + + def _prefill(y): + while y.size > 1: + n = min(prefill_step_size, y.size - 1) + model(y[:n][None], cache=model_cache) + quantize_cache_fn(model_cache) + mx.eval([c.state for c in model_cache if hasattr(c, "state")]) + y = y[n:] + mx.clear_cache() + return y + + with mx.stream(generation_stream): + y = _prefill(y) + + ntoks = 0 + draft_tok = draft_lp = None + + while ntoks < max_tokens: + if draft_tok is None: + toks, lps, hidden, prev_tokens = _step_backbone(y, prev_tokens, n_predict=1) + mx.eval(toks) + main_tok, main_lp = toks[0], lps[0] + ntoks += 1 + yield main_tok.item(), main_lp, False + if ntoks >= max_tokens: + return + hidden_at_main = hidden[:, -1:, :] + draft_tok, draft_lp = _step_mtp(hidden_at_main, main_tok, prev_tokens) + mx.eval(draft_tok) + y = mx.array([main_tok.item()], mx.uint32) + else: + y_with_draft = mx.concatenate([y, mx.array([draft_tok.item()], mx.uint32)]) + toks, lps, hidden, prev_tokens = _step_backbone( + y_with_draft, prev_tokens, n_predict=2 + ) + mx.eval(toks, draft_tok) + + verify_pred, bonus_tok = toks[0], toks[1] + verify_lp, bonus_lp = lps[0], lps[1] + draft_tok_id = draft_tok.item() + + if _is_greedy: + accept = verify_pred.item() == draft_tok_id + else: + log_accept = (verify_lp[draft_tok_id] - draft_lp[draft_tok_id]).item() + accept = log_accept >= 0 or random.random() < math.exp(log_accept) + + hidden_at_confirmed = hidden[:, 0:1, :] + hidden_at_draft = hidden[:, 1:2, :] + + if accept: + ntoks += 1 + yield draft_tok_id, draft_lp, True + if ntoks >= max_tokens: + return + ntoks += 1 + yield bonus_tok.item(), bonus_lp, False + if ntoks >= max_tokens: + return + draft_tok, draft_lp = _step_mtp(hidden_at_draft, bonus_tok, prev_tokens) + mx.eval(draft_tok) + y = mx.array([bonus_tok.item()], mx.uint32) + else: + # Reject draft: trim the draft token from both caches + for c in model_cache: + c.trim(1) + for c in mtp_cache: + c.trim(1) + if logits_processors and prev_tokens is not None: + prev_tokens = prev_tokens[:-1] + verify_tok_id = verify_pred.item() + ntoks += 1 + yield verify_tok_id, verify_lp, False + if ntoks >= max_tokens: + return + draft_tok, draft_lp = _step_mtp( + hidden_at_confirmed, verify_pred, prev_tokens + ) + mx.eval(draft_tok) + y = mx.array([verify_tok_id], mx.uint32) + + def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], prompt: Union[str, mx.array, List[int]], max_tokens: int = 256, draft_model: Optional[nn.Module] = None, + mtp: bool = False, **kwargs, ) -> Generator[GenerationResponse, None, None]: """ @@ -675,6 +858,8 @@ def stream_generate( draft_model (Optional[nn.Module]): An optional draft model. If provided then speculative decoding is used. The draft model must use the same tokenizer as the main model. Default: ``None``. + mtp (bool): Use native Multi-Token Prediction for speculative + decoding. Requires a model with an MTP head. Default: ``False``. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. @@ -698,19 +883,30 @@ def stream_generate( kwargs["max_tokens"] = max_tokens - if draft_model is None: + if draft_model is not None: + kwargs.pop("max_kv_size", None) + kwargs.pop("prompt_progress_callback", None) + token_generator = speculative_generate_step( + prompt, model, draft_model, **kwargs + ) + elif mtp and hasattr(model, "mtp"): + kwargs.pop("max_kv_size", None) + kwargs.pop("prompt_progress_callback", None) + kwargs.pop("num_draft_tokens", None) + token_generator = mtp_generate_step(prompt, model, **kwargs) + else: + if mtp: + warnings.warn( + "--mtp flag ignored: model does not have an MTP head. " + "Falling back to standard generation.", + stacklevel=2, + ) kwargs.pop("num_draft_tokens", None) token_generator = generate_step(prompt, model, **kwargs) # from_draft always false for non-speculative generation token_generator = ( (token, logprobs, False) for token, logprobs in token_generator ) - else: - kwargs.pop("max_kv_size", None) - kwargs.pop("prompt_progress_callback", None) - token_generator = speculative_generate_step( - prompt, model, draft_model, **kwargs - ) with wired_limit(model, [generation_stream]): tic = time.perf_counter() for n, (token, logprobs, from_draft) in enumerate(token_generator): @@ -2083,6 +2279,7 @@ def main(): quantized_kv_start=args.quantized_kv_start, draft_model=draft_model, num_draft_tokens=args.num_draft_tokens, + mtp=args.mtp, ) if not args.verbose: print(response) diff --git a/mlx_lm/models/deepseek_v4.py b/mlx_lm/models/deepseek_v4.py index ec2d60f97..28b9b9d7c 100644 --- a/mlx_lm/models/deepseek_v4.py +++ b/mlx_lm/models/deepseek_v4.py @@ -65,8 +65,12 @@ def __post_init__(self): + [4 if i % 2 else 128 for i in range(max(n - 2, 0))] + ([0] if n >= 2 else []) ) - self.compress_ratios = list(self.compress_ratios[: self.num_hidden_layers]) - if len(self.compress_ratios) != self.num_hidden_layers: + total_layers = self.num_hidden_layers + self.num_nextn_predict_layers + self.compress_ratios = list(self.compress_ratios[:total_layers]) + # MTP layers default to compress_ratio=0 (no compression) + while len(self.compress_ratios) < total_layers: + self.compress_ratios.append(0) + if len(self.compress_ratios) < self.num_hidden_layers: raise ValueError( "`compress_ratios` must have one entry per hidden layer, " f"got {len(self.compress_ratios)} for {self.num_hidden_layers} layers." @@ -1819,6 +1823,35 @@ def __call__( return self.ffn_hc.expand(x, residual, post, comb) +class MTPBlock(nn.Module): + def __init__(self, config: ModelArgs, layer_idx: int): + super().__init__() + dim = config.hidden_size + self.block = DeepseekV4Block(config, layer_idx) + self.e_proj = nn.Linear(dim, dim, bias=False) + self.h_proj = nn.Linear(dim, dim, bias=False) + self.enorm = nn.RMSNorm(dim, eps=config.rms_norm_eps) + self.hnorm = nn.RMSNorm(dim, eps=config.rms_norm_eps) + self.norm = nn.RMSNorm(dim, eps=config.rms_norm_eps) + self.hc_head = HyperHead(config) + + def __call__( + self, + h: mx.array, + embed_tokens: nn.Embedding, + input_ids: mx.array, + mask: Optional[mx.array], + cache: Optional[Any], + ) -> mx.array: + e = embed_tokens(input_ids) + e = self.enorm(e) + h_norm = self.hnorm(h) + x = self.e_proj(e)[:, :, None, :] + self.h_proj(h_norm) + x = mx.contiguous(x) + x = self.block(x, mask, cache, input_ids) + return x + + class DeepseekV4Model(PipelineMixin, nn.Module): def __init__(self, config: ModelArgs): super().__init__() @@ -1831,7 +1864,12 @@ def __init__(self, config: ModelArgs): self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hc_head = HyperHead(config) - def __call__(self, inputs: mx.array, cache: Optional[Any] = None) -> mx.array: + def __call__( + self, + inputs: mx.array, + cache: Optional[Any] = None, + return_raw_hidden: bool = False, + ) -> mx.array: h = self.embed_tokens(inputs) h = mx.broadcast_to( h[:, :, None, :], @@ -1875,7 +1913,10 @@ def __call__(self, inputs: mx.array, cache: Optional[Any] = None) -> mx.array: if pipeline_size > 1: h = mx.distributed.all_gather(h)[: h.shape[0]] - return self.norm(self.hc_head(h)) + out = self.norm(self.hc_head(h)) + if return_raw_hidden: + return out, h + return out class Model(nn.Module): @@ -1885,9 +1926,24 @@ def __init__(self, config: ModelArgs): self.model_type = config.model_type self.model = DeepseekV4Model(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + if config.num_nextn_predict_layers > 0: + n = config.num_hidden_layers + self.mtp = [ + MTPBlock(config, n + i) + for i in range(config.num_nextn_predict_layers) + ] - def __call__(self, inputs: mx.array, cache: Optional[Any] = None): - return self.lm_head(self.model(inputs, cache)) + def __call__( + self, + inputs: mx.array, + cache: Optional[Any] = None, + return_hidden: bool = False, + ): + if return_hidden: + h, h_raw = self.model(inputs, cache, return_raw_hidden=True) + return self.lm_head(h), h_raw + h = self.model(inputs, cache) + return self.lm_head(h) @property def layers(self): @@ -1915,12 +1971,64 @@ def make_cache(self): caches.append(RotatingKVCache(max_size=self.args.sliding_window)) return caches + def make_mtp_cache(self): + if not hasattr(self, "mtp"): + return None + caches = [] + for mtp_block in self.mtp: + attn = mtp_block.block.attn + if attn.compress_ratio: + caches.append(DeepseekV4Cache(self.args.sliding_window)) + else: + caches.append(RotatingKVCache(max_size=self.args.sliding_window)) + return caches + + def mtp_forward( + self, + h: mx.array, + input_ids: mx.array, + cache: Optional[List[Any]] = None, + ) -> mx.array: + if cache is None: + cache = [None] * len(self.mtp) + + first_cache = cache[0] + mask_cache = ( + first_cache.local + if isinstance(first_cache, DeepseekV4Cache) + else first_cache + ) + mask = create_attention_mask( + h[:, :, 0, :] if h.ndim == 4 else h, + mask_cache, + window_size=self.args.sliding_window, + return_array=True, + ) + + for mtp_block, layer_cache in zip(self.mtp, cache): + h = mtp_block( + h, self.model.embed_tokens, input_ids, mask, layer_cache + ) + + out = mtp_block.hc_head(h) + out = mtp_block.norm(out) + return self.lm_head(out) + def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: n_layers = self.args.num_hidden_layers + has_mtp = hasattr(self, "mtp") + has_mtp_weights = any(k.startswith("mtp.") for k in weights) + # Disable MTP module if weights are absent (e.g. quantized checkpoints) + if has_mtp and not has_mtp_weights: + del self.mtp + has_mtp = False new_weights = {} for k, v in weights.items(): if k.startswith("mtp."): + if not has_mtp: + continue + new_weights[k] = v continue parts = k.split(".") if len(parts) >= 2 and parts[0] == "layers": @@ -2022,8 +2130,23 @@ def dequant_fp4(weight: mx.array, scale: mx.array, block_size: int = 32): remapped = {} w_remap = {"w1": "gate_proj", "w2": "down_proj", "w3": "up_proj"} + mtp_block_subs = ( + "attn.", "ffn.", "attn_norm.", "ffn_norm.", + "hc_attn_", "hc_ffn_", + ) for k, v in weights.items(): nk = "model." + k if k.startswith("layers.") else k + # MTP block: nest block-internal weights under .block. + if nk.startswith("mtp."): + parts = nk.split(".", 2) # ["mtp", "0", "rest"] + if len(parts) == 3: + rest = parts[2] + if any(rest.startswith(s) for s in mtp_block_subs): + nk = f"mtp.{parts[1]}.block.{rest}" + # HC head weights for MTP block + for param in ("fn", "base", "scale"): + if rest == f"hc_head_{param}": + nk = f"mtp.{parts[1]}.hc_head.{param}" nk = nk.replace(".ffn.gate.bias", ".ffn.gate.e_score_correction_bias") for sub in ("attn", "ffn"): for param in ("fn", "base", "scale"): @@ -2033,6 +2156,7 @@ def dequant_fp4(weight: mx.array, scale: mx.array, block_size: int = 32): remapped[nk] = v weights = remapped + # Stack routed expert weights for main model layers for layer_idx in range(n_layers): prefix = f"model.layers.{layer_idx}.ffn.experts" for src, dst in ( @@ -2050,6 +2174,25 @@ def dequant_fp4(weight: mx.array, scale: mx.array, block_size: int = 32): mx.stack(stacked) ) + # Stack routed expert weights for MTP layers + if has_mtp: + for mtp_idx in range(self.args.num_nextn_predict_layers): + prefix = f"mtp.{mtp_idx}.block.ffn.experts" + for src, dst in ( + ("w1", "gate_proj"), + ("w2", "down_proj"), + ("w3", "up_proj"), + ): + key0 = f"{prefix}.0.{src}.weight" + if key0 in weights: + stacked = [ + weights.pop(f"{prefix}.{e}.{src}.weight") + for e in range(self.args.n_routed_experts) + ] + weights[ + f"mtp.{mtp_idx}.block.ffn.switch_mlp.{dst}.weight" + ] = mx.stack(stacked) + return weights def shard(self, group: Optional[mx.distributed.Group] = None): diff --git a/mlx_lm/server.py b/mlx_lm/server.py index ce8d95817..da4d93303 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -983,6 +983,7 @@ def progress(tokens_processed, tokens_total): prompt_cache=cache, draft_model=draft_model, num_draft_tokens=args.num_draft_tokens, + mtp=getattr(self.cli_args, "mtp", False), prompt_progress_callback=progress, prefill_step_size=self.cli_args.prefill_step_size, ): @@ -1790,6 +1791,12 @@ def main(): help="Number of tokens to draft when using speculative decoding.", default=3, ) + parser.add_argument( + "--mtp", + action="store_true", + help="Use native Multi-Token Prediction for speculative decoding " + "(requires a model with an MTP head, e.g. DeepSeek-V4).", + ) parser.add_argument( "--trust-remote-code", action="store_true",