Skip to content

Implement cutlass_fused_moe mxfp8#2581

Merged
aleozlx merged 12 commits intoflashinfer-ai:mainfrom
zianglih:mxfp8
Mar 7, 2026
Merged

Implement cutlass_fused_moe mxfp8#2581
aleozlx merged 12 commits intoflashinfer-ai:mainfrom
zianglih:mxfp8

Conversation

@zianglih
Copy link
Copy Markdown
Contributor

@zianglih zianglih commented Feb 18, 2026

📌 Description

@HumansAnd

#2505 implements mxfp8 for trtllm backend.

However, in SGLang, --moe-runner-backend flashinfer_trtllm bypasses SGLang topk implementation and does not work with expert routing replay in MoE RL.

We want to implement mxfp8 x mxfp8 for cutlass_fused_moe which works with MoE RL training.

This PR mainly reuses existing code path for WMxfp4AMxfp8Quant:

🔍 Related Issues

miles MXFP8/NVFP4 RL roadmap: radixark/miles#615
SGLang FlashInfer MXFP8 integration: sgl-project/sglang#18945

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Toggleable MXFPX/MXFP8 activation-scaling across MOE inference, updating workspace sizing, kernel selection, block-scaling and dispatch to enable MXFP8-aware execution and validation.
    • Added MXFP8×MXFP8 quantization mode and emitted MXFPX-aware GEMM/kernel variants; public APIs now expose an MXFPX/activation-scaling flag.
  • Tests

    • Added unit tests and helpers for MXFP8 quantization, packing/dequantization, and end-to-end MXFP8×MXFP8 MOE inference validation.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @zianglih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new mixed-precision quantization scheme, MXFP8xMXFP8, for the cutlass_fused_moe operation. The primary goal is to enable efficient Mixture-of-Experts (MoE) training, particularly for Reinforcement Learning (RL) scenarios, by providing a quantization method that integrates seamlessly with existing MoE training pipelines. This enhancement addresses a gap where the flashinfer_trtllm backend's top-k implementation was not compatible with expert routing replay in MoE RL.

Highlights

  • MXFP8xMXFP8 Quantization Support: Implemented support for MXFP8 activations and MXFP8 weights within the cutlass_fused_moe kernel, enabling a new mixed-precision quantization mode.
  • MoE RL Training Compatibility: The new MXFP8xMXFP8 implementation is designed to work with MoE RL training, addressing limitations encountered with the flashinfer_trtllm backend's top-k implementation.
  • Code Path Reuse and Extension: Existing code paths for WMxfp4AMxfp8Quant were largely reused and extended to accommodate the new MXFP8xMXFP8 quantization, minimizing redundant code.
  • Dynamic Kernel Dispatching: Updated kernel dispatching logic to correctly identify and utilize the new MXFP8 block-scaled kernels, including checks for supported tile shapes on SM100 architectures.
  • Comprehensive Testing: Added new utility functions for MXFP8 quantization and dequantization, along with a dedicated test case to validate the moe_mxfp8_mxfp8 functionality.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh
    • Updated getWorkspaceDeviceBufferSizes and configureWsPtrs function signatures to include a use_mxfp8_fp8_block_scaling boolean parameter.
    • Modified the logic for determining fpX_scaling_type to correctly apply MXFPX scaling when use_mxfp8_fp8_block_scaling is true and use_fp8 is enabled.
    • Adjusted the assignment of fc1_fp4_act_scale_ and fc2_fp4_act_scale_ to consider use_mxfp8_fp8_block_scaling alongside use_block_scaling.
    • Updated the fc1_fp8_dequant and fc2_fp8_dequant pointers to conditionally use mxfp8_mxfp4 global scales when MXFP8 weight block scales are active.
    • Added a new use_mxfp8_fp8_block_scaling local variable and passed it to configureWsPtrs.
    • Modified layout_info1.fpX_block_scaling_type and layout_info2.fpX_block_scaling_type to dynamically set MXFPX scaling based on WeightType and mxfp8_mxfp4 weight block scales.
  • csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu
    • Extended the isFp8Quant() check to include isWFp8AMxfp8Quant() for kernel runner selection.
    • Added new member functions isWFp8AMxfp8Quant() and isMxfp8ActScalingQuant() to differentiate between various FP8 quantization modes.
    • Updated getQuantParams to handle isWFp8AMxfp8Quant() by checking for 4 quantization scales and extracting fc1_weight_block, fc1_global, fc2_weight_block, and fc2_global scales, with specific dimension and type checks.
    • Added a conditional compilation block for USING_OSS_CUTLASS_MOE_GEMM to support MXFP8xMXFP8 quantization.
    • Modified isFp8Quant() to explicitly exclude mUseMxfp8ActScaling.
    • Added new TVM_FFI_ICHECK statements to validate input_sf presence when isMxfp8ActScalingQuant() is true.
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h
    • Updated comments for the MXFP8MXFP4Inputs struct to clarify its reuse for MXFP8xMXFP8 (FP8 weights with MXFPX block scales).
    • Added a static MXFP8MXFP8 method to QuantParams that reuses the MXFP8MXFP4 constructor.
    • Modified the CutlassMoeFCRunner class declaration to include use_mxfp8_fp8_block_scaling in the getWorkspaceDeviceBufferSizes and configureWsPtrs method signatures.
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl
    • Updated the IsBlockScaled constexpr to include IsMXFPX in its condition.
    • Modified ElementWeightBlockScaled to conditionally use cutlass::mx_float8_t<ElementWeight> when IsMXFPX is true and WeightType is not SafeFP4, otherwise using cutlass::mx_float4_t<ElementWeight>.
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h
    • Modified calcMaxWorkspace to allocate workspace using MXFPX requirements when use_fp8 is true, ensuring sufficient space for both regular FP8 and MXFP8 modes.
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h
    • Updated getDispatchFunctionForSM100 to pass a boolean IsMXFPX template parameter instead of is_wfp4afp8.
    • Added new constexpr variables is_wfp8amxfp8, is_wfp8amxfp8_tile_supported, supports_mxfpx, and supports_mxfpx_tile to manage MXFPX dispatching.
    • Introduced checks for MXFPX support and tile shape compatibility within dispatchMoeGemmFinalDispatchTmaWarpSpecialized.
    • Updated the kernel dispatching logic to conditionally select functions based on use_mxfpx and its support for specific tile shapes.
  • flashinfer/jit/gemm/cutlass/generate_kernels.py
    • Added is_gemm_op_valid_sm100 logic to enforce specific tile shape constraints (M=128, N in [64, 128, 256]) for MXFP block-scaled paths.
    • Modified generate_sm100_grouped_gemm_operations to emit both regular FP8xFP8 and MXFP8xMXFP8 variants by iterating through mxfp_modes.
  • tests/moe/test_trtllm_cutlass_fused_moe.py
    • Added quant_mxfp8_batches function to quantize batches of tensors to MXFP8.
    • Added pack_mxfp8_scales_u8_to_int32_batches function to pack MXFP8 scales from uint8 to int32.
    • Added dequant_mxfp8_batches function to dequantize batches of MXFP8 tensors.
    • Implemented a new test case test_moe_mxfp8_mxfp8 to validate the functionality of MoE with MXFP8 activations and MXFP8 weights, including quantization, execution, and comparison with a reference output.
