Skip to content

MoE training optimization #2438

@ptrendx

Description

@ptrendx

This is a tracking issue for the multiple efforts needed for the performance optimization of MoE training, with the focus on D2H Sync-Free MoE. All the problem sizes should be supplied from device buffers.

TE/common:

  • GroupedTensor type [Common] NVTEGroupedTensor class and helpers #2388
  • Quantization support for the GroupedTensor
  • Grouped swizzle kernel
  • GroupedGemm with GroupedTensor support
  • Fused router optimization
  • Activation support for the GroupedTensor
    • BF16 -> BF16 version
    • BF16 -> quantized format should be mostly covered by the effort needed to bring quantization
    • Ability to return amax together with the result for NVFP4 and FP8 current scaling quantization
  • Padding in permutation
  • Fusion of swizzle with quantization
    • Changes to the NVTETensor/GroupedTensor
    • MXFP8 quantization support
    • NVFP4 quantization support
    • unswizzle kernel + grouped version (for checkpointing/debugging)

TE/pyTorch:

  • Expose the grouped tensor type internally in PyTorch modules
  • Expose the grouped tensor type externally [@timmoon10 is doubtful of feasibility]
    • Expose the grouped tensor type as pyTorch tensor
    • Enable grouped tensor input to GroupedLinear
  • Enable single grouped tensor weight option in GroupedLinear
  • Utilize preswizzled inputs in the gemm
  • Changes to te.Sequential to enable grouped tensors
  • End to end MoE support in TransformerLayer

TE/JAX:

Sub-issues

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions