-
Notifications
You must be signed in to change notification settings - Fork 576
Labels
Description
This is a tracking issue for the multiple efforts needed for the performance optimization of MoE training, with the focus on D2H Sync-Free MoE. All the problem sizes should be supplied from device buffers.
TE/common:
- GroupedTensor type [Common] NVTEGroupedTensor class and helpers #2388
- Quantization support for the GroupedTensor
- FP8 per-tensor
- MXFP8
- NVFP4
- Grouped amax kernel (Unofficial API) [PyTorch][NVFP4][MOE] NVFP4 Grouped Hadamard Amax Kernel #2351
- Grouped quantization kernel & grouped RHT quantization kernel separately ((Unofficial API)) [PyTorch][NVFP4][MOE] NVFP4 Grouped Quantize with Hadamard Transform #2411
- Grouped_amax and Grouped_quantize APIs with NVTEGroupedTensor
- Grouped rowwise quantization fused with RHT+quantization fusion
- Grouped swizzle kernel
- GroupedGemm with GroupedTensor support
- FP8 per-tensor
- via cuBLAS
- via CUTLASS/cuDNN kernel
- MXFP8
- via cuBLAS (needs cuBLAS support)
- via CUTLASS/cuDNN kernel Add device-Initiated Grouped GEMM supporting m_splits on device #2360
- NVFP4
- via cuBLAS (needs cuBLAS support)
- via CUTLASS/cuDNN kernel
- FP8 per-tensor
- Fused router optimization
- Activation support for the GroupedTensor
- BF16 -> BF16 version
- BF16 -> quantized format should be mostly covered by the effort needed to bring quantization
- Ability to return amax together with the result for NVFP4 and FP8 current scaling quantization
- Padding in permutation
- Fusion of swizzle with quantization
- Changes to the NVTETensor/GroupedTensor
- MXFP8 quantization support
- NVFP4 quantization support
- unswizzle kernel + grouped version (for checkpointing/debugging)
TE/pyTorch:
- Expose the grouped tensor type internally in PyTorch modules
- Expose the grouped tensor type externally [@timmoon10 is doubtful of feasibility]
- Expose the grouped tensor type as pyTorch tensor
- Enable grouped tensor input to GroupedLinear
- Enable single grouped tensor weight option in GroupedLinear
- Utilize preswizzled inputs in the gemm
- Changes to te.Sequential to enable grouped tensors
- End to end MoE support in TransformerLayer
TE/JAX:
- Triton binding [JAX] Triton binding #2437
- Router / TopK custom call [Draft] TopK Fusion to JAX #2385
- Permutation custom call [JAX] Wrapper for Permutation Triton kernel #2419
- Custom partitioning for grouped_quantize and grouped_gemm
- Code refactor:
- GroupedScaledTensor/GroupedGEMM with first_dims and last_dims instead of group_sizes
- Remove D2H in GroupedQuantizeFFI/GroupedGemmFFI
- Utilize preswizzled inputs in the gemm
- E2E
- GroupedMLP
- MaxText integration
shifangx