Activity
  • The pull request is currently marked as Work In Progress (WIP).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for mxfp8 x mxfp8 quantization in the cutlass_fused_moe kernel. The changes are extensive, touching kernel implementations, bindings, and adding new tests to validate the functionality. The implementation correctly adds new code paths for mxfp8 handling, including new quantization modes and scaling types.

My review focuses on improving code maintainability by addressing several instances of code duplication. I've identified repeated logic for determining scaling types and for dispatching kernels, and I've suggested refactoring these into helper functions or using other C++ patterns to reduce redundancy. These changes should make the code cleaner and easier to maintain in the future.

Comment on lines +2598 to +2603
auto fpX_scaling_type = getScalingType();
if constexpr (use_fp8) {
if (use_mxfp8_fp8_block_scaling) {
fpX_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX;
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic to determine fpX_scaling_type is duplicated in configureWsPtrs (lines 2755-2760) and setupTmaWarpSpecializedInputs (lines 3933-3938). To improve maintainability and reduce code duplication, consider extracting this logic into a private helper function within the CutlassMoeFCRunner class.

For example:

__host__ __device__ inline TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType
getFpXScalingTypeHelper(bool use_mxfp8_fp8_block_scaling) const {
    auto fpX_scaling_type = getScalingType();
    if constexpr (use_fp8) {
        if (use_mxfp8_fp8_block_scaling) {
            fpX_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX;
        }
    }
    return fpX_scaling_type;
}

Then, you can call this helper function in all three places.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Feb 19, 2026

/bot run

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Feb 19, 2026

@flashinfer-bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !326 has been created, and the CI pipeline #44336503 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #44336503: 7/20 passed

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Feb 22, 2026

📝 Walkthrough

Walkthrough

Adds MXFPX/MXFP8 activation-scaling and block-scaling support across the Cutlass MoE backend: new template parameter IsMXFPX, a use_mxfp8_act_scaling flag, MXFP8-MXFP8 quant paths, workspace sizing/pointer wiring updates, dispatch/template changes, explicit instantiations, and tests for MXFP8 flows.

Changes

Cohort / File(s) Summary
Core Cutlass MoE runner
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh, csrc/fused_moe/cutlass_backend/cutlass_fused_moe_instantiation.cu
Add IsMXFPX template param; thread use_mxfp8_act_scaling through workspace sizing, configureWsPtrs, BlockScale/GEMM signatures; add explicit MXFPX FP8 instantiations.
Binding & quant routing
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu
Propagate IsMXFPX into runner selection; add isWMxfp8AMxfp8Quant() / isMxfp8ActScalingQuant(); build/validate MXFP8MXFP8 quant params and pass mUseMxfp8ActScaling to run sites.
Moe GEMM interfaces & dispatch
csrc/nv_internal/.../cutlass_kernels/include/moe_kernels.h, .../moe_gemm/moe_gemm_template_dispatch.h, .../moe_gemm/moe_gemm_template_dispatch_tma_ws.h, csrc/nv_internal/.../cutlass_kernels/include/moe_gemm_kernels.h
Add IsMXFPX template param to runners/interfaces; add use_mxfp8_act_scaling to public APIs; propagate MXFPX through block-scaling, tile-shape selection, and dispatch.
Kernel launcher / TMA paths
csrc/nv_internal/.../moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl, .../moe_gemm/moe_gemm_template_dispatch_tma_ws.h
Include IsMXFPX in IsBlockScaled and replace FP4-specific gating with MXFPX-aware gating in launcher selection and dispatch.
GEMM kernel generation (JIT)
flashinfer/jit/gemm/cutlass/generate_kernels.py
Emit MXFP variants: add mxfp_modes to generate MXFP8xMXFP8 launchers where valid; enforce SM100 tile rules for MXFP paths.
FP8 explicit instantiations
csrc/nv_internal/.../moe_gemm/moe_gemm_kernels_fp8_fp8.cu
Add explicit MoeGemmRunner instantiations for FP8 with IsMXFPX=true.
Tests & utilities
tests/moe/test_trtllm_cutlass_fused_moe.py
Add MXFP8 test helpers (quant_mxfp8_batches, pack_mxfp8_scales_u8_to_int32_batches, dequant_mxfp8_batches) and test_moe_mxfp8_mxfp8 E2E test exercising MXFP8-MXFP8 path.

Sequence Diagram(s)

sequenceDiagram
    participant Binding as Host Binding
    participant Runner as CutlassMoeFCRunner<IsMXFPX>
    participant MoeGemm as MoeGemmRunner<IsMXFPX>
    participant Device as GPU (CUTLASS / TMA)

    Binding->>Runner: prepare inputs, quant params, use_mxfp8_act_scaling
    Runner->>MoeGemm: get workspace sizes, configureWsPtrs, select kernel
    MoeGemm->>Device: dispatch GEMM/TMA (MXFPX or non‑MXFP path)
    Device-->>MoeGemm: kernel results
    MoeGemm-->>Runner: post-process (accumulate/gate/dequant)
    Runner-->>Binding: final outputs
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • djmmoss
  • bkryu
  • yongwww
  • cyx-6
  • nv-yunzheq
  • jiahanc
  • IwakuraRein
  • jimmyzho

Poem

"I'm a rabbit in silicon glades, hopping MXFPX trails,
I stash the scales and flip the flags, across runners, rows, and rails,
I pack the bytes and tune the kernels, tests cheer every little hop,
With quant and dispatch I skip and tap — precise hops that never stop! 🐇✨"

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 11.76% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main feature being implemented: MXFP8 support for cutlass_fused_moe, which aligns with the changeset's primary purpose.
Description check ✅ Passed The description addresses why the change is needed (SGLang MoE RL compatibility), provides context on related prior work, and includes links to related issues, though it lacks explicit detail on what testing was performed locally.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@zianglih zianglih marked this pull request as draft February 22, 2026 06:18
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
tests/moe/test_trtllm_cutlass_fused_moe.py (1)

1389-1392: GPU architecture skip should use flashinfer.utils functions per coding guidelines.

The @pytest.mark.skipif uses torch.cuda.get_device_capability() directly. The same pattern is used elsewhere in this file (e.g., lines 481-484, 1256-1258), but the coding guidelines require using flashinfer.utils functions like get_compute_capability or is_sm90a_supported for architecture-gated skips. Consider updating all new tests to follow this guideline.

As per coding guidelines: tests/**/*.py: Test files must use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, etc.) to skip tests on unsupported GPU architectures.

#!/bin/bash
# Check what flashinfer.utils functions are available for capability checks
rg -n "def get_compute_capability\|def is_sm" --type=py -g '*/utils*'
# Check if any test files in the repo already follow the guideline
rg -n "from flashinfer.utils import\|flashinfer.utils.get_compute" --type=py -g 'tests/**'
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_trtllm_cutlass_fused_moe.py` around lines 1389 - 1392, The
skip decorator in the test uses torch.cuda.get_device_capability() directly;
replace it with the utility helpers from flashinfer.utils (e.g., import and call
get_compute_capability() or the appropriate is_smXX_supported helper) so
GPU-architecture gating follows project guidelines; update the
`@pytest.mark.skipif` on the test (around the decorator that currently checks
torch.cuda.get_device_capability()) to call
flashinfer.utils.get_compute_capability() or
is_sm90a_supported/is_sm100_supported as appropriate and mirror the pattern used
in other tests in this file.
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu (1)

889-900: fc1_weight_block validation always assumes gated (×2) activation — correct for current usage but potentially fragile.

The check at line 894 unconditionally multiplies inter_size by 2 for the fc1 N-dimension. This is correct for SwiGLU/SwigluBias (gated activations), which is the only activation type currently used with MXFP8. However, the NVFP4 path (further below) conditionally handles both gated and non-gated cases via isGatedActivation(base_activation_type).

If non-gated activations are supported for MXFP8 in the future, this check will need to be updated.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu`
around lines 889 - 900, The fc1_weight_block size check unconditionally assumes
a gated (×2) N-dimension; update the validation in the block that references
fc1_weight_block, TmaWarpSpecializedGroupedGemmInput::alignToSfDim, and
inter_size to only multiply inter_size by 2 when the activation is gated (use
isGatedActivation(base_activation_type) or equivalent), otherwise use inter_size
as-is, and adjust the error message to reflect both gated and non-gated expected
shapes so the check matches the NVFP4 branch behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh`:
- Around line 3578-3586: The current gate use_mxfp8_weight_block_scales only
checks fc1.weight_block_scale and can still dereference fc2.global_scale; change
the logic to enable MXFP8 block-scale mode only when fp8_scales_required
indicates MXFP8 (same flag used elsewhere) AND both
quant_params.mxfp8_mxfp4.fc1.weight_block_scale and
quant_params.mxfp8_mxfp4.fc2.weight_block_scale are set and both corresponding
global_scale pointers are non-null; then use that boolean to choose
fc1_fp8_dequant and fc2_fp8_dequant (otherwise fall back to quant_params.fp8.*
dequant pointers). Ensure the symbol names mentioned
(use_mxfp8_weight_block_scales, quant_params.mxfp8_mxfp4.fc1.weight_block_scale,
quant_params.mxfp8_mxfp4.fc2.weight_block_scale,
quant_params.mxfp8_mxfp4.fc1.global_scale,
quant_params.mxfp8_mxfp4.fc2.global_scale, fp8_scales_required, fc1_fp8_dequant,
fc2_fp8_dequant) are used to locate and update the condition and selections.

In `@tests/moe/test_trtllm_cutlass_fused_moe.py`:
- Around line 1448-1462: Add a GPU-architecture skip to the test function
test_moe_mxfp8_mxfp8 by decorating it with pytest.mark.skipif using the
flashinfer.utils helpers: call get_compute_capability() and pass it to
is_sm90a_supported (or the appropriate helper for MXFP8 support) and skip when
those indicate unsupported hardware; ensure you import pytest and the helpers
(get_compute_capability, is_sm90a_supported) and place the skip decorator
immediately above the test_moe_mxfp8_mxfp8 definition so the test is skipped on
unsupported GPUs.

---

Duplicate comments:
In `@csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh`:
- Around line 4108-4117: This is a duplicate of the MXFP8 validation/gating code
— remove the redundant block and consolidate the FP8 gating logic so only the
earlier validated ternary selection for dequant scales is used; specifically,
keep a single place that computes the fc1/fc2 dequant scale (the ternaries using
std::is_same_v<WeightType, __nv_fp8_e4m3> and
quant_params.mxfp8_mxfp4.fc?.global_scale vs quant_params.fp8.dequant_fc?) and
pass those same values along with fc1_expert_weights, fc2_expert_weights,
fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, fc1_expert_biases,
fc2_bias to the kernel — delete or merge the duplicated lines to avoid repeated
gating/validation.

---

Nitpick comments:
In `@csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu`:
- Around line 889-900: The fc1_weight_block size check unconditionally assumes a
gated (×2) N-dimension; update the validation in the block that references
fc1_weight_block, TmaWarpSpecializedGroupedGemmInput::alignToSfDim, and
inter_size to only multiply inter_size by 2 when the activation is gated (use
isGatedActivation(base_activation_type) or equivalent), otherwise use inter_size
as-is, and adjust the error message to reflect both gated and non-gated expected
shapes so the check matches the NVFP4 branch behavior.

In `@tests/moe/test_trtllm_cutlass_fused_moe.py`:
- Around line 1389-1392: The skip decorator in the test uses
torch.cuda.get_device_capability() directly; replace it with the utility helpers
from flashinfer.utils (e.g., import and call get_compute_capability() or the
appropriate is_smXX_supported helper) so GPU-architecture gating follows project
guidelines; update the `@pytest.mark.skipif` on the test (around the decorator
that currently checks torch.cuda.get_device_capability()) to call
flashinfer.utils.get_compute_capability() or
is_sm90a_supported/is_sm100_supported as appropriate and mirror the pattern used
in other tests in this file.

@ziang-and ziang-and requested review from bkryu and cyx-6 as code owners March 3, 2026 22:35
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
tests/moe/test_trtllm_cutlass_fused_moe.py (1)

1389-1392: ⚠️ Potential issue | 🟠 Major

Use flashinfer.utils helpers for architecture skip gating (Line 1390).

This new test still hard-codes torch.cuda.get_device_capability() instead of the required flashinfer.utils-based skip helpers.

Suggested patch
 import pytest
 from flashinfer.fused_moe.core import ActivationType
 import torch
 from torch.nn import functional as F
+from flashinfer.utils import get_compute_capability
@@
 `@pytest.mark.skipif`(
-    torch.cuda.get_device_capability()[0] not in [10],
+    get_compute_capability()[0] != 10,
     reason="MXFP8xMXFP8 is only supported on SM100 for now",
 )
#!/bin/bash
# Verify architecture-skip helper usage in this test file.
rg -n "from flashinfer.utils import|get_compute_capability|is_sm90a_supported|torch.cuda.get_device_capability" tests/moe/test_trtllm_cutlass_fused_moe.py -C 2

As per coding guidelines: tests/**/*.py: Test files must use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, etc.) to skip unsupported GPU architectures.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_trtllm_cutlass_fused_moe.py` around lines 1389 - 1392, Replace
the hard-coded torch.cuda.get_device_capability() check in the pytest skip
marker with the flashinfer.utils helpers: import and call
get_compute_capability() (or an appropriate helper such as is_sm90a_supported())
and use its result to gate the test instead of directly calling
torch.cuda.get_device_capability; update the skipif condition at the
pytest.mark.skipif decorator around the test (the block using
torch.cuda.get_device_capability) to use get_compute_capability() or
is_sm90a_supported() from flashinfer.utils so the test follows the repository's
architecture-skip helpers.
🧹 Nitpick comments (1)
flashinfer/jit/gemm/cutlass/generate_kernels.py (1)

410-418: Deduplicate block-scaled validity checks to avoid drift.

The MXFP branch repeats the exact same constraints as the FP4 block-scaled branch. Collapsing them into one predicate will keep future tuning changes synchronized.

♻️ Suggested refactor
-    # FP4 Has some much more limited sizes
-    if op.act_type == e2m1 or op.weight_type == e2m1:
-        if tile_n not in [64, 128, 256] or tile_m != 128:
-            return False
-        # TODO Revert this once cutlass adds support for blockscaled + no smem
-        if (
-            op.arch == 100
-            and op.epi_schedule == EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm
-        ):
-            return False
-
-    # MXFP block-scaled paths currently follow the same shape/schedule limits as FP4 block scaling.
-    if op.is_mx_fpx:
+    # FP4 and MXFP block-scaled paths currently share the same shape/schedule limits.
+    uses_block_scaled_limits = (
+        op.act_type == e2m1 or op.weight_type == e2m1 or op.is_mx_fpx
+    )
+    if uses_block_scaled_limits:
         if tile_n not in [64, 128, 256] or tile_m != 128:
             return False
         if (
             op.arch == 100
             and op.epi_schedule == EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm
         ):
             return False
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/jit/gemm/cutlass/generate_kernels.py` around lines 410 - 418, The
MXFP branch duplicates the same block-scaled validity constraints used by the
FP4 block-scaled branch; combine these into a single predicate to avoid drift by
moving the shared checks into a common condition used by both paths. Update the
logic that currently tests tile_n, tile_m, op.arch and op.epi_schedule
(references: op.is_mx_fpx, tile_n, tile_m, op.arch, op.epi_schedule,
EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm) so the tile_n in
[64,128,256] / tile_m == 128 and the special-case arch/epilogue check are
expressed once and reused for FP4 and MXFP block-scaled validation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@tests/moe/test_trtllm_cutlass_fused_moe.py`:
- Around line 1389-1392: Replace the hard-coded
torch.cuda.get_device_capability() check in the pytest skip marker with the
flashinfer.utils helpers: import and call get_compute_capability() (or an
appropriate helper such as is_sm90a_supported()) and use its result to gate the
test instead of directly calling torch.cuda.get_device_capability; update the
skipif condition at the pytest.mark.skipif decorator around the test (the block
using torch.cuda.get_device_capability) to use get_compute_capability() or
is_sm90a_supported() from flashinfer.utils so the test follows the repository's
architecture-skip helpers.

---

Nitpick comments:
In `@flashinfer/jit/gemm/cutlass/generate_kernels.py`:
- Around line 410-418: The MXFP branch duplicates the same block-scaled validity
constraints used by the FP4 block-scaled branch; combine these into a single
predicate to avoid drift by moving the shared checks into a common condition
used by both paths. Update the logic that currently tests tile_n, tile_m,
op.arch and op.epi_schedule (references: op.is_mx_fpx, tile_n, tile_m, op.arch,
op.epi_schedule, EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm) so the
tile_n in [64,128,256] / tile_m == 128 and the special-case arch/epilogue check
are expressed once and reused for FP4 and MXFP block-scaled validation.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f28fcfb and aba577a.

📒 Files selected for processing (2)
  • flashinfer/jit/gemm/cutlass/generate_kernels.py
  • tests/moe/test_trtllm_cutlass_fused_moe.py

@IwakuraRein IwakuraRein added run-ci and removed run-ci labels Mar 4, 2026
@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented Mar 5, 2026

Hopper still failing. Can confirm test_trtllm_cutlass_fused_moe passes on Blackwell:

pytest tests/moe/test_trtllm_cutlass_fused_moe.py
=========================================== test session starts ===========================================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /root/flashinfer-mxfp8-src
configfile: pytest.ini
plugins: hydra-core-1.3.2, anyio-4.12.1, typeguard-4.4.4
collected 29 items                                                                                        

tests/moe/test_trtllm_cutlass_fused_moe.py ......................sssssss          [100%]

=================================== warnings summary ====================================
tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe_fp8[otype0-wtype0-128-2-2-128-1]
  /root/flashinfer-mxfp8-src/tests/moe/test_trtllm_cutlass_fused_moe.py:445: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
    hidden_states_scale = torch.tensor(hidden_states_scale[0]).cuda()

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================= 22 passed, 7 skipped, 1 warning in 239.50s (0:03:59) ==================

@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented Mar 5, 2026

Can confirm H100 also passes:

root@b35f7fa2ee4d:/flashinfer-mxfp8-src# pytest tests/moe/test_trtllm_cutlass_fused_moe.py
======================= test session starts =======================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /flashinfer-mxfp8-src
configfile: pytest.ini
plugins: anyio-4.11.0
collected 29 items                                                

tests/moe/test_trtllm_cutlass_fused_moe.py ..ssssssss...... [ 55%]
ssssss.......                                                         [100%]

============================= warnings summary ==============================
tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe_fp8[otype0-wtype0-128-2-2-128-1]
  /flashinfer-mxfp8-src/tests/moe/test_trtllm_cutlass_fused_moe.py:445: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
    hidden_states_scale = torch.tensor(hidden_states_scale[0]).cuda()

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========== 15 passed, 14 skipped, 1 warning in 2047.03s (0:34:07) ===========
root@b35f7fa2ee4d:/flashinfer-mxfp8-src# nvidia-smi
Thu Mar  5 06:04:43 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.08             Driver Version: 550.127.08     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 PCIe               On  |   00000000:D2:00.0 Off |                    0 |
| N/A   26C    P0             46W /  310W |       1MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+
root@b35f7fa2ee4d:/flashinfer-mxfp8-src# 

fc1_weight_block.size(0) == num_experts_on_rank &&
fc1_weight_block.size(1) ==
TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) *
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this only supposed to work with gated activations? The bf16 variant of this kernel supports both gated and non-gated activations.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works with gating, reference the unit test here: https://github.com/zianglih/flashinfer/blob/aba577ad95f7998b46616dd5c0fa7f8b1818f717/tests/moe/test_trtllm_cutlass_fused_moe.py#L1417-L1418

