Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ template class CutlassMoeFCRunner<half, cutlass::uint4b_t>;
#ifdef ENABLE_FP8
// template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3>;
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half>;
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half, __nv_fp8_e4m3, half, true>;
template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, half, half>;
#ifdef ENABLE_BF16
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>;
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16, __nv_fp8_e4m3,
__nv_bfloat16, true>;
template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>;
template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16>;
template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_fp8_e4m3>;
Expand Down
228 changes: 136 additions & 92 deletions csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class DtypeUtils {

class FusedMoeRunner : public tvm::ffi::ModuleObj {
public:
template <typename TypeAct, typename TypeWeight, bool NeedQuant = false>
template <typename TypeAct, typename TypeWeight, bool NeedQuant = false, bool IsMXFPX = false>
std::unique_ptr<kernels::CutlassMoeFCRunnerInterface> switch_output_type(DLDataType output_type) {
switch (encode_dlpack_dtype(output_type)) {
case int64_code: // INT64 == FP4
Expand All @@ -94,19 +94,20 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
// return std::make_unique<kernels::CutlassMoeFCRunner<Type, Type>>();
case float16_code:
if constexpr (NeedQuant) {
return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, half, half>>();
return std::make_unique<
kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, half, half, half, IsMXFPX>>();
} else {
return std::make_unique<
kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, half, TypeAct>>();
kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, half, TypeAct, half, IsMXFPX>>();
}
#ifdef ENABLE_BF16
case bfloat16_code:
if constexpr (NeedQuant) {
return std::make_unique<
kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, __nv_bfloat16, __nv_bfloat16>>();
return std::make_unique<kernels::CutlassMoeFCRunner<
TypeAct, TypeWeight, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, IsMXFPX>>();
} else {
return std::make_unique<
kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, __nv_bfloat16, TypeAct>>();
return std::make_unique<kernels::CutlassMoeFCRunner<TypeAct, TypeWeight, __nv_bfloat16,
TypeAct, __nv_bfloat16, IsMXFPX>>();
}
#endif
default:
Expand Down Expand Up @@ -145,7 +146,9 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
#endif

#ifdef ENABLE_FP8
if (isFp8Quant()) {
if (isWMxfp8AMxfp8Quant()) {
mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp8_e4m3, false, true>(mOutputDtype);
} else if (isFp8Quant()) {
mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp8_e4m3>(mOutputDtype);
}
#endif
Expand Down Expand Up @@ -397,8 +400,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall,
use_lora, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params,
enable_pdl, stream);
use_lora, lora_params, mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling, min_latency_mode,
min_latency_params, enable_pdl, stream);
Comment on lines +403 to +404
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add MXFP8 input_sf validation to the non‑min‑latency path.

runMoeMinLatency enforces input_sf when use_mxfp8_act_scaling=True, but runMoe does not. That can permit null scale factors and lead to invalid reads in the MXFP8 act-scaling path.

💡 Suggested fix (mirror min-latency guard)
@@
     TVM_FFI_ICHECK_EQ(fc1_expert_weights.size(0), fc2_expert_weights.size(0))
         << "fc1_expert_weights and fc2_expert_weights must have the same number of experts.";
@@
     if (isGatedActivation(base_activation_type)) {
       TVM_FFI_ICHECK_EQ(fc1_expert_weights.size(1),
                         fc2_expert_weights.size(2) * mInnerDimMultiplier * 2)
           << "fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size.";
     } else {
       TVM_FFI_ICHECK_EQ(fc1_expert_weights.size(1),
                         fc2_expert_weights.size(2) * mInnerDimMultiplier)
           << "fc1_expert_weights inter size must be equal to fc2_expert_weights inter size.";
     }
+
+    TVM_FFI_ICHECK(!input_sf.has_value() || isMxfp8ActScalingQuant() || isNvfp4Quant())
+        << "Block-scaling factors provided for non block-scaling quantization";
+    TVM_FFI_ICHECK(!isMxfp8ActScalingQuant() || input_sf.has_value())
+        << "input_sf must be provided when use_mxfp8_act_scaling=True";
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu`
around lines 403 - 404, The non‑min‑latency path in runMoe is missing validation
for MXFP8 input scale factors: when use_mxfp8_act_scaling is true you must
verify input_sf is non‑null (mirror the guard in runMoeMinLatency). Update
runMoe to check use_mxfp8_act_scaling and assert or return an error if input_sf
is null before proceeding into the MXFP8 act-scaling path, using the same
check/logic pattern used in runMoeMinLatency to prevent invalid reads of the
scale factors.

#else
mKernelRunner->runMoe(
input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr,
Expand All @@ -414,7 +417,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, lora_params,
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, stream);
mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling, min_latency_mode, min_latency_params,
enable_pdl, stream);
#endif
}

Expand Down Expand Up @@ -490,8 +494,10 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
<< "fc1_expert_weights inter size must be equal to fc2_expert_weights inter size.";
}

TVM_FFI_ICHECK(!input_sf.has_value() || isWMxfp4AMxfp8Quant() || isNvfp4Quant())
TVM_FFI_ICHECK(!input_sf.has_value() || isMxfp8ActScalingQuant() || isNvfp4Quant())
<< "Block-scaling factors provided for non block-scaling quantization";
TVM_FFI_ICHECK(!isMxfp8ActScalingQuant() || input_sf.has_value())
<< "input_sf must be provided when use_mxfp8_act_scaling=True";

int experts_per_token = token_selected_experts.size(1);
int64_t num_rows = input.size(0);
Expand Down Expand Up @@ -581,8 +587,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall,
use_lora_ml, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params,
enable_pdl, stream);
use_lora_ml, lora_params, mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling,
min_latency_mode, min_latency_params, enable_pdl, stream);
#else
mKernelRunner->runMoe(
input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr,
Expand All @@ -598,8 +604,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, false, use_lora_ml,
lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl,
stream);
lora_params, mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling, min_latency_mode,
min_latency_params, enable_pdl, stream);
#endif
}

Expand Down Expand Up @@ -838,8 +844,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
bool min_latency_mode) {
size_t moe_workspace_size = mKernelRunner->getWorkspaceSize(
num_rows, hidden_size, inter_size, num_experts, experts_per_token, activation_type,
parallelismConfig, /* use_lora */ false, mUseDeepSeekFP8BlockScaling, min_latency_mode,
mUseW4GroupScaling);
parallelismConfig, /* use_lora */ false, mUseDeepSeekFP8BlockScaling, mUseMxfp8ActScaling,
min_latency_mode, mUseW4GroupScaling);
size_t src_to_dest_map_size = experts_per_token * num_rows * sizeof(int);

std::vector<size_t> workspaces{moe_workspace_size, src_to_dest_map_size};
Expand All @@ -862,7 +868,66 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
int64_t num_experts_on_rank, int64_t hidden_size, int64_t inter_size,
Optional<Array<Tensor>> quant_scales,
ActivationType base_activation_type = ActivationType::Swiglu) const {
if (isFp8Quant()) {
if (isWMxfp8AMxfp8Quant()) {
#ifdef USING_OSS_CUTLASS_MOE_GEMM
TVM_FFI_ICHECK(quant_scales.has_value())
<< "Expecting quant scales for MXFP8xMXFP8 quantization";
TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 4)
<< "Expecting 4 quant scales for MXFP8xMXFP8 quantization";

TensorView fc1_weight_block = quant_scales.value()[0];
TensorView fc1_global = quant_scales.value()[1];
TensorView fc2_weight_block = quant_scales.value()[2];
TensorView fc2_global = quant_scales.value()[3];

// The input for scale fc1_weight_block / fc2_weight_block is packed into INT32
constexpr int FP8_PER_INT32 = 4;
CHECK_INPUT_TYPE(fc1_weight_block, dl_int32);
CHECK_INPUT_TYPE(fc1_global, dl_float32);
CHECK_INPUT_TYPE(fc2_weight_block, dl_int32);
CHECK_INPUT_TYPE(fc2_global, dl_float32);
CHECK_DIM(3, fc1_weight_block);
CHECK_DIM(1, fc1_global);
CHECK_DIM(3, fc2_weight_block);
CHECK_DIM(1, fc2_global);
TVM_FFI_ICHECK(
fc1_weight_block.size(0) == num_experts_on_rank &&
fc1_weight_block.size(1) ==
TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) *
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this only supposed to work with gated activations? The bf16 variant of this kernel supports both gated and non-gated activations.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works with gating, reference the unit test here: https://github.com/zianglih/flashinfer/blob/aba577ad95f7998b46616dd5c0fa7f8b1818f717/tests/moe/test_trtllm_cutlass_fused_moe.py#L1417-L1418

Also I have tried this kernel in SGLang sgl-project/sglang#18945 and can run Qwen3-30B-A3B without problems.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My question was for non gated activations like squared relu. Does it work with them? I tested yesterday and it did not.

2 &&
fc1_weight_block.size(2) * FP8_PER_INT32 *
TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize ==
TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX))
<< "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 "
"// block_scale_vector_size)";
TVM_FFI_ICHECK_EQ(fc1_global.size(0), num_experts_on_rank)
<< "fc1 global size must be (num_experts_on_rank,)";
TVM_FFI_ICHECK(
fc2_weight_block.size(0) == num_experts_on_rank &&
fc2_weight_block.size(1) ==
TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
hidden_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) &&
fc2_weight_block.size(2) * FP8_PER_INT32 *
TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize ==
TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX))
<< "fc2 weight block size must be (num_experts_on_rank, hidden_size, inter_size // 4 // "
"block_scale_vector_size)";
TVM_FFI_ICHECK_EQ(fc2_global.size(0), num_experts_on_rank)
<< "fc2 global size must be (num_experts_on_rank,)";

return kernels::QuantParams::MXFP8MXFP8(
static_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(fc1_weight_block.data_ptr()),
static_cast<float const*>(fc1_global.data_ptr()),
static_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(fc2_weight_block.data_ptr()),
static_cast<float const*>(fc2_global.data_ptr()));
#else
TVM_FFI_ICHECK(false)
<< "MXFP8 x MXFP8 quantization is not supported in OSS Cutlass Moe Gemm";
#endif
} else if (isFp8Quant()) {
TVM_FFI_ICHECK(quant_scales.has_value()) << "Expecting quant scales for fp8 quantization";
TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 4)
<< "Expecting 4 quant scales for fp8 quantization";
Expand Down Expand Up @@ -1168,9 +1233,16 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {

bool isFp8Quant() const {
return !mUseDeepSeekFP8BlockScaling && mActivationDtype == dl_float8_e4m3fn &&
mWeightDtype == dl_float8_e4m3fn;
mWeightDtype == dl_float8_e4m3fn && !mUseMxfp8ActScaling;
}

bool isWMxfp8AMxfp8Quant() const {
return !mUseDeepSeekFP8BlockScaling && mActivationDtype == dl_float8_e4m3fn &&
mWeightDtype == dl_float8_e4m3fn && mUseMxfp8ActScaling;
}

bool isMxfp8ActScalingQuant() const { return isWMxfp8AMxfp8Quant() || isWMxfp4AMxfp8Quant(); }

bool isNvfp4Quant() const {
return mWeightDtype == dl_int64 &&
mActivationDtype != dl_float8_e4m3fn; // FP8 activation does not use FP4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,11 @@ constexpr bool isGatedActivation(ActivationType activation_type) {
activation_type == ActivationType::SwigluBias;
}

template <typename T, /*The type used for activations/scales/compute*/
typename WeightType, /* The type for the MoE weights */
typename OutputType, /* The output type for the GEMM */
typename ScaleBiasType = OutputType /* The type for the scales/bias */
>
template <typename T, /*The type used for activations/scales/compute*/
typename WeightType, /* The type for the MoE weights */
typename OutputType, /* The output type for the GEMM */
typename ScaleBiasType = OutputType, /* The type for the scales/bias */
bool IsMXFPX = false>
class MoeGemmRunner {
public:
MoeGemmRunner();
Expand Down Expand Up @@ -273,6 +273,8 @@ class MoeGemmRunner {
static constexpr bool use_fp8 = false;
static constexpr bool use_w4afp8 = false;
#endif
static constexpr bool use_mxfp8 = use_fp8 && IsMXFPX;

static constexpr bool use_w4_groupwise = use_w4afp8 || use_wfp4a16;

#if defined(ENABLE_FP4)
Expand Down
Loading
Loading