Skip to content

[draft] add tileN = 8,16 to SM120 blockscale GEMM.#3495

Draft
b8zhong wants to merge 1 commit into
flashinfer-ai:mainfrom
bzhng-development:brayden/sm120-tile-n-16
Draft

[draft] add tileN = 8,16 to SM120 blockscale GEMM.#3495
b8zhong wants to merge 1 commit into
flashinfer-ai:mainfrom
bzhng-development:brayden/sm120-tile-n-16

Conversation

@b8zhong

@b8zhong b8zhong commented Jun 2, 2026

Copy link
Copy Markdown
Contributor

📌 Description

Note: it's intentionally not added for grouped GEMM to reduce complexity, since there is nearly no speedup (1-2%).
image

🔍 Related Issues

NVIDIA/cutlass#3292

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

@coderabbitai

coderabbitai Bot commented Jun 2, 2026

Copy link
Copy Markdown
Contributor

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 24e4007a-7b5a-4ab6-b180-f7001d216963

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@b8zhong b8zhong changed the title add tileN = 8,16 to SM120 blockscale GEMM. [draft] add tileN = 8,16 to SM120 blockscale GEMM. Jun 2, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for smaller tile sizes (tile_n = 8 and tile_n = 16) for SM120/121 in FP4 and MXFP8 GEMM kernels, along with emitting StreamK scheduler launchers. However, several critical issues were identified: adding tile_n = 8 to the dispatch macros in both MXFP4 and NVFP4 group GEMM files will cause linker errors because the corresponding template instantiations are excluded when swap_ab is false. Additionally, the new tile_n = 8 configurations are missing from the heuristic candidate array, and test coverage for tile_n = 8 needs to be added to the test suite while safely handling the invalid swap_ab = false configuration.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +46 to +48
} else if (tile_n == 8) { \
constexpr int TILE_N = 8; \
return __VA_ARGS__(); \

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Critical Linker Error Bug

Adding tile_n == 8 to the dispatch macro here will cause a linker error (undefined reference) during compilation.

Why this happens:

  1. tile_n and swap_ab are runtime variables dispatched via nested macros (DISPATCH_TILE_N and DISPATCH_SWAP_AB).
  2. When tile_n == 8, the compiler generates branches for both SWAP_AB = true and SWAP_AB = false.
  3. In the SWAP_AB = false branch, it attempts to call CutlassMXFP4GroupwiseScaledGroupGEMMSM120<128, 8, 128, false>.
  4. However, this template instantiation is explicitly excluded in csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja via {% if not (tile_n == 8 and swap_ab == "false") %}.
  5. Because the symbol is referenced but never defined, the linker fails.

Suggested Fix:

Use if constexpr in the innermost block (around line 134) to prevent instantiating the invalid configuration:

if constexpr (is_valid_config<c_type_in_a, c_type_in_b, c_type_sf_a, c_type_sf_b, c_type_out>()) {
  if constexpr (TILE_N == 8 && !SWAP_AB) {
    TVM_FFI_ICHECK(false) << "tile_n=8 is only supported with swap_ab=true";
    return false;
  } else {
    using cutlass_t_in_a = cutlass_dtype_t<c_type_in_a>;
    // ... existing dispatch code ...
  }
}

Comment on lines +46 to +48
} else if (tile_n == 8) { \
constexpr int TILE_N = 8; \
return __VA_ARGS__(); \

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Critical Linker Error Bug

Adding tile_n == 8 to the dispatch macro here will cause a linker error (undefined reference) during compilation.

Why this happens:

  1. tile_n and swap_ab are runtime variables dispatched via nested macros (DISPATCH_TILE_N and DISPATCH_SWAP_AB).
  2. When tile_n == 8, the compiler generates branches for both SWAP_AB = true and SWAP_AB = false.
  3. In the SWAP_AB = false branch, it attempts to call CutlassNVFP4GroupwiseScaledGroupGEMMSM120<128, 8, 128, false>.
  4. However, this template instantiation is explicitly excluded in csrc/group_gemm_nvfp4_groupwise_sm120_kernel_inst.jinja via {% if not (tile_n == 8 and swap_ab == "false") %}.
  5. Because the symbol is referenced but never defined, the linker fails.

Suggested Fix:

Use if constexpr in the innermost block (around line 141) to prevent instantiating the invalid configuration:

if constexpr (is_valid_config<c_type_in_a, c_type_in_b, c_type_sf_a, c_type_sf_b, c_type_out>()) {
  if constexpr (TILE_N == 8 && !SWAP_AB) {
    TVM_FFI_ICHECK(false) << "tile_n=8 is only supported with swap_ab=true";
    return false;
  } else {
    using cutlass_t_in_a = cutlass_dtype_t<c_type_in_a>;
    // ... existing dispatch code ...
  }
}

CutlassTileConfigSM120::CtaShape128x128x256B, CutlassTileConfigSM120::CtaShape256x128x128B,
CutlassTileConfigSM120::CtaShape128x32x128B, CutlassTileConfigSM120::CtaShape128x32x64B,
CutlassTileConfigSM120::CtaShape128x64x128B, CutlassTileConfigSM120::CtaShape128x64x64B,
CutlassTileConfigSM120::CtaShape128x16x128B, CutlassTileConfigSM120::CtaShape128x16x64B,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Missing Candidate Configurations in Heuristic

You have added CtaShape128x8x128B and CtaShape128x8x64B to the CutlassTileConfigSM120 enum and the switch-case dispatchers, but they are missing from the all_tiles candidate array in get_candidate_configs_sm120.

Without adding them here, the heuristic search will never consider or profile these tile_n = 8 configurations, making them completely unreachable at runtime when using the heuristic path.

      CutlassTileConfigSM120::CtaShape128x16x128B,  CutlassTileConfigSM120::CtaShape128x16x64B,
      CutlassTileConfigSM120::CtaShape128x8x128B,   CutlassTileConfigSM120::CtaShape128x8x64B,


for swap_ab in [True, False]:
for tile_n in [32, 64, 128]:
for tile_n in [16, 32, 64, 128]:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Missing Test Coverage for tile_n = 8

You added tile_n = 8 support to the Grouped GEMM kernels, but it is excluded from the test suite here.

To test tile_n = 8 without triggering the ValueError when swap_ab is False, you can dynamically adjust the tile_n list based on swap_ab:

for swap_ab in [True, False]:
    for tile_n in ([8, 16, 32, 64, 128] if swap_ab else [16, 32, 64, 128]):

mma_sm_list = [1]
tile_m_list = [128]
tile_n_list = [32, 64, 128]
tile_n_list = [16, 32, 64, 128]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Missing Test Coverage for tile_n = 8

You added tile_n = 8 support to the Grouped GEMM kernels, but it is excluded from the test suite here.

To test tile_n = 8 safely, add 8 to tile_n_list and skip the invalid swap_ab = False configuration inside the test loop:

# In the loop:
for mma_sm, tile_m, tile_n, tile_k, swap_ab in product(...):
    if tile_n == 8 and not swap_ab:
        continue
Suggested change
tile_n_list = [16, 32, 64, 128]
tile_n_list = [8, 16, 32, 64, 128]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants