-
Notifications
You must be signed in to change notification settings - Fork 336
[mxfp8 moe training] add CUDA kernel to quantize 3d tensor colwise #3002
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: danielvegamyhre/stack/68
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3002
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b3b709c with merge base 66384a9 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
2b1b340
to
146b42a
Compare
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
146b42a
to
9921d5e
Compare
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
9921d5e
to
b3b709c
Compare
@slayton58 i would be curious to get your thoughts on ways to improve this mem bw bound kernel for quantizing 3d expert weights (E,N,K) along the N dim, as well as the kernel for 2d weights which is similar. We hit around 70% peak mem bw (5600 gbps/s of the 8000gbps on b200), but ideally we could get to at least 80% (or as close to speed of light as possible!). Context: torch.compile and handwritten triton kernels were both slow for mxfp8 quant for RHS operands where we scale colwise (32x1 granularity) e.g., triton hit 3700 gbps (46% peak mem bw). So I added a CUDA kernel here which I derived from a TE kernel which achieves ~5600gbps (#2513). Basically we stripped out internal TE types, added support for different scale calculation modes (floor, rceil) to align with torchao numerics, then resolved some perf issues resulting from those changes to get reasonable perf (see PR for details). Now, I'm finding quantizing 3d expert weights along dim1 is scaling extremely poorly as number of experts increases (see this PR's description for details, and see #2999 for benchmarks). So I added a similar CUDA kernel to our mxfp8_cuda extension specifically for quantizing 3d expert weights colwise and writing directly to col major format we need it in. The first approach I tried was just updating the 2d kernel to handle 3d tensors by treating it as a 2d tensor of shape (E*N, K) but the coordinate mapping / pointer arithemetic became a complicated mess that wasn't working. So I made a new kernel, that is similar to the 2d kernel but passes in separate input/output TMA descriptors for each expert, then the kernel operates on each 2d expert with logical separation, in parallel. |
Stacked PRs:
[mxfp8 moe training] add CUDA kernel to quantize 3d tensor colwise
Summary
.contiguous()
calls:to_mx
+ torch.compile requires transposing contiguous tensor (E,N,K) -> (E,K,N) then calling .contiguous() to scale along the N dim (needed for backwards)Test plan
Kernel microbenchmarks
NOTE: devgpu is having problems right now, all kernels are slower than usual - still sorting it out but for now we can just focus on the RELATIVE difference in the benchmarks below, rather than the absolute numbers: