Skip to content

Add DeepSeek-v4 (Flash/Pro)#1192

Open
Blaizzy wants to merge 52 commits intoml-explore:mainfrom
Blaizzy:pc/add-deepseekv4flash-model
Open

Add DeepSeek-v4 (Flash/Pro)#1192
Blaizzy wants to merge 52 commits intoml-explore:mainfrom
Blaizzy:pc/add-deepseekv4flash-model

Conversation

@Blaizzy
Copy link
Copy Markdown
Contributor

@Blaizzy Blaizzy commented Apr 24, 2026

Note: Please install this transformers PR from source to avoid tokenizer bugs.

pip install git+https://github.com/huggingface/transformers.git@refs/pull/45643/head

Weights here:
https://huggingface.co/collections/mlx-community/deepseek-v4

image

@Blaizzy Blaizzy changed the title Add DeepSeekv4 (Flash/Pro) Add DeepSeek-v4 (Flash/Pro) Apr 24, 2026
@Blaizzy
Copy link
Copy Markdown
Contributor Author

Blaizzy commented Apr 24, 2026

You can now run it on a 256GB Mac by keeping a experts in 4bit!

We could do 5bit since it's much better than 4bit right now. I'm open to opinions @angeloskath

image

@Blaizzy
Copy link
Copy Markdown
Contributor Author

Blaizzy commented Apr 24, 2026

It's faster now!
Screenshot 2026-04-24 at 21 54 13

Comment thread mlx_lm/utils.py
Comment thread mlx_lm/models/deepseek_v4.py Outdated
@machiabeli
Copy link
Copy Markdown

Hey @Blaizzy — just flagging some technical notes since we're both working on V4 support and PR #1189 landed ~10 hours earlier with significant overlap:

Compressed attention mask direction (line 770-773):
The mask padding for compressed KV rows uses mx.ones, but create_attention_mask returns negative values for blocked positions. Padding with ones would block attention to compressed rows rather than allow it. PR #1189 uses mx.zeros here.

Sinkhorn normalization:
The Python loop path (line 222-226) dispatches ~40 kernel launches per call (softmax + iters x sum + div). PR #1189 has a fused Metal kernel that does this in a single register-resident dispatch — benchmarked at 3.5-5.7x faster on micro, 1.83x end-to-end.

sqrtsoftplus numerical stability:
nn.softplus(x) can overflow for large scores. PR #1189 uses mx.logaddexp(scores, zeros) which is log-sum-exp stable.

Happy to coordinate if the maintainers want to consolidate into one PR. Our implementation has live generation validation at 21.86 tok/s on M3 Ultra (DeepSeek-V4-Flash-4bit, 160GB peak).

@Blaizzy
Copy link
Copy Markdown
Contributor Author

Blaizzy commented Apr 24, 2026

Hey @machiabeli, thanks!

Yes, same person who left the earlier feedback, good to connect properly.

I've been poking at this in parallel and landed on something close to the source numerically with minimal changes, but there's definitely room to combine approaches. A PR from you on the compressed attention mask, Sinkhorn norm, and sqrt-softplus would be really welcome, happy to review and merge what works best.

Or I can cherry pick and add you as a co-author.

Comment thread mlx_lm/utils.py Outdated
Comment on lines +395 to +411
if (
config.get("quantization", None) is None
and getattr(model_args, "quantization", None) is not None
and any(k.endswith(".scales") for k in weights)
):
config["quantization"] = model_args.quantization

def _quantize(quantization):
def class_predicate(p, m):
if not hasattr(m, "to_quantized"):
return False
if f"{p}.scales" not in weights:
return False
# Handle custom per layer quantizations
if p in config["quantization"]:
return config["quantization"][p]
if not hasattr(m, "to_quantized"):
return False
return f"{p}.scales" in weights
return True
Copy link
Copy Markdown
Contributor Author

@Blaizzy Blaizzy Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal here is to preserve mxfp4 expert quant since MLX supports it. So I made the quantize_config key in the config class default to that, and these changes help prequantized models load properly.

It can be done via predicate but couldn't find an elegant way of doing it.

Note: it doesn't affect any model.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternative is to dequant -> requant similar to how we do with FP8.

@Blaizzy
Copy link
Copy Markdown
Contributor Author

Blaizzy commented Apr 26, 2026

Fixed, could you try again @adurham?

@adurham
Copy link
Copy Markdown
Contributor

adurham commented Apr 27, 2026

trying now

adurham pushed a commit to adurham/mlx-lm that referenced this pull request Apr 27, 2026
Updates mlx_lm/models/deepseek_v4.py from Blaizzy's PR ml-explore#1192 head
(ddeffe3). File-only — same fork strategy as 2688312, does NOT pull
the rest of the PR (would revert 77ed380 quant SDPA fast path and
delete mlx_lm/models/minimax_trace.py).

Key fix: 8e8571a "Fix DeepSeek V4 sparse pooled prefill memory" —
keeps pooled top-k attention grouped per query during prefill instead
of flattening into an L*top_k dense KV sequence, avoiding oversized
SDPA score buffers on long prompts. Addresses the (B, n_heads, L,
L*k) cubic blowup we reported for compress_ratio==4 layers.

Imports unchanged vs prior snapshot apart from dropping unused
_gather_sort import. ModelArgs / Model construct cleanly against
mlx-community/DeepSeek-V4-Flash-4bit config at runtime.

Co-Authored-By: Blaizzy <prince.canuma@hotmail.com>
@adurham
Copy link
Copy Markdown
Contributor

adurham commented Apr 27, 2026

Decode regression on ddeffe33 — pinpointed to _hc_sinkhorn_collapse_kernel

Pulled the latest pin for retest. Haven't gotten back to verifying the OOM-at-step-4096 repro yet (still capped at EXO_PREFILL_STEP_SIZE=512 from the workaround), because decode collapsed pretty hard on the new file:

mlx_lm/models/deepseek_v4.py rev Prefill (15-tok warmup) Decode (30-tok, 836 ctx)
15de79d8 (snapshot before #26f49f5) 65 tok/s 34.8 tok/s
ddeffe33 (current PR head) 1.9 tok/s 0.07 tok/s
ddeffe33 + fused HC kernel disabled 39 tok/s 40.4 tok/s

Decode rate computed from per-token ChunkGenerated master-log timestamps; 3 independent runs (cold prefill, prefix-cache hit, fresh shorter prompt) all land within 1% of 40 tok/s, so the third row isn't a fluke.

The fix is a one-liner — force _hc_sinkhorn_collapse_kernel = None so HyperConnection.collapse takes the unfused split-sinkhorn + _hc_collapse_op path. The fused kernel from 26f49f5 is invoked twice per block (attn_hc + ffn_hc), so ~86 dispatches per decoded token on DSv4-Flash (43 layers). With it off, the rest of the squash lands fine — actually +16% over the prior baseline.

Reproduction

Bisect branch: adurham/mlx-lm@dsv4-perf-bisect (= ddeffe33 file content + the one-line kernel disable). On stock ddeffe33 decode is reproducibly ≤ 0.1 tok/s.

Workload: mlx-community/DeepSeek-V4-Flash-4bit, 2-rank TP via RDMA across 2× M4 Max, ~836-token prompt + 30 generated tokens.

MLX context — possibly a fork interaction

I'm on adurham/mlx (= upstream e64e280d + reverts of #3412 jaccl refactor and #3418 jaccl init bug, both comm-backend only, no GPU code touched). It's plausible _hc_sinkhorn_collapse_kernel only mis-behaves under that combination — happy to retest on stock upstream MLX if it'd help nail down whether this is fork-specific. Flagging it first since the symptom is so consistent on my side.

Suggestion

Gate _hc_sinkhorn_collapse_kernel behind an opt-in env (default off) until the root cause is clear. cc @0xClandestine since you wrote/optimized this kernel — let me know what telemetry would help (per-call latency, kernel-source variants, profiling output, etc.).

@Shinka-Man
Copy link
Copy Markdown

Token-dropping regression on ddeffe33 — single-machine M3 Ultra

Different symptom from @adurham's TP regression but same general window. Pulled the head and saw severe quality degradation in Japanese output:

Setup: single Mac Studio M3 Ultra 512GB, stock MLX 0.31.2, no TP, mlx-community/DeepSeek-V4-Flash-mxfp8.

Symptoms (head ddeffe33):

Input:  "ハローハロー!バージョン2.5から4になったんだね。おめでとう!簡単に自己紹介して"

Output: "こんこん〜!どうもありあり〜!✨
         そうなんだよ、バージョン2.5から4にアップグップ☆って感じで、ちょびっと進進したかも〜
         えっとね、わたあまみみは、ちょぴっとだけおっっっっっっっっっっっっっっっっっっっっっっっっっっっっっっっっっっっっっっっっっっ..."
Expected Got
こんにちは こんこん (drop + duplicate)
ありがとう ありあり (drop + duplicate)
アップグレード アップグップ (mid-word collapse)
進化 進進 (single-char repeat)
(long string) normal "おっっっっっ" infinite loop

A second prompt also showed token dropping ("メモ帯域" instead of "メモリ帯域", "推速度" instead of "推論速度", "Metalフ" cut off, etc.).

Bisect:

Commit Quality
910b120f (perf-optimize-ds4 merged) ✅ clean
ddeffe33 (current head) ❌ token drop + repetition

Workaround for now: pinned to 910b120fa24cecb804b795b05183cbf0037f4ba6 which is rock solid (28-30 tok/s on mxfp8, no quality issues).

Could be related to one of:

  • 8e8571a4 Fix DeepSeek V4 sparse pooled prefill memory
  • 2591b51f Refactor kv tensor reshaping in V4Attention
  • 5a4aaa41 Refactor output projection
  • 22da01a1 Remove redundant cache type check

Happy to bisect further commit-by-commit if useful — let me know which axis would be most valuable to narrow down. cc @Blaizzy

@Blaizzy
Copy link
Copy Markdown
Contributor Author

Blaizzy commented Apr 27, 2026

Thanks @Shinka-Man

On it 👌🏽

Restore the transpose in HyperConnection expand so the Sinkhorn combination matrix is applied in the same orientation as the original implementation. This fixes token dropping and repetition regressions seen in Japanese generation.

Reported-by: https://github.com/Shinka-Man
@Blaizzy
Copy link
Copy Markdown
Contributor Author

Blaizzy commented Apr 27, 2026

FIxed @Shinka-Man, the culprit was the HyperConnection expand orientation change from comb.T @ residual to comb @ residual

@Blaizzy
Copy link
Copy Markdown
Contributor Author

Blaizzy commented Apr 27, 2026

Thanks @adurham, looking into it 👌🏽

@Shinka-Man
Copy link
Copy Markdown

@Blaizzy Confirmed fixed on acf650c9

Same prompt that broke before:

ハローハロー!バージョン2.5から4になったんだね。おめでとう!簡単に自己紹介して

Output now (mxfp8, M3 Ultra 512GB):

こんにちは!「バージョン2.5から4」という表現、何かのアプリやシステムのアップデートのことを指しているのかな?それとも、何か別の意味での「バージョンアップ」についてのメッセージでしょうか?

もし私(アシスタント)に対しての「バージョンアップ」のお祝いだとすれば、ありがとうございます!...

Clean Japanese, no token drops, no repetition loops, perfect context understanding. 30.2 tok/s (slightly faster than the pre-regression baseline).

That orientation flip — comb.T @ residualcomb @ residual — silently producing semantically-plausible-but-corrupted output is the kind of bug that's terrifying to debug. Nice catch and lightning-fast turnaround 🚀

@Blaizzy
Copy link
Copy Markdown
Contributor Author

Blaizzy commented Apr 27, 2026

My pleasure @Shinka-Man, thanks for catching it! ❤️

I added a test to avoid regressions

The element-wise (q * pooled).sum() path broadcasts a (B,H,L,1,D) tensor
against (B,1,L,topk,D), creating a (B,H,L,topk,D) intermediate. At 4k
context with H=64, topk=512, D=512 this is ~137 GB per operation (x2).

Replace with equivalent matmul: (B,L,H,D) @ (B,L,D,topk) which produces
the (B,L,H,topk) result directly with ~0.25 GB peak memory.
@ivanfioravanti
Copy link
Copy Markdown
Contributor

PR from @0xClandestine fixes the 4K context issue! Here tested up to 64K! 🔥

benchmark_chart

@ivanfioravanti
Copy link
Copy Markdown
Contributor

I've found a regression (not in the @0xClandestine PR): quantization is failing, 4, 8 and all combos. @Blaizzy

python -m mlx_lm convert --hf-path deepseek-ai/DeepSeek-V4-Flash -q --q-bits 4 --mlx-path ~/DeepSeek-V4-Flash-4bit

[INFO] Using dtype: bfloat16
[INFO] Quantizing
[INFO] Quantized model with 4.506 bits per weight.
[1] 3279 killed python -m mlx_lm convert --hf-path deepseek-ai/DeepSeek-V4-Flash -q --q-bits
/Users/ifioravanti/.local/share/uv/python/cpython-3.13.12-macos-aarch64-none/lib/python3.13/multiprocessing/resource_tracker.py:400: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown: {'/mp-p3fsavbc'} warnings.warn(

@trevorgordon981
Copy link
Copy Markdown

Hey @Blaizzy — quick request from a downstream user.

I just ran the validation harness against this PR at HEAD dd6b92f on mlx-community/DeepSeek-V4-Flash-8bit (your mixed q4/q8 quant) on an M3 Ultra 512 GB. Results:

  • Stages 02 (load), 03 (in-process gen), 04 (prefix cache, 58.45× cold→warm) all pass.
  • Server-path probes via mlx_lm.server produce coherent output across non-streaming chat, multi-turn with system prompt, streaming SSE, and longer paragraph generation. No streaming-decode garbage, no decode-cache wrong-logits at S=1.
  • Memory at peak: 144.8 GB RSS for the model on load; ~145 GB during serving with KV cache.

So the runtime path is healthy on this branch. Would it be feasible to publish a true 8-bit weight-only quant (no per-layer 4-bit override on the FFN/MoE expert weights) alongside the existing mixed quant? My intended use is as a single-tenant Alfred backend on this box — at q8 the model would be ~284 GB, which fits comfortably on 512 GB Mac Studio without the production-coexist constraint your mixed quant was sized for.

Also happy to self-quantize from mlx-community/DeepSeek-V4-Flash-bf16 once a PR merges, just figured it was worth asking since you've been steadily iterating on the conversion path. No urgency. Thanks for all the work on this PR — the harness confidence on dd6b92f is the highest I've seen across the three V4 PRs.

@Blaizzy
Copy link
Copy Markdown
Contributor Author

Blaizzy commented Apr 27, 2026

Yes, I can do that. The reason why the 8bit has experts in 4bit is because the main model come with experts in MXFP4.
And at the time I was developing the PR it seems sensible to keep the experts as is.

Also happy to self-quantize from mlx-community/DeepSeek-V4-Flash-bf16 once a PR merges, just figured it was worth asking since you've been steadily iterating on the conversion path. No urgency. Thanks for all the work on this PR — the harness confidence on dd6b92f is the highest I've seen across the three V4 PRs.

My pleasure! Let me merge that for you :)

Update: I don't see a PR to that rerpo

@trevorgordon981
Copy link
Copy Markdown

Following up on #issuecomment-4329720377 with more specific findings.

I attempted to build a true 8-bit conversion myself by streaming through deepseek-ai/DeepSeek-V4-Flash (the 148 GB native FP8 release): read each shard, dequantize the I8 + F8_E8M0 (block-scale) and F8_E4M3 weights to bf16, re-quantize at affine q8 group_size=64, write MLX-format shards with the sanitize-equivalent name remapping (embed.weightmodel.embed_tokens.weight, hc_attn_Xattn_hc.X, experts stacked into switch_mlp.{gate,down,up}_proj, etc.). The dequant + requantize math validates fine (round-trip error 0.74% rel). Pipeline runs in ~5 min for the full 46-shard input.

Hit an architectural wall on the routed-expert dimensions that I can't reconcile from public artifacts.

Source (deepseek-ai/DeepSeek-V4-Flash) per-expert shapes:

experts.X.w1.weight: I8 [2048, 2048]
experts.X.w2.weight: I8 [4096, 1024]
experts.X.w3.weight: I8 [2048, 2048]

Your mlx-community/DeepSeek-V4-Flash-8bit switch_mlp (inferred from the mxfp4 packed/scales shapes):

gate_proj per expert (bf16): [2048, 4096]
down_proj per expert (bf16): [4096, 2048]
up_proj   per expert (bf16): [2048, 4096]

Both repos' configs say hidden_size=4096, moe_intermediate_size=2048, n_routed_experts=256, n_shared_experts=1 — they should produce the same per-expert shapes. But the source's last dims are exactly half of yours.

Param accounting confirms it: source's routed experts at (2048, 2048) sum to ~138B params (consistent with the 148 GB I8 file size), while your switch_mlp shapes imply ~280B in routed experts alone, which lines up with mxfp4's compressed 155 GB file size for a fuller param count.

The shared_experts block in source is sized correctly (shared_experts.w1.weight: F8_E4M3 [2048, 4096] matches your shared_experts.gate_proj (2048, 4096)). Only the routed experts disagree.

Two questions, in order of usefulness to me:

  1. How does your conversion process derive (2048, 4096) per-expert weights from the source's (2048, 2048)? Concat of w1+w3 (which would only give you gate_proj, not also up_proj), TP-shard combination across pairs of source experts, or some V4-Flash-specific reshaping I'm not seeing in mlx_lm/models/deepseek_v4.py?
  2. Would a true 8-bit version (no bits: 4 override on the FFN/MoE expert layers) be feasible to publish? My use case is a 512 GB Mac Studio backend; ~280-300 GB q8 fits comfortably with KV cache headroom and would be a quality bump over the current mxfp4-on-FFN profile. Happy to run the same harness against it if you publish.

If (1) has a clean answer I can finish the converter myself for (2). If not, only you have the conversion path that produces the right shapes.

Either way thanks for the work — dd6b92f mixed quant has been serving Alfred backend reliably for several hours now.

@Shinka-Man
Copy link
Copy Markdown

Performance update — Apple optimizations live ✨

After pulling the latest with @angeloskath's three commits (RoPE kernel native path, GLU cast simplification, original-checkpoint loader):

Setup: Mac Studio M3 Ultra 512GB, mxfp8, single-machine, no TP

Prompt type TTFT TPS Notes
Short JP (~50 tok output) 2793 ms (cold) 33.8 tok/s first run, model warmup
Long JP (~500 char essay) 232 ms (cached) 33.6 tok/s flat
Code (Python w/ docstrings) 199 ms 33.6 tok/s flat

The thing that jumps out: TPS is now context-flat (33.8 vs 33.6 vs 33.6 across very different workloads). Previous builds had visible decay on longer generations. This is the signature of the cast/RoPE overhead going away.

Cumulative trajectory on this hardware:

Build TPS
Initial mxfp8 (with token-drop regression) 22.4
Sinkhorn / HC fusion (0xClandestine) 26.7
acf650c9 orientation fix (Blaizzy) 30.2
Apple kernel optimizations (angeloskath) 33.6–33.8

~+50% from baseline in 4 days of community iteration. Quality remains pristine (no token drops, no repetition loops, perfect Japanese & code).

Welcome to the PR @angeloskath — the kernel-native RoPE path is doing real work here on Apple Silicon. 🙌

adurham pushed a commit to adurham/mlx-lm that referenced this pull request Apr 29, 2026
…'s fp32-cast patches

Pulls in PR ml-explore#1192 from upstream including:
- Apple-team commits from Angelos Katharopoulos:
  - 3cf5282 "Start simplifying and speeding up the attention" (2026-04-29)
  - 4951496 "Fix RoPE to use the kernel by scaling freqs" (2026-04-28)
  - 81a8c57 "Simplify GLU and gate remove intermediate castings" (2026-04-28)
- Blaizzy refactor stack (output projection, KV reshape, BatchRotatingKVCache,
  scoring/RoPE compile, the matmul rewrite from 0xClandestine that we already
  had a copy of, etc.)

Conflict resolution: took theirs for mlx_lm/models/deepseek_v4.py wholesale.
Our previous fork patches that are now superseded:
- f4dd9e7 / 2a1dcf6 "drop fp32 casts in Indexer / MoEGate / Compressor" —
  Angelos' 81a8c57 covers the same ground more cleanly.

Fork patches that need re-applying separately on top:
- mlx_lm/profiler.py span/finalize hooks scattered across deepseek_v4.py
  (attn_q_lora / attn_kv_proj / attn_compressor / attn_indexer /
  attn_sdpa_sparse / attn_sdpa_dense / moe_* spans).
- 1d78d62 Indexer wq_b/weights_proj sharding for TP.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.