From 3cd2b1ec78d804170309d5fe0c5543b4b2a596f8 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 2 Oct 2025 20:49:23 +0000 Subject: [PATCH] Modified CanonicalizeGemmInput() logic to pull from column-wise data for FP8 GEMM on Blackwell when row-wise is not available. Signed-off-by: Alp Dener --- .../common/gemm/cublaslt_gemm.cu | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index f287072bcb..a1b501faea 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -101,6 +101,16 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); } + } else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) { + // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed + // data with the mirrored transpose-flag if we don't have row-wise data. + NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype), + "Input A is missing column-wise usage"); + ret.A = A.columnwise_data.dptr; + ret.transA = is_A_transposed ? CUBLAS_OP_N : CUBLAS_OP_T; + ret.Atype = A.columnwise_data.dtype; + ret.A_scale_inv = A.columnwise_scale_inv.dptr; + ret.lda = is_A_transposed ? m : k; } } else if (is_mxfp_scaling(A.scaling_mode)) { // MXFP8 @@ -160,6 +170,16 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); } + } else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) { + // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed + // data with the mirrored transpose-flag if we don't have row-wise data. + NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype), + "Input B is missing column-wise usage"); + ret.B = B.columnwise_data.dptr; + ret.transB = is_B_transposed ? CUBLAS_OP_N : CUBLAS_OP_T; + ret.Btype = B.columnwise_data.dtype; + ret.B_scale_inv = B.columnwise_scale_inv.dptr; + ret.ldb = is_B_transposed ? k : n; } } else if (is_mxfp_scaling(B.scaling_mode)) { // MXFP8