Also I have tried this kernel in SGLang sgl-project/sglang#18945 and can run Qwen3-30B-A3B without problems.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My question was for non gated activations like squared relu. Does it work with them? I tested yesterday and it did not.

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Mar 6, 2026

hi i locally repro the H100 errors

============================================================ short test summary info =============================================================
FAILED tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe[128-2-2-128-1] - RuntimeError: Check failed: (lib_handle_ != nullptr) is false: Failed to load dynamic shared library /home/aleyang/.cache/flashinfer/0.6.4/90...
FAILED tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe_fp8[otype0-wtype0-128-2-2-128-1] - RuntimeError: Check failed: (lib_handle_ != nullptr) is false: Failed to load dynamic shared library /home/aleyang/.cache/flashinfer/0.6.4/90...
FAILED tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe_expert_parallel[128-2-8-128-1] - RuntimeError: Check failed: (lib_handle_ != nullptr) is false: Failed to load dynamic shared library /home/aleyang/.cache/flashinfer/0.6.4/90...
FAILED tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe_tensor_parallel[128-2-2-128-1] - RuntimeError: Check failed: (lib_handle_ != nullptr) is false: Failed to load dynamic shared library /home/aleyang/.cache/flashinfer/0.6.4/90...
FAILED tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe_tensor_parallel[128-4-2-128-1] - RuntimeError: Check failed: (lib_handle_ != nullptr) is false: Failed to load dynamic shared library /home/aleyang/.cache/flashinfer/0.6.4/90...
FAILED tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe_tensor_expert_parallel[128-2-2-8-128-1] - RuntimeError: Check failed: (lib_handle_ != nullptr) is false: Failed to load dynamic shared library /home/aleyang/.cache/flashinfer/0.6.4/90...
FAILED tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe_tensor_expert_parallel[128-4-2-8-128-1] - RuntimeError: Check failed: (lib_handle_ != nullptr) is false: Failed to load dynamic shared library /home/aleyang/.cache/flashinfer/0.6.4/90...
FAILED tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe_fp8_block_scaling[128-2-2-128-1] - RuntimeError: Check failed: (lib_handle_ != nullptr) is false: Failed to load dynamic shared library /home/aleyang/.cache/flashinfer/0.6.4/90...
FAILED tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe_bf16_mxfp4[None-None-None-128-2-2-128-1] - RuntimeError: Check failed: (lib_handle_ != nullptr) is false: Failed to load dynamic shared library /home/aleyang/.cache/flashinfer/0.6.4/90...
FAILED tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe_bf16_mxfp4[0.5-0.0-7.0-128-2-2-128-1] - RuntimeError: Check failed: (lib_handle_ != nullptr) is false: Failed to load dynamic shared library /home/aleyang/.cache/flashinfer/0.6.4/90...
FAILED tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe_bf16_mxfp4[1.702-1.0-7.0-128-2-2-128-1] - RuntimeError: Check failed: (lib_handle_ != nullptr) is false: Failed to load dynamic shared library /home/aleyang/.cache/flashinfer/0.6.4/90...
FAILED tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe_w4a8[False-dtype0-128-2-2-128-1] - RuntimeError: Check failed: (lib_handle_ != nullptr) is false: Failed to load dynamic shared library /home/aleyang/.cache/flashinfer/0.6.4/90...
FAILED tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe_w4a8[False-dtype1-128-2-2-128-1] - RuntimeError: Check failed: (lib_handle_ != nullptr) is false: Failed to load dynamic shared library /home/aleyang/.cache/flashinfer/0.6.4/90...
FAILED tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe_w4a8[True-dtype0-128-2-2-128-1] - RuntimeError: Check failed: (lib_handle_ != nullptr) is false: Failed to load dynamic shared library /home/aleyang/.cache/flashinfer/0.6.4/90...
FAILED tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe_w4a8[True-dtype1-128-2-2-128-1] - RuntimeError: Check failed: (lib_handle_ != nullptr) is false: Failed to load dynamic shared library /home/aleyang/.cache/flashinfer/0.6.4/90...
============================================= 15 failed, 20 skipped, 1 warning in 2887.52s (0:48:07) =============================================
aleyang@08d592f87dea:/workspace/flashinfer$ nvidia-smi
Fri Mar  6 22:52:42 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 590.48.01              Driver Version: 590.48.01      CUDA Version: 13.1     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 PCIe               On  |   00000000:41:00.0 Off |                    0 |
| N/A   43C    P0             47W /  350W |       0MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+
aleyang@08d592f87dea:/workspace/flashinfer$ history | tail -n 4
    2  pip install --no-build-isolation -e .
    3  pytest tests/moe/test_trtllm_cutlass_fused_moe.py
    4  nvidia-smi
    5  history | tail -n 4
