Skip to content

Add MTP speculative decoding for DeepSeek-V4#15

Draft
0xClandestine wants to merge 2 commits intoBlaizzy:pc/add-deepseekv4flash-modelfrom
0xClandestine:feat/add-ds4-mtp
Draft

Add MTP speculative decoding for DeepSeek-V4#15
0xClandestine wants to merge 2 commits intoBlaizzy:pc/add-deepseekv4flash-modelfrom
0xClandestine:feat/add-ds4-mtp

Conversation

@0xClandestine
Copy link
Copy Markdown

Summary

  • Implement native Multi-Token Prediction (MTP) speculative decoding for DeepSeek-V4-Flash
  • MTPBlock wraps DeepseekV4Block with projection layers (e_proj, h_proj), norms, and per-block HyperHead — following the HF reference architecture
  • mtp_generate_step() in generate.py implements the draft/verify loop: backbone predicts token t+1, MTP drafts t+2, next backbone step verifies the draft — emitting up to 2 tokens per backbone forward pass
  • --mtp CLI flag for both generate and server commands, with graceful fallback warning for models without an MTP head
  • Weight sanitization remaps and stacks MTP expert weights from HF checkpoints

Reference: ml-explore/mlx-lm#990 (Qwen3.5 MTP pattern)

Test plan

  • Load DeepSeek-V4-Flash weights with MTP head and verify sanitize remapping
  • Run mlx_lm.generate --mtp and confirm speculative decoding produces coherent output
  • Verify draft acceptance rate and tokens/sec improvement over standard decoding
  • Test --mtp flag on a model without MTP head → confirm fallback warning
  • Test server with --mtp flag

@Blaizzy Blaizzy force-pushed the pc/add-deepseekv4flash-model branch 3 times, most recently from fdede5d to 16f7205 Compare April 26, 2026 09:38
Implement native MTP support following the HF reference architecture
and ml-explore/mlx-lm PR ml-explore#990 patterns:

Model (deepseek_v4.py):
- MTPBlock wrapping DeepseekV4Block with e_proj, h_proj, enorm, hnorm,
  norm, and per-block HyperHead
- return_hidden support in Model.__call__ for exposing raw 4D hidden state
- mtp_forward() and make_mtp_cache() on Model
- Weight sanitization: keep and remap MTP weights, stack MTP experts

Generation (generate.py):
- mtp_generate_step() speculative decoding loop with draft/verify cycle
- Greedy exact-match and probabilistic acceptance modes
- --mtp CLI flag with graceful fallback warning

Server (server.py):
- --mtp CLI flag and stream_generate integration
Quantized checkpoints (e.g. 4-bit) typically strip MTP weights to
save ~3.2 GB. Detect this in sanitize() and delete self.mtp so the
--mtp flag falls back gracefully with a warning.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant