Skip to content

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Sep 14, 2025

Stacked PRs:


[mxfp8 moe training] add CUDA kernel to quantize 3d tensor colwise

Summary

  • This PR adds a new CUDA kernel specifically for quantizing 3d expert weights shape (E,N,K) along the N dimension and writing directly to column major format.
    • Design: I create separate input/output TMA descriptors for each expert, and process each 2d expert in parallel.
  • The existing methods for quantizing 3d expert weights both scale very poorly. I have verified this via benchmarking and traces (see previous PR), and hypothesize that it is due to required .contiguous() calls:
    • Using to_mx + torch.compile requires transposing contiguous tensor (E,N,K) -> (E,K,N) then calling .contiguous() to scale along the N dim (needed for backwards)
    • Using the existing CUDA kernel for casting along dim1 is possible, by treating the 3d input tensor as a 2d tensor of shape (E*N, K). However, this produces a 2d output tensor in column major format, and there is no way to reshape and restride the tensor to be 3d again AND preserve the column major format, such that numerics are preserved. Thus, we have to transform the output to column major afterwards, requiring a .contiguous() call.

Test plan

  • Added tests that verify numerical accuracy

Kernel microbenchmarks

NOTE: devgpu is having problems right now, all kernels are slower than usual - still sorting it out but for now we can just focus on the RELATIVE difference in the benchmarks below, rather than the absolute numbers:

input_shape         to_mx_us    cuda_2d_us    cuda_3d_us    to_mx_gbps    cuda_2d_gbps    cuda_3d_gbps
----------------  ----------  ------------  ------------  ------------  --------------  --------------
(1, 8192, 5120)      806.208       585.152       645.6         157.701         217.277         196.933
(2, 8192, 5120)     1154.74        907.52        682.4         220.206         280.192         372.626
(4, 8192, 5120)     1615.97       1141.38        802.784       314.709         445.567         633.495
(8, 8192, 5120)     2577.06       1551.81        968.336       394.682         655.441        1050.38
(16, 8192, 5120)    4538.69       2679.42       1414.02        448.199         759.207        1438.62
(64, 8192, 5120)   15856.5        8560.64       4288.61        513.16          950.507        1897.34

Copy link

pytorch-bot bot commented Sep 14, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3002

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b3b709c with merge base 66384a9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

danielvegamyhre added a commit that referenced this pull request Sep 14, 2025
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 14, 2025
@danielvegamyhre danielvegamyhre added mx moe topic: not user facing Use this tag if you don't want this PR to show up in release notes labels Sep 14, 2025
@danielvegamyhre danielvegamyhre marked this pull request as draft September 14, 2025 23:49
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/68 to main September 14, 2025 23:51
danielvegamyhre added a commit that referenced this pull request Sep 14, 2025
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/68 September 14, 2025 23:51
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/68 to main September 15, 2025 00:16
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/68 September 15, 2025 00:16
@danielvegamyhre danielvegamyhre marked this pull request as ready for review September 15, 2025 00:17
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/68 to main September 15, 2025 02:18
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/68 September 15, 2025 02:18
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/68 to main September 15, 2025 02:19
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/68 September 15, 2025 02:19
@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented Sep 15, 2025

@slayton58 i would be curious to get your thoughts on ways to improve this mem bw bound kernel for quantizing 3d expert weights (E,N,K) along the N dim, as well as the kernel for 2d weights which is similar. We hit around 70% peak mem bw (5600 gbps/s of the 8000gbps on b200), but ideally we could get to at least 80% (or as close to speed of light as possible!).

Context: torch.compile and handwritten triton kernels were both slow for mxfp8 quant for RHS operands where we scale colwise (32x1 granularity) e.g., triton hit 3700 gbps (46% peak mem bw). So I added a CUDA kernel here which I derived from a TE kernel which achieves ~5600gbps (#2513). Basically we stripped out internal TE types, added support for different scale calculation modes (floor, rceil) to align with torchao numerics, then resolved some perf issues resulting from those changes to get reasonable perf (see PR for details).

Now, I'm finding quantizing 3d expert weights along dim1 is scaling extremely poorly as number of experts increases (see this PR's description for details, and see #2999 for benchmarks). So I added a similar CUDA kernel to our mxfp8_cuda extension specifically for quantizing 3d expert weights colwise and writing directly to col major format we need it in.

The first approach I tried was just updating the 2d kernel to handle 3d tensors by treating it as a 2d tensor of shape (E*N, K) but the coordinate mapping / pointer arithemetic became a complicated mess that wasn't working. So I made a new kernel, that is similar to the 2d kernel but passes in separate input/output TMA descriptors for each expert, then the kernel operates on each 2d expert with logical separation, in parallel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. moe mx topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant