Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 204 additions & 7 deletions mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
import copy
import functools
import json
import math
import random
import sys
import time
import warnings
from collections import deque
from dataclasses import dataclass
from functools import partial
Expand Down Expand Up @@ -219,6 +222,12 @@ def setup_arg_parser():
help="Number of tokens to draft when using speculative decoding.",
default=3,
)
parser.add_argument(
"--mtp",
action="store_true",
help="Use native Multi-Token Prediction for speculative decoding "
"(requires a model with an MTP head, e.g. DeepSeek-V4).",
)
return parser


Expand Down Expand Up @@ -654,12 +663,186 @@ def _draft_generate(y, num_draft):
_rewind_cache(num_draft, n)


def mtp_generate_step(
prompt: mx.array,
model: nn.Module,
*,
max_tokens: int = 256,
sampler: Optional[Callable[[mx.array], mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
prompt_cache: Optional[Any] = None,
prefill_step_size: int = 2048,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
) -> Generator[Tuple[mx.array, mx.array, bool], None, None]:
"""A generator that uses the model's native MTP head for speculative decoding.

Each iteration runs one backbone forward pass over the current token and its
pending draft, then one MTP forward pass to propose the next draft. Up to 2
tokens are emitted per backbone step: one always-accepted backbone token and
one conditionally-accepted draft token.

The model must implement ``mtp_forward(hidden, next_tok, mtp_cache)`` and
support ``return_hidden=True`` in its ``__call__``.

Yields:
Tuple[mx.array, mx.array, bool]: (token, log-probabilities, from_draft).
``from_draft`` is ``True`` when the token came from the MTP head.
"""
y = prompt.astype(mx.uint32)
prev_tokens = None

if prompt_cache is None:
model_cache = cache.make_prompt_cache(model)
mtp_cache = model.make_mtp_cache()
else:
n_main = len(model.layers)
model_cache = prompt_cache[:n_main]
mtp_cache = prompt_cache[n_main:] or model.make_mtp_cache()

_is_greedy = sampler is None
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))

quantize_cache_fn = functools.partial(
maybe_quantize_kv_cache,
quantized_kv_start=quantized_kv_start,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)

def _process_and_sample(tokens, logits):
if logits_processors:
for processor in logits_processors:
logits = processor(tokens, logits)
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
return sampler(logprobs), logprobs

def _step_backbone(y, prev_tokens, n_predict=1):
with mx.stream(generation_stream):
logits, hidden = model(
y[None], cache=model_cache, return_hidden=True
)
logits = logits[:, -n_predict:, :]
quantize_cache_fn(model_cache)
toks, lps = [], []
for i in range(n_predict):
if logits_processors:
prev_tokens = (
mx.concatenate([prev_tokens, y[i : i + 1]])
if prev_tokens is not None
else y[i : i + 1]
)
tok, lp = _process_and_sample(prev_tokens, logits[:, i, :].squeeze(0))
toks.append(tok)
lps.append(lp)
return mx.stack(toks), mx.stack(lps), hidden, prev_tokens

def _step_mtp(hidden_last, main_tok, prev_tokens):
next_ids = main_tok.reshape(1, 1)
with mx.stream(generation_stream):
mtp_logits = model.mtp_forward(hidden_last, next_ids, mtp_cache)
quantize_cache_fn(mtp_cache)
mtp_logits = mtp_logits[:, -1, :].squeeze(0)
if logits_processors:
tokens_for_proc = (
mx.concatenate([prev_tokens, main_tok.reshape(-1)])
if prev_tokens is not None
else main_tok.reshape(-1)
)
else:
tokens_for_proc = prev_tokens
draft_tok, draft_lp = _process_and_sample(tokens_for_proc, mtp_logits)
return draft_tok, draft_lp

def _prefill(y):
while y.size > 1:
n = min(prefill_step_size, y.size - 1)
model(y[:n][None], cache=model_cache)
quantize_cache_fn(model_cache)
mx.eval([c.state for c in model_cache if hasattr(c, "state")])
y = y[n:]
mx.clear_cache()
return y

with mx.stream(generation_stream):
y = _prefill(y)

ntoks = 0
draft_tok = draft_lp = None

while ntoks < max_tokens:
if draft_tok is None:
toks, lps, hidden, prev_tokens = _step_backbone(y, prev_tokens, n_predict=1)
mx.eval(toks)
main_tok, main_lp = toks[0], lps[0]
ntoks += 1
yield main_tok.item(), main_lp, False
if ntoks >= max_tokens:
return
hidden_at_main = hidden[:, -1:, :]
draft_tok, draft_lp = _step_mtp(hidden_at_main, main_tok, prev_tokens)
mx.eval(draft_tok)
y = mx.array([main_tok.item()], mx.uint32)
else:
y_with_draft = mx.concatenate([y, mx.array([draft_tok.item()], mx.uint32)])
toks, lps, hidden, prev_tokens = _step_backbone(
y_with_draft, prev_tokens, n_predict=2
)
mx.eval(toks, draft_tok)

verify_pred, bonus_tok = toks[0], toks[1]
verify_lp, bonus_lp = lps[0], lps[1]
draft_tok_id = draft_tok.item()

if _is_greedy:
accept = verify_pred.item() == draft_tok_id
else:
log_accept = (verify_lp[draft_tok_id] - draft_lp[draft_tok_id]).item()
accept = log_accept >= 0 or random.random() < math.exp(log_accept)

hidden_at_confirmed = hidden[:, 0:1, :]
hidden_at_draft = hidden[:, 1:2, :]

if accept:
ntoks += 1
yield draft_tok_id, draft_lp, True
if ntoks >= max_tokens:
return
ntoks += 1
yield bonus_tok.item(), bonus_lp, False
if ntoks >= max_tokens:
return
draft_tok, draft_lp = _step_mtp(hidden_at_draft, bonus_tok, prev_tokens)
mx.eval(draft_tok)
y = mx.array([bonus_tok.item()], mx.uint32)
else:
# Reject draft: trim the draft token from both caches
for c in model_cache:
c.trim(1)
for c in mtp_cache:
c.trim(1)
if logits_processors and prev_tokens is not None:
prev_tokens = prev_tokens[:-1]
verify_tok_id = verify_pred.item()
ntoks += 1
yield verify_tok_id, verify_lp, False
if ntoks >= max_tokens:
return
draft_tok, draft_lp = _step_mtp(
hidden_at_confirmed, verify_pred, prev_tokens
)
mx.eval(draft_tok)
y = mx.array([verify_tok_id], mx.uint32)


def stream_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, mx.array, List[int]],
max_tokens: int = 256,
draft_model: Optional[nn.Module] = None,
mtp: bool = False,
**kwargs,
) -> Generator[GenerationResponse, None, None]:
"""
Expand All @@ -675,6 +858,8 @@ def stream_generate(
draft_model (Optional[nn.Module]): An optional draft model. If provided
then speculative decoding is used. The draft model must use the same
tokenizer as the main model. Default: ``None``.
mtp (bool): Use native Multi-Token Prediction for speculative
decoding. Requires a model with an MTP head. Default: ``False``.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.

Expand All @@ -698,19 +883,30 @@ def stream_generate(

kwargs["max_tokens"] = max_tokens

if draft_model is None:
if draft_model is not None:
kwargs.pop("max_kv_size", None)
kwargs.pop("prompt_progress_callback", None)
token_generator = speculative_generate_step(
prompt, model, draft_model, **kwargs
)
elif mtp and hasattr(model, "mtp"):
kwargs.pop("max_kv_size", None)
kwargs.pop("prompt_progress_callback", None)
kwargs.pop("num_draft_tokens", None)
token_generator = mtp_generate_step(prompt, model, **kwargs)
else:
if mtp:
warnings.warn(
"--mtp flag ignored: model does not have an MTP head. "
"Falling back to standard generation.",
stacklevel=2,
)
kwargs.pop("num_draft_tokens", None)
token_generator = generate_step(prompt, model, **kwargs)
# from_draft always false for non-speculative generation
token_generator = (
(token, logprobs, False) for token, logprobs in token_generator
)
else:
kwargs.pop("max_kv_size", None)
kwargs.pop("prompt_progress_callback", None)
token_generator = speculative_generate_step(
prompt, model, draft_model, **kwargs
)
with wired_limit(model, [generation_stream]):
tic = time.perf_counter()
for n, (token, logprobs, from_draft) in enumerate(token_generator):
Expand Down Expand Up @@ -2083,6 +2279,7 @@ def main():
quantized_kv_start=args.quantized_kv_start,
draft_model=draft_model,
num_draft_tokens=args.num_draft_tokens,
mtp=args.mtp,
)
if not args.verbose:
print(response)
Expand Down
Loading