diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu index 6469b9a0cd..e8d78329b7 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu @@ -32,9 +32,12 @@ template class CutlassMoeFCRunner; #ifdef ENABLE_FP8 // template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half>; +template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half, __nv_fp8_e4m3, half, true>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, half, half>; #ifdef ENABLE_BF16 template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>; +template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16, __nv_fp8_e4m3, + __nv_bfloat16, true>; template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_fp8_e4m3>; diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 0cead70f37..a7f03548aa 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -1297,6 +1297,8 @@ __global__ void computeStridesTmaWarpSpecializedKernel( quant_params.fp8_mxfp4); setupIfSelected(TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaledConfig{}, quant_params.mxfp8_mxfp4); + setupIfSelected(TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaledConfig{}, + quant_params.mxfp8_mxfp8); assert(gemm_m <= INT32_MAX); assert(gemm1_n > 0 && gemm1_n <= INT32_MAX); @@ -1585,9 +1587,9 @@ void expandInputRowsKernelLauncher( // Always MXFP8 if constexpr (std::is_same_v && !std::is_same_v) { - TLLM_CHECK_WITH_INFO(quant_params.mxfp8_mxfp4.fc1.weight_block_scale || prequant_scales, - "MXFP8xMXFP4 block scaling or prequant_scales or prequant_scales " - "parameters not provided"); + TLLM_CHECK_WITH_INFO(quant_params.mxfp8_mxfp4.fc1.weight_block_scale || + quant_params.mxfp8_mxfp8.fc1.weight_block_scale || prequant_scales, + "MXFP8 block scaling or prequant_scales parameters not provided"); return prequant_scales ? &expandInputRowsKernel< InputActivationsType, ExpandedActivationsType, @@ -1600,7 +1602,8 @@ void expandInputRowsKernelLauncher( else if constexpr (std::is_same_v && std::is_same_v) { TLLM_CHECK_WITH_INFO(!prequant_scales, "FP8 is not supported for AWQ"); - return quant_params.mxfp8_mxfp4.fc1.weight_block_scale + return (quant_params.mxfp8_mxfp4.fc1.weight_block_scale || + quant_params.mxfp8_mxfp8.fc1.weight_block_scale) ? &expandInputRowsKernel< InputActivationsType, ExpandedActivationsType, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, false> @@ -2329,7 +2332,10 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 "NVFP4 block scaling is expected for FP4xFP4"); return fn(NVFP4); } else if constexpr (std::is_same_v) { - return quant_params.mxfp8_mxfp4.fc2.weight_block_scale ? fn(MXFPX) : fn(NONE); + return (quant_params.mxfp8_mxfp4.fc2.weight_block_scale || + quant_params.mxfp8_mxfp8.fc2.weight_block_scale) + ? fn(MXFPX) + : fn(NONE); } else #endif { @@ -2526,16 +2532,17 @@ void dequantFP8(OutputType* output, InputType const* input, int64_t const* num_v } template -CutlassMoeFCRunner::CutlassMoeFCRunner() + bool IsMXFPX, class Enable> +CutlassMoeFCRunner::CutlassMoeFCRunner() : blockscale_gemm_runner_{ std::make_unique>()} {} template + bool IsMXFPX, class Enable> std::map> -CutlassMoeFCRunner::getWorkspaceDeviceBufferSizes(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, @@ -2544,6 +2551,7 @@ CutlassMoeFCRunner -size_t -CutlassMoeFCRunner::getWorkspaceSize( - int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts, int const experts_per_token, ActivationType activation_type, - MOEParallelismConfig parallelism_config, bool use_lora, bool use_deepseek_fp8_block_scale, - bool min_latency_mode, bool use_awq) { + bool IsMXFPX, class Enable> +size_t CutlassMoeFCRunner::getWorkspaceSize(int64_t const num_rows, + int64_t const hidden_size, + int64_t const inter_size, int const num_experts, + int const experts_per_token, + ActivationType activation_type, + MOEParallelismConfig parallelism_config, + bool use_lora, + bool use_deepseek_fp8_block_scale, + bool use_mxfp8_act_scaling, + bool min_latency_mode, bool use_awq) { int const ep_size = parallelism_config.ep_size; TLLM_CHECK_WITH_INFO(num_experts % ep_size == 0, "Number of experts must be a multiple of ep size"); auto sizes_map = getWorkspaceDeviceBufferSizes( num_rows, hidden_size, inter_size, num_experts / ep_size, experts_per_token, activation_type, - use_lora, use_deepseek_fp8_block_scale, min_latency_mode, use_awq); + use_lora, use_deepseek_fp8_block_scale, use_mxfp8_act_scaling, min_latency_mode, use_awq); std::vector sizes(sizes_map.size()); std::transform(sizes_map.begin(), sizes_map.end(), sizes.begin(), [](auto& v) { return v.second.first; }); @@ -2732,8 +2745,8 @@ CutlassMoeFCRunner:: } template -void CutlassMoeFCRunner +void CutlassMoeFCRunner::configureWsPtrs(char* ws_ptr, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, @@ -2742,10 +2755,11 @@ void CutlassMoeFCRunner + bool IsMXFPX, class Enable> kernels::fp8_blockscale_gemm::CutlassFp8BlockScaleGemmRunnerInterface* -CutlassMoeFCRunner::getDeepSeekBlockScaleGemmRunner() const { TLLM_CHECK_WITH_INFO( (std::is_same_v && std::is_same_v), @@ -2868,15 +2882,16 @@ CutlassMoeFCRunner -void CutlassMoeFCRunner::BlockScaleFC1( - DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, T* const output, - void* const gemm_output, int64_t const* const expert_first_token_offset, - WeightType const* const fc1_expert_weights, ScaleBiasType const* const fc1_expert_biases, - float const* const fc2_fp8_quant, int64_t const num_rows, int64_t const expanded_num_rows, - int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, - ActivationParams fc1_activation_type, QuantParams& quant_params, bool enable_pdl, - cudaStream_t stream) { + bool IsMXFPX, class Enable> +void CutlassMoeFCRunner:: + BlockScaleFC1(DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, T* const output, + void* const gemm_output, int64_t const* const expert_first_token_offset, + WeightType const* const fc1_expert_weights, + ScaleBiasType const* const fc1_expert_biases, float const* const fc2_fp8_quant, + int64_t const num_rows, int64_t const expanded_num_rows, + int64_t const hidden_size, int64_t const inter_size, + int const num_experts_per_node, ActivationParams fc1_activation_type, + QuantParams& quant_params, bool enable_pdl, cudaStream_t stream) { bool const is_gated_activation = isGatedActivation(fc1_activation_type); int shape_n = is_gated_activation ? inter_size * 2 : inter_size; @@ -2900,18 +2915,20 @@ void CutlassMoeFCRunner -void CutlassMoeFCRunner::BlockScaleFC2( - DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output, - OutputType* const final_output, int64_t const* const expert_first_token_offset, - WeightType const* const fc2_expert_weights, ScaleBiasType const* const fc2_expert_biases, - float const* const unpermuted_final_scales, int const* const unpermuted_row_to_permuted_row, - int const* const permuted_row_to_unpermuted_row, int const* const token_selected_experts, - int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, - int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const unpadded_hidden_size, - int64_t const inter_size, int64_t const num_experts_per_node, int64_t const k, - MOEParallelismConfig parallelism_config, bool const enable_alltoall, QuantParams& quant_params, - bool enable_pdl, cudaStream_t stream) { + bool IsMXFPX, class Enable> +void CutlassMoeFCRunner:: + BlockScaleFC2( + DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output, + OutputType* const final_output, int64_t const* const expert_first_token_offset, + WeightType const* const fc2_expert_weights, ScaleBiasType const* const fc2_expert_biases, + float const* const unpermuted_final_scales, int const* const unpermuted_row_to_permuted_row, + int const* const permuted_row_to_unpermuted_row, int const* const token_selected_experts, + int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, + int64_t const expanded_num_rows, int64_t const hidden_size, + int64_t const unpadded_hidden_size, int64_t const inter_size, + int64_t const num_experts_per_node, int64_t const k, + MOEParallelismConfig parallelism_config, bool const enable_alltoall, + QuantParams& quant_params, bool enable_pdl, cudaStream_t stream) { int shape_n = hidden_size; int shape_k = inter_size; @@ -2931,12 +2948,15 @@ void CutlassMoeFCRunner -T const* -CutlassMoeFCRunner::applyPrequantScale( - void* smoothed_act, void const* permuted_data, void const* prequant_scales, - int64_t const* num_valid_tokens_ptr, int64_t const expanded_num_rows, int64_t const seq_len, - bool const use_awq, cudaStream_t stream) { + bool IsMXFPX, class Enable> +T const* CutlassMoeFCRunner::applyPrequantScale(void* smoothed_act, + void const* permuted_data, + void const* prequant_scales, + int64_t const* num_valid_tokens_ptr, + int64_t const expanded_num_rows, + int64_t const seq_len, bool const use_awq, + cudaStream_t stream) { T const* gemm_input; bool use_prequant_scale_kernel = use_awq && !std::is_same_v; if (use_prequant_scale_kernel) { @@ -2958,9 +2978,9 @@ CutlassMoeFCRunner: } template -void CutlassMoeFCRunner::gemm1( - MoeGemmRunner& gemm_runner, + bool IsMXFPX, class Enable> +void CutlassMoeFCRunner::gemm1( + MoeGemmRunner& gemm_runner, DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, T* const output, void* const intermediate_result, int64_t const* const expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput const tma_ws_input_template, @@ -3188,9 +3208,9 @@ void CutlassMoeFCRunner -void CutlassMoeFCRunner::gemm2( - MoeGemmRunner& gemm_runner, + bool IsMXFPX, class Enable> +void CutlassMoeFCRunner::gemm2( + MoeGemmRunner& gemm_runner, DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, void* const gemm_output, OutputType* const final_output, int64_t const* const expert_first_token_offset, @@ -3317,8 +3337,8 @@ void CutlassMoeFCRunner -bool CutlassMoeFCRunner +bool CutlassMoeFCRunner::setupLoraWorkspace(int64_t expanded_num_rows, int64_t num_rows, int64_t inter_size, int64_t hidden_size, int start_expert, bool is_gated_activation, @@ -3411,12 +3431,15 @@ bool CutlassMoeFCRunner -auto CutlassMoeFCRunner::loraFC1( - int64_t expanded_num_rows, int64_t inter_size, int64_t hidden_size, int num_experts_per_node, - int start_expert, int64_t const* num_valid_tokens_ptr, bool is_gated_activation, - ScaleBiasType const* fc1_expert_biases, LoraParams& lora_params, float const* input_fp8_dequant, - cudaStream_t stream) -> ScaleBiasType const* { + bool IsMXFPX, class Enable> +auto CutlassMoeFCRunner::loraFC1(int64_t expanded_num_rows, int64_t inter_size, + int64_t hidden_size, int num_experts_per_node, + int start_expert, int64_t const* num_valid_tokens_ptr, + bool is_gated_activation, + ScaleBiasType const* fc1_expert_biases, + LoraParams& lora_params, float const* input_fp8_dequant, + cudaStream_t stream) -> ScaleBiasType const* { TLLM_CHECK_WITH_INFO(!act_fp4, "LoRA does not support FP4 activations"); std::vector& host_permuted_fc1_weight_ptrs = host_lora_workspace_.host_permuted_fc1_weight_ptrs; @@ -3490,11 +3513,13 @@ auto CutlassMoeFCRunner -void CutlassMoeFCRunner::loraFC2( - int64_t inter_size, int64_t hidden_size, int num_experts_per_node, int start_expert, - int64_t const* num_valid_tokens_ptr, int64_t num_tokens, LoraParams& lora_params, - float const* fc2_fp8_quant, cudaStream_t stream) { + bool IsMXFPX, class Enable> +void CutlassMoeFCRunner::loraFC2(int64_t inter_size, int64_t hidden_size, + int num_experts_per_node, int start_expert, + int64_t const* num_valid_tokens_ptr, int64_t num_tokens, + LoraParams& lora_params, float const* fc2_fp8_quant, + cudaStream_t stream) { std::vector& host_permuted_fc2_weight_ptrs = host_lora_workspace_.host_permuted_fc2_weight_ptrs; std::vector& host_permuted_fc2_lora_ranks = @@ -3531,19 +3556,20 @@ void CutlassMoeFCRunner -void CutlassMoeFCRunner::runMoe( - void const* input_activations_void, void const* input_sf_void, bool const swizzled_input_sf, - int const* token_selected_experts, float const* token_final_scales, - void const* fc1_expert_weights_void, void const* fc1_expert_biases_void, - ActivationParams fc1_activation_type, void const* fc2_expert_weights_void, - void const* fc2_expert_biases_void, QuantParams quant_params, int64_t const num_rows, - int64_t const hidden_size, int64_t const unpadded_hidden_size, int64_t const inter_size, - int const full_num_experts, int const experts_per_token, char* workspace_ptr, - void* final_output_void, int* unpermuted_row_to_permuted_row, - MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool use_lora, - LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, - MoeMinLatencyParams& min_latency_params, bool enable_pdl, cudaStream_t stream) { + bool IsMXFPX, class Enable> +void CutlassMoeFCRunner:: + runMoe(void const* input_activations_void, void const* input_sf_void, + bool const swizzled_input_sf, int const* token_selected_experts, + float const* token_final_scales, void const* fc1_expert_weights_void, + void const* fc1_expert_biases_void, ActivationParams fc1_activation_type, + void const* fc2_expert_weights_void, void const* fc2_expert_biases_void, + QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, + int64_t const unpadded_hidden_size, int64_t const inter_size, int const full_num_experts, + int const experts_per_token, char* workspace_ptr, void* final_output_void, + int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, + bool const enable_alltoall, bool use_lora, LoraParams& lora_params, + bool use_deepseek_fp8_block_scale, bool use_mxfp8_act_scaling, bool min_latency_mode, + MoeMinLatencyParams& min_latency_params, bool enable_pdl, cudaStream_t stream) { static constexpr bool int_scales_required = std::is_same::value || std::is_same::value || use_wfp4a16; @@ -3566,6 +3592,14 @@ void CutlassMoeFCRunner::value) == 0, - "Hidden size %d does not meet minimum alignment requirements for MXFP8_MXFP4 MOE GEMM %d", + "Hidden size %d does not meet minimum alignment requirements for MXFP8 MOE GEMM %d", (int)hidden_size, (int)(64 * 8 / sizeof_bits::value)); TLLM_CHECK_WITH_INFO( inter_size % (64 * 8 / sizeof_bits::value) == 0, - "Inter size %d does not meet minimum alignment requirements for MXFP8_MXFP4 MOE GEMM %d", + "Inter size %d does not meet minimum alignment requirements for MXFP8 MOE GEMM %d", (int)inter_size, (int)(64 * 8 / sizeof_bits::value)); } else { // For NoSmem epilogue schedule, we need to align the output of the GEMM to 256 bits, for gated @@ -3651,6 +3686,15 @@ void CutlassMoeFCRunner + bool IsMXFPX, class Enable> std::pair -CutlassMoeFCRunner:: +CutlassMoeFCRunner:: computeStridesTmaWarpSpecialized( int64_t const* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput layout_info1, TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, @@ -3935,9 +3979,9 @@ CutlassMoeFCRunner:: } template + bool IsMXFPX, class Enable> std::pair -CutlassMoeFCRunner:: +CutlassMoeFCRunner:: computeStridesTmaWarpSpecializedLowLatency( TmaWarpSpecializedGroupedGemmInput layout_info1, TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, @@ -3954,9 +3998,9 @@ CutlassMoeFCRunner:: } template + bool IsMXFPX, class Enable> std::pair -CutlassMoeFCRunner:: +CutlassMoeFCRunner:: setupTmaWarpSpecializedInputs(int64_t num_rows, int64_t expanded_num_rows, ActivationParams fc1_activation_type, int64_t hidden_size, int64_t unpadded_hidden_size, int64_t inter_size, diff --git a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu index 48c79ad858..a59204ef9a 100644 --- a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu +++ b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu @@ -82,7 +82,7 @@ class DtypeUtils { class FusedMoeRunner : public tvm::ffi::ModuleObj { public: - template + template std::unique_ptr switch_output_type(DLDataType output_type) { switch (encode_dlpack_dtype(output_type)) { case int64_code: // INT64 == FP4 @@ -94,19 +94,20 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { // return std::make_unique>(); case float16_code: if constexpr (NeedQuant) { - return std::make_unique>(); + return std::make_unique< + kernels::CutlassMoeFCRunner>(); } else { return std::make_unique< - kernels::CutlassMoeFCRunner>(); + kernels::CutlassMoeFCRunner>(); } #ifdef ENABLE_BF16 case bfloat16_code: if constexpr (NeedQuant) { - return std::make_unique< - kernels::CutlassMoeFCRunner>(); + return std::make_unique>(); } else { - return std::make_unique< - kernels::CutlassMoeFCRunner>(); + return std::make_unique>(); } #endif default: @@ -145,7 +146,9 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { #endif #ifdef ENABLE_FP8 - if (isFp8Quant()) { + if (isWMxfp8AMxfp8Quant()) { + mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp8_e4m3, false, true>(mOutputDtype); + } else if (isFp8Quant()) { mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp8_e4m3>(mOutputDtype); } #endif @@ -397,8 +400,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { static_cast(experts_per_token), static_cast(workspace_info.workspace.data_ptr()), output.data_ptr(), static_cast(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, - use_lora, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, - enable_pdl, stream); + use_lora, lora_params, mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling, min_latency_mode, + min_latency_params, enable_pdl, stream); #else mKernelRunner->runMoe( input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr, @@ -414,7 +417,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { static_cast(experts_per_token), static_cast(workspace_info.workspace.data_ptr()), output.data_ptr(), static_cast(workspace_info.src_to_dest_map), parallelism_config, false, lora_params, - mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, stream); + mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling, min_latency_mode, min_latency_params, + enable_pdl, stream); #endif } @@ -490,8 +494,10 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { << "fc1_expert_weights inter size must be equal to fc2_expert_weights inter size."; } - TVM_FFI_ICHECK(!input_sf.has_value() || isWMxfp4AMxfp8Quant() || isNvfp4Quant()) + TVM_FFI_ICHECK(!input_sf.has_value() || isMxfp8ActScalingQuant() || isNvfp4Quant()) << "Block-scaling factors provided for non block-scaling quantization"; + TVM_FFI_ICHECK(!isMxfp8ActScalingQuant() || input_sf.has_value()) + << "input_sf must be provided when use_mxfp8_act_scaling=True"; int experts_per_token = token_selected_experts.size(1); int64_t num_rows = input.size(0); @@ -581,8 +587,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { static_cast(experts_per_token), static_cast(workspace_info.workspace.data_ptr()), output.data_ptr(), static_cast(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, - use_lora_ml, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, - enable_pdl, stream); + use_lora_ml, lora_params, mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling, + min_latency_mode, min_latency_params, enable_pdl, stream); #else mKernelRunner->runMoe( input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr, @@ -598,8 +604,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { static_cast(experts_per_token), static_cast(workspace_info.workspace.data_ptr()), output.data_ptr(), static_cast(workspace_info.src_to_dest_map), parallelism_config, false, use_lora_ml, - lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, - stream); + lora_params, mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling, min_latency_mode, + min_latency_params, enable_pdl, stream); #endif } @@ -838,8 +844,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { bool min_latency_mode) { size_t moe_workspace_size = mKernelRunner->getWorkspaceSize( num_rows, hidden_size, inter_size, num_experts, experts_per_token, activation_type, - parallelismConfig, /* use_lora */ false, mUseDeepSeekFP8BlockScaling, min_latency_mode, - mUseW4GroupScaling); + parallelismConfig, /* use_lora */ false, mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling, + min_latency_mode, mUseW4GroupScaling); size_t src_to_dest_map_size = experts_per_token * num_rows * sizeof(int); std::vector workspaces{moe_workspace_size, src_to_dest_map_size}; @@ -862,7 +868,66 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { int64_t num_experts_on_rank, int64_t hidden_size, int64_t inter_size, Optional> quant_scales, ActivationType base_activation_type = ActivationType::Swiglu) const { - if (isFp8Quant()) { + if (isWMxfp8AMxfp8Quant()) { +#ifdef USING_OSS_CUTLASS_MOE_GEMM + TVM_FFI_ICHECK(quant_scales.has_value()) + << "Expecting quant scales for MXFP8xMXFP8 quantization"; + TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 4) + << "Expecting 4 quant scales for MXFP8xMXFP8 quantization"; + + TensorView fc1_weight_block = quant_scales.value()[0]; + TensorView fc1_global = quant_scales.value()[1]; + TensorView fc2_weight_block = quant_scales.value()[2]; + TensorView fc2_global = quant_scales.value()[3]; + + // The input for scale fc1_weight_block / fc2_weight_block is packed into INT32 + constexpr int FP8_PER_INT32 = 4; + CHECK_INPUT_TYPE(fc1_weight_block, dl_int32); + CHECK_INPUT_TYPE(fc1_global, dl_float32); + CHECK_INPUT_TYPE(fc2_weight_block, dl_int32); + CHECK_INPUT_TYPE(fc2_global, dl_float32); + CHECK_DIM(3, fc1_weight_block); + CHECK_DIM(1, fc1_global); + CHECK_DIM(3, fc2_weight_block); + CHECK_DIM(1, fc2_global); + TVM_FFI_ICHECK( + fc1_weight_block.size(0) == num_experts_on_rank && + fc1_weight_block.size(1) == + TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) * + 2 && + fc1_weight_block.size(2) * FP8_PER_INT32 * + TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize == + TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX)) + << "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 " + "// block_scale_vector_size)"; + TVM_FFI_ICHECK_EQ(fc1_global.size(0), num_experts_on_rank) + << "fc1 global size must be (num_experts_on_rank,)"; + TVM_FFI_ICHECK( + fc2_weight_block.size(0) == num_experts_on_rank && + fc2_weight_block.size(1) == + TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + hidden_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) && + fc2_weight_block.size(2) * FP8_PER_INT32 * + TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize == + TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX)) + << "fc2 weight block size must be (num_experts_on_rank, hidden_size, inter_size // 4 // " + "block_scale_vector_size)"; + TVM_FFI_ICHECK_EQ(fc2_global.size(0), num_experts_on_rank) + << "fc2 global size must be (num_experts_on_rank,)"; + + return kernels::QuantParams::MXFP8MXFP8( + static_cast(fc1_weight_block.data_ptr()), + static_cast(fc1_global.data_ptr()), + static_cast(fc2_weight_block.data_ptr()), + static_cast(fc2_global.data_ptr())); +#else + TVM_FFI_ICHECK(false) + << "MXFP8 x MXFP8 quantization is not supported in OSS Cutlass Moe Gemm"; +#endif + } else if (isFp8Quant()) { TVM_FFI_ICHECK(quant_scales.has_value()) << "Expecting quant scales for fp8 quantization"; TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 4) << "Expecting 4 quant scales for fp8 quantization"; @@ -1168,9 +1233,16 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { bool isFp8Quant() const { return !mUseDeepSeekFP8BlockScaling && mActivationDtype == dl_float8_e4m3fn && - mWeightDtype == dl_float8_e4m3fn; + mWeightDtype == dl_float8_e4m3fn && !mUseMxfp8ActScaling; } + bool isWMxfp8AMxfp8Quant() const { + return !mUseDeepSeekFP8BlockScaling && mActivationDtype == dl_float8_e4m3fn && + mWeightDtype == dl_float8_e4m3fn && mUseMxfp8ActScaling; + } + + bool isMxfp8ActScalingQuant() const { return isWMxfp8AMxfp8Quant() || isWMxfp4AMxfp8Quant(); } + bool isNvfp4Quant() const { return mWeightDtype == dl_int64 && mActivationDtype != dl_float8_e4m3fn; // FP8 activation does not use FP4 diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h index b77efbcac1..4e85da745c 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h @@ -239,11 +239,11 @@ constexpr bool isGatedActivation(ActivationType activation_type) { activation_type == ActivationType::SwigluBias; } -template +template class MoeGemmRunner { public: MoeGemmRunner(); @@ -273,6 +273,8 @@ class MoeGemmRunner { static constexpr bool use_fp8 = false; static constexpr bool use_w4afp8 = false; #endif + static constexpr bool use_mxfp8 = use_fp8 && IsMXFPX; + static constexpr bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; #if defined(ENABLE_FP4) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index e278269b97..a5e786b148 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -270,6 +270,19 @@ struct QuantParams { GemmInputs fc2; } mxfp8_mxfp4; + // MXFP8 MXFP8 quantization params + // This mode uses block scaled MXFP8 activations and MXFP8 weights. + struct MXFP8MXFP8Inputs { + struct GemmInputs { + TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* weight_block_scale = + nullptr; // (experts, n, k / 32) + float const* global_scale = nullptr; // (num_experts_per_node, ) + }; + + GemmInputs fc1; + GemmInputs fc2; + } mxfp8_mxfp8; + // FP4 quantization params struct FP4Inputs { struct GemmInputs { @@ -357,6 +370,17 @@ struct QuantParams { return qp; } + static QuantParams MXFP8MXFP8( + TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* fc1_weight_block_scale, + float const* fc1_global_scale, // + TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* fc2_weight_block_scale, + float const* fc2_global_scale) { + QuantParams qp; + qp.mxfp8_mxfp8.fc1 = {fc1_weight_block_scale, fc1_global_scale}; + qp.mxfp8_mxfp8.fc2 = {fc2_weight_block_scale, fc2_global_scale}; + return qp; + } + static QuantParams FP4( float const* fc1_act_global_scale, TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* fc1_weight_block_scale, @@ -426,8 +450,8 @@ class CutlassMoeFCRunnerInterface { int64_t const inter_size, int const num_experts, int const experts_per_token, ActivationType activation_type, MOEParallelismConfig parallelism_config, bool use_lora, - bool use_deepseek_fp8_block_scale, bool min_latency_mode, - bool use_awq) = 0; + bool use_deepseek_fp8_block_scale, bool use_mxfp8_act_scaling, + bool min_latency_mode, bool use_awq) = 0; virtual void setTactic(std::optional gemm1_config, std::optional gemm2_config) = 0; virtual std::vector getTactics(MoeGemmId gemm_id) = 0; @@ -443,8 +467,9 @@ class CutlassMoeFCRunnerInterface { void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, - bool min_latency_mode, MoeMinLatencyParams& min_latency_params, - bool enable_pdl, cudaStream_t stream) = 0; + bool use_mxfp8_act_scaling, bool min_latency_mode, + MoeMinLatencyParams& min_latency_params, bool enable_pdl, + cudaStream_t stream) = 0; // Aliases for profiling the gemms virtual void gemm1(void const* const input, void* const output, void* const intermediate_result, @@ -524,12 +549,12 @@ template + bool IsMXFPX = false, typename Enable = void> class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { using DeepSeekBlockScaleGemmRunner = tensorrt_llm::kernels::fp8_blockscale_gemm::CutlassFp8BlockScaleGemmRunnerInterface; using ScaleBiasType = BackBoneType; - using Self = CutlassMoeFCRunner; + using Self = CutlassMoeFCRunner; #if defined(ENABLE_FP4) #if defined(ENABLE_BF16) @@ -574,7 +599,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { static constexpr bool use_fp4 = false; #endif - static constexpr bool use_block_scaling = use_fp4 || use_wfp4afp8; + static constexpr bool use_mxfp8 = use_fp8 && IsMXFPX; + static constexpr bool use_block_scaling = use_fp4 || use_wfp4afp8 || use_mxfp8; // This should leave the variable unchanged in any currently supported configuration using UnfusedGemmOutputType = BackBoneType; @@ -596,8 +622,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { int64_t const fc1_output_size, int const num_experts, int const experts_per_token, ActivationType activation_type, MOEParallelismConfig parallelism_config, bool use_lora, - bool use_deepseek_fp8_block_scale, bool min_latency_mode, - bool use_awq) override; + bool use_deepseek_fp8_block_scale, bool use_mxfp8_act_scaling, + bool min_latency_mode, bool use_awq) override; void setTactic(std::optional gemm1_config, std::optional gemm2_config) override { @@ -624,12 +650,13 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool use_lora, - LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, + LoraParams& lora_params, bool use_deepseek_fp8_block_scale, + bool use_mxfp8_act_scaling, bool min_latency_mode, MoeMinLatencyParams& min_latency_params, bool enable_pdl, cudaStream_t stream) override; // We make these GEMM1 & GEMM2 static because they need to be stateless for the profiler to work - static void gemm1(MoeGemmRunner& gemm_runner, + static void gemm1(MoeGemmRunner& gemm_runner, // This argument must not be null if fp8 block scaling is being used. // The gemm_runner will be ignored in that case. NOTE: it would // be great if we could consolidate gemm_runner and fp8_blockscale_gemm_runner. @@ -655,7 +682,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { int* num_active_experts_per, int* active_expert_global_ids, bool enable_pdl); static void gemm2( - MoeGemmRunner& gemm_runner, + MoeGemmRunner& gemm_runner, DeepSeekBlockScaleGemmRunner* fp8_blockscale_gemm_runner, T const* const input, void* const gemm_output, OutputType* const final_output, int64_t const* const expert_first_token_offset, @@ -838,12 +865,14 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { std::map> getWorkspaceDeviceBufferSizes( int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int const experts_per_token, ActivationType activation_type, - bool use_lora, bool use_deepseek_fp8_block_scale, bool min_latency_mode, bool use_awq); + bool use_lora, bool use_deepseek_fp8_block_scale, bool use_mxfp8_act_scaling, + bool min_latency_mode, bool use_awq); void configureWsPtrs(char* ws_ptr, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int const experts_per_token, ActivationType activation_type, MOEParallelismConfig parallelism_config, bool use_lora, - bool use_deepseek_fp8_block_scale, bool min_latency_mode, bool use_awq); + bool use_deepseek_fp8_block_scale, bool use_mxfp8_act_scaling, + bool min_latency_mode, bool use_awq); private: bool mayHaveDifferentGEMMOutputType() const { @@ -865,9 +894,10 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { // TODO: This should eventually take the quant params to give more flexibility static auto getScalingType() { - return use_wfp4afp8 ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX - : use_fp4 ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 - : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE; + return (use_wfp4afp8 || use_mxfp8) + ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX + : use_fp4 ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 + : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE; } bool setupLoraWorkspace(int64_t expanded_num_rows, int64_t num_rows, int64_t inter_size, @@ -916,7 +946,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { int64_t const expanded_num_rows, int64_t const seq_len, bool const use_awq, cudaStream_t stream); - MoeGemmRunner moe_gemm_runner_; + MoeGemmRunner moe_gemm_runner_; std::unique_ptr blockscale_gemm_runner_; std::optional gemm1_config_; diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl index db5788bfdd..6c71c6acd8 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl @@ -255,7 +255,7 @@ using namespace cutlass::epilogue; static_assert(cutlass::platform::is_same::value || IsWFP4AFP8, \ "TMA warp specialized MOE implementation does not support mixed input types"); \ \ - constexpr static bool IsBlockScaled = IsFP4 || IsWFP4AFP8; \ + constexpr static bool IsBlockScaled = IsFP4 || IsWFP4AFP8 || IsMXFPX; \ static_assert(!IsBlockScaled || IsBlackwell, "Block scaled is only implemented for SM100"); \ \ static_assert(FUSION == EpilogueFusion::NONE || FUSION == EpilogueFusion::FINALIZE, \ diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu index 08b7ce1930..e78f8aa823 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu @@ -19,8 +19,10 @@ namespace tensorrt_llm::kernels::cutlass_kernels { #ifdef ENABLE_FP8 template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half>; +template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half, half, true>; #ifdef ENABLE_BF16 template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>; +template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16, true>; #endif // template class MoeGemmRunner<__nv_fp8_e5m2, __nv_fp8_e5m2>; #endif diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h index dded5014bf..5a4e01ad12 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h @@ -525,17 +525,19 @@ void dispatchMoeGemmToCutlass( namespace tensorrt_llm::kernels::cutlass_kernels { -template +template std::vector -MoeGemmRunner::getConfigs( +MoeGemmRunner::getConfigs( bool supports_finalize_fusion) const { return getConfigs(sm_, supports_finalize_fusion); } -template +template std::vector -MoeGemmRunner::getConfigs(int sm, - bool supports_finalize_fusion) { +MoeGemmRunner::getConfigs( + int sm, bool supports_finalize_fusion) { std::vector candidate_configs = getTmaWarpSpecializedConfigs(sm, supports_finalize_fusion); std::vector ampere_configs = getAmpereConfigs(sm); @@ -543,9 +545,10 @@ MoeGemmRunner::getConfigs(int sm, return candidate_configs; } -template +template std::vector -MoeGemmRunner::getAmpereConfigs(int sm) { +MoeGemmRunner::getAmpereConfigs(int sm) { using tensorrt_llm::cutlass_extensions::CutlassGemmConfig; static constexpr auto weight_only_flag = std::is_same::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY; @@ -571,9 +574,10 @@ MoeGemmRunner::getAmpereConfigs(int sm return ampere_configs; } -template +template std::vector -MoeGemmRunner::getTmaWarpSpecializedConfigs( +MoeGemmRunner::getTmaWarpSpecializedConfigs( int sm, bool supports_finalize_fusion) { using tensorrt_llm::cutlass_extensions::CutlassGemmConfig; static constexpr auto weight_only_flag = @@ -668,15 +672,18 @@ MoeGemmRunner::getTmaWarpSpecializedCo return tma_ws_configs; } -template -bool MoeGemmRunner::isTmaWarpSpecialized( +template +bool MoeGemmRunner::isTmaWarpSpecialized( cutlass_extensions::CutlassGemmConfig gemm_config) const { bool config_is_tma_warp_specialized = gemm_config.is_tma_warp_specialized; return supportsTmaWarpSpecialized() && config_is_tma_warp_specialized; } -template -bool MoeGemmRunner::supportsTmaWarpSpecialized(int sm) { +template +bool MoeGemmRunner::supportsTmaWarpSpecialized( + int sm) { return (sm == 90 && tensorrt_llm::kernels::cutlass_kernels::isValidHopperMOESpecialisation()) || @@ -687,14 +694,16 @@ bool MoeGemmRunner::supportsTmaWarpSpe tensorrt_llm::kernels::cutlass_kernels::isValidSM120MOESpecialisation()); } -template -int MoeGemmRunner::getSM() const { +template +int MoeGemmRunner::getSM() const { return this->sm_; } // currently support sm80 bf16/fp16 gate activation, only set predication tensor for m direction -template -bool MoeGemmRunner::supportsFusedGatedActivation( +template +bool MoeGemmRunner::supportsFusedGatedActivation( ActivationType activation_type, int gemm_n, int gemm_k) const { constexpr bool ENABLE_FUSED_GATED_ACTIVATION = true; return (activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu) && @@ -703,16 +712,18 @@ bool MoeGemmRunner::supportsFusedGated ENABLE_FUSED_GATED_ACTIVATION; } -template -bool MoeGemmRunner::isFusedGatedActivation( +template +bool MoeGemmRunner::isFusedGatedActivation( cutlass_extensions::CutlassGemmConfig gemm_config, ActivationType activation_type, int gemm_n, int gemm_k) const { return supportsFusedGatedActivation(activation_type, gemm_n, gemm_k) && !gemm_config.is_tma_warp_specialized; } -template -MoeGemmRunner::MoeGemmRunner() { +template +MoeGemmRunner::MoeGemmRunner() { int device{-1}; tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device)); sm_ = tensorrt_llm::common::getSMVersion(); @@ -720,9 +731,10 @@ MoeGemmRunner::MoeGemmRunner() { cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); } -template +template template -void MoeGemmRunner::dispatchToArch( +void MoeGemmRunner::dispatchToArch( GroupedGemmInput inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs) { static_assert( @@ -917,8 +929,9 @@ void MoeGemmRunner::dispatchToArch( } } -template -size_t MoeGemmRunner::getMaxWorkspaceSize( +template +size_t MoeGemmRunner::getMaxWorkspaceSize( int num_experts) const { if (num_experts != num_experts_) { TLLM_LOG_TRACE("Calling getMaxWorkspaceSize() with a new expert count %d vs %d", num_experts, @@ -929,8 +942,9 @@ size_t MoeGemmRunner::getMaxWorkspaceS return gemm_workspace_size_; } -template -size_t MoeGemmRunner::calcMaxWorkspaceSize( +template +size_t MoeGemmRunner::calcMaxWorkspaceSize( int num_experts) const { if constexpr (use_w4_groupwise) { return cutlass_kernels_oss::calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput::calcMaxWorkspace auto fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE; if constexpr (use_wfp4afp8) { fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; + } else if constexpr (use_mxfp8) { + fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; } else if (use_fp4) { fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4; } @@ -986,16 +1002,18 @@ size_t MoeGemmRunner::calcMaxWorkspace } } -template +template template -void MoeGemmRunner::runGemm( +void MoeGemmRunner::runGemm( GroupedGemmInput inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs) { dispatchToArch(inputs, hopper_inputs); } -template -void MoeGemmRunner::moeGemmBiasAct( +template +void MoeGemmRunner::moeGemmBiasAct( GroupedGemmInput inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs) { switch (inputs.activation_type) { @@ -1026,8 +1044,9 @@ void MoeGemmRunner::moeGemmBiasAct( } } -template -void MoeGemmRunner::moeGemm( +template +void MoeGemmRunner::moeGemm( GroupedGemmInput inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs) { runGemm(inputs, hopper_inputs); diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h index 5adacd0ce2..5cd0b12b41 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h @@ -65,27 +65,27 @@ using tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput using EpilogueFusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion; template + EpilogueFusion FUSION, typename TileShape, typename ClusterShape, bool IsMXFPX> auto getDispatchFunctionForSM100(cutlass_extensions::EpilogueScheduleType epilogue_schedule, bool dynamic_cga, bool swap_ab) { auto select_swap_ab = [dynamic_cga, epilogue_schedule](auto swap_ab_t) { auto select_dynamic_cga = [epilogue_schedule](auto dynamic_cga_t) { #if defined(ENABLE_FP4) constexpr bool is_block_scaled = - std::is_same_v || std::is_same_v; + IsMXFPX || std::is_same_v || std::is_same_v; #else - constexpr bool is_block_scaled = false; + constexpr bool is_block_scaled = IsMXFPX; #endif if constexpr ((!is_block_scaled || Arch::kMinComputeCapability == 103) && FUSION != EpilogueFusion::FINALIZE) { auto func_map = std::array{ &kernels::cutlass_kernels_oss::tma_warp_specialized_generic_moe_gemm_kernelLauncher< Arch, T, WeightType, OutputType, cutlass::epilogue::PtrArrayNoSmemWarpSpecialized, - EpilogueTag, FUSION, TileShape, ClusterShape, is_wfp4afp8, + EpilogueTag, FUSION, TileShape, ClusterShape, IsMXFPX, decltype(dynamic_cga_t)::value, false, decltype(swap_ab_t)::value>, &kernels::cutlass_kernels_oss::tma_warp_specialized_generic_moe_gemm_kernelLauncher< Arch, T, WeightType, OutputType, cutlass::epilogue::PtrArrayTmaWarpSpecialized, - EpilogueTag, FUSION, TileShape, ClusterShape, is_wfp4afp8, + EpilogueTag, FUSION, TileShape, ClusterShape, IsMXFPX, decltype(dynamic_cga_t)::value, false, decltype(swap_ab_t)::value> }; @@ -100,8 +100,8 @@ auto getDispatchFunctionForSM100(cutlass_extensions::EpilogueScheduleType epilog "No Smem epilogue schedule is not supported for block scaled types or finalize fusion"); return &kernels::cutlass_kernels_oss::tma_warp_specialized_generic_moe_gemm_kernelLauncher< Arch, T, WeightType, OutputType, cutlass::epilogue::PtrArrayTmaWarpSpecialized, - EpilogueTag, FUSION, TileShape, ClusterShape, is_wfp4afp8, - decltype(dynamic_cga_t)::value, false, decltype(swap_ab_t)::value>; + EpilogueTag, FUSION, TileShape, ClusterShape, IsMXFPX, decltype(dynamic_cga_t)::value, + false, decltype(swap_ab_t)::value>; } }; return dynamic_cga ? select_dynamic_cga(tensorrt_llm::common::ConstBool{}) @@ -112,7 +112,7 @@ auto getDispatchFunctionForSM100(cutlass_extensions::EpilogueScheduleType epilog } template + EpilogueFusion FUSION, typename TileShape, typename ClusterShape, bool IsMXFPX> void dispatchMoeGemmFinalDispatchTmaWarpSpecialized( TmaWarpSpecializedGroupedGemmInput hopper_input, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, @@ -166,64 +166,69 @@ void dispatchMoeGemmFinalDispatchTmaWarpSpecialized( #else constexpr static bool is_wfp4afp8 = false; #endif - if constexpr (is_wfp4afp8) { - TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type == - TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, - "MXFPX is the only supported scaling type for WFP4AFP8"); +#if defined(ENABLE_FP8) + constexpr static bool is_wfp8afp8 = + std::is_same_v && std::is_same_v; +#else + constexpr static bool is_wfp8afp8 = false; +#endif + constexpr static bool supports_mxfpx = is_wfp4afp8 || is_wfp8afp8; + if constexpr (IsMXFPX && !supports_mxfpx) { + TLLM_THROW("MXFPX is not supported for the selected weight combination"); + } else if constexpr (!IsMXFPX && is_wfp4afp8) { + TLLM_THROW("MXFPX is the only supported scaling type for WFP4AFP8"); } else { - TLLM_CHECK_WITH_INFO(hopper_input.fpX_block_scaling_type != - TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, - "MXFPX is not supported for the selected weight combination"); - } - - if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120) { - bool const dynamic_cga = - gemm_config.dynamic_cluster_shape != cutlass_extensions::ClusterShape::Undefined; - bool const swap_ab = hopper_input.swap_ab; - auto cluster_shape = - cutlass_extensions::enum_to_shape_tuple(gemm_config.dynamic_cluster_shape); - auto cluster_shape_cute = cute::Shape{ - std::get<0>(cluster_shape), std::get<1>(cluster_shape), cute::_1{}}; - auto cluster_shape_fallback = - cutlass_extensions::enum_to_shape_tuple(gemm_config.fallback_cluster_shape); - auto cluster_shape_cute_fallback = cute::Shape{ - std::get<0>(cluster_shape_fallback), std::get<1>(cluster_shape_fallback), cute::_1{}}; - - // HACK debug the gemm_config used to produce selected_func - // std::cout << "[SM100 gemm_config] sm_version=" << gemm_config.sm_version - // << ", tile_config_sm100=" << static_cast(gemm_config.tile_config_sm100) - // << ", epilogue_schedule=" << static_cast(gemm_config.epilogue_schedule) - // << ", dynamic_cluster_shape=" << - // static_cast(gemm_config.dynamic_cluster_shape) - // << ", fallback_cluster_shape=" - // << static_cast(gemm_config.fallback_cluster_shape) << std::endl; - - auto selected_func = - getDispatchFunctionForSM100( - gemm_config.epilogue_schedule, dynamic_cga, swap_ab); - selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy, - workspace_size, cluster_shape_cute, cluster_shape_cute_fallback); - } else if constexpr (Arch::kMinComputeCapability >= 120 || Arch::kMinComputeCapability == 90) { - using EpilogueSchedule = void; // These are hardcoded in the launcher - constexpr bool dynamic_cga = false; - auto selected_func = - hopper_input.swap_ab - ? kernels::cutlass_kernels_oss::tma_warp_specialized_generic_moe_gemm_kernelLauncher< - Arch, T, WeightType, OutputType, EpilogueSchedule, EpilogueTag, FUSION, - TileShape, ClusterShape, is_wfp4afp8, dynamic_cga, false, true> - : kernels::cutlass_kernels_oss::tma_warp_specialized_generic_moe_gemm_kernelLauncher< - Arch, T, WeightType, OutputType, EpilogueSchedule, EpilogueTag, FUSION, - TileShape, ClusterShape, is_wfp4afp8, dynamic_cga, false, false>; - - selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy, - workspace_size, {}, {}); + if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120) { + bool const dynamic_cga = + gemm_config.dynamic_cluster_shape != cutlass_extensions::ClusterShape::Undefined; + bool const swap_ab = hopper_input.swap_ab; + auto cluster_shape = + cutlass_extensions::enum_to_shape_tuple(gemm_config.dynamic_cluster_shape); + auto cluster_shape_cute = cute::Shape{ + std::get<0>(cluster_shape), std::get<1>(cluster_shape), cute::_1{}}; + auto cluster_shape_fallback = + cutlass_extensions::enum_to_shape_tuple(gemm_config.fallback_cluster_shape); + auto cluster_shape_cute_fallback = cute::Shape{ + std::get<0>(cluster_shape_fallback), std::get<1>(cluster_shape_fallback), cute::_1{}}; + + // HACK debug the gemm_config used to produce selected_func + // std::cout << "[SM100 gemm_config] sm_version=" << gemm_config.sm_version + // << ", tile_config_sm100=" << static_cast(gemm_config.tile_config_sm100) + // << ", epilogue_schedule=" << static_cast(gemm_config.epilogue_schedule) + // << ", dynamic_cluster_shape=" << + // static_cast(gemm_config.dynamic_cluster_shape) + // << ", fallback_cluster_shape=" + // << static_cast(gemm_config.fallback_cluster_shape) << std::endl; + + auto selected_func = + getDispatchFunctionForSM100( + gemm_config.epilogue_schedule, dynamic_cga, swap_ab); + selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy, + workspace_size, cluster_shape_cute, cluster_shape_cute_fallback); + } else if constexpr (Arch::kMinComputeCapability >= 120 || + Arch::kMinComputeCapability == 90) { + using EpilogueSchedule = void; // These are hardcoded in the launcher + constexpr bool dynamic_cga = false; + auto selected_func = + hopper_input.swap_ab + ? kernels::cutlass_kernels_oss:: + tma_warp_specialized_generic_moe_gemm_kernelLauncher< + Arch, T, WeightType, OutputType, EpilogueSchedule, EpilogueTag, FUSION, + TileShape, ClusterShape, IsMXFPX, dynamic_cga, false, true> + : kernels::cutlass_kernels_oss:: + tma_warp_specialized_generic_moe_gemm_kernelLauncher< + Arch, T, WeightType, OutputType, EpilogueSchedule, EpilogueTag, FUSION, + TileShape, ClusterShape, IsMXFPX, dynamic_cga, false, false>; + selected_func(hopper_input, num_experts, multi_processor_count, stream, occupancy, + workspace_size, {}, {}); + } } } } template + typename WeightType, bool IsMXFPX> constexpr bool are_tile_shapes_supported_sm100() { // We use a runtime cluster shape for SM100, so we only support 1x1x1 and 2x1x1 cluster shapes. if (cute::size<0>(ClusterShape{}) > 2 || cute::size<1>(ClusterShape{}) != 1 || @@ -262,6 +267,15 @@ constexpr bool are_tile_shapes_supported_sm100() { } #endif +#ifdef ENABLE_FP8 + if constexpr (IsMXFPX && std::is_same_v && + std::is_same_v) { + if ((TileN != 64 && TileN != 128 && TileN != 256) || TileM != 128) { + return false; + } + } +#endif + if constexpr (std::is_same_v) { if constexpr ((TileN == 16 || TileN == 8) && cute::size<0>(ClusterShape{}) == 1 && cute::size<1>(ClusterShape{}) == 1) { @@ -311,10 +325,11 @@ constexpr bool are_tile_shapes_supported_sm120() { that may not be very useful in practice. */ template + typename WeightType, bool IsMXFPX> constexpr bool are_tile_shapes_supported() { if constexpr (Arch::kMinComputeCapability >= 100 && Arch::kMinComputeCapability < 120) { - return are_tile_shapes_supported_sm100(); + return are_tile_shapes_supported_sm100(); } else if constexpr (Arch::kMinComputeCapability == 120 || Arch::kMinComputeCapability == 121) { return are_tile_shapes_supported_sm120(); } @@ -339,7 +354,7 @@ constexpr bool are_tile_shapes_supported() { } template + EpilogueFusion FUSION, typename TileShape, bool IsMXFPX> void dispatchMoeGemmSelectClusterShapeTmaWarpSpecialized( TmaWarpSpecializedGroupedGemmInput hopper_input, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, @@ -350,9 +365,10 @@ void dispatchMoeGemmSelectClusterShapeTmaWarpSpecialized( #define SHAPE_CASE(M, N, K) \ case cutlass_extensions::ClusterShape::ClusterShape_##M##x##N##x##K: { \ using ClusterShape = Shape<_##M, _##N, _##K>; \ - if constexpr (are_tile_shapes_supported()) { \ + if constexpr (are_tile_shapes_supported()) { \ dispatchMoeGemmFinalDispatchTmaWarpSpecialized( \ + FUSION, TileShape, ClusterShape, IsMXFPX>( \ hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, \ workspace_size); \ break; \ @@ -387,20 +403,23 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized( cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy, size_t* workspace_size) { using namespace cute; - -#define SHAPE_CASE(SMVERSION, M, N, K) \ - case cutlass_extensions::CutlassTileConfigSM##SMVERSION::CtaShape##M##x##N##x##K##B: { \ - constexpr int KtileBytes = \ - (K * 8) / \ - cutlass::sizeof_bits< \ - typename kernels::cutlass_kernels::TllmToCutlassTypeAdapter::type>::value; \ - using KTileDim = Int; \ - using TileShape = Shape<_##M, _##N, KTileDim>; \ - dispatchMoeGemmSelectClusterShapeTmaWarpSpecialized< \ - cutlass::arch::Sm##SMVERSION, T, WeightType, OutputType, EpilogueTag, FUSION, TileShape>( \ - hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, \ - workspace_size); \ - break; \ + auto dispatch_by_mxfpx = [&](auto is_mxfpx_t) { + constexpr bool IsMXFPX = decltype(is_mxfpx_t)::value; + +#define SHAPE_CASE(SMVERSION, M, N, K) \ + case cutlass_extensions::CutlassTileConfigSM##SMVERSION::CtaShape##M##x##N##x##K##B: { \ + constexpr int KtileBytes = \ + (K * 8) / \ + cutlass::sizeof_bits< \ + typename kernels::cutlass_kernels::TllmToCutlassTypeAdapter::type>::value; \ + using KTileDim = Int; \ + using TileShape = Shape<_##M, _##N, KTileDim>; \ + dispatchMoeGemmSelectClusterShapeTmaWarpSpecialized( \ + hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, \ + workspace_size); \ + break; \ } #define DEFAULT_CASE(SMVERSION) \ case cutlass_extensions::CutlassTileConfigSM##SMVERSION::Undefined: \ @@ -414,76 +433,87 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized( (int)gemm_config.tile_config_sm##SMVERSION); \ break; - if (gemm_config.sm_version == 90) { - if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation()) { - switch (gemm_config.tile_config_sm90) { - SHAPE_CASE(90, 128, 16, 128) - SHAPE_CASE(90, 128, 32, 128) - SHAPE_CASE(90, 128, 64, 128) - SHAPE_CASE(90, 128, 128, 128) - SHAPE_CASE(90, 128, 256, 128) - SHAPE_CASE(90, 256, 128, 128) - DEFAULT_CASE(90) + if (gemm_config.sm_version == 90) { + if constexpr (!IsMXFPX && kernels::cutlass_kernels::isValidHopperMOESpecialisation< + T, WeightType, EpilogueTag, FUSION>()) { + switch (gemm_config.tile_config_sm90) { + SHAPE_CASE(90, 128, 16, 128) + SHAPE_CASE(90, 128, 32, 128) + SHAPE_CASE(90, 128, 64, 128) + SHAPE_CASE(90, 128, 128, 128) + SHAPE_CASE(90, 128, 256, 128) + SHAPE_CASE(90, 256, 128, 128) + DEFAULT_CASE(90) + } + } else { + TLLM_THROW("Unsupported SM90 configuration requested"); } - } else { - TLLM_THROW("Unsupported SM90 configuration requested"); } - } #if defined(ENABLE_FP4) && defined(COMPILE_BLACKWELL_SM103_TMA_GROUPED_GEMMS) - // Check this before SM100 because we fall back to SM100 if not NVFP4 - else if (gemm_config.sm_version == 103 && std::is_same_v && - std::is_same_v) { - if constexpr (kernels::cutlass_kernels::isValidBlackwellMOESpecialisation< - T, WeightType, EpilogueTag, FUSION>()) { - switch (gemm_config.tile_config_sm100) { - SHAPE_CASE(103, 128, 128, 128) - SHAPE_CASE(103, 128, 256, 128) - - DEFAULT_CASE(100) // 100 because we use the same member variable for SM100 and SM103 + // Check this before SM100 because we fall back to SM100 if not NVFP4 + else if (gemm_config.sm_version == 103 && std::is_same_v && + std::is_same_v) { + if constexpr (kernels::cutlass_kernels::isValidBlackwellMOESpecialisation< + T, WeightType, EpilogueTag, FUSION>()) { + switch (gemm_config.tile_config_sm100) { + SHAPE_CASE(103, 128, 128, 128) + SHAPE_CASE(103, 128, 256, 128) + + DEFAULT_CASE(100) // 100 because we use the same member variable for SM100 and SM103 + } + } else { + TLLM_THROW("Unsupported SM103 configuration requested"); } - } else { - TLLM_THROW("Unsupported SM103 configuration requested"); } - } #endif - else if (gemm_config.sm_version >= 100 && gemm_config.sm_version < 120) { - if constexpr (kernels::cutlass_kernels::isValidBlackwellMOESpecialisation< - T, WeightType, EpilogueTag, FUSION>()) { - switch (gemm_config.tile_config_sm100) { - SHAPE_CASE(100, 64, 32, 128) - SHAPE_CASE(100, 64, 64, 128) - SHAPE_CASE(100, 64, 128, 128) - SHAPE_CASE(100, 64, 256, 128) - - SHAPE_CASE(100, 128, 16, 128) - SHAPE_CASE(100, 128, 32, 128) - SHAPE_CASE(100, 128, 64, 128) - SHAPE_CASE(100, 128, 128, 128) - SHAPE_CASE(100, 128, 256, 128) - - // SHAPE_CASE(100, 128, 128, 64) - // SHAPE_CASE(100, 128, 256, 64) - DEFAULT_CASE(100) + else if (gemm_config.sm_version >= 100 && gemm_config.sm_version < 120) { + if constexpr (kernels::cutlass_kernels::isValidBlackwellMOESpecialisation< + T, WeightType, EpilogueTag, FUSION>()) { + switch (gemm_config.tile_config_sm100) { + SHAPE_CASE(100, 64, 32, 128) + SHAPE_CASE(100, 64, 64, 128) + SHAPE_CASE(100, 64, 128, 128) + SHAPE_CASE(100, 64, 256, 128) + + SHAPE_CASE(100, 128, 16, 128) + SHAPE_CASE(100, 128, 32, 128) + SHAPE_CASE(100, 128, 64, 128) + SHAPE_CASE(100, 128, 128, 128) + SHAPE_CASE(100, 128, 256, 128) + + // SHAPE_CASE(100, 128, 128, 64) + // SHAPE_CASE(100, 128, 256, 64) + DEFAULT_CASE(100) + } + } else { + TLLM_THROW("Unsupported SM100 configuration requested"); } - } else { - TLLM_THROW("Unsupported SM100 configuration requested"); - } - } else if (gemm_config.sm_version == 120 || gemm_config.sm_version == 121) { - TLLM_LOG_TRACE("At %s, SM120 config=%d", __PRETTY_FUNCTION__, - (int)gemm_config.tile_config_sm120); - if constexpr (kernels::cutlass_kernels::isValidSM120MOESpecialisation()) { - switch (gemm_config.tile_config_sm120) { - SHAPE_CASE(120, 128, 128, 64) - SHAPE_CASE(120, 128, 128, 128) - SHAPE_CASE(120, 128, 256, 64) - SHAPE_CASE(120, 256, 128, 64) - DEFAULT_CASE(120) + } else if (gemm_config.sm_version == 120 || gemm_config.sm_version == 121) { + char const* const pretty_function = __PRETTY_FUNCTION__; + TLLM_LOG_TRACE("At %s, SM120 config=%d", pretty_function, + (int)(gemm_config.tile_config_sm120)); + if constexpr (kernels::cutlass_kernels::isValidSM120MOESpecialisation< + T, WeightType, EpilogueTag, FUSION>()) { + switch (gemm_config.tile_config_sm120) { + SHAPE_CASE(120, 128, 128, 64) + SHAPE_CASE(120, 128, 128, 128) + SHAPE_CASE(120, 128, 256, 64) + SHAPE_CASE(120, 256, 128, 64) + DEFAULT_CASE(120) + } } } - } #undef SHAPE_CASE +#undef DEFAULT_CASE + }; + + bool const use_mxfpx = hopper_input.fpX_block_scaling_type == + TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; + if (use_mxfpx) { + dispatch_by_mxfpx(tensorrt_llm::common::ConstBool{}); + } else { + dispatch_by_mxfpx(tensorrt_llm::common::ConstBool{}); + } } template diff --git a/flashinfer/jit/gemm/cutlass/generate_kernels.py b/flashinfer/jit/gemm/cutlass/generate_kernels.py index 09ef34303d..a7d1959361 100644 --- a/flashinfer/jit/gemm/cutlass/generate_kernels.py +++ b/flashinfer/jit/gemm/cutlass/generate_kernels.py @@ -407,6 +407,16 @@ def is_gemm_op_valid_sm100(op): ): return False + # MXFP block-scaled paths currently follow the same shape/schedule limits as FP4 block scaling. + if op.is_mx_fpx: + if tile_n not in [64, 128, 256] or tile_m != 128: + return False + if ( + op.arch == 100 + and op.epi_schedule == EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm + ): + return False + # Shapes for fp8 small N shapes if ( (op.act_type == DataType.e4m3) @@ -913,31 +923,40 @@ def generate_sm100_grouped_gemm_operations(is_arch_enabled, arch): if dtype in [DataType.e4m3, e2m1]: otypes = [DataType.f16, DataType.bf16] - for otype in otypes: - moe_gemm_operation = TrtLlm_GemmLauncher( - GemmKind.Grouped, - arch, - dtype, - weight_type, - otype, - otype, - otype, - quant_op, - epi_tag, - cta_shape_mnk, - warp_shape, - stages, - cga_shape, - mainloop_schedule, - epi_schedule, - epi_fusion, - is_mx_fpx=(dtype == DataType.e4m3 and weight_type == e2m1), - dynamic_cga=dynamic_cga, - swap_ab=swap_ab, - ) + mxfp_modes = [False] + if dtype == DataType.e4m3 and weight_type == e2m1: + # MXFP8 x MXFP4 path. + mxfp_modes = [True] + elif dtype == DataType.e4m3 and weight_type == DataType.e4m3: + # Emit both regular FP8xFP8 and MXFP8xMXFP8 variants. + mxfp_modes = [False, True] - if is_op_valid(moe_gemm_operation): - operations.append(moe_gemm_operation) + for otype in otypes: + for is_mx_fpx in mxfp_modes: + moe_gemm_operation = TrtLlm_GemmLauncher( + GemmKind.Grouped, + arch, + dtype, + weight_type, + otype, + otype, + otype, + quant_op, + epi_tag, + cta_shape_mnk, + warp_shape, + stages, + cga_shape, + mainloop_schedule, + epi_schedule, + epi_fusion, + is_mx_fpx=is_mx_fpx, + dynamic_cga=dynamic_cga, + swap_ab=swap_ab, + ) + + if is_op_valid(moe_gemm_operation): + operations.append(moe_gemm_operation) return operations diff --git a/tests/moe/test_trtllm_cutlass_fused_moe.py b/tests/moe/test_trtllm_cutlass_fused_moe.py index e579141ce6..9c8a547583 100644 --- a/tests/moe/test_trtllm_cutlass_fused_moe.py +++ b/tests/moe/test_trtllm_cutlass_fused_moe.py @@ -1179,6 +1179,35 @@ def quant_mxfp4_batches(a, num_experts): return result_quant_a, result_sfs +def quant_mxfp8_batches(a, num_experts): + quant_a = [] + sfs = [] + for i in range(num_experts): + a_fp8, a_sf = mxfp8_quantize(a[i].cuda(), True, 32) + quant_a.append(a_fp8) + sfs.append(a_sf) + + result_quant_a = torch.stack(quant_a) + result_sfs = torch.stack(sfs) + + return result_quant_a, result_sfs + + +def pack_mxfp8_scales_u8_to_int32_batches( + scale_u8: torch.Tensor, rows: int, cols: int +) -> torch.Tensor: + num_experts = scale_u8.size(0) + aligned_rows = ceil_div(rows, 128) * 128 + k_scales = cols // 32 + aligned_k_scales = ceil_div(k_scales, 4) * 4 + return ( + scale_u8.contiguous() + .view(num_experts, aligned_rows, aligned_k_scales) + .view(torch.int32) + .contiguous() + ) + + def dequant_mxfp4_batches( mat_fp4: torch.Tensor, scale_tensor: torch.Tensor, @@ -1195,6 +1224,26 @@ def dequant_mxfp4_batches( ) +def dequant_mxfp8_batches( + mat_fp8: torch.Tensor, + scale_tensor: torch.Tensor, +): + num_batches = mat_fp8.size(0) + + scale_tensor = scale_tensor.view(num_batches, -1) + + return torch.stack( + [ + mxfp8_dequantize_host( + mat_fp8[b, :, :].cpu().view(torch.uint8), + scale_tensor[b, :].cpu().view(torch.uint8).reshape(-1), + True, + ) + for b in range(num_batches) + ] + ) + + @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @@ -1328,6 +1377,117 @@ def test_moe_mxfp8_mxfp4( torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("num_experts", NUM_EXPERTS) +@pytest.mark.parametrize("top_k", TOP_K_VALUES) +@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) +@pytest.mark.parametrize("otype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize( + ("alpha", "beta", "limit"), [(None, None, None), (0.5, 0.0, 7.0), (1.702, 1.0, 7.0)] +) +@pytest.mark.skipif( + torch.cuda.get_device_capability()[0] not in [10], + reason="MXFP8xMXFP8 is only supported on SM100 for now", +) +def test_moe_mxfp8_mxfp8( + batch_size, + hidden_size, + num_experts, + top_k, + intermediate_size, + otype, + alpha, + beta, + limit, +): + """Test MoE with MXFP8 activations and MXFP8 weights.""" + if top_k > num_experts: + pytest.skip( + f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})" + ) + + torch.manual_seed(42) + e = num_experts + m = batch_size + n = intermediate_size + k = hidden_size + + x = torch.randn(m, k, dtype=otype).cuda() + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=otype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=otype) / 10 + + mxfp8_x, mxfp8_x_sf = mxfp8_quantize(x, True, 32) + mxfp8_w1, mxfp8_w1_scale = quant_mxfp8_batches(w1, e) + mxfp8_w2, mxfp8_w2_scale = quant_mxfp8_batches(w2, e) + mxfp8_w1_scale_i32 = pack_mxfp8_scales_u8_to_int32_batches(mxfp8_w1_scale, 2 * n, k) + mxfp8_w2_scale_i32 = pack_mxfp8_scales_u8_to_int32_batches(mxfp8_w2_scale, k, n) + + router_logits = torch.randn(m, e, dtype=otype).cuda() + routing_weights, selected_experts = compute_routing(router_logits, top_k) + + fake_input_scale = torch.ones(e, device=x.device, dtype=torch.float32) + quant_scales = [ + mxfp8_w1_scale_i32, + fake_input_scale, + mxfp8_w2_scale_i32, + fake_input_scale, + ] + + flash_output = torch.zeros_like(x) + + if alpha is not None and limit is not None and beta is not None: + alpha_t = torch.ones(e, device=x.device) * alpha + limit_t = torch.ones(e, device=x.device) * limit + beta_t = torch.ones(e, device=x.device) * beta + else: + alpha_t = None + limit_t = None + beta_t = None + + _ = fused_moe.cutlass_fused_moe( + mxfp8_x, + selected_experts.to(torch.int), + routing_weights, + mxfp8_w1.contiguous(), + mxfp8_w2.contiguous(), + otype, + swiglu_alpha=alpha_t, + swiglu_limit=limit_t, + swiglu_beta=beta_t, + quant_scales=quant_scales, + input_sf=mxfp8_x_sf, + use_mxfp8_act_scaling=True, + output=flash_output, + ) + + dq_mxfp8_x = ( + mxfp8_dequantize_host( + mxfp8_x.cpu().view(torch.uint8), + mxfp8_x_sf.cpu().view(torch.uint8).reshape(-1), + True, + ) + .cuda() + .to(otype) + ) + dq_mxfp8_w1 = dequant_mxfp8_batches(mxfp8_w1, mxfp8_w1_scale).cuda().to(otype) + dq_mxfp8_w2 = dequant_mxfp8_batches(mxfp8_w2, mxfp8_w2_scale).cuda().to(otype) + + ref_output = compute_with_experts( + e, + dq_mxfp8_x, + dq_mxfp8_w1, + dq_mxfp8_w2, + selected_experts, + routing_weights, + alpha, + beta, + limit, + ) + + torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1) + + def dequant_mxfp4_batches_host( mat_fp4: torch.Tensor, scale_tensor: torch.Tensor,