Skip to content
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
bb6978a
Optimize HC sinkhorn Metal kernel with float4 SIMD and bounds guard
0xClandestine Apr 24, 2026
4d33a92
Round grid up to 256-multiple for Metal dispatch safety
0xClandestine Apr 24, 2026
1740f5c
Fix HC sinkhorn kernel: drop dead #if preprocessor guard
0xClandestine Apr 24, 2026
0c76952
Fix sinkhorn kernel bugs and add split sparse attention for prefill
0xClandestine Apr 24, 2026
fbad2b7
Fix DS4 sanitize: stack grouped wo_a, remap quant metadata keys
0xClandestine Apr 24, 2026
b2e0456
Fix sanitize for pre-quantized DS4: stacked experts, biases, group_size
0xClandestine Apr 24, 2026
43b8430
Fix tokenizer loading for custom model types with rope_scaling
0xClandestine Apr 24, 2026
e7f62b6
Replace metal::fast::recip with 1/x for mlx 0.31 compatibility
0xClandestine Apr 24, 2026
6755555
Add fused Metal sparse attention kernel for DS4 prefill
0xClandestine Apr 25, 2026
e4687ca
Fix fused sparse attn kernel: address space and reference cast bugs
0xClandestine Apr 25, 2026
71e337b
Fix sanitize: don't drop expert .biases keys needed by QuantizedSwitc…
0xClandestine Apr 25, 2026
fef630b
Fix fused sparse attention kernel: update address space for topk_idxs…
Blaizzy Apr 25, 2026
d64e3fb
Fix fused sparse attn: use device int32_t* for topk_idxs pointer
0xClandestine Apr 25, 2026
0597549
Fix fused sparse attn: topk_idxs is constant address space, not device
0xClandestine Apr 25, 2026
9bd7cd8
Fix fused sparse attention Metal kernel: address space, dispatch, out…
0xClandestine Apr 25, 2026
981f2fa
Fix generate_step crash: flatten extra batch dims in model __call__
0xClandestine Apr 25, 2026
40341ea
Fix decode regression: skip indexer for L==1, use full pooled KV dire…
0xClandestine Apr 25, 2026
5ee3d96
Add fused partial RoPE Metal kernel to eliminate split/concat interme…
0xClandestine Apr 25, 2026
40b1c02
Optimize decode path: q-norm kernel, overlap_transform, HC fn bf16
0xClandestine Apr 25, 2026
cd74c13
Defer Compressor wkv/wgate GEMVs until a full window is ready
0xClandestine Apr 25, 2026
11c5397
Three more decode optimizations: collapse/expand kernels + MoE gate bf16
0xClandestine Apr 25, 2026
7579a71
Revert "Three more decode optimizations: collapse/expand kernels + Mo…
0xClandestine Apr 25, 2026
afdbaf1
Fix ModelArgs: __post_init__ references quantization_config not quant…
0xClandestine Apr 25, 2026
09c0874
Remove unused quantization_config from ModelArgs and update __post_in…
Blaizzy Apr 25, 2026
2cf48d2
Add @mx.compile decorator to fused_sparse_attention and _split_sparse…
Blaizzy Apr 25, 2026
8726c2e
format
Blaizzy Apr 25, 2026
3737312
Fix prefill slowdown: skip fused sparse attn kernel for L > 1
0xClandestine Apr 25, 2026
ea99374
Revert rope kernel and cache x-buffering per maintainer feedback
0xClandestine Apr 25, 2026
9126d7c
Address review feedback: revert n_rows guard, fused attn kernel, HC f…
0xClandestine Apr 25, 2026
eecab43
Vectorize HC sinkhorn kernel: float4 loads, unroll inner loops
0xClandestine Apr 26, 2026
a3c04af
Restore cast_predicate HC exclusions accidentally dropped in review c…
0xClandestine Apr 26, 2026
da913ff
Replace _split_sparse_attention with standard SDPA for prefill
0xClandestine Apr 26, 2026
de4c25c
Revert SwitchGLU, HyperConnection, and compress path to match upstream
0xClandestine Apr 26, 2026
1e7fcf8
Revert README to match upstream
0xClandestine Apr 26, 2026
a0bcde7
Remove dead code: q-norm kernel and _split_sparse_attention
0xClandestine Apr 26, 2026
c55bc53
Compile MoE expert selection + skip empty pooled tensor processing
0xClandestine Apr 26, 2026
fba00da
Fuse HyperHead into single compiled graph
0xClandestine Apr 26, 2026
28a5a1f
Cache repeated reshape/cast ops in attention hot path
0xClandestine Apr 26, 2026
ea8ea24
Fix _ensure_cached None check before dtype comparison
0xClandestine Apr 26, 2026
49e37b0
Fuse HC compute_weights into single compiled graph
0xClandestine Apr 26, 2026
9929953
Fuse sinkhorn + collapse into single Metal kernel dispatch
0xClandestine Apr 26, 2026
9a69e7b
Use explicit adds and ILP-friendly layout in sinkhorn kernels
0xClandestine Apr 26, 2026
245f6cb
Branchless sinkhorn + native bfloat4 loads in fused collapse kernel
0xClandestine Apr 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions README.md
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let’s remove the readme changes for now

Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,80 @@ model, tokenizer = load(
)
```

### DeepSeek V4 / DeepSeek V4 Flash

This fork adds native support for **DeepSeek-V4** and **DeepSeek-V4-Flash** on Apple Silicon, including full Metal kernel acceleration.

#### Usage

```python
from mlx_lm import load, generate

model, tokenizer = load("path/to/deepseek-v4-flash")
text = generate(model, tokenizer, prompt="Explain attention sinks.", verbose=True)
```

Or from the command line:

```bash
mlx_lm.generate --model path/to/deepseek-v4-flash --prompt "Explain attention sinks."
```

The model type is `deepseek_v4`. Pre-quantized checkpoints (FP8 or FP4 experts) are loaded and dequantized automatically — no manual conversion step is required.

#### Architecture

DeepSeek V4 Flash introduces several architectural innovations that this implementation fully supports:

**HyperConnection** replaces standard residual connections. Each layer maintains `hc_mult=4` parallel hidden streams that are combined through a learned Sinkhorn-normalized mixing matrix. A custom Metal kernel (`_make_hc_split_sinkhorn_kernel`) computes the 4×4 doubly-stochastic combination weights using float4 SIMD and online Sinkhorn iterations entirely on-GPU.

**Compressed attention (Compressor + Indexer)** provides long-range context without quadratic cost. At every layer with `compress_ratio > 0`, hidden states are pooled into a compressed KV sequence (at ratio 4 with overlap or ratio 128 for long range). During decode, an x-buffer defers the expensive `wkv`/`wgate` projections until a full compression window is ready, saving `(ratio−1)/ratio` GEMVs per step. An Indexer then selects the top-k most relevant compressed KV entries per query head using learned index projections.

**Sparse attention paths** handle prefill and decode differently:
- Prefill (`L > 1`): a fused Metal kernel (`ds4_fused_sparse_attn`) computes online softmax over the local sliding-window KV and the top-k sparse compressed KV in a single pass, avoiding materialising the `[B, L, topk, D]` gather intermediate.
- Decode (`L = 1`): the indexer is skipped (compressed pool fits within `index_topk` anyway); standard SDPA is used over `[local KV ∥ pooled KV]`.

**Attention sinks** add a learnable virtual token to every attention layer whose score is a per-head bias and whose value contribution is zero, stabilising attention distributions over long contexts.

**Mixture of Experts** uses 256 routed experts plus 1 shared expert per MoE layer. Expert routing uses `sqrtsoftplus` scoring with auxiliary-free top-k selection (`noaux_tc`). The first `num_hash_layers` layers use hash-based routing. `LimitedSwiGLU` clamps gate and up projections to prevent activation overflow.

**Grouped output projection** splits the large O-projection into 8 groups (`o_groups=8`), each with its own low-rank A matrix (`wo_a`) shared across groups, reducing peak memory during the projection step.

**Partial RoPE** applies rotary embeddings only to the `qk_rope_head_dim`-sized suffix of each head, leaving the `nope` prefix unrotated. A dedicated Metal kernel (`ds4_partial_rope`) fuses this with the split/concat that a naive implementation would require.

**Per-head query RMS norm** (`ds4_q_norm`) normalises each query head in a fused Metal kernel before the RoPE step, replacing the `mx.rsqrt` + elementwise pattern.

#### Metal Kernels

| Kernel | Purpose |
|---|---|
| `ds4_partial_rope` | Fused partial RoPE — eliminates split/concat intermediate |
| `ds4_q_norm` | Per-head query RMS normalisation |
| `_make_hc_split_sinkhorn_kernel` | float4 SIMD Sinkhorn for HyperConnection mixing weights |
| `ds4_fused_sparse_attn` | Online-softmax prefill over local + sparse KV + attention sink |
| `_split_sparse_attention` | MLX fallback for `ds4_fused_sparse_attn` on CPU or older Metal |

All kernels fall back gracefully to pure MLX when Metal is unavailable.

#### Quantization Support

The `sanitize` method handles pre-quantized checkpoints transparently:

- **FP8 weights** (E4M3/E5M2, block-scaled) are dequantized to BF16 on load.
- **FP4/MXFP4 expert weights** are unpacked from the 4-bit lookup table and dequantized to BF16, then re-quantized using MLX's native group-quantized matmul format.
- The safetensors loader is extended to reinterpret the `F8_E8M0` dtype used by some HuggingFace checkpoints that standard MLX cannot parse.

Precision-sensitive parameters (attention sinks, HyperConnection base/scale, expert correction biases) are excluded from any subsequent `cast` operations via `cast_predicate`.

#### Cache

`DeepseekV4Cache` wraps `RotatingKVCache` for the local sliding window and adds two parallel state buffers (compressor and indexer). It implements the full `BatchRotatingKVCache` interface — supporting `extract`, `extend`, `merge`, `filter`, and `trim` — so batch generation and prompt caching work out of the box.

#### Infrastructure Changes

- **`tokenizer_utils.py`**: adds a fallback to `PreTrainedTokenizerFast` when `AutoTokenizer` raises `AttributeError` on custom model types whose config triggers transformers' `rope_scaling` standardisation before `max_position_embeddings` is available.
- **`utils.py`**: adds `_load_safetensors` which patches `F8_E8M0` dtype headers in-place before loading, allowing FP8-quantized checkpoints to be loaded without a separate conversion step.

### Large Models

> [!NOTE]
Expand Down
Loading