perf: fuse Qwen3 no-cache attention with mx.compile#14
Open
0xClandestine wants to merge 5 commits into
Open
Conversation
mlx-lm 0.31.3 requires mlx>=0.31.2 on Darwin per its published metadata. Bump the lower bounds to match what's actually needed at runtime. Installed versions unchanged (mlx 0.31.2, mlx-lm 0.31.3). Test suite: 43 passed, 1 skipped. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Previously dflash-benchmark would overwrite benchmark/results/<chip>/<name>.json on every run. Append UTC YYYYMMDDTHHMMSSZ to the basename so repeated runs never lose data and deltas across branches/commits stay traceable. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- New archs/ directory with pluggable architecture system - Supports Qwen3 (dense + MoE), Llama, Gemma architectures - Handles both standard config and Gemma-style speculator config - Updated DRAFT_REGISTRY with 16 models from z-lab and RedHatAI - Backward compatibility maintained via model.py wrapper Models now supported: - z-lab: Qwen3.5-4B/9B/27B/35B-A3B/122B-A10B, Qwen3-4B/8B, Qwen3.6-27B/35B-A3B, Qwen3-Coder-Next/30B-A3B, Kimi-K2.5, Llama-3.1-8B-Instruct, GPT-OSS-20B/120B - RedHatAI: Gemma-4-31B-it
Removed models with known issues: - Qwen3.6-27B: Gated repo (requires HF auth) - Kimi-K2.5: MLA not supported yet - GPT-OSS models: Target architecture not in mlx-lm Kept 12 verified working models.
Adds make_qwen3_no_cache_attn() to kernels.py — a compiled factory that fuses Q/K/V projection, RMSNorm, RoPE, GQA head expansion, and SDPA into a single mx.compile trace. Wired into Qwen3DFlashAttention as a fast-exit for the no-cache inference path (cache=None, no dflash_cross_attention). Optimization discovered and verified by phew-mlx 0.1.6 (https://github.com/0xClandestine/phew): - Rule: compile_boundary + primitive_subst (SDPA pattern) - Verified: PASS fp32_to_fp32 atol=1e-05 rtol=1e-05, 5 seeds, 3 sizes - Speedup: ~1.2-1.35x on M-series hardware
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
make_qwen3_no_cache_attn()tokernels.py— a factory that returns a@mx.compile-decorated function fusing Q/K/V projection, RMSNorm, RoPE, GQA head expansion, and SDPA into a single compiled traceQwen3DFlashAttention.__call__as a fast-exit for the no-cache path (cache=None, nodflash_cross_attentionkernel)How it was found
Optimization discovered and verified by phew-mlx 0.1.6 — a search-based MLX superoptimizer for Apple Silicon. phew traced the Qwen3 attention forward pass, applied
compile_boundaryandprimitive_subst(SDPA pattern matching) rules, and emitted a verified-equivalent rewrite.Verification result:
PASS fp32_to_fp32 atol=1e-05 rtol=1e-05 — 5 seeds × 3 problem sizesSpeedup: ~1.2–1.35× on M-series hardware (measured across 20 runs, σ/μ < 5%)
The gain comes from
mx.compilefusing the projection matmuls, norms, RoPE, and attention into one graph — eliminating dispatch overhead and enabling cross-op fusion that MLX's JIT can exploit.