[draft] add tileN = 8,16 to SM120 blockscale GEMM.#3495
Conversation
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces support for smaller tile sizes (tile_n = 8 and tile_n = 16) for SM120/121 in FP4 and MXFP8 GEMM kernels, along with emitting StreamK scheduler launchers. However, several critical issues were identified: adding tile_n = 8 to the dispatch macros in both MXFP4 and NVFP4 group GEMM files will cause linker errors because the corresponding template instantiations are excluded when swap_ab is false. Additionally, the new tile_n = 8 configurations are missing from the heuristic candidate array, and test coverage for tile_n = 8 needs to be added to the test suite while safely handling the invalid swap_ab = false configuration.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| } else if (tile_n == 8) { \ | ||
| constexpr int TILE_N = 8; \ | ||
| return __VA_ARGS__(); \ |
There was a problem hiding this comment.
Critical Linker Error Bug
Adding tile_n == 8 to the dispatch macro here will cause a linker error (undefined reference) during compilation.
Why this happens:
tile_nandswap_abare runtime variables dispatched via nested macros (DISPATCH_TILE_NandDISPATCH_SWAP_AB).- When
tile_n == 8, the compiler generates branches for bothSWAP_AB = trueandSWAP_AB = false. - In the
SWAP_AB = falsebranch, it attempts to callCutlassMXFP4GroupwiseScaledGroupGEMMSM120<128, 8, 128, false>. - However, this template instantiation is explicitly excluded in
csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinjavia{% if not (tile_n == 8 and swap_ab == "false") %}. - Because the symbol is referenced but never defined, the linker fails.
Suggested Fix:
Use if constexpr in the innermost block (around line 134) to prevent instantiating the invalid configuration:
if constexpr (is_valid_config<c_type_in_a, c_type_in_b, c_type_sf_a, c_type_sf_b, c_type_out>()) {
if constexpr (TILE_N == 8 && !SWAP_AB) {
TVM_FFI_ICHECK(false) << "tile_n=8 is only supported with swap_ab=true";
return false;
} else {
using cutlass_t_in_a = cutlass_dtype_t<c_type_in_a>;
// ... existing dispatch code ...
}
}| } else if (tile_n == 8) { \ | ||
| constexpr int TILE_N = 8; \ | ||
| return __VA_ARGS__(); \ |
There was a problem hiding this comment.
Critical Linker Error Bug
Adding tile_n == 8 to the dispatch macro here will cause a linker error (undefined reference) during compilation.
Why this happens:
tile_nandswap_abare runtime variables dispatched via nested macros (DISPATCH_TILE_NandDISPATCH_SWAP_AB).- When
tile_n == 8, the compiler generates branches for bothSWAP_AB = trueandSWAP_AB = false. - In the
SWAP_AB = falsebranch, it attempts to callCutlassNVFP4GroupwiseScaledGroupGEMMSM120<128, 8, 128, false>. - However, this template instantiation is explicitly excluded in
csrc/group_gemm_nvfp4_groupwise_sm120_kernel_inst.jinjavia{% if not (tile_n == 8 and swap_ab == "false") %}. - Because the symbol is referenced but never defined, the linker fails.
Suggested Fix:
Use if constexpr in the innermost block (around line 141) to prevent instantiating the invalid configuration:
if constexpr (is_valid_config<c_type_in_a, c_type_in_b, c_type_sf_a, c_type_sf_b, c_type_out>()) {
if constexpr (TILE_N == 8 && !SWAP_AB) {
TVM_FFI_ICHECK(false) << "tile_n=8 is only supported with swap_ab=true";
return false;
} else {
using cutlass_t_in_a = cutlass_dtype_t<c_type_in_a>;
// ... existing dispatch code ...
}
}| CutlassTileConfigSM120::CtaShape128x128x256B, CutlassTileConfigSM120::CtaShape256x128x128B, | ||
| CutlassTileConfigSM120::CtaShape128x32x128B, CutlassTileConfigSM120::CtaShape128x32x64B, | ||
| CutlassTileConfigSM120::CtaShape128x64x128B, CutlassTileConfigSM120::CtaShape128x64x64B, | ||
| CutlassTileConfigSM120::CtaShape128x16x128B, CutlassTileConfigSM120::CtaShape128x16x64B, |
There was a problem hiding this comment.
Missing Candidate Configurations in Heuristic
You have added CtaShape128x8x128B and CtaShape128x8x64B to the CutlassTileConfigSM120 enum and the switch-case dispatchers, but they are missing from the all_tiles candidate array in get_candidate_configs_sm120.
Without adding them here, the heuristic search will never consider or profile these tile_n = 8 configurations, making them completely unreachable at runtime when using the heuristic path.
CutlassTileConfigSM120::CtaShape128x16x128B, CutlassTileConfigSM120::CtaShape128x16x64B,
CutlassTileConfigSM120::CtaShape128x8x128B, CutlassTileConfigSM120::CtaShape128x8x64B,|
|
||
| for swap_ab in [True, False]: | ||
| for tile_n in [32, 64, 128]: | ||
| for tile_n in [16, 32, 64, 128]: |
There was a problem hiding this comment.
Missing Test Coverage for tile_n = 8
You added tile_n = 8 support to the Grouped GEMM kernels, but it is excluded from the test suite here.
To test tile_n = 8 without triggering the ValueError when swap_ab is False, you can dynamically adjust the tile_n list based on swap_ab:
for swap_ab in [True, False]:
for tile_n in ([8, 16, 32, 64, 128] if swap_ab else [16, 32, 64, 128]):| mma_sm_list = [1] | ||
| tile_m_list = [128] | ||
| tile_n_list = [32, 64, 128] | ||
| tile_n_list = [16, 32, 64, 128] |
There was a problem hiding this comment.
Missing Test Coverage for tile_n = 8
You added tile_n = 8 support to the Grouped GEMM kernels, but it is excluded from the test suite here.
To test tile_n = 8 safely, add 8 to tile_n_list and skip the invalid swap_ab = False configuration inside the test loop:
# In the loop:
for mma_sm, tile_m, tile_n, tile_k, swap_ab in product(...):
if tile_n == 8 and not swap_ab:
continue| tile_n_list = [16, 32, 64, 128] | |
| tile_n_list = [8, 16, 32, 64, 128] |
📌 Description
Note: it's intentionally not added for grouped GEMM to reduce complexity, since there is nearly no speedup (1-2%).

🔍 Related Issues
NVIDIA/cutlass#3292
🧪 Tests
unittest, etc.).Reviewer Notes