aleyang@08d592f87dea:/workspace/flashinfer$ git log -1
commit aba577ad95f7998b46616dd5c0fa7f8b1818f717 (HEAD -> mxfp8, mxfp8/mxfp8)
Merge: f28fcfb6 2371ee86
Author: Ziang Li <[email protected]>
Date:   Tue Mar 3 14:35:14 2026 -0800

    Merge branch 'main' into mxfp8

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Mar 6, 2026

this was the error i am seeing

E   RuntimeError: Check failed: (lib_handle_ != nullptr) is false: Failed to load dynamic shared library /home/aleyang/.cache/flashinfer/0.6.4/90a/cached_ops/fused_moe_90/fused_moe_90.so /home/aleyang/.cache/flashinfer/0.6.4/90a/cached_ops/fused_moe_90/fused_moe_90.so: undefined symbol: _ZN12tensorrt_llm7kernels19cutlass_kernels_oss52tma_warp_specialized_generic_moe_gemm_kernelLauncherIN7cutlass4arch4Sm90E13__nv_fp8_e4m3S6_6__halfvNS_18cutlass_extensions17EpilogueOpDefaultELNS0_15cutlass_kernels34TmaWarpSpecializedGroupedGemmInput14EpilogueFusionE0EN4cute5tupleIJNSD_1CILi128EEENSF_ILi16EEESG_EEENSE_IJNSF_ILi2EEENSF_ILi1EEESK_EEELb1ELb0ELb0ELb0EEEvSB_iiP11CUstream_stPiPmNSE_IJiiSK_EEESQ_

