Skip to content

perf: fuse Qwen3 no-cache attention with mx.compile#14

Open
0xClandestine wants to merge 5 commits into
bstnxbt:mainfrom
0xClandestine:perf/phew-optimizations
Open

perf: fuse Qwen3 no-cache attention with mx.compile#14
0xClandestine wants to merge 5 commits into
bstnxbt:mainfrom
0xClandestine:perf/phew-optimizations

Conversation

@0xClandestine
Copy link
Copy Markdown
Contributor

Summary

  • Adds make_qwen3_no_cache_attn() to kernels.py — a factory that returns a @mx.compile-decorated function fusing Q/K/V projection, RMSNorm, RoPE, GQA head expansion, and SDPA into a single compiled trace
  • Wires it into Qwen3DFlashAttention.__call__ as a fast-exit for the no-cache path (cache=None, no dflash_cross_attention kernel)
  • Falls back to the existing cache/DFlash paths unchanged

How it was found

Optimization discovered and verified by phew-mlx 0.1.6 — a search-based MLX superoptimizer for Apple Silicon. phew traced the Qwen3 attention forward pass, applied compile_boundary and primitive_subst (SDPA pattern matching) rules, and emitted a verified-equivalent rewrite.

Verification result: PASS fp32_to_fp32 atol=1e-05 rtol=1e-05 — 5 seeds × 3 problem sizes

Speedup: ~1.2–1.35× on M-series hardware (measured across 20 runs, σ/μ < 5%)

The gain comes from mx.compile fusing the projection matmuls, norms, RoPE, and attention into one graph — eliminating dispatch overhead and enabling cross-op fusion that MLX's JIT can exploit.

bstnxbt and others added 5 commits April 24, 2026 12:32
mlx-lm 0.31.3 requires mlx>=0.31.2 on Darwin per its published
metadata. Bump the lower bounds to match what's actually needed
at runtime. Installed versions unchanged (mlx 0.31.2, mlx-lm 0.31.3).
Test suite: 43 passed, 1 skipped.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Previously dflash-benchmark would overwrite benchmark/results/<chip>/<name>.json
on every run. Append UTC YYYYMMDDTHHMMSSZ to the basename so repeated runs
never lose data and deltas across branches/commits stay traceable.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- New archs/ directory with pluggable architecture system
- Supports Qwen3 (dense + MoE), Llama, Gemma architectures
- Handles both standard config and Gemma-style speculator config
- Updated DRAFT_REGISTRY with 16 models from z-lab and RedHatAI
- Backward compatibility maintained via model.py wrapper

Models now supported:
- z-lab: Qwen3.5-4B/9B/27B/35B-A3B/122B-A10B, Qwen3-4B/8B,
  Qwen3.6-27B/35B-A3B, Qwen3-Coder-Next/30B-A3B, Kimi-K2.5,
  Llama-3.1-8B-Instruct, GPT-OSS-20B/120B
- RedHatAI: Gemma-4-31B-it
Removed models with known issues:
- Qwen3.6-27B: Gated repo (requires HF auth)
- Kimi-K2.5: MLA not supported yet
- GPT-OSS models: Target architecture not in mlx-lm

Kept 12 verified working models.
Adds make_qwen3_no_cache_attn() to kernels.py — a compiled factory that
fuses Q/K/V projection, RMSNorm, RoPE, GQA head expansion, and SDPA into
a single mx.compile trace. Wired into Qwen3DFlashAttention as a fast-exit
for the no-cache inference path (cache=None, no dflash_cross_attention).

Optimization discovered and verified by phew-mlx 0.1.6
(https://github.com/0xClandestine/phew):
  - Rule: compile_boundary + primitive_subst (SDPA pattern)
  - Verified: PASS fp32_to_fp32 atol=1e-05 rtol=1e-05, 5 seeds, 3 sizes
  - Speedup: ~1.2-1.35x on M-series hardware
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants