Implement cutlass_fused_moe mxfp8#2581
Conversation
Summary of ChangesHello @zianglih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a new mixed-precision quantization scheme, MXFP8xMXFP8, for the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for mxfp8 x mxfp8 quantization in the cutlass_fused_moe kernel. The changes are extensive, touching kernel implementations, bindings, and adding new tests to validate the functionality. The implementation correctly adds new code paths for mxfp8 handling, including new quantization modes and scaling types.
My review focuses on improving code maintainability by addressing several instances of code duplication. I've identified repeated logic for determining scaling types and for dispatching kernels, and I've suggested refactoring these into helper functions or using other C++ patterns to reduce redundancy. These changes should make the code cleaner and easier to maintain in the future.
...v_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h
Outdated
Show resolved
Hide resolved
| auto fpX_scaling_type = getScalingType(); | ||
| if constexpr (use_fp8) { | ||
| if (use_mxfp8_fp8_block_scaling) { | ||
| fpX_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; | ||
| } | ||
| } |
There was a problem hiding this comment.
This logic to determine fpX_scaling_type is duplicated in configureWsPtrs (lines 2755-2760) and setupTmaWarpSpecializedInputs (lines 3933-3938). To improve maintainability and reduce code duplication, consider extracting this logic into a private helper function within the CutlassMoeFCRunner class.
For example:
__host__ __device__ inline TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType
getFpXScalingTypeHelper(bool use_mxfp8_fp8_block_scaling) const {
auto fpX_scaling_type = getScalingType();
if constexpr (use_fp8) {
if (use_mxfp8_fp8_block_scaling) {
fpX_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX;
}
}
return fpX_scaling_type;
}Then, you can call this helper function in all three places.
|
/bot run |
|
@flashinfer-bot run |
|
[FAILED] Pipeline #44336503: 7/20 passed |
📝 WalkthroughWalkthroughAdds MXFPX/MXFP8 activation-scaling and block-scaling support across the Cutlass MoE backend: new template parameter Changes
Sequence Diagram(s)sequenceDiagram
participant Binding as Host Binding
participant Runner as CutlassMoeFCRunner<IsMXFPX>
participant MoeGemm as MoeGemmRunner<IsMXFPX>
participant Device as GPU (CUTLASS / TMA)
Binding->>Runner: prepare inputs, quant params, use_mxfp8_act_scaling
Runner->>MoeGemm: get workspace sizes, configureWsPtrs, select kernel
MoeGemm->>Device: dispatch GEMM/TMA (MXFPX or non‑MXFP path)
Device-->>MoeGemm: kernel results
MoeGemm-->>Runner: post-process (accumulate/gate/dequant)
Runner-->>Binding: final outputs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
tests/moe/test_trtllm_cutlass_fused_moe.py (1)
1389-1392: GPU architecture skip should useflashinfer.utilsfunctions per coding guidelines.The
@pytest.mark.skipifusestorch.cuda.get_device_capability()directly. The same pattern is used elsewhere in this file (e.g., lines 481-484, 1256-1258), but the coding guidelines require usingflashinfer.utilsfunctions likeget_compute_capabilityoris_sm90a_supportedfor architecture-gated skips. Consider updating all new tests to follow this guideline.As per coding guidelines:
tests/**/*.py: Test files must use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, etc.) to skip tests on unsupported GPU architectures.#!/bin/bash # Check what flashinfer.utils functions are available for capability checks rg -n "def get_compute_capability\|def is_sm" --type=py -g '*/utils*' # Check if any test files in the repo already follow the guideline rg -n "from flashinfer.utils import\|flashinfer.utils.get_compute" --type=py -g 'tests/**'🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/test_trtllm_cutlass_fused_moe.py` around lines 1389 - 1392, The skip decorator in the test uses torch.cuda.get_device_capability() directly; replace it with the utility helpers from flashinfer.utils (e.g., import and call get_compute_capability() or the appropriate is_smXX_supported helper) so GPU-architecture gating follows project guidelines; update the `@pytest.mark.skipif` on the test (around the decorator that currently checks torch.cuda.get_device_capability()) to call flashinfer.utils.get_compute_capability() or is_sm90a_supported/is_sm100_supported as appropriate and mirror the pattern used in other tests in this file.csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu (1)
889-900: fc1_weight_block validation always assumes gated (×2) activation — correct for current usage but potentially fragile.The check at line 894 unconditionally multiplies
inter_sizeby 2 for the fc1 N-dimension. This is correct for SwiGLU/SwigluBias (gated activations), which is the only activation type currently used with MXFP8. However, the NVFP4 path (further below) conditionally handles both gated and non-gated cases viaisGatedActivation(base_activation_type).If non-gated activations are supported for MXFP8 in the future, this check will need to be updated.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu` around lines 889 - 900, The fc1_weight_block size check unconditionally assumes a gated (×2) N-dimension; update the validation in the block that references fc1_weight_block, TmaWarpSpecializedGroupedGemmInput::alignToSfDim, and inter_size to only multiply inter_size by 2 when the activation is gated (use isGatedActivation(base_activation_type) or equivalent), otherwise use inter_size as-is, and adjust the error message to reflect both gated and non-gated expected shapes so the check matches the NVFP4 branch behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh`:
- Around line 3578-3586: The current gate use_mxfp8_weight_block_scales only
checks fc1.weight_block_scale and can still dereference fc2.global_scale; change
the logic to enable MXFP8 block-scale mode only when fp8_scales_required
indicates MXFP8 (same flag used elsewhere) AND both
quant_params.mxfp8_mxfp4.fc1.weight_block_scale and
quant_params.mxfp8_mxfp4.fc2.weight_block_scale are set and both corresponding
global_scale pointers are non-null; then use that boolean to choose
fc1_fp8_dequant and fc2_fp8_dequant (otherwise fall back to quant_params.fp8.*
dequant pointers). Ensure the symbol names mentioned
(use_mxfp8_weight_block_scales, quant_params.mxfp8_mxfp4.fc1.weight_block_scale,
quant_params.mxfp8_mxfp4.fc2.weight_block_scale,
quant_params.mxfp8_mxfp4.fc1.global_scale,
quant_params.mxfp8_mxfp4.fc2.global_scale, fp8_scales_required, fc1_fp8_dequant,
fc2_fp8_dequant) are used to locate and update the condition and selections.
In `@tests/moe/test_trtllm_cutlass_fused_moe.py`:
- Around line 1448-1462: Add a GPU-architecture skip to the test function
test_moe_mxfp8_mxfp8 by decorating it with pytest.mark.skipif using the
flashinfer.utils helpers: call get_compute_capability() and pass it to
is_sm90a_supported (or the appropriate helper for MXFP8 support) and skip when
those indicate unsupported hardware; ensure you import pytest and the helpers
(get_compute_capability, is_sm90a_supported) and place the skip decorator
immediately above the test_moe_mxfp8_mxfp8 definition so the test is skipped on
unsupported GPUs.
---
Duplicate comments:
In `@csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh`:
- Around line 4108-4117: This is a duplicate of the MXFP8 validation/gating code
— remove the redundant block and consolidate the FP8 gating logic so only the
earlier validated ternary selection for dequant scales is used; specifically,
keep a single place that computes the fc1/fc2 dequant scale (the ternaries using
std::is_same_v<WeightType, __nv_fp8_e4m3> and
quant_params.mxfp8_mxfp4.fc?.global_scale vs quant_params.fp8.dequant_fc?) and
pass those same values along with fc1_expert_weights, fc2_expert_weights,
fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, fc1_expert_biases,
fc2_bias to the kernel — delete or merge the duplicated lines to avoid repeated
gating/validation.
---
Nitpick comments:
In `@csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu`:
- Around line 889-900: The fc1_weight_block size check unconditionally assumes a
gated (×2) N-dimension; update the validation in the block that references
fc1_weight_block, TmaWarpSpecializedGroupedGemmInput::alignToSfDim, and
inter_size to only multiply inter_size by 2 when the activation is gated (use
isGatedActivation(base_activation_type) or equivalent), otherwise use inter_size
as-is, and adjust the error message to reflect both gated and non-gated expected
shapes so the check matches the NVFP4 branch behavior.
In `@tests/moe/test_trtllm_cutlass_fused_moe.py`:
- Around line 1389-1392: The skip decorator in the test uses
torch.cuda.get_device_capability() directly; replace it with the utility helpers
from flashinfer.utils (e.g., import and call get_compute_capability() or the
appropriate is_smXX_supported helper) so GPU-architecture gating follows project
guidelines; update the `@pytest.mark.skipif` on the test (around the decorator
that currently checks torch.cuda.get_device_capability()) to call
flashinfer.utils.get_compute_capability() or
is_sm90a_supported/is_sm100_supported as appropriate and mirror the pattern used
in other tests in this file.
There was a problem hiding this comment.
♻️ Duplicate comments (1)
tests/moe/test_trtllm_cutlass_fused_moe.py (1)
1389-1392:⚠️ Potential issue | 🟠 MajorUse
flashinfer.utilshelpers for architecture skip gating (Line 1390).This new test still hard-codes
torch.cuda.get_device_capability()instead of the requiredflashinfer.utils-based skip helpers.Suggested patch
import pytest from flashinfer.fused_moe.core import ActivationType import torch from torch.nn import functional as F +from flashinfer.utils import get_compute_capability @@ `@pytest.mark.skipif`( - torch.cuda.get_device_capability()[0] not in [10], + get_compute_capability()[0] != 10, reason="MXFP8xMXFP8 is only supported on SM100 for now", )#!/bin/bash # Verify architecture-skip helper usage in this test file. rg -n "from flashinfer.utils import|get_compute_capability|is_sm90a_supported|torch.cuda.get_device_capability" tests/moe/test_trtllm_cutlass_fused_moe.py -C 2As per coding guidelines:
tests/**/*.py: Test files must useflashinfer.utilsfunctions (get_compute_capability,is_sm90a_supported, etc.) to skip unsupported GPU architectures.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/test_trtllm_cutlass_fused_moe.py` around lines 1389 - 1392, Replace the hard-coded torch.cuda.get_device_capability() check in the pytest skip marker with the flashinfer.utils helpers: import and call get_compute_capability() (or an appropriate helper such as is_sm90a_supported()) and use its result to gate the test instead of directly calling torch.cuda.get_device_capability; update the skipif condition at the pytest.mark.skipif decorator around the test (the block using torch.cuda.get_device_capability) to use get_compute_capability() or is_sm90a_supported() from flashinfer.utils so the test follows the repository's architecture-skip helpers.
🧹 Nitpick comments (1)
flashinfer/jit/gemm/cutlass/generate_kernels.py (1)
410-418: Deduplicate block-scaled validity checks to avoid drift.The MXFP branch repeats the exact same constraints as the FP4 block-scaled branch. Collapsing them into one predicate will keep future tuning changes synchronized.
♻️ Suggested refactor
- # FP4 Has some much more limited sizes - if op.act_type == e2m1 or op.weight_type == e2m1: - if tile_n not in [64, 128, 256] or tile_m != 128: - return False - # TODO Revert this once cutlass adds support for blockscaled + no smem - if ( - op.arch == 100 - and op.epi_schedule == EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm - ): - return False - - # MXFP block-scaled paths currently follow the same shape/schedule limits as FP4 block scaling. - if op.is_mx_fpx: + # FP4 and MXFP block-scaled paths currently share the same shape/schedule limits. + uses_block_scaled_limits = ( + op.act_type == e2m1 or op.weight_type == e2m1 or op.is_mx_fpx + ) + if uses_block_scaled_limits: if tile_n not in [64, 128, 256] or tile_m != 128: return False if ( op.arch == 100 and op.epi_schedule == EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm ): return False🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/jit/gemm/cutlass/generate_kernels.py` around lines 410 - 418, The MXFP branch duplicates the same block-scaled validity constraints used by the FP4 block-scaled branch; combine these into a single predicate to avoid drift by moving the shared checks into a common condition used by both paths. Update the logic that currently tests tile_n, tile_m, op.arch and op.epi_schedule (references: op.is_mx_fpx, tile_n, tile_m, op.arch, op.epi_schedule, EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm) so the tile_n in [64,128,256] / tile_m == 128 and the special-case arch/epilogue check are expressed once and reused for FP4 and MXFP block-scaled validation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@tests/moe/test_trtllm_cutlass_fused_moe.py`:
- Around line 1389-1392: Replace the hard-coded
torch.cuda.get_device_capability() check in the pytest skip marker with the
flashinfer.utils helpers: import and call get_compute_capability() (or an
appropriate helper such as is_sm90a_supported()) and use its result to gate the
test instead of directly calling torch.cuda.get_device_capability; update the
skipif condition at the pytest.mark.skipif decorator around the test (the block
using torch.cuda.get_device_capability) to use get_compute_capability() or
is_sm90a_supported() from flashinfer.utils so the test follows the repository's
architecture-skip helpers.
---
Nitpick comments:
In `@flashinfer/jit/gemm/cutlass/generate_kernels.py`:
- Around line 410-418: The MXFP branch duplicates the same block-scaled validity
constraints used by the FP4 block-scaled branch; combine these into a single
predicate to avoid drift by moving the shared checks into a common condition
used by both paths. Update the logic that currently tests tile_n, tile_m,
op.arch and op.epi_schedule (references: op.is_mx_fpx, tile_n, tile_m, op.arch,
op.epi_schedule, EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm) so the
tile_n in [64,128,256] / tile_m == 128 and the special-case arch/epilogue check
are expressed once and reused for FP4 and MXFP block-scaled validation.
|
Hopper still failing. Can confirm |
|
Can confirm H100 also passes: |
| fc1_weight_block.size(0) == num_experts_on_rank && | ||
| fc1_weight_block.size(1) == | ||
| TmaWarpSpecializedGroupedGemmInput::alignToSfDim( | ||
| inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) * |
There was a problem hiding this comment.
Is this only supposed to work with gated activations? The bf16 variant of this kernel supports both gated and non-gated activations.
There was a problem hiding this comment.
It works with gating, reference the unit test here: https://github.com/zianglih/flashinfer/blob/aba577ad95f7998b46616dd5c0fa7f8b1818f717/tests/moe/test_trtllm_cutlass_fused_moe.py#L1417-L1418
Also I have tried this kernel in SGLang sgl-project/sglang#18945 and can run Qwen3-30B-A3B without problems.
There was a problem hiding this comment.
My question was for non gated activations like squared relu. Does it work with them? I tested yesterday and it did not.
|
hi i locally repro the H100 errors |
|
this was the error i am seeing |
|
Hi @aleozlx , this is exactly the ci error. Have you tried clear any cache? Thanks! |
|
this is a fresh container, i don't think there is anything to clear? |
|
i switch to main branch, they passed |
|
now trying running tests after |
|
I'll also retry on Hopper |
|
Hi @aleozlx, rerunning both Blackwell and Hopper now, Blackwell also shows the failure now. Let me debug. |
Head branch was pushed to by a user without write access
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (1)
406-516:⚠️ Potential issue | 🔴 CriticalVerify that explicit instantiation TUs have matching IsMXFPX=false variants for all SM versions and type combinations now being dispatched.
The lambda
dispatch_by_mxfpxis instantiated twice at runtime (lines 513, 515) with bothIsMXFPX=trueandIsMXFPX=false, causing the nested SHAPE_CASE macros to generate references totma_warp_specialized_generic_moe_gemm_kernelLauncher<>with both boolean variants across all supported SM versions (SM90, SM100, SM120). This widens the linker surface: legacy FP8 paths now bakeIsMXFPX=falseinto symbols that may not have explicit instantiations in the compilation units.If
moe_gemm_kernels_fp8_fp8.cuand related type-specific TUs do not emit matchingIsMXFPX=falsevariants for all the tile/cluster shapes and epilogue configurations now being dispatched, this change is the root cause of the undefined symbol errors infused_moe_90.so.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h` around lines 406 - 516, The change instantiates dispatch_by_mxfpx with both IsMXFPX=true and false which causes references to tma_warp_specialized_generic_moe_gemm_kernelLauncher<> (via dispatchMoeGemmSelectClusterShapeTmaWarpSpecialized and the SHAPE_CASE macro) for SM90/100/120 even when no explicit TU instantiations exist; fix by adding matching explicit instantiations for the IsMXFPX=false variants (all relevant template parameter combinations: T, WeightType, EpilogueTag, FUSION and all tile shapes used in SHAPE_CASE for SM90/100/120) into the corresponding type-specific TUs (e.g., the FP8/FP4 kernel TUs), or alternately narrow the runtime dispatch so dispatch_by_mxfpx is only instantiated for the boolean value that actually has corresponding explicit instantiations to avoid generating undefined symbols.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h`:
- Around line 491-504: The SM120 branch can silently fall through when
kernels::cutlass_kernels::isValidSM120MOESpecialisation<T,WeightType,EpilogueTag,FUSION>()
is false causing count==0 and a no-op; update the else path for the if constexpr
in the gemm_config.sm_version == 120/121 block to fail loudly (throw an
exception or call a fatal logger) with a clear message referencing SM120 and the
tile_config_sm120 value (use pretty_function or include
gemm_config.tile_config_sm120 in the message) so invalid SM120 specialisations
cannot proceed; ensure the thrown exception type matches surrounding error
handling conventions.
---
Outside diff comments:
In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h`:
- Around line 406-516: The change instantiates dispatch_by_mxfpx with both
IsMXFPX=true and false which causes references to
tma_warp_specialized_generic_moe_gemm_kernelLauncher<> (via
dispatchMoeGemmSelectClusterShapeTmaWarpSpecialized and the SHAPE_CASE macro)
for SM90/100/120 even when no explicit TU instantiations exist; fix by adding
matching explicit instantiations for the IsMXFPX=false variants (all relevant
template parameter combinations: T, WeightType, EpilogueTag, FUSION and all tile
shapes used in SHAPE_CASE for SM90/100/120) into the corresponding type-specific
TUs (e.g., the FP8/FP4 kernel TUs), or alternately narrow the runtime dispatch
so dispatch_by_mxfpx is only instantiated for the boolean value that actually
has corresponding explicit instantiations to avoid generating undefined symbols.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 0ec1b1ce-bd64-4a88-a347-ddfc149c8f83
📒 Files selected for processing (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h
| } else if (gemm_config.sm_version == 120 || gemm_config.sm_version == 121) { | ||
| char const* const pretty_function = __PRETTY_FUNCTION__; | ||
| TLLM_LOG_TRACE("At %s, SM120 config=%d", pretty_function, | ||
| (int)(gemm_config.tile_config_sm120)); | ||
| if constexpr (kernels::cutlass_kernels::isValidSM120MOESpecialisation< | ||
| T, WeightType, EpilogueTag, FUSION>()) { | ||
| switch (gemm_config.tile_config_sm120) { | ||
| SHAPE_CASE(120, 128, 128, 64) | ||
| SHAPE_CASE(120, 128, 128, 128) | ||
| SHAPE_CASE(120, 128, 256, 64) | ||
| SHAPE_CASE(120, 256, 128, 64) | ||
| DEFAULT_CASE(120) | ||
| } | ||
| } |
There was a problem hiding this comment.
Throw on invalid SM120 specialisations instead of silently falling through.
If isValidSM120MOESpecialisation<...>() is false, this branch currently exits without a throw. In the workspace path that can leave count == 0, and in the execution path it becomes a no-op instead of a hard failure.
💡 Proposed fix
} else if (gemm_config.sm_version == 120 || gemm_config.sm_version == 121) {
char const* const pretty_function = __PRETTY_FUNCTION__;
TLLM_LOG_TRACE("At %s, SM120 config=%d", pretty_function,
(int)(gemm_config.tile_config_sm120));
if constexpr (kernels::cutlass_kernels::isValidSM120MOESpecialisation<
T, WeightType, EpilogueTag, FUSION>()) {
switch (gemm_config.tile_config_sm120) {
SHAPE_CASE(120, 128, 128, 64)
SHAPE_CASE(120, 128, 128, 128)
SHAPE_CASE(120, 128, 256, 64)
SHAPE_CASE(120, 256, 128, 64)
DEFAULT_CASE(120)
}
+ } else {
+ TLLM_THROW("Unsupported SM120 configuration requested");
}
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| } else if (gemm_config.sm_version == 120 || gemm_config.sm_version == 121) { | |
| char const* const pretty_function = __PRETTY_FUNCTION__; | |
| TLLM_LOG_TRACE("At %s, SM120 config=%d", pretty_function, | |
| (int)(gemm_config.tile_config_sm120)); | |
| if constexpr (kernels::cutlass_kernels::isValidSM120MOESpecialisation< | |
| T, WeightType, EpilogueTag, FUSION>()) { | |
| switch (gemm_config.tile_config_sm120) { | |
| SHAPE_CASE(120, 128, 128, 64) | |
| SHAPE_CASE(120, 128, 128, 128) | |
| SHAPE_CASE(120, 128, 256, 64) | |
| SHAPE_CASE(120, 256, 128, 64) | |
| DEFAULT_CASE(120) | |
| } | |
| } | |
| } else if (gemm_config.sm_version == 120 || gemm_config.sm_version == 121) { | |
| char const* const pretty_function = __PRETTY_FUNCTION__; | |
| TLLM_LOG_TRACE("At %s, SM120 config=%d", pretty_function, | |
| (int)(gemm_config.tile_config_sm120)); | |
| if constexpr (kernels::cutlass_kernels::isValidSM120MOESpecialisation< | |
| T, WeightType, EpilogueTag, FUSION>()) { | |
| switch (gemm_config.tile_config_sm120) { | |
| SHAPE_CASE(120, 128, 128, 64) | |
| SHAPE_CASE(120, 128, 128, 128) | |
| SHAPE_CASE(120, 128, 256, 64) | |
| SHAPE_CASE(120, 256, 128, 64) | |
| DEFAULT_CASE(120) | |
| } | |
| } else { | |
| TLLM_THROW("Unsupported SM120 configuration requested"); | |
| } | |
| } |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h`
around lines 491 - 504, The SM120 branch can silently fall through when
kernels::cutlass_kernels::isValidSM120MOESpecialisation<T,WeightType,EpilogueTag,FUSION>()
is false causing count==0 and a no-op; update the else path for the if constexpr
in the gemm_config.sm_version == 120/121 block to fail loudly (throw an
exception or call a fatal logger) with a clear message referencing SM120 and the
tile_config_sm120 value (use pretty_function or include
gemm_config.tile_config_sm120 in the message) so invalid SM120 specialisations
cannot proceed; ensure the thrown exception type matches surrounding error
handling conventions.
|
The previous Blackwell failure I saw was due to stale ninja file, irrelevant to this PR. The Hoppper failure is a real regression, now fixed by 6ab67b4 . Blackwell: Hopper: |
|
/bot run |
|
@flashinfer-bot run |
<!-- .github/pull_request_template.md --> ## 📌 Description @HumansAnd <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> flashinfer-ai#2505 implements mxfp8 for trtllm backend. However, in SGLang, `--moe-runner-backend flashinfer_trtllm` bypasses SGLang topk implementation and does not work with expert routing replay in MoE RL. We want to implement `mxfp8 x mxfp8` for `cutlass_fused_moe` which works with MoE RL training. This PR mainly reuses existing code path for `WMxfp4AMxfp8Quant`: https://github.com/flashinfer-ai/flashinfer/blob/952b6ab2838d676b4257fcc23bb00f67fdd38efc/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu#L1191 ## 🔍 Related Issues <!-- Link any related issues here --> miles MXFP8/NVFP4 RL roadmap: radixark/miles#615 SGLang FlashInfer MXFP8 integration: sgl-project/sglang#18945 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Toggleable MXFPX/MXFP8 activation-scaling across MOE inference, updating workspace sizing, kernel selection, block-scaling and dispatch to enable MXFP8-aware execution and validation. * Added MXFP8×MXFP8 quantization mode and emitted MXFPX-aware GEMM/kernel variants; public APIs now expose an MXFPX/activation-scaling flag. * **Tests** * Added unit tests and helpers for MXFP8 quantization, packing/dequantization, and end-to-end MXFP8×MXFP8 MOE inference validation. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
<!-- .github/pull_request_template.md --> ## 📌 Description @HumansAnd <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> flashinfer-ai#2505 implements mxfp8 for trtllm backend. However, in SGLang, `--moe-runner-backend flashinfer_trtllm` bypasses SGLang topk implementation and does not work with expert routing replay in MoE RL. We want to implement `mxfp8 x mxfp8` for `cutlass_fused_moe` which works with MoE RL training. This PR mainly reuses existing code path for `WMxfp4AMxfp8Quant`: https://github.com/flashinfer-ai/flashinfer/blob/952b6ab2838d676b4257fcc23bb00f67fdd38efc/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu#L1191 ## 🔍 Related Issues <!-- Link any related issues here --> miles MXFP8/NVFP4 RL roadmap: radixark/miles#615 SGLang FlashInfer MXFP8 integration: sgl-project/sglang#18945 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Toggleable MXFPX/MXFP8 activation-scaling across MOE inference, updating workspace sizing, kernel selection, block-scaling and dispatch to enable MXFP8-aware execution and validation. * Added MXFP8×MXFP8 quantization mode and emitted MXFPX-aware GEMM/kernel variants; public APIs now expose an MXFPX/activation-scaling flag. * **Tests** * Added unit tests and helpers for MXFP8 quantization, packing/dequantization, and end-to-end MXFP8×MXFP8 MOE inference validation. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Amey Naik <[email protected]>

📌 Description
@HumansAnd
#2505 implements mxfp8 for trtllm backend.
However, in SGLang,
--moe-runner-backend flashinfer_trtllmbypasses SGLang topk implementation and does not work with expert routing replay in MoE RL.We want to implement
mxfp8 x mxfp8forcutlass_fused_moewhich works with MoE RL training.This PR mainly reuses existing code path for
WMxfp4AMxfp8Quant:flashinfer/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu
Line 1191 in 952b6ab
🔍 Related Issues
miles MXFP8/NVFP4 RL roadmap: radixark/miles#615
SGLang FlashInfer MXFP8 integration: sgl-project/sglang#18945
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests