-
Notifications
You must be signed in to change notification settings - Fork 896
Implement cutlass_fused_moe mxfp8
#2581
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5576dad
083e3db
8ac78b2
4dead9d
44e7b59
9393a7a
8e1091b
4bedd31
d45dcc6
f28fcfb
aba577a
6ab67b4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -82,7 +82,7 @@ class DtypeUtils { | |
|
|
||
| class FusedMoeRunner : public tvm::ffi::ModuleObj { | ||
| public: | ||
| template <typename TypeAct, typename TypeWeight, bool NeedQuant = false> | ||
| template <typename TypeAct, typename TypeWeight, bool NeedQuant = false, bool IsMXFPX = false> | ||
| std::unique_ptr<kernels::CutlassMoeFCRunnerInterface> switch_output_type(DLDataType output_type) { | ||
| switch (encode_dlpack_dtype(output_type)) { | ||
| case int64_code: // INT64 == FP4 | ||
|
|
@@ -94,19 +94,20 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { | |
| // return std::make_unique<kernels::CutlassMoeFCRunner<Type, Type>>(); | ||
| case float16_code: | ||
| if constexpr (NeedQuant) { | ||
| return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, half, half>>(); | ||
| return std::make_unique< | ||
| kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, half, half, half, IsMXFPX>>(); | ||
| } else { | ||
| return std::make_unique< | ||
| kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, half, TypeAct>>(); | ||
| kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, half, TypeAct, half, IsMXFPX>>(); | ||
| } | ||
| #ifdef ENABLE_BF16 | ||
| case bfloat16_code: | ||
| if constexpr (NeedQuant) { | ||
| return std::make_unique< | ||
| kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, __nv_bfloat16, __nv_bfloat16>>(); | ||
| return std::make_unique<kernels::CutlassMoeFCRunner< | ||
| TypeAct, TypeWeight, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, IsMXFPX>>(); | ||
| } else { | ||
| return std::make_unique< | ||
| kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, __nv_bfloat16, TypeAct>>(); | ||
| return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, __nv_bfloat16, | ||
| TypeAct, __nv_bfloat16, IsMXFPX>>(); | ||
| } | ||
| #endif | ||
| default: | ||
|
|
@@ -145,7 +146,9 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { | |
| #endif | ||
|
|
||
| #ifdef ENABLE_FP8 | ||
| if (isFp8Quant()) { | ||
| if (isWMxfp8AMxfp8Quant()) { | ||
| mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp8_e4m3, false, true>(mOutputDtype); | ||
| } else if (isFp8Quant()) { | ||
| mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp8_e4m3>(mOutputDtype); | ||
| } | ||
| #endif | ||
|
|
@@ -397,8 +400,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { | |
| static_cast<int>(experts_per_token), | ||
| static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(), | ||
| static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, | ||
| use_lora, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, | ||
| enable_pdl, stream); | ||
| use_lora, lora_params, mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling, min_latency_mode, | ||
| min_latency_params, enable_pdl, stream); | ||
| #else | ||
| mKernelRunner->runMoe( | ||
| input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr, | ||
|
|
@@ -414,7 +417,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { | |
| static_cast<int>(experts_per_token), | ||
| static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(), | ||
| static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params, | ||
| mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, stream); | ||
| mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling, min_latency_mode, min_latency_params, | ||
| enable_pdl, stream); | ||
| #endif | ||
| } | ||
|
|
||
|
|
@@ -490,8 +494,10 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { | |
| << "fc1_expert_weights inter size must be equal to fc2_expert_weights inter size."; | ||
| } | ||
|
|
||
| TVM_FFI_ICHECK(!input_sf.has_value() || isWMxfp4AMxfp8Quant() || isNvfp4Quant()) | ||
| TVM_FFI_ICHECK(!input_sf.has_value() || isMxfp8ActScalingQuant() || isNvfp4Quant()) | ||
| << "Block-scaling factors provided for non block-scaling quantization"; | ||
| TVM_FFI_ICHECK(!isMxfp8ActScalingQuant() || input_sf.has_value()) | ||
| << "input_sf must be provided when use_mxfp8_act_scaling=True"; | ||
|
|
||
| int experts_per_token = token_selected_experts.size(1); | ||
| int64_t num_rows = input.size(0); | ||
|
|
@@ -581,8 +587,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { | |
| static_cast<int>(experts_per_token), | ||
| static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(), | ||
| static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, | ||
| use_lora_ml, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, | ||
| enable_pdl, stream); | ||
| use_lora_ml, lora_params, mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling, | ||
| min_latency_mode, min_latency_params, enable_pdl, stream); | ||
| #else | ||
| mKernelRunner->runMoe( | ||
| input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr, | ||
|
|
@@ -598,8 +604,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { | |
| static_cast<int>(experts_per_token), | ||
| static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(), | ||
| static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, use_lora_ml, | ||
| lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, | ||
| stream); | ||
| lora_params, mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling, min_latency_mode, | ||
| min_latency_params, enable_pdl, stream); | ||
| #endif | ||
| } | ||
|
|
||
|
|
@@ -838,8 +844,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { | |
| bool min_latency_mode) { | ||
| size_t moe_workspace_size = mKernelRunner->getWorkspaceSize( | ||
| num_rows, hidden_size, inter_size, num_experts, experts_per_token, activation_type, | ||
| parallelismConfig, /* use_lora */ false, mUseDeepSeekFP8BlockScaling, min_latency_mode, | ||
| mUseW4GroupScaling); | ||
| parallelismConfig, /* use_lora */ false, mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling, | ||
| min_latency_mode, mUseW4GroupScaling); | ||
| size_t src_to_dest_map_size = experts_per_token * num_rows * sizeof(int); | ||
|
|
||
| std::vector<size_t> workspaces{moe_workspace_size, src_to_dest_map_size}; | ||
|
|
@@ -862,7 +868,66 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { | |
| int64_t num_experts_on_rank, int64_t hidden_size, int64_t inter_size, | ||
| Optional<Array<Tensor>> quant_scales, | ||
| ActivationType base_activation_type = ActivationType::Swiglu) const { | ||
| if (isFp8Quant()) { | ||
| if (isWMxfp8AMxfp8Quant()) { | ||
| #ifdef USING_OSS_CUTLASS_MOE_GEMM | ||
| TVM_FFI_ICHECK(quant_scales.has_value()) | ||
| << "Expecting quant scales for MXFP8xMXFP8 quantization"; | ||
| TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 4) | ||
| << "Expecting 4 quant scales for MXFP8xMXFP8 quantization"; | ||
|
|
||
| TensorView fc1_weight_block = quant_scales.value()[0]; | ||
| TensorView fc1_global = quant_scales.value()[1]; | ||
| TensorView fc2_weight_block = quant_scales.value()[2]; | ||
| TensorView fc2_global = quant_scales.value()[3]; | ||
|
|
||
| // The input for scale fc1_weight_block / fc2_weight_block is packed into INT32 | ||
| constexpr int FP8_PER_INT32 = 4; | ||
| CHECK_INPUT_TYPE(fc1_weight_block, dl_int32); | ||
| CHECK_INPUT_TYPE(fc1_global, dl_float32); | ||
| CHECK_INPUT_TYPE(fc2_weight_block, dl_int32); | ||
| CHECK_INPUT_TYPE(fc2_global, dl_float32); | ||
| CHECK_DIM(3, fc1_weight_block); | ||
| CHECK_DIM(1, fc1_global); | ||
| CHECK_DIM(3, fc2_weight_block); | ||
| CHECK_DIM(1, fc2_global); | ||
| TVM_FFI_ICHECK( | ||
| fc1_weight_block.size(0) == num_experts_on_rank && | ||
| fc1_weight_block.size(1) == | ||
| TmaWarpSpecializedGroupedGemmInput::alignToSfDim( | ||
| inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) * | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this only supposed to work with gated activations? The bf16 variant of this kernel supports both gated and non-gated activations.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It works with gating, reference the unit test here: https://github.com/zianglih/flashinfer/blob/aba577ad95f7998b46616dd5c0fa7f8b1818f717/tests/moe/test_trtllm_cutlass_fused_moe.py#L1417-L1418 Also I have tried this kernel in SGLang sgl-project/sglang#18945 and can run Qwen3-30B-A3B without problems. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My question was for non gated activations like squared relu. Does it work with them? I tested yesterday and it did not. |
||
| 2 && | ||
| fc1_weight_block.size(2) * FP8_PER_INT32 * | ||
| TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize == | ||
| TmaWarpSpecializedGroupedGemmInput::alignToSfDim( | ||
| hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX)) | ||
| << "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 " | ||
| "// block_scale_vector_size)"; | ||
| TVM_FFI_ICHECK_EQ(fc1_global.size(0), num_experts_on_rank) | ||
| << "fc1 global size must be (num_experts_on_rank,)"; | ||
| TVM_FFI_ICHECK( | ||
| fc2_weight_block.size(0) == num_experts_on_rank && | ||
| fc2_weight_block.size(1) == | ||
| TmaWarpSpecializedGroupedGemmInput::alignToSfDim( | ||
| hidden_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) && | ||
| fc2_weight_block.size(2) * FP8_PER_INT32 * | ||
| TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize == | ||
| TmaWarpSpecializedGroupedGemmInput::alignToSfDim( | ||
| inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX)) | ||
| << "fc2 weight block size must be (num_experts_on_rank, hidden_size, inter_size // 4 // " | ||
| "block_scale_vector_size)"; | ||
| TVM_FFI_ICHECK_EQ(fc2_global.size(0), num_experts_on_rank) | ||
| << "fc2 global size must be (num_experts_on_rank,)"; | ||
|
|
||
| return kernels::QuantParams::MXFP8MXFP8( | ||
| static_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(fc1_weight_block.data_ptr()), | ||
| static_cast<float const*>(fc1_global.data_ptr()), | ||
| static_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(fc2_weight_block.data_ptr()), | ||
| static_cast<float const*>(fc2_global.data_ptr())); | ||
| #else | ||
| TVM_FFI_ICHECK(false) | ||
| << "MXFP8 x MXFP8 quantization is not supported in OSS Cutlass Moe Gemm"; | ||
| #endif | ||
| } else if (isFp8Quant()) { | ||
| TVM_FFI_ICHECK(quant_scales.has_value()) << "Expecting quant scales for fp8 quantization"; | ||
| TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 4) | ||
| << "Expecting 4 quant scales for fp8 quantization"; | ||
|
|
@@ -1168,9 +1233,16 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { | |
|
|
||
| bool isFp8Quant() const { | ||
| return !mUseDeepSeekFP8BlockScaling && mActivationDtype == dl_float8_e4m3fn && | ||
| mWeightDtype == dl_float8_e4m3fn; | ||
| mWeightDtype == dl_float8_e4m3fn && !mUseMxfp8ActScaling; | ||
| } | ||
|
|
||
| bool isWMxfp8AMxfp8Quant() const { | ||
| return !mUseDeepSeekFP8BlockScaling && mActivationDtype == dl_float8_e4m3fn && | ||
| mWeightDtype == dl_float8_e4m3fn && mUseMxfp8ActScaling; | ||
| } | ||
|
|
||
| bool isMxfp8ActScalingQuant() const { return isWMxfp8AMxfp8Quant() || isWMxfp4AMxfp8Quant(); } | ||
|
|
||
| bool isNvfp4Quant() const { | ||
| return mWeightDtype == dl_int64 && | ||
| mActivationDtype != dl_float8_e4m3fn; // FP8 activation does not use FP4 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add MXFP8
input_sfvalidation to the non‑min‑latency path.runMoeMinLatencyenforcesinput_sfwhenuse_mxfp8_act_scaling=True, butrunMoedoes not. That can permit null scale factors and lead to invalid reads in the MXFP8 act-scaling path.💡 Suggested fix (mirror min-latency guard)
@@ TVM_FFI_ICHECK_EQ(fc1_expert_weights.size(0), fc2_expert_weights.size(0)) << "fc1_expert_weights and fc2_expert_weights must have the same number of experts."; @@ if (isGatedActivation(base_activation_type)) { TVM_FFI_ICHECK_EQ(fc1_expert_weights.size(1), fc2_expert_weights.size(2) * mInnerDimMultiplier * 2) << "fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size."; } else { TVM_FFI_ICHECK_EQ(fc1_expert_weights.size(1), fc2_expert_weights.size(2) * mInnerDimMultiplier) << "fc1_expert_weights inter size must be equal to fc2_expert_weights inter size."; } + + TVM_FFI_ICHECK(!input_sf.has_value() || isMxfp8ActScalingQuant() || isNvfp4Quant()) + << "Block-scaling factors provided for non block-scaling quantization"; + TVM_FFI_ICHECK(!isMxfp8ActScalingQuant() || input_sf.has_value()) + << "input_sf must be provided when use_mxfp8_act_scaling=True";🤖 Prompt for AI Agents