Pipeline parallelism patches for Llama, Qwen2, and Mixtral (with benchmark data) #1051
guruswami-ai
started this conversation in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
We have working pipeline parallelism implementations for three models that currently only support tensor parallelism (or no distributed inference at all):
PipelineMixin)PipelineMixin)The patches follow established patterns from DeepSeek V3 and Ministral3 in the codebase. TP uses
shard_linear. PP usesPipelineMixinwithsend/recvat layer boundaries.Full source files: patches/ (written against mlx-lm 0.30.8)
Why PP matters
PP preserves nearly all single-node generation speed because it has far fewer sync points than TP:
For models that fit on a single node but need more memory headroom for long context, PP gives you 2x the memory budget with minimal speed penalty. TP is still needed for models that genuinely require multi-node compute (Llama 405B, DeepSeek V3, Kimi K2.5).
Important caveat: PP hits Metal's ~60-second GPU timeout on large dense models. Llama 405B PP2 (63 layers per node) crashes. PP works for models up to ~200B dense parameters. Above that, TP is required.
Benchmark data
These patches were validated across 290 benchmark configurations on a 5-node M3 Ultra TB5 cluster:
Happy to submit these as PRs if there is interest. They are complete, tested, and follow existing codebase conventions.
Beta Was this translation helpful? Give feedback.
All reactions