@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented Mar 6, 2026

Hi @aleozlx , this is exactly the ci error. Have you tried clear any cache? Thanks!

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Mar 6, 2026

this is a fresh container, i don't think there is anything to clear?

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Mar 6, 2026

i switch to main branch, they passed

aleyang@08d592f87dea:/workspace/flashinfer$ git checkout main
M	requirements.txt
Switched to branch 'main'
Your branch is up to date with 'origin/main'.
aleyang@08d592f87dea:/workspace/flashinfer$ pytest tests/moe/test_trtllm_cutlass_fused_moe.py
============================================================== test session starts ===============================================================
platform linux -- Python 3.12.12, pytest-9.0.2, pluggy-1.6.0
rootdir: /workspace/flashinfer
configfile: pytest.ini
collected 29 items

tests/moe/test_trtllm_cutlass_fused_moe.py ..ssssssss......ssssss.......                                                                   [100%]

================================================================ warnings summary ================================================================
tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe_fp8[otype0-wtype0-128-2-2-128-1]
  /workspace/flashinfer/tests/moe/test_trtllm_cutlass_fused_moe.py:445: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
    hidden_states_scale = torch.tensor(hidden_states_scale[0]).cuda()

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================= 15 passed, 14 skipped, 1 warning in 2305.76s (0:38:25)

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Mar 6, 2026

now trying running tests after rm -frv ~/.cache/flashinfer/* let's see

@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented Mar 6, 2026

I'll also retry on Hopper

@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented Mar 7, 2026

Hi @aleozlx, rerunning both Blackwell and Hopper now, Blackwell also shows the failure now. Let me debug.

auto-merge was automatically disabled March 7, 2026 01:05

Head branch was pushed to by a user without write access

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (1)

406-516: ⚠️ Potential issue | 🔴 Critical

Verify that explicit instantiation TUs have matching IsMXFPX=false variants for all SM versions and type combinations now being dispatched.

The lambda dispatch_by_mxfpx is instantiated twice at runtime (lines 513, 515) with both IsMXFPX=true and IsMXFPX=false, causing the nested SHAPE_CASE macros to generate references to tma_warp_specialized_generic_moe_gemm_kernelLauncher<> with both boolean variants across all supported SM versions (SM90, SM100, SM120). This widens the linker surface: legacy FP8 paths now bake IsMXFPX=false into symbols that may not have explicit instantiations in the compilation units.

If moe_gemm_kernels_fp8_fp8.cu and related type-specific TUs do not emit matching IsMXFPX=false variants for all the tile/cluster shapes and epilogue configurations now being dispatched, this change is the root cause of the undefined symbol errors in fused_moe_90.so.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h`
around lines 406 - 516, The change instantiates dispatch_by_mxfpx with both
IsMXFPX=true and false which causes references to
tma_warp_specialized_generic_moe_gemm_kernelLauncher<> (via
dispatchMoeGemmSelectClusterShapeTmaWarpSpecialized and the SHAPE_CASE macro)
for SM90/100/120 even when no explicit TU instantiations exist; fix by adding
matching explicit instantiations for the IsMXFPX=false variants (all relevant
template parameter combinations: T, WeightType, EpilogueTag, FUSION and all tile
shapes used in SHAPE_CASE for SM90/100/120) into the corresponding type-specific
TUs (e.g., the FP8/FP4 kernel TUs), or alternately narrow the runtime dispatch
so dispatch_by_mxfpx is only instantiated for the boolean value that actually
has corresponding explicit instantiations to avoid generating undefined symbols.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h`:
- Around line 491-504: The SM120 branch can silently fall through when
kernels::cutlass_kernels::isValidSM120MOESpecialisation<T,WeightType,EpilogueTag,FUSION>()
is false causing count==0 and a no-op; update the else path for the if constexpr
in the gemm_config.sm_version == 120/121 block to fail loudly (throw an
exception or call a fatal logger) with a clear message referencing SM120 and the
tile_config_sm120 value (use pretty_function or include
gemm_config.tile_config_sm120 in the message) so invalid SM120 specialisations
cannot proceed; ensure the thrown exception type matches surrounding error
handling conventions.

---

Outside diff comments:
In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h`:
- Around line 406-516: The change instantiates dispatch_by_mxfpx with both
IsMXFPX=true and false which causes references to
tma_warp_specialized_generic_moe_gemm_kernelLauncher<> (via
dispatchMoeGemmSelectClusterShapeTmaWarpSpecialized and the SHAPE_CASE macro)
for SM90/100/120 even when no explicit TU instantiations exist; fix by adding
matching explicit instantiations for the IsMXFPX=false variants (all relevant
template parameter combinations: T, WeightType, EpilogueTag, FUSION and all tile
shapes used in SHAPE_CASE for SM90/100/120) into the corresponding type-specific
TUs (e.g., the FP8/FP4 kernel TUs), or alternately narrow the runtime dispatch
so dispatch_by_mxfpx is only instantiated for the boolean value that actually
has corresponding explicit instantiations to avoid generating undefined symbols.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 0ec1b1ce-bd64-4a88-a347-ddfc149c8f83

📥 Commits

Reviewing files that changed from the base of the PR and between aba577a and 6ab67b4.

📒 Files selected for processing (1)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h

Comment on lines +491 to 504
} else if (gemm_config.sm_version == 120 || gemm_config.sm_version == 121) {
char const* const pretty_function = __PRETTY_FUNCTION__;
TLLM_LOG_TRACE("At %s, SM120 config=%d", pretty_function,
(int)(gemm_config.tile_config_sm120));
if constexpr (kernels::cutlass_kernels::isValidSM120MOESpecialisation<
T, WeightType, EpilogueTag, FUSION>()) {
switch (gemm_config.tile_config_sm120) {
SHAPE_CASE(120, 128, 128, 64)
SHAPE_CASE(120, 128, 128, 128)
SHAPE_CASE(120, 128, 256, 64)
SHAPE_CASE(120, 256, 128, 64)
DEFAULT_CASE(120)
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Throw on invalid SM120 specialisations instead of silently falling through.

If isValidSM120MOESpecialisation<...>() is false, this branch currently exits without a throw. In the workspace path that can leave count == 0, and in the execution path it becomes a no-op instead of a hard failure.

💡 Proposed fix
     } else if (gemm_config.sm_version == 120 || gemm_config.sm_version == 121) {
       char const* const pretty_function = __PRETTY_FUNCTION__;
       TLLM_LOG_TRACE("At %s, SM120 config=%d", pretty_function,
                      (int)(gemm_config.tile_config_sm120));
       if constexpr (kernels::cutlass_kernels::isValidSM120MOESpecialisation<
                         T, WeightType, EpilogueTag, FUSION>()) {
         switch (gemm_config.tile_config_sm120) {
           SHAPE_CASE(120, 128, 128, 64)
           SHAPE_CASE(120, 128, 128, 128)
           SHAPE_CASE(120, 128, 256, 64)
           SHAPE_CASE(120, 256, 128, 64)
           DEFAULT_CASE(120)
         }
+      } else {
+        TLLM_THROW("Unsupported SM120 configuration requested");
       }
     }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
} else if (gemm_config.sm_version == 120 || gemm_config.sm_version == 121) {
char const* const pretty_function = __PRETTY_FUNCTION__;
TLLM_LOG_TRACE("At %s, SM120 config=%d", pretty_function,
(int)(gemm_config.tile_config_sm120));
if constexpr (kernels::cutlass_kernels::isValidSM120MOESpecialisation<
T, WeightType, EpilogueTag, FUSION>()) {
switch (gemm_config.tile_config_sm120) {
SHAPE_CASE(120, 128, 128, 64)
SHAPE_CASE(120, 128, 128, 128)
SHAPE_CASE(120, 128, 256, 64)
SHAPE_CASE(120, 256, 128, 64)
DEFAULT_CASE(120)
}
}
} else if (gemm_config.sm_version == 120 || gemm_config.sm_version == 121) {
char const* const pretty_function = __PRETTY_FUNCTION__;
TLLM_LOG_TRACE("At %s, SM120 config=%d", pretty_function,
(int)(gemm_config.tile_config_sm120));
if constexpr (kernels::cutlass_kernels::isValidSM120MOESpecialisation<
T, WeightType, EpilogueTag, FUSION>()) {
switch (gemm_config.tile_config_sm120) {
SHAPE_CASE(120, 128, 128, 64)
SHAPE_CASE(120, 128, 128, 128)
SHAPE_CASE(120, 128, 256, 64)
SHAPE_CASE(120, 256, 128, 64)
DEFAULT_CASE(120)
}
} else {
TLLM_THROW("Unsupported SM120 configuration requested");
}
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h`
around lines 491 - 504, The SM120 branch can silently fall through when
kernels::cutlass_kernels::isValidSM120MOESpecialisation<T,WeightType,EpilogueTag,FUSION>()
is false causing count==0 and a no-op; update the else path for the if constexpr
in the gemm_config.sm_version == 120/121 block to fail loudly (throw an
exception or call a fatal logger) with a clear message referencing SM120 and the
tile_config_sm120 value (use pretty_function or include
gemm_config.tile_config_sm120 in the message) so invalid SM120 specialisations
cannot proceed; ensure the thrown exception type matches surrounding error
handling conventions.

@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented Mar 7, 2026

The previous Blackwell failure I saw was due to stale ninja file, irrelevant to this PR.
image

The Hoppper failure is a real regression, now fixed by 6ab67b4 .

Blackwell:

root@B200-146:~/flashinfer-mxfp8-src# git log -1
commit 6ab67b45c922c4919d67314d1412172ad289215c (HEAD -> mxfp8, origin/mxfp8)
Author: Ziang Li <[email protected]>
Date:   Fri Mar 6 17:04:48 2026 -0800

    Disallow `IsMXFPX` for sm90
root@B200-146:~/flashinfer-mxfp8-src# git status
On branch mxfp8
Your branch is up to date with 'origin/mxfp8'.

Changes not staged for commit:
  (use "git add <file>..." to update what will be committed)
  (use "git restore <file>..." to discard changes in working directory)
        modified:   flashinfer/jit/core.py

no changes added to commit (use "git add" and/or "git commit -a")
root@B200-146:~/flashinfer-mxfp8-src# 
root@B200-146:~/flashinfer-mxfp8-src# pytest tests/moe/test_trtllm_cutlass_fused_moe.py
============================== test session starts ==============================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /root/flashinfer-mxfp8-src
configfile: pytest.ini
plugins: hydra-core-1.3.2, anyio-4.12.1, typeguard-4.4.4
collected 35 items                                                              

tests/moe/test_trtllm_cutlass_fused_moe.py ............................ss [ 85%]
sssss                                                                     [100%]

=============================== warnings summary ================================
tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe_fp8[otype0-wtype0-128-2-2-128-1]
  /root/flashinfer-mxfp8-src/tests/moe/test_trtllm_cutlass_fused_moe.py:445: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
    hidden_states_scale = torch.tensor(hidden_states_scale[0]).cuda()

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============= 28 passed, 7 skipped, 1 warning in 159.54s (0:02:39) ==============

Hopper:

root@04bb9764b66d:/flashinfer-mxfp8-src# pytest tests/moe/test_trtllm_cutlass_fused_moe.py
===================================== test session starts ======================================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /flashinfer-mxfp8-src
configfile: pytest.ini
plugins: anyio-4.11.0
collected 35 items                                                                             

tests/moe/test_trtllm_cutlass_fused_moe.py 
..ssssssss......ssssssssssss.......           [100%]

======================================= warnings summary =======================================
tests/moe/test_trtllm_cutlass_fused_moe.py::test_moe_fp8[otype0-wtype0-128-2-2-128-1]
  /flashinfer-mxfp8-src/tests/moe/test_trtllm_cutlass_fused_moe.py:445: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
    hidden_states_scale = torch.tensor(hidden_states_scale[0]).cuda()

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
==================== 15 passed, 20 skipped, 1 warning in 364.08s (0:06:04) =====================
root@04bb9764b66d:/flashinfer-mxfp8-src# 
root@04bb9764b66d:/flashinfer-mxfp8-src# git log -1
commit 6ab67b45c922c4919d67314d1412172ad289215c (HEAD -> mxfp8
, origin/mxfp8)
Author: Ziang Li <[email protected]>
Date:   Fri Mar 6 17:04:48 2026 -0800

    Disallow `IsMXFPX` for sm90
root@04bb9764b66d:/flashinfer-mxfp8-src# git status
On branch mxfp8
Your branch is up to date with 'origin/mxfp8'.

nothing to commit, working tree clean

@aleozlx aleozlx removed the run-ci label Mar 7, 2026
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Mar 7, 2026

/bot run

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Mar 7, 2026

@flashinfer-bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !326 has been updated with latest changes, and the CI pipeline #45573960 is currently running. I'll report back once the pipeline job completes.

@aleozlx aleozlx enabled auto-merge (squash) March 7, 2026 03:07
@aleozlx aleozlx merged commit 62d1b0c into flashinfer-ai:main Mar 7, 2026
42 of 55 checks passed
frankwang28 pushed a commit to frankwang28/flashinfer that referenced this pull request Mar 18, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

@HumansAnd

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

flashinfer-ai#2505 implements mxfp8
for trtllm backend.

However, in SGLang, `--moe-runner-backend flashinfer_trtllm` bypasses
SGLang topk implementation and does not work with expert routing replay
in MoE RL.

We want to implement `mxfp8 x mxfp8` for `cutlass_fused_moe` which works
with MoE RL training.

This PR mainly reuses existing code path for `WMxfp4AMxfp8Quant`:

https://github.com/flashinfer-ai/flashinfer/blob/952b6ab2838d676b4257fcc23bb00f67fdd38efc/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu#L1191

## 🔍 Related Issues

<!-- Link any related issues here -->
miles MXFP8/NVFP4 RL roadmap:
radixark/miles#615
SGLang FlashInfer MXFP8 integration:
sgl-project/sglang#18945

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->



<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Toggleable MXFPX/MXFP8 activation-scaling across MOE inference,
updating workspace sizing, kernel selection, block-scaling and dispatch
to enable MXFP8-aware execution and validation.
* Added MXFP8×MXFP8 quantization mode and emitted MXFPX-aware
GEMM/kernel variants; public APIs now expose an MXFPX/activation-scaling
flag.

* **Tests**
* Added unit tests and helpers for MXFP8 quantization,
packing/dequantization, and end-to-end MXFP8×MXFP8 MOE inference
validation.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

@HumansAnd

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

flashinfer-ai#2505 implements mxfp8
for trtllm backend.

However, in SGLang, `--moe-runner-backend flashinfer_trtllm` bypasses
SGLang topk implementation and does not work with expert routing replay
in MoE RL.

We want to implement `mxfp8 x mxfp8` for `cutlass_fused_moe` which works
with MoE RL training.

This PR mainly reuses existing code path for `WMxfp4AMxfp8Quant`:

https://github.com/flashinfer-ai/flashinfer/blob/952b6ab2838d676b4257fcc23bb00f67fdd38efc/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu#L1191

## 🔍 Related Issues

<!-- Link any related issues here -->
miles MXFP8/NVFP4 RL roadmap:
radixark/miles#615
SGLang FlashInfer MXFP8 integration:
sgl-project/sglang#18945

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Toggleable MXFPX/MXFP8 activation-scaling across MOE inference,
updating workspace sizing, kernel selection, block-scaling and dispatch
to enable MXFP8-aware execution and validation.
* Added MXFP8×MXFP8 quantization mode and emitted MXFPX-aware
GEMM/kernel variants; public APIs now expose an MXFPX/activation-scaling
flag.

* **Tests**
* Added unit tests and helpers for MXFP8 quantization,
packing/dequantization, and end-to-end MXFP8×MXFP8 MOE inference
validation.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Amey Naik <[email protected]>
@zianglih zianglih deleted the mxfp8 branch April 6, 2026 08:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants