From 5576dad490f329d204cdba8a539ae9faf564083a Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 18 Feb 2026 00:52:38 -0800 Subject: [PATCH 01/11] Initial implementation for `cutlass_fused_moe` mxfp8 --- .../cutlass_fused_moe_kernels.cuh | 99 +++++++---- .../flashinfer_cutlass_fused_moe_binding.cu | 76 ++++++++- .../cutlass_kernels/include/moe_kernels.h | 20 ++- .../launchers/moe_gemm_tma_ws_launcher.inl | 16 +- .../moe_gemm/moe_gemm_template_dispatch.h | 7 +- .../moe_gemm_template_dispatch_tma_ws.h | 110 ++++++++---- .../jit/gemm/cutlass/generate_kernels.py | 67 +++++--- tests/moe/test_trtllm_cutlass_fused_moe.py | 160 ++++++++++++++++++ 8 files changed, 456 insertions(+), 99 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 0cead70f37..9a4f428682 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -2534,17 +2534,14 @@ CutlassMoeFCRunner:: template -std::map> -CutlassMoeFCRunner::getWorkspaceDeviceBufferSizes(int64_t const num_rows, - int64_t const hidden_size, - int64_t const inter_size, - int const num_experts_per_node, - int const experts_per_token, - ActivationType activation_type, - bool use_lora, - bool use_deepseek_fp8_block_scale, - bool min_latency_mode, bool use_awq) { +std::map> CutlassMoeFCRunner< + T, WeightType, OutputType, InputType, BackBoneType, + Enable>::getWorkspaceDeviceBufferSizes(int64_t const num_rows, int64_t const hidden_size, + int64_t const inter_size, int const num_experts_per_node, + int const experts_per_token, + ActivationType activation_type, bool use_lora, + bool use_deepseek_fp8_block_scale, bool min_latency_mode, + bool use_awq, bool use_mxfp8_fp8_block_scaling) { size_t num_moe_inputs = min_latency_mode ? num_experts_per_node * num_rows : experts_per_token * num_rows; size_t const permuted_elems = num_moe_inputs * hidden_size; @@ -2598,21 +2595,27 @@ CutlassMoeFCRunner(num_rows * num_experts_per_node)); + auto fpX_scaling_type = getScalingType(); + if constexpr (use_fp8) { + if (use_mxfp8_fp8_block_scaling) { + fpX_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; + } + } size_t const sf_size = - getScalingType() == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX + fpX_scaling_type == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX ? sizeof(TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF) : sizeof(TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF); size_t const fc1_fp4_act_scale_size = - getOffsetActivationSF(num_experts_per_node, act_sf_rows, hidden_size, getScalingType()) * + getOffsetActivationSF(num_experts_per_node, act_sf_rows, hidden_size, fpX_scaling_type) * sf_size; size_t const fc2_fp4_act_scale_size = - getOffsetActivationSF(num_experts_per_node, act_sf_rows, inter_size, getScalingType()) * + getOffsetActivationSF(num_experts_per_node, act_sf_rows, inter_size, fpX_scaling_type) * sf_size; size_t const fp4_act_scale_size = std::max(fc1_fp4_act_scale_size, fc2_fp4_act_scale_size); size_t const tma_ws_size = using_tma_ws ? TmaWarpSpecializedGroupedGemmInput::workspaceSize( - num_experts_per_node, getScalingType()) + num_experts_per_node, fpX_scaling_type) : 0; size_t const gemm_workspace_size = moe_gemm_runner_.getMaxWorkspaceSize(num_experts_per_node); @@ -2722,7 +2725,8 @@ CutlassMoeFCRunner:: "Number of experts must be a multiple of ep size"); auto sizes_map = getWorkspaceDeviceBufferSizes( num_rows, hidden_size, inter_size, num_experts / ep_size, experts_per_token, activation_type, - use_lora, use_deepseek_fp8_block_scale, min_latency_mode, use_awq); + use_lora, use_deepseek_fp8_block_scale, min_latency_mode, use_awq, + /*use_mxfp8_fp8_block_scaling=*/use_fp8); std::vector sizes(sizes_map.size()); std::transform(sizes_map.begin(), sizes_map.end(), sizes.begin(), [](auto& v) { return v.second.first; }); @@ -2742,10 +2746,18 @@ void CutlassMoeFCRunner(quant_params.wo.fc2_weight_scales); - auto const* fc1_fp8_dequant = quant_params.fp8.dequant_fc1; + bool const use_mxfp8_weight_block_scales = + fp8_scales_required && quant_params.mxfp8_mxfp4.fc1.weight_block_scale; + auto const* fc1_fp8_dequant = use_mxfp8_weight_block_scales + ? quant_params.mxfp8_mxfp4.fc1.global_scale + : quant_params.fp8.dequant_fc1; auto const* fc2_fp8_quant = quant_params.fp8.quant_fc2; - auto const* fc2_fp8_dequant = quant_params.fp8.dequant_fc2; + auto const* fc2_fp8_dequant = use_mxfp8_weight_block_scales + ? quant_params.mxfp8_mxfp4.fc2.global_scale + : quant_params.fp8.dequant_fc2; auto const* input_fp8_dequant = quant_params.fp8.dequant_input; auto const* fc2_wfp4afp8_quant_scale = quant_params.fp8_mxfp4.fc2.act_global_scale; @@ -3651,7 +3669,8 @@ void CutlassMoeFCRunner && quant_params.mxfp8_mxfp4.fc1.weight_block_scale; configureWsPtrs(workspace_ptr, num_rows, hidden_size, inter_size, num_experts_per_node, experts_per_token, fc1_activation_type, parallelism_config, use_lora, - use_deepseek_fp8_block_scale, min_latency_mode, use_awq); + use_deepseek_fp8_block_scale, min_latency_mode, use_awq, + use_mxfp8_fp8_block_scaling); int start_expert = num_experts_per_node * parallelism_config.ep_rank; int end_expert = start_expert + num_experts_per_node; @@ -3905,8 +3930,14 @@ CutlassMoeFCRunner:: layout_info1.int4_groupwise_params.use_wfp4a16 = use_wfp4a16; layout_info2.int4_groupwise_params.use_wfp4a16 = use_wfp4a16; - layout_info1.fpX_block_scaling_type = getScalingType(); - layout_info2.fpX_block_scaling_type = getScalingType(); + auto fpX_block_scaling_type = getScalingType(); + if constexpr (std::is_same_v) { + if (quant_params.mxfp8_mxfp4.fc1.weight_block_scale) { + fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; + } + } + layout_info1.fpX_block_scaling_type = fpX_block_scaling_type; + layout_info2.fpX_block_scaling_type = fpX_block_scaling_type; int const threads = std::min(1024, num_experts_per_node); int const blocks = (num_experts_per_node + threads - 1) / threads; @@ -4074,9 +4105,17 @@ CutlassMoeFCRunner:: expert_first_token_offset_, gemm1_tma_ws_input, gemm2_tma_ws_input, num_rows, expanded_num_rows, fc1_out_size, hidden_size, hidden_size, inter_size, num_experts_per_node, reinterpret_cast(gemm1_input), reinterpret_cast(gemm2_input), - fc1_expert_weights, fc2_expert_weights, quant_params.fp8.dequant_fc1, - quant_params.fp8.dequant_fc2, fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, - fc1_expert_biases, fc2_bias, reinterpret_cast(gemm1_output), + fc1_expert_weights, fc2_expert_weights, + (std::is_same_v && + quant_params.mxfp8_mxfp4.fc1.weight_block_scale) + ? quant_params.mxfp8_mxfp4.fc1.global_scale + : quant_params.fp8.dequant_fc1, + (std::is_same_v && + quant_params.mxfp8_mxfp4.fc1.weight_block_scale) + ? quant_params.mxfp8_mxfp4.fc2.global_scale + : quant_params.fp8.dequant_fc2, + fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, fc1_expert_biases, fc2_bias, + reinterpret_cast(gemm1_output), reinterpret_cast(fc2_result_), permuted_token_final_scales_, permuted_row_to_unpermuted_row_, enable_pdl, stream); } diff --git a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu index 48c79ad858..8f0872a045 100644 --- a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu +++ b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu @@ -145,7 +145,7 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { #endif #ifdef ENABLE_FP8 - if (isFp8Quant()) { + if (isFp8Quant() || isWFp8AMxfp8Quant()) { mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp8_e4m3>(mOutputDtype); } #endif @@ -490,8 +490,10 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { << "fc1_expert_weights inter size must be equal to fc2_expert_weights inter size."; } - TVM_FFI_ICHECK(!input_sf.has_value() || isWMxfp4AMxfp8Quant() || isNvfp4Quant()) + TVM_FFI_ICHECK(!input_sf.has_value() || isMxfp8ActScalingQuant() || isNvfp4Quant()) << "Block-scaling factors provided for non block-scaling quantization"; + TVM_FFI_ICHECK(!isMxfp8ActScalingQuant() || input_sf.has_value()) + << "input_sf must be provided when use_mxfp8_act_scaling=True"; int experts_per_token = token_selected_experts.size(1); int64_t num_rows = input.size(0); @@ -862,7 +864,66 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { int64_t num_experts_on_rank, int64_t hidden_size, int64_t inter_size, Optional> quant_scales, ActivationType base_activation_type = ActivationType::Swiglu) const { - if (isFp8Quant()) { + if (isWFp8AMxfp8Quant()) { +#ifdef USING_OSS_CUTLASS_MOE_GEMM + TVM_FFI_ICHECK(quant_scales.has_value()) + << "Expecting quant scales for MXFP8xMXFP8 quantization"; + TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 4) + << "Expecting 4 quant scales for MXFP8xMXFP8 quantization"; + + TensorView fc1_weight_block = quant_scales.value()[0]; + TensorView fc1_global = quant_scales.value()[1]; + TensorView fc2_weight_block = quant_scales.value()[2]; + TensorView fc2_global = quant_scales.value()[3]; + + // The input for scale fc1_weight_block / fc2_weight_block is packed into INT32 + constexpr int FP8_PER_INT32 = 4; + CHECK_INPUT_TYPE(fc1_weight_block, dl_int32); + CHECK_INPUT_TYPE(fc1_global, dl_float32); + CHECK_INPUT_TYPE(fc2_weight_block, dl_int32); + CHECK_INPUT_TYPE(fc2_global, dl_float32); + CHECK_DIM(3, fc1_weight_block); + CHECK_DIM(1, fc1_global); + CHECK_DIM(3, fc2_weight_block); + CHECK_DIM(1, fc2_global); + TVM_FFI_ICHECK( + fc1_weight_block.size(0) == num_experts_on_rank && + fc1_weight_block.size(1) == + TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) * + 2 && + fc1_weight_block.size(2) * FP8_PER_INT32 * + TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize == + TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX)) + << "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 " + "// block_scale_vector_size)"; + TVM_FFI_ICHECK_EQ(fc1_global.size(0), num_experts_on_rank) + << "fc1 global size must be (num_experts_on_rank,)"; + TVM_FFI_ICHECK( + fc2_weight_block.size(0) == num_experts_on_rank && + fc2_weight_block.size(1) == + TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + hidden_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) && + fc2_weight_block.size(2) * FP8_PER_INT32 * + TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize == + TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX)) + << "fc2 weight block size must be (num_experts_on_rank, hidden_size, inter_size // 4 // " + "block_scale_vector_size)"; + TVM_FFI_ICHECK_EQ(fc2_global.size(0), num_experts_on_rank) + << "fc2 global size must be (num_experts_on_rank,)"; + + return kernels::QuantParams::MXFP8MXFP8( + static_cast(fc1_weight_block.data_ptr()), + static_cast(fc1_global.data_ptr()), + static_cast(fc2_weight_block.data_ptr()), + static_cast(fc2_global.data_ptr())); +#else + TVM_FFI_ICHECK(false) + << "MXFP8 x MXFP8 quantization is not supported in OSS Cutlass Moe Gemm"; +#endif + } else if (isFp8Quant()) { TVM_FFI_ICHECK(quant_scales.has_value()) << "Expecting quant scales for fp8 quantization"; TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 4) << "Expecting 4 quant scales for fp8 quantization"; @@ -1168,9 +1229,16 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { bool isFp8Quant() const { return !mUseDeepSeekFP8BlockScaling && mActivationDtype == dl_float8_e4m3fn && - mWeightDtype == dl_float8_e4m3fn; + mWeightDtype == dl_float8_e4m3fn && !mUseMxfp8ActScaling; + } + + bool isWFp8AMxfp8Quant() const { + return !mUseDeepSeekFP8BlockScaling && mActivationDtype == dl_float8_e4m3fn && + mWeightDtype == dl_float8_e4m3fn && mUseMxfp8ActScaling; } + bool isMxfp8ActScalingQuant() const { return isWFp8AMxfp8Quant() || isWMxfp4AMxfp8Quant(); } + bool isNvfp4Quant() const { return mWeightDtype == dl_int64 && mActivationDtype != dl_float8_e4m3fn; // FP8 activation does not use FP4 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index e278269b97..37840a4a50 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -257,8 +257,9 @@ struct QuantParams { GemmInputs fc2; } fp8_mxfp4; - // MXFP8 MXFP4 quantization params - // This mode uses block scaled MXFP8 and MXFP4 weights + // MXFP8 block-scaled quantization params. + // Historical note: this payload shape is also reused by MXFP8xMXFP8 (FP8 weights with MXFPX + // block scales), so this field name is legacy. struct MXFP8MXFP4Inputs { struct GemmInputs { TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* weight_block_scale = @@ -357,6 +358,15 @@ struct QuantParams { return qp; } + static QuantParams MXFP8MXFP8( + TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* fc1_weight_block_scale, + float const* fc1_global_scale, // + TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* fc2_weight_block_scale, + float const* fc2_global_scale) { + return MXFP8MXFP4(fc1_weight_block_scale, fc1_global_scale, fc2_weight_block_scale, + fc2_global_scale); + } + static QuantParams FP4( float const* fc1_act_global_scale, TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* fc1_weight_block_scale, @@ -838,12 +848,14 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { std::map> getWorkspaceDeviceBufferSizes( int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int const experts_per_token, ActivationType activation_type, - bool use_lora, bool use_deepseek_fp8_block_scale, bool min_latency_mode, bool use_awq); + bool use_lora, bool use_deepseek_fp8_block_scale, bool min_latency_mode, bool use_awq, + bool use_mxfp8_fp8_block_scaling); void configureWsPtrs(char* ws_ptr, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int const experts_per_token, ActivationType activation_type, MOEParallelismConfig parallelism_config, bool use_lora, - bool use_deepseek_fp8_block_scale, bool min_latency_mode, bool use_awq); + bool use_deepseek_fp8_block_scale, bool min_latency_mode, bool use_awq, + bool use_mxfp8_fp8_block_scaling); private: bool mayHaveDifferentGEMMOutputType() const { diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl index db5788bfdd..26e8459d67 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl @@ -255,7 +255,7 @@ using namespace cutlass::epilogue; static_assert(cutlass::platform::is_same::value || IsWFP4AFP8, \ "TMA warp specialized MOE implementation does not support mixed input types"); \ \ - constexpr static bool IsBlockScaled = IsFP4 || IsWFP4AFP8; \ + constexpr static bool IsBlockScaled = IsFP4 || IsWFP4AFP8 || IsMXFPX; \ static_assert(!IsBlockScaled || IsBlackwell, "Block scaled is only implemented for SM100"); \ \ static_assert(FUSION == EpilogueFusion::NONE || FUSION == EpilogueFusion::FINALIZE, \ @@ -311,11 +311,15 @@ using namespace cutlass::epilogue; std::conditional_t, \ cutlass::nv_float4_t>, \ cute::tuple>; \ - using ElementWeightBlockScaled = \ - std::conditional_t, \ - cutlass::nv_float4_t>, \ - cute::tuple>; \ + using ElementWeightBlockScaled = std::conditional_t< \ + IsSM120, \ + std::conditional_t< \ + IsMXFPX, \ + std::conditional_t::value, \ + cutlass::mx_float4_t, \ + cutlass::mx_float8_t>, \ + cutlass::nv_float4_t>, \ + cute::tuple>; \ \ /* Activation matrix alignment */ \ constexpr static int AlignmentAct = \ diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h index dded5014bf..da8824d09b 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h @@ -949,8 +949,13 @@ size_t MoeGemmRunner::calcMaxWorkspace auto fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE; if constexpr (use_wfp4afp8) { fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; - } else if (use_fp4) { + } else if constexpr (use_fp4) { fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4; + } else if constexpr (use_fp8) { + // FP8 runners can be used in MXFP8 mode (UE8M0 block scales), which needs larger TMA WS. + // Allocate using MXFPX requirements so workspace is sufficient for both regular FP8 and + // MXFP8. + fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; } size_t max_size = 0; bool has_config = false; diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h index 5adacd0ce2..72c6d3538b 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h @@ -65,27 +65,27 @@ using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput using EpilogueFusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion; template + EpilogueFusion FUSION, typename TileShape, typename ClusterShape, bool IsMXFPX> auto getDispatchFunctionForSM100(cutlass_extensions::EpilogueScheduleType epilogue_schedule, bool dynamic_cga, bool swap_ab) { auto select_swap_ab = [dynamic_cga, epilogue_schedule](auto swap_ab_t) { auto select_dynamic_cga = [epilogue_schedule](auto dynamic_cga_t) { #if defined(ENABLE_FP4) constexpr bool is_block_scaled = - std::is_same_v || std::is_same_v; + IsMXFPX || std::is_same_v || std::is_same_v; #else - constexpr bool is_block_scaled = false; + constexpr bool is_block_scaled = IsMXFPX; #endif if constexpr ((!is_block_scaled || Arch::kMinComputeCapability == 103) && FUSION != EpilogueFusion::FINALIZE) { auto func_map = std::array{ &kernels::cutlass_kernels_oss::tma_warp_specialized_generic_moe_gemm_kernelLauncher< Arch, T, WeightType, OutputType, cutlass::epilogue::PtrArrayNoSmemWarpSpecialized, - EpilogueTag, FUSION, TileShape, ClusterShape, is_wfp4afp8, + EpilogueTag, FUSION, TileShape, ClusterShape, IsMXFPX, decltype(dynamic_cga_t)::value, false, decltype(swap_ab_t)::value>, &kernels::cutlass_kernels_oss::tma_warp_specialized_generic_moe_gemm_kernelLauncher< Arch, T, WeightType, OutputType, cutlass::epilogue::PtrArrayTmaWarpSpecialized, - EpilogueTag, FUSION, TileShape, ClusterShape, is_wfp4afp8, + EpilogueTag, FUSION, TileShape, ClusterShape, IsMXFPX, decltype(dynamic_cga_t)::value, false, decltype(swap_ab_t)::value> }; @@ -100,8 +100,8 @@ auto getDispatchFunctionForSM100(cutlass_extensions::EpilogueScheduleType epilog "No Smem epilogue schedule is not supported for block scaled types or finalize fusion"); return &kernels::cutlass_kernels_oss::tma_warp_specialized_generic_moe_gemm_kernelLauncher< Arch, T, WeightType, OutputType, cutlass::epilogue::PtrArrayTmaWarpSpecialized, - EpilogueTag, FUSION, TileShape, ClusterShape, is_wfp4afp8, - decltype(dynamic_cga_t)::value, false, decltype(swap_ab_t)::value>; + EpilogueTag, FUSION, TileShape, ClusterShape, IsMXFPX, decltype(dynamic_cga_t)::value, + false, decltype(swap_ab_t)::value>; } }; return dynamic_cga ? select_dynamic_cga(tensorrt_llm::common::ConstBool{}) @@ -166,14 +166,31 @@ void dispatchMoeGemmFinalDispatchTmaWarpSpecialized( #else constexpr static bool is_wfp4afp8 = false; #endif +#if defined(ENABLE_FP8) + constexpr static bool is_wfp8amxfp8 = + std::is_same_v && std::is_same_v; +#else + constexpr static bool is_wfp8amxfp8 = false; +#endif + constexpr static int tile_m = cute::size<0>(TileShape{}); + constexpr static int tile_n = cute::size<1>(TileShape{}); + // Keep this in sync with generate_kernels.py is_gemm_op_valid_sm100() for op.is_mx_fpx. + constexpr static bool is_wfp8amxfp8_tile_supported = + tile_m == 128 && (tile_n == 64 || tile_n == 128 || tile_n == 256); + constexpr static bool supports_mxfpx = is_wfp4afp8 || is_wfp8amxfp8; + constexpr static bool supports_mxfpx_tile = + is_wfp4afp8 || !is_wfp8amxfp8 || is_wfp8amxfp8_tile_supported; + bool const use_mxfpx = hopper_input.fpX_block_scaling_type == + TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; if constexpr (is_wfp4afp8) { - TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type == - TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, - "MXFPX is the only supported scaling type for WFP4AFP8"); - } else { - TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type != - TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, + TLLM_CHECK_WITH_INFO(use_mxfpx, "MXFPX is the only supported scaling type for WFP4AFP8"); + } else if (use_mxfpx) { + TLLM_CHECK_WITH_INFO(supports_mxfpx, "MXFPX is not supported for the selected weight combination"); + TLLM_CHECK_WITH_INFO( + supports_mxfpx_tile, + "MXFPX is not supported for this tile shape; expected fp8 tile_m=128 and tile_n in {64, " + "128, 256}"); } if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120) { @@ -198,26 +215,59 @@ void dispatchMoeGemmFinalDispatchTmaWarpSpecialized( // << ", fallback_cluster_shape=" // << static_cast(gemm_config.fallback_cluster_shape) << std::endl; - auto selected_func = - getDispatchFunctionForSM100( - gemm_config.epilogue_schedule, dynamic_cga, swap_ab); - selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy, - workspace_size, cluster_shape_cute, cluster_shape_cute_fallback); + if (use_mxfpx) { + if constexpr (supports_mxfpx && supports_mxfpx_tile) { + auto selected_func = + getDispatchFunctionForSM100( + gemm_config.epilogue_schedule, dynamic_cga, swap_ab); + selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy, + workspace_size, cluster_shape_cute, cluster_shape_cute_fallback); + } else { + TLLM_THROW("MXFPX not supported by this tile shape"); + } + } else { + auto selected_func = + getDispatchFunctionForSM100( + gemm_config.epilogue_schedule, dynamic_cga, swap_ab); + selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy, + workspace_size, cluster_shape_cute, cluster_shape_cute_fallback); + } } else if constexpr (Arch::kMinComputeCapability >= 120 || Arch::kMinComputeCapability == 90) { using EpilogueSchedule = void; // These are hardcoded in the launcher constexpr bool dynamic_cga = false; - auto selected_func = - hopper_input.swap_ab - ? kernels::cutlass_kernels_oss::tma_warp_specialized_generic_moe_gemm_kernelLauncher< - Arch, T, WeightType, OutputType, EpilogueSchedule, EpilogueTag, FUSION, - TileShape, ClusterShape, is_wfp4afp8, dynamic_cga, false, true> - : kernels::cutlass_kernels_oss::tma_warp_specialized_generic_moe_gemm_kernelLauncher< - Arch, T, WeightType, OutputType, EpilogueSchedule, EpilogueTag, FUSION, - TileShape, ClusterShape, is_wfp4afp8, dynamic_cga, false, false>; - - selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy, - workspace_size, {}, {}); + if (use_mxfpx) { + if constexpr (supports_mxfpx && supports_mxfpx_tile) { + auto selected_func = + hopper_input.swap_ab + ? kernels::cutlass_kernels_oss:: + tma_warp_specialized_generic_moe_gemm_kernelLauncher< + Arch, T, WeightType, OutputType, EpilogueSchedule, EpilogueTag, FUSION, + TileShape, ClusterShape, true, dynamic_cga, false, true> + : kernels::cutlass_kernels_oss:: + tma_warp_specialized_generic_moe_gemm_kernelLauncher< + Arch, T, WeightType, OutputType, EpilogueSchedule, EpilogueTag, FUSION, + TileShape, ClusterShape, true, dynamic_cga, false, false>; + selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy, + workspace_size, {}, {}); + } else { + TLLM_THROW("MXFPX not supported by this tile shape"); + } + } else { + auto selected_func = + hopper_input.swap_ab + ? kernels::cutlass_kernels_oss:: + tma_warp_specialized_generic_moe_gemm_kernelLauncher< + Arch, T, WeightType, OutputType, EpilogueSchedule, EpilogueTag, FUSION, + TileShape, ClusterShape, false, dynamic_cga, false, true> + : kernels::cutlass_kernels_oss:: + tma_warp_specialized_generic_moe_gemm_kernelLauncher< + Arch, T, WeightType, OutputType, EpilogueSchedule, EpilogueTag, FUSION, + TileShape, ClusterShape, false, dynamic_cga, false, false>; + selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy, + workspace_size, {}, {}); + } } } } diff --git a/flashinfer/jit/gemm/cutlass/generate_kernels.py b/flashinfer/jit/gemm/cutlass/generate_kernels.py index f7a87bedbd..ae9199930b 100644 --- a/flashinfer/jit/gemm/cutlass/generate_kernels.py +++ b/flashinfer/jit/gemm/cutlass/generate_kernels.py @@ -407,6 +407,16 @@ def is_gemm_op_valid_sm100(op): ): return False + # MXFP block-scaled paths currently follow the same shape/schedule limits as FP4 block scaling. + if op.is_mx_fpx: + if tile_n not in [64, 128, 256] or tile_m != 128: + return False + if ( + op.arch == 100 + and op.epi_schedule == EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm + ): + return False + # Shapes for fp8 small N shapes if ( (op.act_type == DataType.e4m3) @@ -913,31 +923,40 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled, arch): if dtype in [DataType.e4m3, e2m1]: otypes = [DataType.f16, DataType.bf16] - for otype in otypes: - moe_gemm_operation = TrtLlm_GemmLauncher( - GemmKind.Grouped, - arch, - dtype, - weight_type, - otype, - otype, - otype, - quant_op, - epi_tag, - cta_shape_mnk, - warp_shape, - stages, - cga_shape, - mainloop_schedule, - epi_schedule, - epi_fusion, - is_mx_fpx=(dtype == DataType.e4m3 and weight_type == e2m1), - dynamic_cga=dynamic_cga, - swap_ab=swap_ab, - ) + mxfp_modes = [False] + if dtype == DataType.e4m3 and weight_type == e2m1: + # MXFP8 x MXFP4 path. + mxfp_modes = [True] + elif dtype == DataType.e4m3 and weight_type == DataType.e4m3: + # Emit both regular FP8xFP8 and MXFP8xMXFP8 variants. + mxfp_modes = [False, True] - if is_op_valid(moe_gemm_operation): - operations.append(moe_gemm_operation) + for otype in otypes: + for is_mx_fpx in mxfp_modes: + moe_gemm_operation = TrtLlm_GemmLauncher( + GemmKind.Grouped, + arch, + dtype, + weight_type, + otype, + otype, + otype, + quant_op, + epi_tag, + cta_shape_mnk, + warp_shape, + stages, + cga_shape, + mainloop_schedule, + epi_schedule, + epi_fusion, + is_mx_fpx=is_mx_fpx, + dynamic_cga=dynamic_cga, + swap_ab=swap_ab, + ) + + if is_op_valid(moe_gemm_operation): + operations.append(moe_gemm_operation) return operations diff --git a/tests/moe/test_trtllm_cutlass_fused_moe.py b/tests/moe/test_trtllm_cutlass_fused_moe.py index e9b40db7c6..6c1dc4e0f2 100644 --- a/tests/moe/test_trtllm_cutlass_fused_moe.py +++ b/tests/moe/test_trtllm_cutlass_fused_moe.py @@ -1179,6 +1179,35 @@ def quant_mxfp4_batches(a, num_experts): return result_quant_a, result_sfs +def quant_mxfp8_batches(a, num_experts): + quant_a = [] + sfs = [] + for i in range(num_experts): + a_fp8, a_sf = mxfp8_quantize(a[i].cuda(), True, 32) + quant_a.append(a_fp8) + sfs.append(a_sf) + + result_quant_a = torch.stack(quant_a) + result_sfs = torch.stack(sfs) + + return result_quant_a, result_sfs + + +def pack_mxfp8_scales_u8_to_int32_batches( + scale_u8: torch.Tensor, rows: int, cols: int +) -> torch.Tensor: + num_experts = scale_u8.size(0) + aligned_rows = ceil_div(rows, 128) * 128 + k_scales = cols // 32 + aligned_k_scales = ceil_div(k_scales, 4) * 4 + return ( + scale_u8.contiguous() + .view(num_experts, aligned_rows, aligned_k_scales) + .view(torch.int32) + .contiguous() + ) + + def dequant_mxfp4_batches( mat_fp4: torch.Tensor, scale_tensor: torch.Tensor, @@ -1195,6 +1224,26 @@ def dequant_mxfp4_batches( ) +def dequant_mxfp8_batches( + mat_fp8: torch.Tensor, + scale_tensor: torch.Tensor, +): + num_batches = mat_fp8.size(0) + + scale_tensor = scale_tensor.view(num_batches, -1) + + return torch.stack( + [ + mxfp8_dequantize_host( + mat_fp8[b, :, :].cpu().view(torch.uint8), + scale_tensor[b, :].cpu().view(torch.uint8).reshape(-1), + True, + ) + for b in range(num_batches) + ] + ) + + @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @@ -1328,6 +1377,117 @@ def test_moe_mxfp8_mxfp4( torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("num_experts", NUM_EXPERTS) +@pytest.mark.parametrize("top_k", TOP_K_VALUES) +@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) +@pytest.mark.parametrize("otype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize( + ("alpha", "beta", "limit"), [(None, None, None), (0.5, 0.0, 7.0), (1.702, 1.0, 7.0)] +) +@pytest.mark.skipif( + torch.cuda.get_device_capability()[0] not in [10, 11, 12], + reason="MXFP8xMXFP8 is only supported on SM100, SM110 and SM120", +) +def test_moe_mxfp8_mxfp8( + batch_size, + hidden_size, + num_experts, + top_k, + intermediate_size, + otype, + alpha, + beta, + limit, +): + """Test MoE with MXFP8 activations and MXFP8 weights.""" + if top_k > num_experts: + pytest.skip( + f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})" + ) + + torch.manual_seed(42) + e = num_experts + m = batch_size + n = intermediate_size + k = hidden_size + + x = torch.randn(m, k, dtype=otype).cuda() + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=otype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=otype) / 10 + + mxfp8_x, mxfp8_x_sf = mxfp8_quantize(x, True, 32) + mxfp8_w1, mxfp8_w1_scale = quant_mxfp8_batches(w1, e) + mxfp8_w2, mxfp8_w2_scale = quant_mxfp8_batches(w2, e) + mxfp8_w1_scale_i32 = pack_mxfp8_scales_u8_to_int32_batches(mxfp8_w1_scale, 2 * n, k) + mxfp8_w2_scale_i32 = pack_mxfp8_scales_u8_to_int32_batches(mxfp8_w2_scale, k, n) + + router_logits = torch.randn(m, e, dtype=otype).cuda() + routing_weights, selected_experts = compute_routing(router_logits, top_k) + + fake_input_scale = torch.ones(e, device=x.device, dtype=torch.float32) + quant_scales = [ + mxfp8_w1_scale_i32, + fake_input_scale, + mxfp8_w2_scale_i32, + fake_input_scale, + ] + + flash_output = torch.zeros_like(x) + + if alpha is not None and limit is not None and beta is not None: + alpha_t = torch.ones(e, device=x.device) * alpha + limit_t = torch.ones(e, device=x.device) * limit + beta_t = torch.ones(e, device=x.device) * beta + else: + alpha_t = None + limit_t = None + beta_t = None + + _ = fused_moe.cutlass_fused_moe( + mxfp8_x, + selected_experts.to(torch.int), + routing_weights, + mxfp8_w1.contiguous(), + mxfp8_w2.contiguous(), + otype, + swiglu_alpha=alpha_t, + swiglu_limit=limit_t, + swiglu_beta=beta_t, + quant_scales=quant_scales, + input_sf=mxfp8_x_sf, + use_mxfp8_act_scaling=True, + output=flash_output, + ) + + dq_mxfp8_x = ( + mxfp8_dequantize_host( + mxfp8_x.cpu().view(torch.uint8), + mxfp8_x_sf.cpu().view(torch.uint8).reshape(-1), + True, + ) + .cuda() + .to(otype) + ) + dq_mxfp8_w1 = dequant_mxfp8_batches(mxfp8_w1, mxfp8_w1_scale).cuda().to(otype) + dq_mxfp8_w2 = dequant_mxfp8_batches(mxfp8_w2, mxfp8_w2_scale).cuda().to(otype) + + ref_output = compute_with_experts( + e, + dq_mxfp8_x, + dq_mxfp8_w1, + dq_mxfp8_w2, + selected_experts, + routing_weights, + alpha, + beta, + limit, + ) + + torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1) + + def dequant_mxfp4_batches_host( mat_fp4: torch.Tensor, scale_tensor: torch.Tensor, From 083e3db19d047f73c7efa4a25fd01ea1bd7c0e70 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 21 Feb 2026 22:49:32 -0800 Subject: [PATCH 02/11] Clean up interface changes --- .../cutlass_fused_moe_kernels.cuh | 65 +++++++------------ .../cutlass_kernels/include/moe_kernels.h | 6 +- 2 files changed, 24 insertions(+), 47 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 9a4f428682..9389e1a283 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -2534,14 +2534,17 @@ CutlassMoeFCRunner:: template -std::map> CutlassMoeFCRunner< - T, WeightType, OutputType, InputType, BackBoneType, - Enable>::getWorkspaceDeviceBufferSizes(int64_t const num_rows, int64_t const hidden_size, - int64_t const inter_size, int const num_experts_per_node, - int const experts_per_token, - ActivationType activation_type, bool use_lora, - bool use_deepseek_fp8_block_scale, bool min_latency_mode, - bool use_awq, bool use_mxfp8_fp8_block_scaling) { +std::map> +CutlassMoeFCRunner::getWorkspaceDeviceBufferSizes(int64_t const num_rows, + int64_t const hidden_size, + int64_t const inter_size, + int const num_experts_per_node, + int const experts_per_token, + ActivationType activation_type, + bool use_lora, + bool use_deepseek_fp8_block_scale, + bool min_latency_mode, bool use_awq) { size_t num_moe_inputs = min_latency_mode ? num_experts_per_node * num_rows : experts_per_token * num_rows; size_t const permuted_elems = num_moe_inputs * hidden_size; @@ -2595,27 +2598,21 @@ std::map> CutlassMoeFCRunner< min_latency_mode ? num_moe_inputs : std::min(num_moe_inputs, static_cast(num_rows * num_experts_per_node)); - auto fpX_scaling_type = getScalingType(); - if constexpr (use_fp8) { - if (use_mxfp8_fp8_block_scaling) { - fpX_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; - } - } size_t const sf_size = - fpX_scaling_type == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX + getScalingType() == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX ? sizeof(TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF) : sizeof(TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF); size_t const fc1_fp4_act_scale_size = - getOffsetActivationSF(num_experts_per_node, act_sf_rows, hidden_size, fpX_scaling_type) * + getOffsetActivationSF(num_experts_per_node, act_sf_rows, hidden_size, getScalingType()) * sf_size; size_t const fc2_fp4_act_scale_size = - getOffsetActivationSF(num_experts_per_node, act_sf_rows, inter_size, fpX_scaling_type) * + getOffsetActivationSF(num_experts_per_node, act_sf_rows, inter_size, getScalingType()) * sf_size; size_t const fp4_act_scale_size = std::max(fc1_fp4_act_scale_size, fc2_fp4_act_scale_size); size_t const tma_ws_size = using_tma_ws ? TmaWarpSpecializedGroupedGemmInput::workspaceSize( - num_experts_per_node, fpX_scaling_type) + num_experts_per_node, getScalingType()) : 0; size_t const gemm_workspace_size = moe_gemm_runner_.getMaxWorkspaceSize(num_experts_per_node); @@ -2725,8 +2722,7 @@ CutlassMoeFCRunner:: "Number of experts must be a multiple of ep size"); auto sizes_map = getWorkspaceDeviceBufferSizes( num_rows, hidden_size, inter_size, num_experts / ep_size, experts_per_token, activation_type, - use_lora, use_deepseek_fp8_block_scale, min_latency_mode, use_awq, - /*use_mxfp8_fp8_block_scaling=*/use_fp8); + use_lora, use_deepseek_fp8_block_scale, min_latency_mode, use_awq); std::vector sizes(sizes_map.size()); std::transform(sizes_map.begin(), sizes_map.end(), sizes.begin(), [](auto& v) { return v.second.first; }); @@ -2746,18 +2742,10 @@ void CutlassMoeFCRunner && quant_params.mxfp8_mxfp4.fc1.weight_block_scale; - configureWsPtrs(workspace_ptr, num_rows, hidden_size, inter_size, num_experts_per_node, experts_per_token, fc1_activation_type, parallelism_config, use_lora, - use_deepseek_fp8_block_scale, min_latency_mode, use_awq, - use_mxfp8_fp8_block_scaling); + use_deepseek_fp8_block_scale, min_latency_mode, use_awq); int start_expert = num_experts_per_node * parallelism_config.ep_rank; int end_expert = start_expert + num_experts_per_node; @@ -3931,11 +3915,6 @@ CutlassMoeFCRunner:: layout_info2.int4_groupwise_params.use_wfp4a16 = use_wfp4a16; auto fpX_block_scaling_type = getScalingType(); - if constexpr (std::is_same_v) { - if (quant_params.mxfp8_mxfp4.fc1.weight_block_scale) { - fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; - } - } layout_info1.fpX_block_scaling_type = fpX_block_scaling_type; layout_info2.fpX_block_scaling_type = fpX_block_scaling_type; diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index 37840a4a50..e46ff9f6b5 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -848,14 +848,12 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { std::map> getWorkspaceDeviceBufferSizes( int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int const experts_per_token, ActivationType activation_type, - bool use_lora, bool use_deepseek_fp8_block_scale, bool min_latency_mode, bool use_awq, - bool use_mxfp8_fp8_block_scaling); + bool use_lora, bool use_deepseek_fp8_block_scale, bool min_latency_mode, bool use_awq); void configureWsPtrs(char* ws_ptr, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int const experts_per_token, ActivationType activation_type, MOEParallelismConfig parallelism_config, bool use_lora, - bool use_deepseek_fp8_block_scale, bool min_latency_mode, bool use_awq, - bool use_mxfp8_fp8_block_scaling); + bool use_deepseek_fp8_block_scale, bool min_latency_mode, bool use_awq); private: bool mayHaveDifferentGEMMOutputType() const { From 8ac78b230af38ee84a5b93e73b70112fbd7408c5 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 21 Feb 2026 23:09:11 -0800 Subject: [PATCH 03/11] Minor clean up --- .../fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 9389e1a283..6e64a8921f 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -3691,6 +3691,7 @@ void CutlassMoeFCRunner:: layout_info1.int4_groupwise_params.use_wfp4a16 = use_wfp4a16; layout_info2.int4_groupwise_params.use_wfp4a16 = use_wfp4a16; - auto fpX_block_scaling_type = getScalingType(); - layout_info1.fpX_block_scaling_type = fpX_block_scaling_type; - layout_info2.fpX_block_scaling_type = fpX_block_scaling_type; + layout_info1.fpX_block_scaling_type = getScalingType(); + layout_info2.fpX_block_scaling_type = getScalingType(); int const threads = std::min(1024, num_experts_per_node); int const blocks = (num_experts_per_node + threads - 1) / threads; From 4dead9d1ed66afb1fe4935b29f49d858c0128cb0 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 21 Feb 2026 23:31:57 -0800 Subject: [PATCH 04/11] Minor rename --- .../flashinfer_cutlass_fused_moe_binding.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu index 8f0872a045..add251a86e 100644 --- a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu +++ b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu @@ -145,7 +145,7 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { #endif #ifdef ENABLE_FP8 - if (isFp8Quant() || isWFp8AMxfp8Quant()) { + if (isFp8Quant() || isWMxfp8AMxfp8Quant()) { mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp8_e4m3>(mOutputDtype); } #endif @@ -864,7 +864,7 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { int64_t num_experts_on_rank, int64_t hidden_size, int64_t inter_size, Optional> quant_scales, ActivationType base_activation_type = ActivationType::Swiglu) const { - if (isWFp8AMxfp8Quant()) { + if (isWMxfp8AMxfp8Quant()) { #ifdef USING_OSS_CUTLASS_MOE_GEMM TVM_FFI_ICHECK(quant_scales.has_value()) << "Expecting quant scales for MXFP8xMXFP8 quantization"; @@ -1232,12 +1232,12 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { mWeightDtype == dl_float8_e4m3fn && !mUseMxfp8ActScaling; } - bool isWFp8AMxfp8Quant() const { + bool isWMxfp8AMxfp8Quant() const { return !mUseDeepSeekFP8BlockScaling && mActivationDtype == dl_float8_e4m3fn && mWeightDtype == dl_float8_e4m3fn && mUseMxfp8ActScaling; } - bool isMxfp8ActScalingQuant() const { return isWFp8AMxfp8Quant() || isWMxfp4AMxfp8Quant(); } + bool isMxfp8ActScalingQuant() const { return isWMxfp8AMxfp8Quant() || isWMxfp4AMxfp8Quant(); } bool isNvfp4Quant() const { return mWeightDtype == dl_int64 && From 44e7b596b18b4c23e096a97bc035921d17ca4510 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sun, 22 Feb 2026 13:45:01 -0800 Subject: [PATCH 05/11] Clean up --- .../launchers/moe_gemm_tma_ws_launcher.inl | 14 +++++--------- .../moe_gemm/moe_gemm_template_dispatch.h | 4 ++-- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl index 26e8459d67..6c71c6acd8 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl @@ -311,15 +311,11 @@ using namespace cutlass::epilogue; std::conditional_t, \ cutlass::nv_float4_t>, \ cute::tuple>; \ - using ElementWeightBlockScaled = std::conditional_t< \ - IsSM120, \ - std::conditional_t< \ - IsMXFPX, \ - std::conditional_t::value, \ - cutlass::mx_float4_t, \ - cutlass::mx_float8_t>, \ - cutlass::nv_float4_t>, \ - cute::tuple>; \ + using ElementWeightBlockScaled = \ + std::conditional_t, \ + cutlass::nv_float4_t>, \ + cute::tuple>; \ \ /* Activation matrix alignment */ \ constexpr static int AlignmentAct = \ diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h index da8824d09b..4c1164db98 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h @@ -949,13 +949,13 @@ size_t MoeGemmRunner::calcMaxWorkspace auto fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE; if constexpr (use_wfp4afp8) { fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; - } else if constexpr (use_fp4) { - fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4; } else if constexpr (use_fp8) { // FP8 runners can be used in MXFP8 mode (UE8M0 block scales), which needs larger TMA WS. // Allocate using MXFPX requirements so workspace is sufficient for both regular FP8 and // MXFP8. fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; + } else if (use_fp4) { + fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4; } size_t max_size = 0; bool has_config = false; From 9393a7abe902fdb952d3dde37d123b6d98280f82 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 23 Feb 2026 15:47:24 -0800 Subject: [PATCH 06/11] Refactor interfaces and cnetralize tile size checking --- .../cutlass_fused_moe_kernels.cuh | 291 ++++++++++------- .../flashinfer_cutlass_fused_moe_binding.cu | 19 +- .../include/moe_gemm_kernels.h | 14 +- .../cutlass_kernels/include/moe_kernels.h | 43 ++- .../moe_gemm/moe_gemm_template_dispatch.h | 17 +- .../moe_gemm_template_dispatch_tma_ws.h | 302 ++++++++---------- 6 files changed, 363 insertions(+), 323 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 6e64a8921f..7b3e80f1ff 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -2526,16 +2526,17 @@ void dequantFP8(OutputType* output, InputType const* input, int64_t const* num_v } template -CutlassMoeFCRunner::CutlassMoeFCRunner() + bool IsMXFPX, class Enable> +CutlassMoeFCRunner::CutlassMoeFCRunner() : blockscale_gemm_runner_{ std::make_unique>()} {} template + bool IsMXFPX, class Enable> std::map> -CutlassMoeFCRunner::getWorkspaceDeviceBufferSizes(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, @@ -2544,6 +2545,7 @@ CutlassMoeFCRunner(num_rows * num_experts_per_node)); + auto const workspace_scaling_type = + use_mxfp8_act_scaling ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX + : getScalingType(); size_t const sf_size = - getScalingType() == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX + workspace_scaling_type == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX ? sizeof(TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF) : sizeof(TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF); - size_t const fc1_fp4_act_scale_size = - getOffsetActivationSF(num_experts_per_node, act_sf_rows, hidden_size, getScalingType()) * - sf_size; + size_t const fc1_fp4_act_scale_size = getOffsetActivationSF(num_experts_per_node, act_sf_rows, + hidden_size, workspace_scaling_type) * + sf_size; size_t const fc2_fp4_act_scale_size = - getOffsetActivationSF(num_experts_per_node, act_sf_rows, inter_size, getScalingType()) * + getOffsetActivationSF(num_experts_per_node, act_sf_rows, inter_size, workspace_scaling_type) * sf_size; size_t const fp4_act_scale_size = std::max(fc1_fp4_act_scale_size, fc2_fp4_act_scale_size); size_t const tma_ws_size = using_tma_ws ? TmaWarpSpecializedGroupedGemmInput::workspaceSize( - num_experts_per_node, getScalingType()) + num_experts_per_node, workspace_scaling_type) : 0; - size_t const gemm_workspace_size = moe_gemm_runner_.getMaxWorkspaceSize(num_experts_per_node); + size_t const gemm_workspace_size = + moe_gemm_runner_.getMaxWorkspaceSize(num_experts_per_node, use_mxfp8_act_scaling); // lora related size_t const lora_input_size = @@ -2710,19 +2716,24 @@ CutlassMoeFCRunner -size_t -CutlassMoeFCRunner::getWorkspaceSize( - int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts, int const experts_per_token, ActivationType activation_type, - MOEParallelismConfig parallelism_config, bool use_lora, bool use_deepseek_fp8_block_scale, - bool min_latency_mode, bool use_awq) { + bool IsMXFPX, class Enable> +size_t CutlassMoeFCRunner::getWorkspaceSize(int64_t const num_rows, + int64_t const hidden_size, + int64_t const inter_size, int const num_experts, + int const experts_per_token, + ActivationType activation_type, + MOEParallelismConfig parallelism_config, + bool use_lora, + bool use_deepseek_fp8_block_scale, + bool use_mxfp8_act_scaling, + bool min_latency_mode, bool use_awq) { int const ep_size = parallelism_config.ep_size; TLLM_CHECK_WITH_INFO(num_experts % ep_size == 0, "Number of experts must be a multiple of ep size"); auto sizes_map = getWorkspaceDeviceBufferSizes( num_rows, hidden_size, inter_size, num_experts / ep_size, experts_per_token, activation_type, - use_lora, use_deepseek_fp8_block_scale, min_latency_mode, use_awq); + use_lora, use_deepseek_fp8_block_scale, use_mxfp8_act_scaling, min_latency_mode, use_awq); std::vector sizes(sizes_map.size()); std::transform(sizes_map.begin(), sizes_map.end(), sizes.begin(), [](auto& v) { return v.second.first; }); @@ -2732,8 +2743,8 @@ CutlassMoeFCRunner:: } template -void CutlassMoeFCRunner +void CutlassMoeFCRunner::configureWsPtrs(char* ws_ptr, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, @@ -2742,10 +2753,11 @@ void CutlassMoeFCRunner + bool IsMXFPX, class Enable> kernels::fp8_blockscale_gemm::CutlassFp8BlockScaleGemmRunnerInterface* -CutlassMoeFCRunner::getDeepSeekBlockScaleGemmRunner() const { TLLM_CHECK_WITH_INFO( (std::is_same_v && std::is_same_v), @@ -2868,15 +2883,16 @@ CutlassMoeFCRunner -void CutlassMoeFCRunner::BlockScaleFC1( - DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, T* const output, - void* const gemm_output, int64_t const* const expert_first_token_offset, - WeightType const* const fc1_expert_weights, ScaleBiasType const* const fc1_expert_biases, - float const* const fc2_fp8_quant, int64_t const num_rows, int64_t const expanded_num_rows, - int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, - ActivationParams fc1_activation_type, QuantParams& quant_params, bool enable_pdl, - cudaStream_t stream) { + bool IsMXFPX, class Enable> +void CutlassMoeFCRunner:: + BlockScaleFC1(DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, T* const output, + void* const gemm_output, int64_t const* const expert_first_token_offset, + WeightType const* const fc1_expert_weights, + ScaleBiasType const* const fc1_expert_biases, float const* const fc2_fp8_quant, + int64_t const num_rows, int64_t const expanded_num_rows, + int64_t const hidden_size, int64_t const inter_size, + int const num_experts_per_node, ActivationParams fc1_activation_type, + QuantParams& quant_params, bool enable_pdl, cudaStream_t stream) { bool const is_gated_activation = isGatedActivation(fc1_activation_type); int shape_n = is_gated_activation ? inter_size * 2 : inter_size; @@ -2900,18 +2916,20 @@ void CutlassMoeFCRunner -void CutlassMoeFCRunner::BlockScaleFC2( - DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output, - OutputType* const final_output, int64_t const* const expert_first_token_offset, - WeightType const* const fc2_expert_weights, ScaleBiasType const* const fc2_expert_biases, - float const* const unpermuted_final_scales, int const* const unpermuted_row_to_permuted_row, - int const* const permuted_row_to_unpermuted_row, int const* const token_selected_experts, - int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, - int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const unpadded_hidden_size, - int64_t const inter_size, int64_t const num_experts_per_node, int64_t const k, - MOEParallelismConfig parallelism_config, bool const enable_alltoall, QuantParams& quant_params, - bool enable_pdl, cudaStream_t stream) { + bool IsMXFPX, class Enable> +void CutlassMoeFCRunner:: + BlockScaleFC2( + DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output, + OutputType* const final_output, int64_t const* const expert_first_token_offset, + WeightType const* const fc2_expert_weights, ScaleBiasType const* const fc2_expert_biases, + float const* const unpermuted_final_scales, int const* const unpermuted_row_to_permuted_row, + int const* const permuted_row_to_unpermuted_row, int const* const token_selected_experts, + int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, + int64_t const expanded_num_rows, int64_t const hidden_size, + int64_t const unpadded_hidden_size, int64_t const inter_size, + int64_t const num_experts_per_node, int64_t const k, + MOEParallelismConfig parallelism_config, bool const enable_alltoall, + QuantParams& quant_params, bool enable_pdl, cudaStream_t stream) { int shape_n = hidden_size; int shape_k = inter_size; @@ -2931,12 +2949,15 @@ void CutlassMoeFCRunner -T const* -CutlassMoeFCRunner::applyPrequantScale( - void* smoothed_act, void const* permuted_data, void const* prequant_scales, - int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, - bool const use_awq, cudaStream_t stream) { + bool IsMXFPX, class Enable> +T const* CutlassMoeFCRunner::applyPrequantScale(void* smoothed_act, + void const* permuted_data, + void const* prequant_scales, + int64_t const* num_valid_tokens_ptr, + int64_t const expanded_num_rows, + int64_t const seq_len, bool const use_awq, + cudaStream_t stream) { T const* gemm_input; bool use_prequant_scale_kernel = use_awq && !std::is_same_v; if (use_prequant_scale_kernel) { @@ -2958,8 +2979,8 @@ CutlassMoeFCRunner: } template -void CutlassMoeFCRunner::gemm1( + bool IsMXFPX, class Enable> +void CutlassMoeFCRunner::gemm1( MoeGemmRunner& gemm_runner, DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, T* const output, void* const intermediate_result, int64_t const* const expert_first_token_offset, @@ -3188,8 +3209,8 @@ void CutlassMoeFCRunner -void CutlassMoeFCRunner::gemm2( + bool IsMXFPX, class Enable> +void CutlassMoeFCRunner::gemm2( MoeGemmRunner& gemm_runner, DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, void* const gemm_output, OutputType* const final_output, @@ -3317,8 +3338,8 @@ void CutlassMoeFCRunner -bool CutlassMoeFCRunner +bool CutlassMoeFCRunner::setupLoraWorkspace(int64_t expanded_num_rows, int64_t num_rows, int64_t inter_size, int64_t hidden_size, int start_expert, bool is_gated_activation, @@ -3411,12 +3432,15 @@ bool CutlassMoeFCRunner -auto CutlassMoeFCRunner::loraFC1( - int64_t expanded_num_rows, int64_t inter_size, int64_t hidden_size, int num_experts_per_node, - int start_expert, int64_t const* num_valid_tokens_ptr, bool is_gated_activation, - ScaleBiasType const* fc1_expert_biases, LoraParams& lora_params, float const* input_fp8_dequant, - cudaStream_t stream) -> ScaleBiasType const* { + bool IsMXFPX, class Enable> +auto CutlassMoeFCRunner::loraFC1(int64_t expanded_num_rows, int64_t inter_size, + int64_t hidden_size, int num_experts_per_node, + int start_expert, int64_t const* num_valid_tokens_ptr, + bool is_gated_activation, + ScaleBiasType const* fc1_expert_biases, + LoraParams& lora_params, float const* input_fp8_dequant, + cudaStream_t stream) -> ScaleBiasType const* { TLLM_CHECK_WITH_INFO(!act_fp4, "LoRA does not support FP4 activations"); std::vector& host_permuted_fc1_weight_ptrs = host_lora_workspace_.host_permuted_fc1_weight_ptrs; @@ -3490,11 +3514,13 @@ auto CutlassMoeFCRunner -void CutlassMoeFCRunner::loraFC2( - int64_t inter_size, int64_t hidden_size, int num_experts_per_node, int start_expert, - int64_t const* num_valid_tokens_ptr, int64_t num_tokens, LoraParams& lora_params, - float const* fc2_fp8_quant, cudaStream_t stream) { + bool IsMXFPX, class Enable> +void CutlassMoeFCRunner::loraFC2(int64_t inter_size, int64_t hidden_size, + int num_experts_per_node, int start_expert, + int64_t const* num_valid_tokens_ptr, int64_t num_tokens, + LoraParams& lora_params, float const* fc2_fp8_quant, + cudaStream_t stream) { std::vector& host_permuted_fc2_weight_ptrs = host_lora_workspace_.host_permuted_fc2_weight_ptrs; std::vector& host_permuted_fc2_lora_ranks = @@ -3531,24 +3557,26 @@ void CutlassMoeFCRunner -void CutlassMoeFCRunner::runMoe( - void const* input_activations_void, void const* input_sf_void, bool const swizzled_input_sf, - int const* token_selected_experts, float const* token_final_scales, - void const* fc1_expert_weights_void, void const* fc1_expert_biases_void, - ActivationParams fc1_activation_type, void const* fc2_expert_weights_void, - void const* fc2_expert_biases_void, QuantParams quant_params, int64_t const num_rows, - int64_t const hidden_size, int64_t const unpadded_hidden_size, int64_t const inter_size, - int const full_num_experts, int const experts_per_token, char* workspace_ptr, - void* final_output_void, int* unpermuted_row_to_permuted_row, - MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool use_lora, - LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, - MoeMinLatencyParams& min_latency_params, bool enable_pdl, cudaStream_t stream) { + bool IsMXFPX, class Enable> +void CutlassMoeFCRunner:: + runMoe(void const* input_activations_void, void const* input_sf_void, + bool const swizzled_input_sf, int const* token_selected_experts, + float const* token_final_scales, void const* fc1_expert_weights_void, + void const* fc1_expert_biases_void, ActivationParams fc1_activation_type, + void const* fc2_expert_weights_void, void const* fc2_expert_biases_void, + QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, + int64_t const unpadded_hidden_size, int64_t const inter_size, int const full_num_experts, + int const experts_per_token, char* workspace_ptr, void* final_output_void, + int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, + bool const enable_alltoall, bool use_lora, LoraParams& lora_params, + bool use_deepseek_fp8_block_scale, bool use_mxfp8_act_scaling, bool min_latency_mode, + MoeMinLatencyParams& min_latency_params, bool enable_pdl, cudaStream_t stream) { static constexpr bool int_scales_required = std::is_same::value || std::is_same::value || use_wfp4a16; static constexpr bool fp8_scales_required = std::is_same::value || std::is_same::value; + bool const is_mxfp8_mode = fp8_scales_required && use_mxfp8_act_scaling; auto const* input_activations = static_cast(input_activations_void); auto const* input_sf = @@ -3563,15 +3591,15 @@ void CutlassMoeFCRunner(quant_params.wo.fc2_weight_scales); - bool const use_mxfp8_weight_block_scales = - fp8_scales_required && quant_params.mxfp8_mxfp4.fc1.weight_block_scale; - auto const* fc1_fp8_dequant = use_mxfp8_weight_block_scales - ? quant_params.mxfp8_mxfp4.fc1.global_scale - : quant_params.fp8.dequant_fc1; + auto const* fc1_fp8_dequant = quant_params.fp8.dequant_fc1; auto const* fc2_fp8_quant = quant_params.fp8.quant_fc2; - auto const* fc2_fp8_dequant = use_mxfp8_weight_block_scales - ? quant_params.mxfp8_mxfp4.fc2.global_scale - : quant_params.fp8.dequant_fc2; + auto const* fc2_fp8_dequant = quant_params.fp8.dequant_fc2; + if (is_mxfp8_mode) { + TLLM_CHECK_WITH_INFO(quant_params.mxfp8_mxfp4.fc2.weight_block_scale != nullptr, + "WMXFP8AMXFP8 requires FC2 weight_block_scale to be non-null"); + fc1_fp8_dequant = quant_params.mxfp8_mxfp4.fc1.global_scale; + fc2_fp8_dequant = quant_params.mxfp8_mxfp4.fc2.global_scale; + } auto const* input_fp8_dequant = quant_params.fp8.dequant_input; auto const* fc2_wfp4afp8_quant_scale = quant_params.fp8_mxfp4.fc2.act_global_scale; @@ -3593,14 +3621,14 @@ void CutlassMoeFCRunner::value) == 0, - "Hidden size %d does not meet minimum alignment requirements for MXFP8_MXFP4 MOE GEMM %d", + "Hidden size %d does not meet minimum alignment requirements for WMXFP8AMXFP8 MOE GEMM %d", (int)hidden_size, (int)(64 * 8 / sizeof_bits::value)); TLLM_CHECK_WITH_INFO( inter_size % (64 * 8 / sizeof_bits::value) == 0, - "Inter size %d does not meet minimum alignment requirements for MXFP8_MXFP4 MOE GEMM %d", + "Inter size %d does not meet minimum alignment requirements for WMXFP8AMXFP8 MOE GEMM %d", (int)inter_size, (int)(64 * 8 / sizeof_bits::value)); } else { // For NoSmem epilogue schedule, we need to align the output of the GEMM to 256 bits, for gated @@ -3657,8 +3685,16 @@ void CutlassMoeFCRunner + bool IsMXFPX, class Enable> std::pair -CutlassMoeFCRunner:: +CutlassMoeFCRunner:: computeStridesTmaWarpSpecialized( int64_t const* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput layout_info1, TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, @@ -3915,8 +3948,15 @@ CutlassMoeFCRunner:: layout_info1.int4_groupwise_params.use_wfp4a16 = use_wfp4a16; layout_info2.int4_groupwise_params.use_wfp4a16 = use_wfp4a16; - layout_info1.fpX_block_scaling_type = getScalingType(); - layout_info2.fpX_block_scaling_type = getScalingType(); + auto runtime_scaling_type = getScalingType(); + if constexpr (use_fp8) { + if (quant_params.mxfp8_mxfp4.fc1.weight_block_scale && + quant_params.mxfp8_mxfp4.fc2.weight_block_scale) { + runtime_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; + } + } + layout_info1.fpX_block_scaling_type = runtime_scaling_type; + layout_info2.fpX_block_scaling_type = runtime_scaling_type; int const threads = std::min(1024, num_experts_per_node); int const blocks = (num_experts_per_node + threads - 1) / threads; @@ -3945,9 +3985,9 @@ CutlassMoeFCRunner:: } template + bool IsMXFPX, class Enable> std::pair -CutlassMoeFCRunner:: +CutlassMoeFCRunner:: computeStridesTmaWarpSpecializedLowLatency( TmaWarpSpecializedGroupedGemmInput layout_info1, TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, @@ -3964,9 +4004,9 @@ CutlassMoeFCRunner:: } template + bool IsMXFPX, class Enable> std::pair -CutlassMoeFCRunner:: +CutlassMoeFCRunner:: setupTmaWarpSpecializedInputs(int64_t num_rows, int64_t expanded_num_rows, ActivationParams fc1_activation_type, int64_t hidden_size, int64_t unpadded_hidden_size, int64_t inter_size, @@ -4080,21 +4120,30 @@ CutlassMoeFCRunner:: } TLLM_CHECK_WITH_INFO(gemm1_input != gemm1_output, "Input and output buffers are overlapping"); + bool use_mxfp8_mode = false; + if constexpr (use_fp8) { + use_mxfp8_mode = quant_params.mxfp8_mxfp4.fc1.weight_block_scale && + quant_params.mxfp8_mxfp4.fc2.weight_block_scale; + } + if (use_mxfp8_mode) { + return Self::computeStridesTmaWarpSpecialized( + expert_first_token_offset_, gemm1_tma_ws_input, gemm2_tma_ws_input, num_rows, + expanded_num_rows, fc1_out_size, hidden_size, hidden_size, inter_size, + num_experts_per_node, reinterpret_cast(gemm1_input), + reinterpret_cast(gemm2_input), fc1_expert_weights, fc2_expert_weights, + quant_params.mxfp8_mxfp4.fc1.global_scale, quant_params.mxfp8_mxfp4.fc2.global_scale, + fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, fc1_expert_biases, fc2_bias, + reinterpret_cast(gemm1_output), + reinterpret_cast(fc2_result_), permuted_token_final_scales_, + permuted_row_to_unpermuted_row_, enable_pdl, stream); + } return Self::computeStridesTmaWarpSpecialized( expert_first_token_offset_, gemm1_tma_ws_input, gemm2_tma_ws_input, num_rows, expanded_num_rows, fc1_out_size, hidden_size, hidden_size, inter_size, num_experts_per_node, reinterpret_cast(gemm1_input), reinterpret_cast(gemm2_input), - fc1_expert_weights, fc2_expert_weights, - (std::is_same_v && - quant_params.mxfp8_mxfp4.fc1.weight_block_scale) - ? quant_params.mxfp8_mxfp4.fc1.global_scale - : quant_params.fp8.dequant_fc1, - (std::is_same_v && - quant_params.mxfp8_mxfp4.fc1.weight_block_scale) - ? quant_params.mxfp8_mxfp4.fc2.global_scale - : quant_params.fp8.dequant_fc2, - fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, fc1_expert_biases, fc2_bias, - reinterpret_cast(gemm1_output), + fc1_expert_weights, fc2_expert_weights, quant_params.fp8.dequant_fc1, + quant_params.fp8.dequant_fc2, fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, + fc1_expert_biases, fc2_bias, reinterpret_cast(gemm1_output), reinterpret_cast(fc2_result_), permuted_token_final_scales_, permuted_row_to_unpermuted_row_, enable_pdl, stream); } diff --git a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu index add251a86e..9c98a2825c 100644 --- a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu +++ b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu @@ -397,8 +397,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { static_cast(experts_per_token), static_cast(workspace_info.workspace.data_ptr()), output.data_ptr(), static_cast(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, - use_lora, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, - enable_pdl, stream); + use_lora, lora_params, mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling, min_latency_mode, + min_latency_params, enable_pdl, stream); #else mKernelRunner->runMoe( input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr, @@ -414,7 +414,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { static_cast(experts_per_token), static_cast(workspace_info.workspace.data_ptr()), output.data_ptr(), static_cast(workspace_info.src_to_dest_map), parallelism_config, false, lora_params, - mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, stream); + mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling, min_latency_mode, min_latency_params, + enable_pdl, stream); #endif } @@ -583,8 +584,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { static_cast(experts_per_token), static_cast(workspace_info.workspace.data_ptr()), output.data_ptr(), static_cast(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, - use_lora_ml, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, - enable_pdl, stream); + use_lora_ml, lora_params, mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling, + min_latency_mode, min_latency_params, enable_pdl, stream); #else mKernelRunner->runMoe( input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr, @@ -600,8 +601,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { static_cast(experts_per_token), static_cast(workspace_info.workspace.data_ptr()), output.data_ptr(), static_cast(workspace_info.src_to_dest_map), parallelism_config, false, use_lora_ml, - lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, - stream); + lora_params, mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling, min_latency_mode, + min_latency_params, enable_pdl, stream); #endif } @@ -840,8 +841,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { bool min_latency_mode) { size_t moe_workspace_size = mKernelRunner->getWorkspaceSize( num_rows, hidden_size, inter_size, num_experts, experts_per_token, activation_type, - parallelismConfig, /* use_lora */ false, mUseDeepSeekFP8BlockScaling, min_latency_mode, - mUseW4GroupScaling); + parallelismConfig, /* use_lora */ false, mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling, + min_latency_mode, mUseW4GroupScaling); size_t src_to_dest_map_size = experts_per_token * num_rows * sizeof(int); std::vector workspaces{moe_workspace_size, src_to_dest_map_size}; diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h index b77efbcac1..476b99f7d8 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h @@ -239,11 +239,10 @@ constexpr bool isGatedActivation(ActivationType activation_type) { activation_type == ActivationType::SwigluBias; } -template +template class MoeGemmRunner { public: MoeGemmRunner(); @@ -309,7 +308,7 @@ class MoeGemmRunner { [[nodiscard]] bool supportsFusedGatedActivation(ActivationType activation_type, int gemm_n, int gemm_k) const; - size_t getMaxWorkspaceSize(int num_experts) const; + size_t getMaxWorkspaceSize(int num_experts, bool use_mxfp8_act_scaling = false) const; [[nodiscard]] int getSM() const; @@ -326,8 +325,9 @@ class MoeGemmRunner { int sm_{}; int multi_processor_count_{}; mutable int num_experts_ = 0; + mutable bool use_mxfp8_act_scaling_ = false; mutable size_t gemm_workspace_size_ = 0; - size_t calcMaxWorkspaceSize(int num_experts) const; + size_t calcMaxWorkspaceSize(int num_experts, bool use_mxfp8_act_scaling) const; }; } // namespace tensorrt_llm::kernels::cutlass_kernels diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index e46ff9f6b5..d17b69a081 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -436,8 +436,8 @@ class CutlassMoeFCRunnerInterface { int64_t const inter_size, int const num_experts, int const experts_per_token, ActivationType activation_type, MOEParallelismConfig parallelism_config, bool use_lora, - bool use_deepseek_fp8_block_scale, bool min_latency_mode, - bool use_awq) = 0; + bool use_deepseek_fp8_block_scale, bool use_mxfp8_act_scaling, + bool min_latency_mode, bool use_awq) = 0; virtual void setTactic(std::optional gemm1_config, std::optional gemm2_config) = 0; virtual std::vector getTactics(MoeGemmId gemm_id) = 0; @@ -453,8 +453,9 @@ class CutlassMoeFCRunnerInterface { void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, - bool min_latency_mode, MoeMinLatencyParams& min_latency_params, - bool enable_pdl, cudaStream_t stream) = 0; + bool use_mxfp8_act_scaling, bool min_latency_mode, + MoeMinLatencyParams& min_latency_params, bool enable_pdl, + cudaStream_t stream) = 0; // Aliases for profiling the gemms virtual void gemm1(void const* const input, void* const output, void* const intermediate_result, @@ -534,12 +535,12 @@ template + bool IsMXFPX = false, typename Enable = void> class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { using DeepSeekBlockScaleGemmRunner = tensorrt_llm::kernels::fp8_blockscale_gemm::CutlassFp8BlockScaleGemmRunnerInterface; using ScaleBiasType = BackBoneType; - using Self = CutlassMoeFCRunner; + using Self = CutlassMoeFCRunner; #if defined(ENABLE_FP4) #if defined(ENABLE_BF16) @@ -584,7 +585,15 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { static constexpr bool use_fp4 = false; #endif - static constexpr bool use_block_scaling = use_fp4 || use_wfp4afp8; + static constexpr bool use_mxfp8 = use_fp8 && IsMXFPX; + static constexpr bool use_block_scaling = use_fp4 || use_wfp4afp8 || use_mxfp8; +#if defined(ENABLE_FP8) + static_assert(!IsMXFPX || + (std::is_same_v && std::is_same_v), + "IsMXFPX requires FP8xFP8 (E4M3) runner types"); +#else + static_assert(!IsMXFPX, "IsMXFPX requires FP8 support"); +#endif // This should leave the variable unchanged in any currently supported configuration using UnfusedGemmOutputType = BackBoneType; @@ -606,8 +615,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { int64_t const fc1_output_size, int const num_experts, int const experts_per_token, ActivationType activation_type, MOEParallelismConfig parallelism_config, bool use_lora, - bool use_deepseek_fp8_block_scale, bool min_latency_mode, - bool use_awq) override; + bool use_deepseek_fp8_block_scale, bool use_mxfp8_act_scaling, + bool min_latency_mode, bool use_awq) override; void setTactic(std::optional gemm1_config, std::optional gemm2_config) override { @@ -634,7 +643,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool use_lora, - LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, + LoraParams& lora_params, bool use_deepseek_fp8_block_scale, + bool use_mxfp8_act_scaling, bool min_latency_mode, MoeMinLatencyParams& min_latency_params, bool enable_pdl, cudaStream_t stream) override; @@ -848,12 +858,14 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { std::map> getWorkspaceDeviceBufferSizes( int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int const experts_per_token, ActivationType activation_type, - bool use_lora, bool use_deepseek_fp8_block_scale, bool min_latency_mode, bool use_awq); + bool use_lora, bool use_deepseek_fp8_block_scale, bool use_mxfp8_act_scaling, + bool min_latency_mode, bool use_awq); void configureWsPtrs(char* ws_ptr, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int const experts_per_token, ActivationType activation_type, MOEParallelismConfig parallelism_config, bool use_lora, - bool use_deepseek_fp8_block_scale, bool min_latency_mode, bool use_awq); + bool use_deepseek_fp8_block_scale, bool use_mxfp8_act_scaling, + bool min_latency_mode, bool use_awq); private: bool mayHaveDifferentGEMMOutputType() const { @@ -875,9 +887,10 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { // TODO: This should eventually take the quant params to give more flexibility static auto getScalingType() { - return use_wfp4afp8 ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX - : use_fp4 ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 - : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE; + return (use_wfp4afp8 || use_mxfp8) + ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX + : use_fp4 ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 + : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE; } bool setupLoraWorkspace(int64_t expanded_num_rows, int64_t num_rows, int64_t inter_size, diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h index 4c1164db98..122b418e4c 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h @@ -919,19 +919,20 @@ void MoeGemmRunner::dispatchToArch( template size_t MoeGemmRunner::getMaxWorkspaceSize( - int num_experts) const { - if (num_experts != num_experts_) { + int num_experts, bool use_mxfp8_act_scaling) const { + if (num_experts != num_experts_ || use_mxfp8_act_scaling != use_mxfp8_act_scaling_) { TLLM_LOG_TRACE("Calling getMaxWorkspaceSize() with a new expert count %d vs %d", num_experts, num_experts_); num_experts_ = num_experts; - gemm_workspace_size_ = calcMaxWorkspaceSize(num_experts); + use_mxfp8_act_scaling_ = use_mxfp8_act_scaling; + gemm_workspace_size_ = calcMaxWorkspaceSize(num_experts, use_mxfp8_act_scaling); } return gemm_workspace_size_; } template size_t MoeGemmRunner::calcMaxWorkspaceSize( - int num_experts) const { + int num_experts, bool use_mxfp8_act_scaling) const { if constexpr (use_w4_groupwise) { return cutlass_kernels_oss::calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput( @@ -949,12 +950,10 @@ size_t MoeGemmRunner::calcMaxWorkspace auto fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE; if constexpr (use_wfp4afp8) { fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; - } else if constexpr (use_fp8) { - // FP8 runners can be used in MXFP8 mode (UE8M0 block scales), which needs larger TMA WS. - // Allocate using MXFPX requirements so workspace is sufficient for both regular FP8 and - // MXFP8. + } else if (use_mxfp8_act_scaling) { + // Runtime MXFP8 act scaling reserves MXFPX workspace requirements. fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; - } else if (use_fp4) { + } else if constexpr (use_fp4) { fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4; } size_t max_size = 0; diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h index 72c6d3538b..baf82d262c 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h @@ -112,7 +112,7 @@ auto getDispatchFunctionForSM100(cutlass_extensions::EpilogueScheduleType epilog } template + EpilogueFusion FUSION, typename TileShape, typename ClusterShape, bool IsMXFPX> void dispatchMoeGemmFinalDispatchTmaWarpSpecialized( TmaWarpSpecializedGroupedGemmInput hopper_input, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, @@ -167,104 +167,59 @@ void dispatchMoeGemmFinalDispatchTmaWarpSpecialized( constexpr static bool is_wfp4afp8 = false; #endif #if defined(ENABLE_FP8) - constexpr static bool is_wfp8amxfp8 = + constexpr static bool is_wfp8afp8 = std::is_same_v && std::is_same_v; #else - constexpr static bool is_wfp8amxfp8 = false; + constexpr static bool is_wfp8afp8 = false; #endif - constexpr static int tile_m = cute::size<0>(TileShape{}); - constexpr static int tile_n = cute::size<1>(TileShape{}); - // Keep this in sync with generate_kernels.py is_gemm_op_valid_sm100() for op.is_mx_fpx. - constexpr static bool is_wfp8amxfp8_tile_supported = - tile_m == 128 && (tile_n == 64 || tile_n == 128 || tile_n == 256); - constexpr static bool supports_mxfpx = is_wfp4afp8 || is_wfp8amxfp8; - constexpr static bool supports_mxfpx_tile = - is_wfp4afp8 || !is_wfp8amxfp8 || is_wfp8amxfp8_tile_supported; - bool const use_mxfpx = hopper_input.fpX_block_scaling_type == - TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; - if constexpr (is_wfp4afp8) { - TLLM_CHECK_WITH_INFO(use_mxfpx, "MXFPX is the only supported scaling type for WFP4AFP8"); - } else if (use_mxfpx) { - TLLM_CHECK_WITH_INFO(supports_mxfpx, - "MXFPX is not supported for the selected weight combination"); - TLLM_CHECK_WITH_INFO( - supports_mxfpx_tile, - "MXFPX is not supported for this tile shape; expected fp8 tile_m=128 and tile_n in {64, " - "128, 256}"); - } + constexpr static bool supports_mxfpx = is_wfp4afp8 || is_wfp8afp8; + if constexpr (IsMXFPX && !supports_mxfpx) { + TLLM_THROW("MXFPX is not supported for the selected weight combination"); + } else if constexpr (!IsMXFPX && is_wfp4afp8) { + TLLM_THROW("MXFPX is the only supported scaling type for WFP4AFP8"); + } else { + if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120) { + bool const dynamic_cga = + gemm_config.dynamic_cluster_shape != cutlass_extensions::ClusterShape::Undefined; + bool const swap_ab = hopper_input.swap_ab; + auto cluster_shape = + cutlass_extensions::enum_to_shape_tuple(gemm_config.dynamic_cluster_shape); + auto cluster_shape_cute = cute::Shape{ + std::get<0>(cluster_shape), std::get<1>(cluster_shape), cute::_1{}}; + auto cluster_shape_fallback = + cutlass_extensions::enum_to_shape_tuple(gemm_config.fallback_cluster_shape); + auto cluster_shape_cute_fallback = cute::Shape{ + std::get<0>(cluster_shape_fallback), std::get<1>(cluster_shape_fallback), cute::_1{}}; + + // HACK debug the gemm_config used to produce selected_func + // std::cout << "[SM100 gemm_config] sm_version=" << gemm_config.sm_version + // << ", tile_config_sm100=" << static_cast(gemm_config.tile_config_sm100) + // << ", epilogue_schedule=" << static_cast(gemm_config.epilogue_schedule) + // << ", dynamic_cluster_shape=" << + // static_cast(gemm_config.dynamic_cluster_shape) + // << ", fallback_cluster_shape=" + // << static_cast(gemm_config.fallback_cluster_shape) << std::endl; - if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120) { - bool const dynamic_cga = - gemm_config.dynamic_cluster_shape != cutlass_extensions::ClusterShape::Undefined; - bool const swap_ab = hopper_input.swap_ab; - auto cluster_shape = - cutlass_extensions::enum_to_shape_tuple(gemm_config.dynamic_cluster_shape); - auto cluster_shape_cute = cute::Shape{ - std::get<0>(cluster_shape), std::get<1>(cluster_shape), cute::_1{}}; - auto cluster_shape_fallback = - cutlass_extensions::enum_to_shape_tuple(gemm_config.fallback_cluster_shape); - auto cluster_shape_cute_fallback = cute::Shape{ - std::get<0>(cluster_shape_fallback), std::get<1>(cluster_shape_fallback), cute::_1{}}; - - // HACK debug the gemm_config used to produce selected_func - // std::cout << "[SM100 gemm_config] sm_version=" << gemm_config.sm_version - // << ", tile_config_sm100=" << static_cast(gemm_config.tile_config_sm100) - // << ", epilogue_schedule=" << static_cast(gemm_config.epilogue_schedule) - // << ", dynamic_cluster_shape=" << - // static_cast(gemm_config.dynamic_cluster_shape) - // << ", fallback_cluster_shape=" - // << static_cast(gemm_config.fallback_cluster_shape) << std::endl; - - if (use_mxfpx) { - if constexpr (supports_mxfpx && supports_mxfpx_tile) { - auto selected_func = - getDispatchFunctionForSM100( - gemm_config.epilogue_schedule, dynamic_cga, swap_ab); - selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy, - workspace_size, cluster_shape_cute, cluster_shape_cute_fallback); - } else { - TLLM_THROW("MXFPX not supported by this tile shape"); - } - } else { auto selected_func = getDispatchFunctionForSM100( + TileShape, ClusterShape, IsMXFPX>( gemm_config.epilogue_schedule, dynamic_cga, swap_ab); selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size, cluster_shape_cute, cluster_shape_cute_fallback); - } - } else if constexpr (Arch::kMinComputeCapability >= 120 || Arch::kMinComputeCapability == 90) { - using EpilogueSchedule = void; // These are hardcoded in the launcher - constexpr bool dynamic_cga = false; - if (use_mxfpx) { - if constexpr (supports_mxfpx && supports_mxfpx_tile) { - auto selected_func = - hopper_input.swap_ab - ? kernels::cutlass_kernels_oss:: - tma_warp_specialized_generic_moe_gemm_kernelLauncher< - Arch, T, WeightType, OutputType, EpilogueSchedule, EpilogueTag, FUSION, - TileShape, ClusterShape, true, dynamic_cga, false, true> - : kernels::cutlass_kernels_oss:: - tma_warp_specialized_generic_moe_gemm_kernelLauncher< - Arch, T, WeightType, OutputType, EpilogueSchedule, EpilogueTag, FUSION, - TileShape, ClusterShape, true, dynamic_cga, false, false>; - selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy, - workspace_size, {}, {}); - } else { - TLLM_THROW("MXFPX not supported by this tile shape"); - } - } else { + } else if constexpr (Arch::kMinComputeCapability >= 120 || + Arch::kMinComputeCapability == 90) { + using EpilogueSchedule = void; // These are hardcoded in the launcher + constexpr bool dynamic_cga = false; auto selected_func = hopper_input.swap_ab ? kernels::cutlass_kernels_oss:: tma_warp_specialized_generic_moe_gemm_kernelLauncher< Arch, T, WeightType, OutputType, EpilogueSchedule, EpilogueTag, FUSION, - TileShape, ClusterShape, false, dynamic_cga, false, true> + TileShape, ClusterShape, IsMXFPX, dynamic_cga, false, true> : kernels::cutlass_kernels_oss:: tma_warp_specialized_generic_moe_gemm_kernelLauncher< Arch, T, WeightType, OutputType, EpilogueSchedule, EpilogueTag, FUSION, - TileShape, ClusterShape, false, dynamic_cga, false, false>; + TileShape, ClusterShape, IsMXFPX, dynamic_cga, false, false>; selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size, {}, {}); } @@ -273,7 +228,7 @@ void dispatchMoeGemmFinalDispatchTmaWarpSpecialized( } template + typename WeightType, bool IsMXFPX> constexpr bool are_tile_shapes_supported_sm100() { // We use a runtime cluster shape for SM100, so we only support 1x1x1 and 2x1x1 cluster shapes. if (cute::size<0>(ClusterShape{}) > 2 || cute::size<1>(ClusterShape{}) != 1 || @@ -312,6 +267,15 @@ constexpr bool are_tile_shapes_supported_sm100() { } #endif +#ifdef ENABLE_FP8 + if constexpr (IsMXFPX && std::is_same_v && + std::is_same_v) { + if ((TileN != 64 && TileN != 128 && TileN != 256) || TileM != 128) { + return false; + } + } +#endif + if constexpr (std::is_same_v) { if constexpr ((TileN == 16 || TileN == 8) && cute::size<0>(ClusterShape{}) == 1 && cute::size<1>(ClusterShape{}) == 1) { @@ -361,10 +325,11 @@ constexpr bool are_tile_shapes_supported_sm120() { that may not be very useful in practice. */ template + typename WeightType, bool IsMXFPX> constexpr bool are_tile_shapes_supported() { if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120) { - return are_tile_shapes_supported_sm100(); + return are_tile_shapes_supported_sm100(); } else if constexpr (Arch::kMinComputeCapability == 120 || Arch::kMinComputeCapability == 121) { return are_tile_shapes_supported_sm120(); } @@ -389,7 +354,7 @@ constexpr bool are_tile_shapes_supported() { } template + EpilogueFusion FUSION, typename TileShape, bool IsMXFPX> void dispatchMoeGemmSelectClusterShapeTmaWarpSpecialized( TmaWarpSpecializedGroupedGemmInput hopper_input, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, @@ -400,9 +365,10 @@ void dispatchMoeGemmSelectClusterShapeTmaWarpSpecialized( #define SHAPE_CASE(M, N, K) \ case cutlass_extensions::ClusterShape::ClusterShape_##M##x##N##x##K: { \ using ClusterShape = Shape<_##M, _##N, _##K>; \ - if constexpr (are_tile_shapes_supported()) { \ + if constexpr (are_tile_shapes_supported()) { \ dispatchMoeGemmFinalDispatchTmaWarpSpecialized( \ + FUSION, TileShape, ClusterShape, IsMXFPX>( \ hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, \ workspace_size); \ break; \ @@ -437,20 +403,23 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized( cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy, size_t* workspace_size) { using namespace cute; - -#define SHAPE_CASE(SMVERSION, M, N, K) \ - case cutlass_extensions::CutlassTileConfigSM##SMVERSION::CtaShape##M##x##N##x##K##B: { \ - constexpr int KtileBytes = \ - (K * 8) / \ - cutlass::sizeof_bits< \ - typename kernels::cutlass_kernels::TllmToCutlassTypeAdapter::type>::value; \ - using KTileDim = Int; \ - using TileShape = Shape<_##M, _##N, KTileDim>; \ - dispatchMoeGemmSelectClusterShapeTmaWarpSpecialized< \ - cutlass::arch::Sm##SMVERSION, T, WeightType, OutputType, EpilogueTag, FUSION, TileShape>( \ - hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, \ - workspace_size); \ - break; \ + auto dispatch_by_mxfpx = [&](auto is_mxfpx_t) { + constexpr bool IsMXFPX = decltype(is_mxfpx_t)::value; + +#define SHAPE_CASE(SMVERSION, M, N, K) \ + case cutlass_extensions::CutlassTileConfigSM##SMVERSION::CtaShape##M##x##N##x##K##B: { \ + constexpr int KtileBytes = \ + (K * 8) / \ + cutlass::sizeof_bits< \ + typename kernels::cutlass_kernels::TllmToCutlassTypeAdapter::type>::value; \ + using KTileDim = Int; \ + using TileShape = Shape<_##M, _##N, KTileDim>; \ + dispatchMoeGemmSelectClusterShapeTmaWarpSpecialized( \ + hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, \ + workspace_size); \ + break; \ } #define DEFAULT_CASE(SMVERSION) \ case cutlass_extensions::CutlassTileConfigSM##SMVERSION::Undefined: \ @@ -464,76 +433,85 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized( (int)gemm_config.tile_config_sm##SMVERSION); \ break; - if (gemm_config.sm_version == 90) { - if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation()) { - switch (gemm_config.tile_config_sm90) { - SHAPE_CASE(90, 128, 16, 128) - SHAPE_CASE(90, 128, 32, 128) - SHAPE_CASE(90, 128, 64, 128) - SHAPE_CASE(90, 128, 128, 128) - SHAPE_CASE(90, 128, 256, 128) - SHAPE_CASE(90, 256, 128, 128) - DEFAULT_CASE(90) + if (gemm_config.sm_version == 90) { + if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation< + T, WeightType, EpilogueTag, FUSION>()) { + switch (gemm_config.tile_config_sm90) { + SHAPE_CASE(90, 128, 16, 128) + SHAPE_CASE(90, 128, 32, 128) + SHAPE_CASE(90, 128, 64, 128) + SHAPE_CASE(90, 128, 128, 128) + SHAPE_CASE(90, 128, 256, 128) + SHAPE_CASE(90, 256, 128, 128) + DEFAULT_CASE(90) + } + } else { + TLLM_THROW("Unsupported SM90 configuration requested"); } - } else { - TLLM_THROW("Unsupported SM90 configuration requested"); } - } #if defined(ENABLE_FP4) && defined(COMPILE_BLACKWELL_SM103_TMA_GROUPED_GEMMS) - // Check this before SM100 because we fall back to SM100 if not NVFP4 - else if (gemm_config.sm_version == 103 && std::is_same_v && - std::is_same_v) { - if constexpr (kernels::cutlass_kernels::isValidBlackwellMOESpecialisation< - T, WeightType, EpilogueTag, FUSION>()) { - switch (gemm_config.tile_config_sm100) { - SHAPE_CASE(103, 128, 128, 128) - SHAPE_CASE(103, 128, 256, 128) - - DEFAULT_CASE(100) // 100 because we use the same member variable for SM100 and SM103 + // Check this before SM100 because we fall back to SM100 if not NVFP4 + else if (gemm_config.sm_version == 103 && std::is_same_v && + std::is_same_v) { + if constexpr (kernels::cutlass_kernels::isValidBlackwellMOESpecialisation< + T, WeightType, EpilogueTag, FUSION>()) { + switch (gemm_config.tile_config_sm100) { + SHAPE_CASE(103, 128, 128, 128) + SHAPE_CASE(103, 128, 256, 128) + + DEFAULT_CASE(100) // 100 because we use the same member variable for SM100 and SM103 + } + } else { + TLLM_THROW("Unsupported SM103 configuration requested"); } - } else { - TLLM_THROW("Unsupported SM103 configuration requested"); } - } #endif - else if (gemm_config.sm_version >= 100 && gemm_config.sm_version < 120) { - if constexpr (kernels::cutlass_kernels::isValidBlackwellMOESpecialisation< - T, WeightType, EpilogueTag, FUSION>()) { - switch (gemm_config.tile_config_sm100) { - SHAPE_CASE(100, 64, 32, 128) - SHAPE_CASE(100, 64, 64, 128) - SHAPE_CASE(100, 64, 128, 128) - SHAPE_CASE(100, 64, 256, 128) - - SHAPE_CASE(100, 128, 16, 128) - SHAPE_CASE(100, 128, 32, 128) - SHAPE_CASE(100, 128, 64, 128) - SHAPE_CASE(100, 128, 128, 128) - SHAPE_CASE(100, 128, 256, 128) - - // SHAPE_CASE(100, 128, 128, 64) - // SHAPE_CASE(100, 128, 256, 64) - DEFAULT_CASE(100) + else if (gemm_config.sm_version >= 100 && gemm_config.sm_version < 120) { + if constexpr (kernels::cutlass_kernels::isValidBlackwellMOESpecialisation< + T, WeightType, EpilogueTag, FUSION>()) { + switch (gemm_config.tile_config_sm100) { + SHAPE_CASE(100, 64, 32, 128) + SHAPE_CASE(100, 64, 64, 128) + SHAPE_CASE(100, 64, 128, 128) + SHAPE_CASE(100, 64, 256, 128) + + SHAPE_CASE(100, 128, 16, 128) + SHAPE_CASE(100, 128, 32, 128) + SHAPE_CASE(100, 128, 64, 128) + SHAPE_CASE(100, 128, 128, 128) + SHAPE_CASE(100, 128, 256, 128) + + // SHAPE_CASE(100, 128, 128, 64) + // SHAPE_CASE(100, 128, 256, 64) + DEFAULT_CASE(100) + } + } else { + TLLM_THROW("Unsupported SM100 configuration requested"); } - } else { - TLLM_THROW("Unsupported SM100 configuration requested"); - } - } else if (gemm_config.sm_version == 120 || gemm_config.sm_version == 121) { - TLLM_LOG_TRACE("At %s, SM120 config=%d", __PRETTY_FUNCTION__, - (int)gemm_config.tile_config_sm120); - if constexpr (kernels::cutlass_kernels::isValidSM120MOESpecialisation()) { - switch (gemm_config.tile_config_sm120) { - SHAPE_CASE(120, 128, 128, 64) - SHAPE_CASE(120, 128, 128, 128) - SHAPE_CASE(120, 128, 256, 64) - SHAPE_CASE(120, 256, 128, 64) - DEFAULT_CASE(120) + } else if (gemm_config.sm_version == 120 || gemm_config.sm_version == 121) { + TLLM_LOG_TRACE("SM120 config=%d", (int)gemm_config.tile_config_sm120); + if constexpr (kernels::cutlass_kernels::isValidSM120MOESpecialisation< + T, WeightType, EpilogueTag, FUSION>()) { + switch (gemm_config.tile_config_sm120) { + SHAPE_CASE(120, 128, 128, 64) + SHAPE_CASE(120, 128, 128, 128) + SHAPE_CASE(120, 128, 256, 64) + SHAPE_CASE(120, 256, 128, 64) + DEFAULT_CASE(120) + } } } - } #undef SHAPE_CASE +#undef DEFAULT_CASE + }; + + bool const use_mxfpx = hopper_input.fpX_block_scaling_type == + TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; + if (use_mxfpx) { + dispatch_by_mxfpx(tensorrt_llm::common::ConstBool{}); + } else { + dispatch_by_mxfpx(tensorrt_llm::common::ConstBool{}); + } } template From 8e1091b79a257dda57146d9318c07551421e8c20 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 23 Feb 2026 17:20:35 -0800 Subject: [PATCH 07/11] Make `use_mxfp8` compile time derivable --- .../cutlass_fused_moe_instantiation.cu | 3 + .../cutlass_fused_moe_kernels.cuh | 7 +- .../flashinfer_cutlass_fused_moe_binding.cu | 19 ++-- .../include/moe_gemm_kernels.h | 22 +++-- .../cutlass_kernels/include/moe_kernels.h | 6 +- .../moe_gemm/moe_gemm_kernels_fp8_fp8.cu | 2 + .../moe_gemm/moe_gemm_template_dispatch.h | 95 +++++++++++-------- .../moe_gemm_template_dispatch_tma_ws.h | 2 +- 8 files changed, 93 insertions(+), 63 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu index 6469b9a0cd..e8d78329b7 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu @@ -32,9 +32,12 @@ template class CutlassMoeFCRunner; #ifdef ENABLE_FP8 // template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half>; +template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half, __nv_fp8_e4m3, half, true>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, half, half>; #ifdef ENABLE_BF16 template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>; +template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16, __nv_fp8_e4m3, + __nv_bfloat16, true>; template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_fp8_e4m3>; diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 7b3e80f1ff..40bfc16b14 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -2620,8 +2620,7 @@ CutlassMoeFCRunner void CutlassMoeFCRunner::gemm1( - MoeGemmRunner& gemm_runner, + MoeGemmRunner& gemm_runner, DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, T* const output, void* const intermediate_result, int64_t const* const expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput const tma_ws_input_template, @@ -3211,7 +3210,7 @@ void CutlassMoeFCRunner void CutlassMoeFCRunner::gemm2( - MoeGemmRunner& gemm_runner, + MoeGemmRunner& gemm_runner, DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, void* const gemm_output, OutputType* const final_output, int64_t const* const expert_first_token_offset, diff --git a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu index 9c98a2825c..a59204ef9a 100644 --- a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu +++ b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu @@ -82,7 +82,7 @@ class DtypeUtils { class FusedMoeRunner : public tvm::ffi::ModuleObj { public: - template + template std::unique_ptr switch_output_type(DLDataType output_type) { switch (encode_dlpack_dtype(output_type)) { case int64_code: // INT64 == FP4 @@ -94,19 +94,20 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { // return std::make_unique>(); case float16_code: if constexpr (NeedQuant) { - return std::make_unique>(); + return std::make_unique< + kernels::CutlassMoeFCRunner>(); } else { return std::make_unique< - kernels::CutlassMoeFCRunner>(); + kernels::CutlassMoeFCRunner>(); } #ifdef ENABLE_BF16 case bfloat16_code: if constexpr (NeedQuant) { - return std::make_unique< - kernels::CutlassMoeFCRunner>(); + return std::make_unique>(); } else { - return std::make_unique< - kernels::CutlassMoeFCRunner>(); + return std::make_unique>(); } #endif default: @@ -145,7 +146,9 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { #endif #ifdef ENABLE_FP8 - if (isFp8Quant() || isWMxfp8AMxfp8Quant()) { + if (isWMxfp8AMxfp8Quant()) { + mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp8_e4m3, false, true>(mOutputDtype); + } else if (isFp8Quant()) { mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp8_e4m3>(mOutputDtype); } #endif diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h index 476b99f7d8..962f60a688 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h @@ -239,10 +239,11 @@ constexpr bool isGatedActivation(ActivationType activation_type) { activation_type == ActivationType::SwigluBias; } -template +template class MoeGemmRunner { public: MoeGemmRunner(); @@ -271,6 +272,14 @@ class MoeGemmRunner { #else static constexpr bool use_fp8 = false; static constexpr bool use_w4afp8 = false; +#endif + static constexpr bool use_mxfp8 = use_fp8 && IsMXFPX; +#if defined(ENABLE_FP8) + static_assert(!IsMXFPX || + (std::is_same_v && std::is_same_v), + "IsMXFPX requires FP8xFP8 (E4M3) runner types"); +#else + static_assert(!IsMXFPX, "IsMXFPX requires FP8 support"); #endif static constexpr bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; @@ -308,7 +317,7 @@ class MoeGemmRunner { [[nodiscard]] bool supportsFusedGatedActivation(ActivationType activation_type, int gemm_n, int gemm_k) const; - size_t getMaxWorkspaceSize(int num_experts, bool use_mxfp8_act_scaling = false) const; + size_t getMaxWorkspaceSize(int num_experts) const; [[nodiscard]] int getSM() const; @@ -325,9 +334,8 @@ class MoeGemmRunner { int sm_{}; int multi_processor_count_{}; mutable int num_experts_ = 0; - mutable bool use_mxfp8_act_scaling_ = false; mutable size_t gemm_workspace_size_ = 0; - size_t calcMaxWorkspaceSize(int num_experts, bool use_mxfp8_act_scaling) const; + size_t calcMaxWorkspaceSize(int num_experts) const; }; } // namespace tensorrt_llm::kernels::cutlass_kernels diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index d17b69a081..d64ad380a5 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -649,7 +649,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { cudaStream_t stream) override; // We make these GEMM1 & GEMM2 static because they need to be stateless for the profiler to work - static void gemm1(MoeGemmRunner& gemm_runner, + static void gemm1(MoeGemmRunner& gemm_runner, // This argument must not be null if fp8 block scaling is being used. // The gemm_runner will be ignored in that case. NOTE: it would // be great if we could consolidate gemm_runner and fp8_blockscale_gemm_runner. @@ -675,7 +675,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { int* num_active_experts_per, int* active_expert_global_ids, bool enable_pdl); static void gemm2( - MoeGemmRunner& gemm_runner, + MoeGemmRunner& gemm_runner, DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, void* const gemm_output, OutputType* const final_output, int64_t const* const expert_first_token_offset, @@ -939,7 +939,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq, cudaStream_t stream); - MoeGemmRunner moe_gemm_runner_; + MoeGemmRunner moe_gemm_runner_; std::unique_ptr blockscale_gemm_runner_; std::optional gemm1_config_; diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu index 08b7ce1930..e78f8aa823 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu @@ -19,8 +19,10 @@ namespace tensorrt_llm::kernels::cutlass_kernels { #ifdef ENABLE_FP8 template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half>; +template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half, half, true>; #ifdef ENABLE_BF16 template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>; +template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16, true>; #endif // template class MoeGemmRunner<__nv_fp8_e5m2, __nv_fp8_e5m2>; #endif diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h index 122b418e4c..b8051ef7fc 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h @@ -525,17 +525,19 @@ void dispatchMoeGemmToCutlass( namespace tensorrt_llm::kernels::cutlass_kernels { -template +template std::vector -MoeGemmRunner::getConfigs( +MoeGemmRunner::getConfigs( bool supports_finalize_fusion) const { return getConfigs(sm_, supports_finalize_fusion); } -template +template std::vector -MoeGemmRunner::getConfigs(int sm, - bool supports_finalize_fusion) { +MoeGemmRunner::getConfigs( + int sm, bool supports_finalize_fusion) { std::vector candidate_configs = getTmaWarpSpecializedConfigs(sm, supports_finalize_fusion); std::vector ampere_configs = getAmpereConfigs(sm); @@ -543,9 +545,10 @@ MoeGemmRunner::getConfigs(int sm, return candidate_configs; } -template +template std::vector -MoeGemmRunner::getAmpereConfigs(int sm) { +MoeGemmRunner::getAmpereConfigs(int sm) { using tensorrt_llm::cutlass_extensions::CutlassGemmConfig; static constexpr auto weight_only_flag = std::is_same::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY; @@ -571,9 +574,10 @@ MoeGemmRunner::getAmpereConfigs(int sm return ampere_configs; } -template +template std::vector -MoeGemmRunner::getTmaWarpSpecializedConfigs( +MoeGemmRunner::getTmaWarpSpecializedConfigs( int sm, bool supports_finalize_fusion) { using tensorrt_llm::cutlass_extensions::CutlassGemmConfig; static constexpr auto weight_only_flag = @@ -668,15 +672,18 @@ MoeGemmRunner::getTmaWarpSpecializedCo return tma_ws_configs; } -template -bool MoeGemmRunner::isTmaWarpSpecialized( +template +bool MoeGemmRunner::isTmaWarpSpecialized( cutlass_extensions::CutlassGemmConfig gemm_config) const { bool config_is_tma_warp_specialized = gemm_config.is_tma_warp_specialized; return supportsTmaWarpSpecialized() && config_is_tma_warp_specialized; } -template -bool MoeGemmRunner::supportsTmaWarpSpecialized(int sm) { +template +bool MoeGemmRunner::supportsTmaWarpSpecialized( + int sm) { return (sm == 90 && tensorrt_llm::kernels::cutlass_kernels::isValidHopperMOESpecialisation()) || @@ -687,14 +694,16 @@ bool MoeGemmRunner::supportsTmaWarpSpe tensorrt_llm::kernels::cutlass_kernels::isValidSM120MOESpecialisation()); } -template -int MoeGemmRunner::getSM() const { +template +int MoeGemmRunner::getSM() const { return this->sm_; } // currently support sm80 bf16/fp16 gate activation, only set predication tensor for m direction -template -bool MoeGemmRunner::supportsFusedGatedActivation( +template +bool MoeGemmRunner::supportsFusedGatedActivation( ActivationType activation_type, int gemm_n, int gemm_k) const { constexpr bool ENABLE_FUSED_GATED_ACTIVATION = true; return (activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu) && @@ -703,16 +712,18 @@ bool MoeGemmRunner::supportsFusedGated ENABLE_FUSED_GATED_ACTIVATION; } -template -bool MoeGemmRunner::isFusedGatedActivation( +template +bool MoeGemmRunner::isFusedGatedActivation( cutlass_extensions::CutlassGemmConfig gemm_config, ActivationType activation_type, int gemm_n, int gemm_k) const { return supportsFusedGatedActivation(activation_type, gemm_n, gemm_k) && !gemm_config.is_tma_warp_specialized; } -template -MoeGemmRunner::MoeGemmRunner() { +template +MoeGemmRunner::MoeGemmRunner() { int device{-1}; tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device)); sm_ = tensorrt_llm::common::getSMVersion(); @@ -720,9 +731,10 @@ MoeGemmRunner::MoeGemmRunner() { cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); } -template +template template -void MoeGemmRunner::dispatchToArch( +void MoeGemmRunner::dispatchToArch( GroupedGemmInput inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs) { static_assert( @@ -917,22 +929,23 @@ void MoeGemmRunner::dispatchToArch( } } -template -size_t MoeGemmRunner::getMaxWorkspaceSize( - int num_experts, bool use_mxfp8_act_scaling) const { - if (num_experts != num_experts_ || use_mxfp8_act_scaling != use_mxfp8_act_scaling_) { +template +size_t MoeGemmRunner::getMaxWorkspaceSize( + int num_experts) const { + if (num_experts != num_experts_) { TLLM_LOG_TRACE("Calling getMaxWorkspaceSize() with a new expert count %d vs %d", num_experts, num_experts_); num_experts_ = num_experts; - use_mxfp8_act_scaling_ = use_mxfp8_act_scaling; - gemm_workspace_size_ = calcMaxWorkspaceSize(num_experts, use_mxfp8_act_scaling); + gemm_workspace_size_ = calcMaxWorkspaceSize(num_experts); } return gemm_workspace_size_; } -template -size_t MoeGemmRunner::calcMaxWorkspaceSize( - int num_experts, bool use_mxfp8_act_scaling) const { +template +size_t MoeGemmRunner::calcMaxWorkspaceSize( + int num_experts) const { if constexpr (use_w4_groupwise) { return cutlass_kernels_oss::calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput( @@ -950,8 +963,7 @@ size_t MoeGemmRunner::calcMaxWorkspace auto fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE; if constexpr (use_wfp4afp8) { fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; - } else if (use_mxfp8_act_scaling) { - // Runtime MXFP8 act scaling reserves MXFPX workspace requirements. + } else if constexpr (use_mxfp8) { fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; } else if constexpr (use_fp4) { fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4; @@ -990,16 +1002,18 @@ size_t MoeGemmRunner::calcMaxWorkspace } } -template +template template -void MoeGemmRunner::runGemm( +void MoeGemmRunner::runGemm( GroupedGemmInput inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs) { dispatchToArch(inputs, hopper_inputs); } -template -void MoeGemmRunner::moeGemmBiasAct( +template +void MoeGemmRunner::moeGemmBiasAct( GroupedGemmInput inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs) { switch (inputs.activation_type) { @@ -1030,8 +1044,9 @@ void MoeGemmRunner::moeGemmBiasAct( } } -template -void MoeGemmRunner::moeGemm( +template +void MoeGemmRunner::moeGemm( GroupedGemmInput inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs) { runGemm(inputs, hopper_inputs); diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h index baf82d262c..455caa5113 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h @@ -489,7 +489,7 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized( TLLM_THROW("Unsupported SM100 configuration requested"); } } else if (gemm_config.sm_version == 120 || gemm_config.sm_version == 121) { - TLLM_LOG_TRACE("SM120 config=%d", (int)gemm_config.tile_config_sm120); + TLLM_LOG_TRACE("At %s, SM120 config=%d", (int)gemm_config.tile_config_sm120); if constexpr (kernels::cutlass_kernels::isValidSM120MOESpecialisation< T, WeightType, EpilogueTag, FUSION>()) { switch (gemm_config.tile_config_sm120) { From 4bedd311854de7fc3ba1c704ca4e7939250bfec6 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 23 Feb 2026 18:10:50 -0800 Subject: [PATCH 08/11] Only run tests for sm100 for now --- tests/moe/test_trtllm_cutlass_fused_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/moe/test_trtllm_cutlass_fused_moe.py b/tests/moe/test_trtllm_cutlass_fused_moe.py index 6c1dc4e0f2..c3b2138b45 100644 --- a/tests/moe/test_trtllm_cutlass_fused_moe.py +++ b/tests/moe/test_trtllm_cutlass_fused_moe.py @@ -1387,8 +1387,8 @@ def test_moe_mxfp8_mxfp4( ("alpha", "beta", "limit"), [(None, None, None), (0.5, 0.0, 7.0), (1.702, 1.0, 7.0)] ) @pytest.mark.skipif( - torch.cuda.get_device_capability()[0] not in [10, 11, 12], - reason="MXFP8xMXFP8 is only supported on SM100, SM110 and SM120", + torch.cuda.get_device_capability()[0] not in [10], + reason="MXFP8xMXFP8 is only supported on SM100 for now", ) def test_moe_mxfp8_mxfp8( batch_size, From d45dcc6877218c1780e8f1467dfc212c974bf9f5 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 23 Feb 2026 19:20:26 -0800 Subject: [PATCH 09/11] Refactor internals --- .../cutlass_fused_moe_kernels.cuh | 66 +++++-------------- .../include/moe_gemm_kernels.h | 8 +-- .../cutlass_kernels/include/moe_kernels.h | 12 +--- .../moe_gemm_template_dispatch_tma_ws.h | 4 +- 4 files changed, 24 insertions(+), 66 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 40bfc16b14..3ea7860080 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -2600,24 +2600,21 @@ CutlassMoeFCRunner(num_rows * num_experts_per_node)); - auto const workspace_scaling_type = - use_mxfp8_act_scaling ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX - : getScalingType(); size_t const sf_size = - workspace_scaling_type == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX + getScalingType() == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX ? sizeof(TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF) : sizeof(TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF); - size_t const fc1_fp4_act_scale_size = getOffsetActivationSF(num_experts_per_node, act_sf_rows, - hidden_size, workspace_scaling_type) * - sf_size; + size_t const fc1_fp4_act_scale_size = + getOffsetActivationSF(num_experts_per_node, act_sf_rows, hidden_size, getScalingType()) * + sf_size; size_t const fc2_fp4_act_scale_size = - getOffsetActivationSF(num_experts_per_node, act_sf_rows, inter_size, workspace_scaling_type) * + getOffsetActivationSF(num_experts_per_node, act_sf_rows, inter_size, getScalingType()) * sf_size; size_t const fp4_act_scale_size = std::max(fc1_fp4_act_scale_size, fc2_fp4_act_scale_size); size_t const tma_ws_size = using_tma_ws ? TmaWarpSpecializedGroupedGemmInput::workspaceSize( - num_experts_per_node, workspace_scaling_type) + num_experts_per_node, getScalingType()) : 0; size_t const gemm_workspace_size = moe_gemm_runner_.getMaxWorkspaceSize(num_experts_per_node); @@ -2818,10 +2815,7 @@ void CutlassMoeFCRunner::value || std::is_same::value; - bool const is_mxfp8_mode = fp8_scales_required && use_mxfp8_act_scaling; auto const* input_activations = static_cast(input_activations_void); auto const* input_sf = @@ -3593,8 +3586,9 @@ void CutlassMoeFCRunner::value) == 0, - "Hidden size %d does not meet minimum alignment requirements for WMXFP8AMXFP8 MOE GEMM %d", + "Hidden size %d does not meet minimum alignment requirements for MXFP8 MOE GEMM %d", (int)hidden_size, (int)(64 * 8 / sizeof_bits::value)); TLLM_CHECK_WITH_INFO( inter_size % (64 * 8 / sizeof_bits::value) == 0, - "Inter size %d does not meet minimum alignment requirements for WMXFP8AMXFP8 MOE GEMM %d", + "Inter size %d does not meet minimum alignment requirements for MXFP8 MOE GEMM %d", (int)inter_size, (int)(64 * 8 / sizeof_bits::value)); } else { // For NoSmem epilogue schedule, we need to align the output of the GEMM to 256 bits, for gated @@ -3684,7 +3678,7 @@ void CutlassMoeFCRunner(gemm1_input), - reinterpret_cast(gemm2_input), fc1_expert_weights, fc2_expert_weights, - quant_params.mxfp8_mxfp4.fc1.global_scale, quant_params.mxfp8_mxfp4.fc2.global_scale, - fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, fc1_expert_biases, fc2_bias, - reinterpret_cast(gemm1_output), - reinterpret_cast(fc2_result_), permuted_token_final_scales_, - permuted_row_to_unpermuted_row_, enable_pdl, stream); - } return Self::computeStridesTmaWarpSpecialized( expert_first_token_offset_, gemm1_tma_ws_input, gemm2_tma_ws_input, num_rows, expanded_num_rows, fc1_out_size, hidden_size, hidden_size, inter_size, num_experts_per_node, diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h index 962f60a688..4e85da745c 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h @@ -274,13 +274,7 @@ class MoeGemmRunner { static constexpr bool use_w4afp8 = false; #endif static constexpr bool use_mxfp8 = use_fp8 && IsMXFPX; -#if defined(ENABLE_FP8) - static_assert(!IsMXFPX || - (std::is_same_v && std::is_same_v), - "IsMXFPX requires FP8xFP8 (E4M3) runner types"); -#else - static_assert(!IsMXFPX, "IsMXFPX requires FP8 support"); -#endif + static constexpr bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; #if defined(ENABLE_FP4) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index d64ad380a5..067060a184 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -257,9 +257,8 @@ struct QuantParams { GemmInputs fc2; } fp8_mxfp4; - // MXFP8 block-scaled quantization params. - // Historical note: this payload shape is also reused by MXFP8xMXFP8 (FP8 weights with MXFPX - // block scales), so this field name is legacy. + // MXFP8 MXFP4 quantization params + // This mode uses block scaled MXFP8 and MXFP4 weights struct MXFP8MXFP4Inputs { struct GemmInputs { TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* weight_block_scale = @@ -587,13 +586,6 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { static constexpr bool use_mxfp8 = use_fp8 && IsMXFPX; static constexpr bool use_block_scaling = use_fp4 || use_wfp4afp8 || use_mxfp8; -#if defined(ENABLE_FP8) - static_assert(!IsMXFPX || - (std::is_same_v && std::is_same_v), - "IsMXFPX requires FP8xFP8 (E4M3) runner types"); -#else - static_assert(!IsMXFPX, "IsMXFPX requires FP8 support"); -#endif // This should leave the variable unchanged in any currently supported configuration using UnfusedGemmOutputType = BackBoneType; diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h index 455caa5113..c70596ac53 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h @@ -489,7 +489,9 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized( TLLM_THROW("Unsupported SM100 configuration requested"); } } else if (gemm_config.sm_version == 120 || gemm_config.sm_version == 121) { - TLLM_LOG_TRACE("At %s, SM120 config=%d", (int)gemm_config.tile_config_sm120); + char const* const pretty_function = __PRETTY_FUNCTION__; + TLLM_LOG_TRACE("At %s, SM120 config=%d", pretty_function, + static_cast(gemm_config.tile_config_sm120)); if constexpr (kernels::cutlass_kernels::isValidSM120MOESpecialisation< T, WeightType, EpilogueTag, FUSION>()) { switch (gemm_config.tile_config_sm120) { From f28fcfb66ef2bd2419fafe006c4b8240017d3f4b Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 23 Feb 2026 21:25:58 -0800 Subject: [PATCH 10/11] Add `MXFP8MXFP8Inputs` and refactor --- .../cutlass_fused_moe_kernels.cuh | 32 ++++++++++++------- .../cutlass_kernels/include/moe_kernels.h | 19 +++++++++-- .../moe_gemm/moe_gemm_template_dispatch.h | 2 +- .../moe_gemm_template_dispatch_tma_ws.h | 2 +- 4 files changed, 39 insertions(+), 16 deletions(-) diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 3ea7860080..a7f03548aa 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1297,6 +1297,8 @@ __global__ void computeStridesTmaWarpSpecializedKernel( quant_params.fp8_mxfp4); setupIfSelected(TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaledConfig{}, quant_params.mxfp8_mxfp4); + setupIfSelected(TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaledConfig{}, + quant_params.mxfp8_mxfp8); assert(gemm_m <= INT32_MAX); assert(gemm1_n > 0 && gemm1_n <= INT32_MAX); @@ -1585,9 +1587,9 @@ void expandInputRowsKernelLauncher( // Always MXFP8 if constexpr (std::is_same_v && !std::is_same_v) { - TLLM_CHECK_WITH_INFO(quant_params.mxfp8_mxfp4.fc1.weight_block_scale || prequant_scales, - "MXFP8xMXFP4 block scaling or prequant_scales or prequant_scales " - "parameters not provided"); + TLLM_CHECK_WITH_INFO(quant_params.mxfp8_mxfp4.fc1.weight_block_scale || + quant_params.mxfp8_mxfp8.fc1.weight_block_scale || prequant_scales, + "MXFP8 block scaling or prequant_scales parameters not provided"); return prequant_scales ? &expandInputRowsKernel< InputActivationsType, ExpandedActivationsType, @@ -1600,7 +1602,8 @@ void expandInputRowsKernelLauncher( else if constexpr (std::is_same_v && std::is_same_v) { TLLM_CHECK_WITH_INFO(!prequant_scales, "FP8 is not supported for AWQ"); - return quant_params.mxfp8_mxfp4.fc1.weight_block_scale + return (quant_params.mxfp8_mxfp4.fc1.weight_block_scale || + quant_params.mxfp8_mxfp8.fc1.weight_block_scale) ? &expandInputRowsKernel< InputActivationsType, ExpandedActivationsType, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, false> @@ -2329,7 +2332,10 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 "NVFP4 block scaling is expected for FP4xFP4"); return fn(NVFP4); } else if constexpr (std::is_same_v) { - return quant_params.mxfp8_mxfp4.fc2.weight_block_scale ? fn(MXFPX) : fn(NONE); + return (quant_params.mxfp8_mxfp4.fc2.weight_block_scale || + quant_params.mxfp8_mxfp8.fc2.weight_block_scale) + ? fn(MXFPX) + : fn(NONE); } else #endif { @@ -2815,7 +2821,7 @@ void CutlassMoeFCRunner::value) == 0, "Hidden size %d does not meet minimum alignment requirements for MXFP8 MOE GEMM %d", diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index 067060a184..a5e786b148 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -270,6 +270,19 @@ struct QuantParams { GemmInputs fc2; } mxfp8_mxfp4; + // MXFP8 MXFP8 quantization params + // This mode uses block scaled MXFP8 activations and MXFP8 weights. + struct MXFP8MXFP8Inputs { + struct GemmInputs { + TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* weight_block_scale = + nullptr; // (experts, n, k / 32) + float const* global_scale = nullptr; // (num_experts_per_node, ) + }; + + GemmInputs fc1; + GemmInputs fc2; + } mxfp8_mxfp8; + // FP4 quantization params struct FP4Inputs { struct GemmInputs { @@ -362,8 +375,10 @@ struct QuantParams { float const* fc1_global_scale, // TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* fc2_weight_block_scale, float const* fc2_global_scale) { - return MXFP8MXFP4(fc1_weight_block_scale, fc1_global_scale, fc2_weight_block_scale, - fc2_global_scale); + QuantParams qp; + qp.mxfp8_mxfp8.fc1 = {fc1_weight_block_scale, fc1_global_scale}; + qp.mxfp8_mxfp8.fc2 = {fc2_weight_block_scale, fc2_global_scale}; + return qp; } static QuantParams FP4( diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h index b8051ef7fc..5a4e01ad12 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h @@ -965,7 +965,7 @@ size_t MoeGemmRunner::calcMax fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; } else if constexpr (use_mxfp8) { fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; - } else if constexpr (use_fp4) { + } else if (use_fp4) { fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4; } size_t max_size = 0; diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h index c70596ac53..3c2d31fd82 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h @@ -491,7 +491,7 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized( } else if (gemm_config.sm_version == 120 || gemm_config.sm_version == 121) { char const* const pretty_function = __PRETTY_FUNCTION__; TLLM_LOG_TRACE("At %s, SM120 config=%d", pretty_function, - static_cast(gemm_config.tile_config_sm120)); + (int)(gemm_config.tile_config_sm120)); if constexpr (kernels::cutlass_kernels::isValidSM120MOESpecialisation< T, WeightType, EpilogueTag, FUSION>()) { switch (gemm_config.tile_config_sm120) { From 6ab67b45c922c4919d67314d1412172ad289215c Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Fri, 6 Mar 2026 17:04:48 -0800 Subject: [PATCH 11/11] Disallow `IsMXFPX` for sm90 --- .../moe_gemm/moe_gemm_template_dispatch_tma_ws.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h index 3c2d31fd82..5cd0b12b41 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h @@ -434,8 +434,8 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized( break; if (gemm_config.sm_version == 90) { - if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation< - T, WeightType, EpilogueTag, FUSION>()) { + if constexpr (!IsMXFPX && kernels::cutlass_kernels::isValidHopperMOESpecialisation< + T, WeightType, EpilogueTag, FUSION>()) { switch (gemm_config.tile_config_sm90) { SHAPE_CASE(90, 128, 16, 128) SHAPE_CASE(90, 128, 32, 128)