From fac30998dfd72b1481210db77a85aa4c788f177a Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Tue, 15 Jul 2025 14:17:13 +0800 Subject: [PATCH 1/8] Support generating all the ldmatrix intrinsics from NVVM ops --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 36 ++++++- .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 3 +- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 12 ++- .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 101 ++++++++++++++---- .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 4 +- mlir/test/Dialect/LLVMIR/invalid.mlir | 17 ++- mlir/test/Dialect/LLVMIR/nvvm.mlir | 11 -- mlir/test/Target/LLVMIR/nvvmir.mlir | 44 ++++++-- 8 files changed, 175 insertions(+), 53 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 45a8904375e2b..cfb21e8331d05 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1990,6 +1990,35 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">, let hasVerifier = 1; } +def LdStMatrixShapeM8N8 : I32EnumAttrCase<"M8N8", 0, "m8n8">; +def LdStMatrixShapeM8N16 : I32EnumAttrCase<"M8N16", 1, "m8n16">; +def LdStMatrixShapeM16N8 : I32EnumAttrCase<"M16N8", 2, "m16n8">; +def LdStMatrixShapeM16N16 : I32EnumAttrCase<"M16N16", 3, "m16n16">; + +def LdStMatrixShape : I32EnumAttr<"LdStMatrixShape", "Matrix shape for ldmatrix and stmatrix", + [LdStMatrixShapeM8N8, LdStMatrixShapeM8N16, LdStMatrixShapeM16N8, LdStMatrixShapeM16N16]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def LdStMatrixShapeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def LdStMatrixEltTypeB16 : I32EnumAttrCase<"B16", 0, "b16">; +def LdStMatrixEltTypeB8 : I32EnumAttrCase<"B8", 1, "b8">; +def LdStMatrixEltTypeB8X16_B6X16_P32 : I32EnumAttrCase<"B8X16_B6X16_P32", 2, "b8x16.b6x16_p32">; +def LdStMatrixEltTypeB8X16_B4X16_P64 : I32EnumAttrCase<"B8X16_B4X16_P64", 3, "b8x16.b4x16_p64">; + +def LdStMatrixEltType : I32EnumAttr<"LdStMatrixEltType", "Element type for ldmatrix and stmatrix", + [LdStMatrixEltTypeB16, LdStMatrixEltTypeB8, + LdStMatrixEltTypeB8X16_B6X16_P32, LdStMatrixEltTypeB8X16_B4X16_P64]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def LdStMatrixEltTypeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, Arguments<(ins LLVM_PointerShared:$ptr, Variadic:$sources, @@ -2021,13 +2050,16 @@ def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">, Results<(outs AnyType:$res)>, - Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> { + Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr: $num, + MMALayoutAttr: $layout, + LdStMatrixShapeAttr: $shape, + LdStMatrixEltTypeAttr: $elttype)> { let summary = "cooperative matrix load"; string llvmBuilder = [{ auto operands = moduleTranslation.lookupValues(opInst.getOperands()); - auto intId = getLdMatrixIntrinsicId($layout, $num); + auto intId = getLdMatrixIntrinsicId($layout, $num, $shape, $elttype); $res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()}); }]; diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 80b3d85488495..470dc2512a9ad 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -289,7 +289,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern { ldMatrixResultType, srcPtr, /*num=*/op.getNumTiles(), /*layout=*/op.getTranspose() ? NVVM::MMALayout::col - : NVVM::MMALayout::row); + : NVVM::MMALayout::row, + NVVM::LdStMatrixShape::M8N8, NVVM::LdStMatrixEltType::B16); // The ldmatrix operation returns either a single i32 value or a struct of // i32 values. Here we unpack those values and cast them back to their diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 6e29b129e8835..93c155b67fb5c 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -806,14 +806,18 @@ LogicalResult NVVM::LdMatrixOp::verify() { return emitOpError("expected num attribute to be 1, 2 or 4"); Type i32 = IntegerType::get(getContext(), 32); - if (getNum() == 1 && getType() != i32) + uint32_t num = getNum(); + if (getShape() == LdStMatrixShape::M16N16) { + num *= 2; + } + if (num == 1 && getType() != i32) return emitOpError("expected destination type is i32"); - if (getNum() == 2 || getNum() == 4) { + if (num == 2 || num == 4) { Type dstType = LLVM::LLVMStructType::getLiteral( - getContext(), SmallVector(getNum(), i32)); + getContext(), SmallVector(num, i32)); if (getType() != dstType) return emitOpError("expected destination type is a structure of ") - << getNum() << " elements of type i32"; + << num << " elements of type i32"; } return success(); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index eecca64c4bf81..5d13933519c54 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -134,33 +134,90 @@ static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) { llvm_unreachable("unsupported vote kind"); } -/// Return the intrinsic ID associated with ldmatrix for the given paramters. -static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, - int32_t num) { +static llvm::Intrinsic::ID +getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, + NVVM::LdStMatrixShape shape, + NVVM::LdStMatrixEltType elttype) { if (layout == NVVM::MMALayout::row) { - switch (num) { - case 1: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16; - case 2: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16; - case 4: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16; - default: - llvm_unreachable("unsupported number of matrix"); + if (shape == NVVM::LdStMatrixShape::M8N8 && + elttype == NVVM::LdStMatrixEltType::B16) { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16; + case 2: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16; + case 4: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16; + } + } else if (shape == NVVM::LdStMatrixShape::M8N16 && + elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { + switch (num) { + case 1: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32; + case 2: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32; + case 4: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32; + } + } else if (shape == NVVM::LdStMatrixShape::M8N16 && + elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { + switch (num) { + case 1: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64; + case 2: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64; + case 4: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64; + } } - } else { - switch (num) { - case 1: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16; - case 2: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16; - case 4: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16; - default: - llvm_unreachable("unsupported number of matrix"); + if (shape == NVVM::LdStMatrixShape::M8N8 && + elttype == NVVM::LdStMatrixEltType::B16) { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16; + case 2: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16; + case 4: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16; + } + } else if (shape == NVVM::LdStMatrixShape::M16N16 && + elttype == NVVM::LdStMatrixEltType::B8) { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8; + case 2: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8; + } + } else if (shape == NVVM::LdStMatrixShape::M16N16 && + elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { + switch (num) { + case 1: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32; + case 2: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32; + } + } else if (shape == NVVM::LdStMatrixShape::M16N16 && + elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { + switch (num) { + case 1: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64; + case 2: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64; + } } } + llvm_unreachable("unsupported matrix configuration"); } /// Return the intrinsic ID associated with st.bulk for the given address type. diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index d0bc806e0aa8c..75a556f471373 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -159,7 +159,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec // CHECK-LABEL: @ldmatrix_x4 func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> { %c0 = arith.constant 0 : index - // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout, num = 4 : i32} {{.*}} -> !llvm.struct<(i32, i32, i32, i32) + // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype, layout = #nvvm.mma_layout, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)> %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16> // CHECK: llvm.extractvalue // CHECK: llvm.bitcast @@ -179,7 +179,7 @@ func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> { // CHECK-LABEL: @ldmatrix_x1 func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) -> vector<1x2xf16> { %c0 = arith.constant 0 : index - // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout, num = 1 : i32} {{.*}} -> i32 + // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype, layout = #nvvm.mma_layout, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape} : {{.*}} -> i32 %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16> // CHECK: llvm.bitcast // CHECK: llvm.insertvalue diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index bd1106e304c60..f9def0877d71a 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1116,7 +1116,7 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector< llvm.func @wmmald_matrix(%arg0: !llvm.ptr) { // expected-error@+1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr) -> i32 + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr) -> i32 llvm.return } @@ -1124,7 +1124,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr) { llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}} - %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> i32 + %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 llvm.return } @@ -1132,7 +1132,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is i32}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32)> + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32)> llvm.return } @@ -1140,10 +1140,19 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}} - %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> llvm.return } +// ----- + +llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + llvm.return +} + + // ----- llvm.func @caller() { diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index c7fa41c98ac92..6a4edd0d22a08 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -385,17 +385,6 @@ llvm.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) { llvm.return } -// CHECK-LABEL: llvm.func @ld_matrix -llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { - // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout, num = 1 : i32} : (!llvm.ptr<3>) -> i32 - %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> i32 - // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout, num = 2 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> - %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> - // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> - %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> - llvm.return -} - // CHECK-LABEL: llvm.func @redux_sync llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 { // CHECK: nvvm.redux.sync add %{{.*}} diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index f86a04186f512..89429a762db92 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -559,17 +559,47 @@ llvm.func @llvm_nvvm_cp_async_bulk_wait_group() { // CHECK-LABEL: @ld_matrix llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}}) - %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> i32 + %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}}) - %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}}) - %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> - // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}) - %l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> i32 + %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + + // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}) + %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}}) - %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}}) - %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + + // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) + %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) + %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) + %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,elttype =#nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + + // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) + %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) + %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) + %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + + // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8.p3(ptr addrspace(3) %{{.*}}) + %m16n16_l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8.p3(ptr addrspace(3) %{{.*}}) + %m16n16_l2t = nvvm.ldmatrix %arg0{num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,elttype =#nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + + // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) + %m16n16_b6_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) + %m16n16_b6_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + + // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) + %m16n16_b4_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) + %m16n16_b4_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,elttype =#nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> llvm.return } From 784b7494581a1df03e16ec0faacc85d35c23960d Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Wed, 16 Jul 2025 15:05:37 +0800 Subject: [PATCH 2/8] Modify the arguments of the ldmatrix op --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 16 +++------ .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 3 +- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 2 +- .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 18 +++++----- .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 4 +-- mlir/test/Dialect/LLVMIR/invalid.mlir | 10 +++--- mlir/test/Target/LLVMIR/nvvmir.mlir | 36 +++++++++---------- 7 files changed, 41 insertions(+), 48 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index cfb21e8331d05..6af9f4e36be3d 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1990,18 +1990,10 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">, let hasVerifier = 1; } -def LdStMatrixShapeM8N8 : I32EnumAttrCase<"M8N8", 0, "m8n8">; -def LdStMatrixShapeM8N16 : I32EnumAttrCase<"M8N16", 1, "m8n16">; -def LdStMatrixShapeM16N8 : I32EnumAttrCase<"M16N8", 2, "m16n8">; -def LdStMatrixShapeM16N16 : I32EnumAttrCase<"M16N16", 3, "m16n16">; - -def LdStMatrixShape : I32EnumAttr<"LdStMatrixShape", "Matrix shape for ldmatrix and stmatrix", - [LdStMatrixShapeM8N8, LdStMatrixShapeM8N16, LdStMatrixShapeM16N8, LdStMatrixShapeM16N16]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::NVVM"; -} -def LdStMatrixShapeAttr : EnumAttr { - let assemblyFormat = "`<` $value `>`"; +def LdStMatrixShapeAttr : NVVM_Attr<"LdStMatrixShape", "ld_st_matrix_shape"> { + let summary = "Matrix shape for ldmatrix and stmatrix"; + let parameters = (ins "int":$m, "int":$n); + let assemblyFormat = "`<` struct(params) `>`"; } def LdStMatrixEltTypeB16 : I32EnumAttrCase<"B16", 0, "b16">; diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 470dc2512a9ad..53eeabb16c984 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -285,12 +285,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern { Value srcPtr = getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType, adaptor.getSrcMemref(), adaptor.getIndices()); + auto shape = NVVM::LdStMatrixShapeAttr::get(rewriter.getContext(), 8, 8); Value ldMatrixResult = b.create( ldMatrixResultType, srcPtr, /*num=*/op.getNumTiles(), /*layout=*/op.getTranspose() ? NVVM::MMALayout::col : NVVM::MMALayout::row, - NVVM::LdStMatrixShape::M8N8, NVVM::LdStMatrixEltType::B16); + /*shape=*/shape, /*elttype=*/NVVM::LdStMatrixEltType::B16); // The ldmatrix operation returns either a single i32 value or a struct of // i32 values. Here we unpack those values and cast them back to their diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 93c155b67fb5c..fbb78ed487448 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -807,7 +807,7 @@ LogicalResult NVVM::LdMatrixOp::verify() { Type i32 = IntegerType::get(getContext(), 32); uint32_t num = getNum(); - if (getShape() == LdStMatrixShape::M16N16) { + if (getShape().getM() == 16 && getShape().getN() == 16) { num *= 2; } if (num == 1 && getType() != i32) diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index 5d13933519c54..098336cc035a4 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -136,10 +136,10 @@ static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) { static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, - NVVM::LdStMatrixShape shape, + NVVM::LdStMatrixShapeAttr shape, NVVM::LdStMatrixEltType elttype) { if (layout == NVVM::MMALayout::row) { - if (shape == NVVM::LdStMatrixShape::M8N8 && + if (shape.getM() == 8 && shape.getN() == 8 && elttype == NVVM::LdStMatrixEltType::B16) { switch (num) { case 1: @@ -149,7 +149,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, case 4: return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16; } - } else if (shape == NVVM::LdStMatrixShape::M8N16 && + } else if (shape.getM() == 8 && shape.getN() == 16 && elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { switch (num) { case 1: @@ -162,7 +162,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, return llvm::Intrinsic:: nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32; } - } else if (shape == NVVM::LdStMatrixShape::M8N16 && + } else if (shape.getM() == 8 && shape.getN() == 16 && elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { switch (num) { case 1: @@ -177,7 +177,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, } } } else { - if (shape == NVVM::LdStMatrixShape::M8N8 && + if (shape.getM() == 8 && shape.getN() == 8 && elttype == NVVM::LdStMatrixEltType::B16) { switch (num) { case 1: @@ -187,7 +187,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, case 4: return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16; } - } else if (shape == NVVM::LdStMatrixShape::M16N16 && + } else if (shape.getM() == 16 && shape.getN() == 16 && elttype == NVVM::LdStMatrixEltType::B8) { switch (num) { case 1: @@ -195,7 +195,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, case 2: return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8; } - } else if (shape == NVVM::LdStMatrixShape::M16N16 && + } else if (shape.getM() == 16 && shape.getN() == 16 && elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { switch (num) { case 1: @@ -205,7 +205,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, return llvm::Intrinsic:: nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32; } - } else if (shape == NVVM::LdStMatrixShape::M16N16 && + } else if (shape.getM() == 16 && shape.getN() == 16 && elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { switch (num) { case 1: @@ -217,7 +217,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, } } } - llvm_unreachable("unsupported matrix configuration"); + llvm_unreachable("unknown ldmatrix kind"); } /// Return the intrinsic ID associated with st.bulk for the given address type. diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index 75a556f471373..2c0ed9b68a3c8 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -159,7 +159,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec // CHECK-LABEL: @ldmatrix_x4 func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> { %c0 = arith.constant 0 : index - // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype, layout = #nvvm.mma_layout, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)> + // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype, layout = #nvvm.mma_layout, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)> %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16> // CHECK: llvm.extractvalue // CHECK: llvm.bitcast @@ -179,7 +179,7 @@ func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> { // CHECK-LABEL: @ldmatrix_x1 func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) -> vector<1x2xf16> { %c0 = arith.constant 0 : index - // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype, layout = #nvvm.mma_layout, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape} : {{.*}} -> i32 + // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype, layout = #nvvm.mma_layout, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape} : {{.*}} -> i32 %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16> // CHECK: llvm.bitcast // CHECK: llvm.insertvalue diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index f9def0877d71a..6c0c942a041c6 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1116,7 +1116,7 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector< llvm.func @wmmald_matrix(%arg0: !llvm.ptr) { // expected-error@+1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr) -> i32 + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr) -> i32 llvm.return } @@ -1124,7 +1124,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr) { llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}} - %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 llvm.return } @@ -1132,7 +1132,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is i32}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32)> + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32)> llvm.return } @@ -1140,7 +1140,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}} - %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> llvm.return } @@ -1148,7 +1148,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 llvm.return } diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 89429a762db92..69d791138ec71 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -559,47 +559,47 @@ llvm.func @llvm_nvvm_cp_async_bulk_wait_group() { // CHECK-LABEL: @ld_matrix llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}}) - %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}}) - %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}}) - %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}) - %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}}) - %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}}) - %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) - %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) - %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) - %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,elttype =#nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,elttype =#nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) - %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) - %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) - %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8.p3(ptr addrspace(3) %{{.*}}) - %m16n16_l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %m16n16_l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8.p3(ptr addrspace(3) %{{.*}}) - %m16n16_l2t = nvvm.ldmatrix %arg0{num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,elttype =#nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %m16n16_l2t = nvvm.ldmatrix %arg0{num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,elttype =#nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) - %m16n16_b6_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %m16n16_b6_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) - %m16n16_b6_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %m16n16_b6_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) - %m16n16_b4_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %m16n16_b4_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) - %m16n16_b4_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,elttype =#nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %m16n16_b4_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,elttype =#nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> llvm.return } From 3b7c285b7e8370bd4962a39da7ff0b0f0bacdf26 Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Mon, 21 Jul 2025 16:24:21 +0800 Subject: [PATCH 3/8] Follow the convention of eltType in the naming --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 4 +-- .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 2 +- .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 16 ++++----- .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 4 +-- mlir/test/Dialect/LLVMIR/invalid.mlir | 10 +++--- mlir/test/Target/LLVMIR/nvvmir.mlir | 36 +++++++++---------- 6 files changed, 36 insertions(+), 36 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 6af9f4e36be3d..0f12829c2b9ad 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -2045,13 +2045,13 @@ def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">, Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr: $num, MMALayoutAttr: $layout, LdStMatrixShapeAttr: $shape, - LdStMatrixEltTypeAttr: $elttype)> { + LdStMatrixEltTypeAttr: $eltType)> { let summary = "cooperative matrix load"; string llvmBuilder = [{ auto operands = moduleTranslation.lookupValues(opInst.getOperands()); - auto intId = getLdMatrixIntrinsicId($layout, $num, $shape, $elttype); + auto intId = getLdMatrixIntrinsicId($layout, $num, $shape, $eltType); $res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()}); }]; diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 53eeabb16c984..b1ed889194ab0 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -291,7 +291,7 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern { /*num=*/op.getNumTiles(), /*layout=*/op.getTranspose() ? NVVM::MMALayout::col : NVVM::MMALayout::row, - /*shape=*/shape, /*elttype=*/NVVM::LdStMatrixEltType::B16); + /*shape=*/shape, /*eltType=*/NVVM::LdStMatrixEltType::B16); // The ldmatrix operation returns either a single i32 value or a struct of // i32 values. Here we unpack those values and cast them back to their diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index 098336cc035a4..0904d2b49f184 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -137,10 +137,10 @@ static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) { static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, NVVM::LdStMatrixShapeAttr shape, - NVVM::LdStMatrixEltType elttype) { + NVVM::LdStMatrixEltType eltType) { if (layout == NVVM::MMALayout::row) { if (shape.getM() == 8 && shape.getN() == 8 && - elttype == NVVM::LdStMatrixEltType::B16) { + eltType == NVVM::LdStMatrixEltType::B16) { switch (num) { case 1: return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16; @@ -150,7 +150,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16; } } else if (shape.getM() == 8 && shape.getN() == 16 && - elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { + eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { switch (num) { case 1: return llvm::Intrinsic:: @@ -163,7 +163,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32; } } else if (shape.getM() == 8 && shape.getN() == 16 && - elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { + eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { switch (num) { case 1: return llvm::Intrinsic:: @@ -178,7 +178,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, } } else { if (shape.getM() == 8 && shape.getN() == 8 && - elttype == NVVM::LdStMatrixEltType::B16) { + eltType == NVVM::LdStMatrixEltType::B16) { switch (num) { case 1: return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16; @@ -188,7 +188,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16; } } else if (shape.getM() == 16 && shape.getN() == 16 && - elttype == NVVM::LdStMatrixEltType::B8) { + eltType == NVVM::LdStMatrixEltType::B8) { switch (num) { case 1: return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8; @@ -196,7 +196,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8; } } else if (shape.getM() == 16 && shape.getN() == 16 && - elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { + eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { switch (num) { case 1: return llvm::Intrinsic:: @@ -206,7 +206,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32; } } else if (shape.getM() == 16 && shape.getN() == 16 && - elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { + eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { switch (num) { case 1: return llvm::Intrinsic:: diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index 2c0ed9b68a3c8..bc2204ee63ede 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -159,7 +159,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec // CHECK-LABEL: @ldmatrix_x4 func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> { %c0 = arith.constant 0 : index - // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype, layout = #nvvm.mma_layout, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)> + // CHECK: nvvm.ldmatrix {{%.+}} {eltType = #nvvm.ld_st_matrix_elttype, layout = #nvvm.mma_layout, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)> %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16> // CHECK: llvm.extractvalue // CHECK: llvm.bitcast @@ -179,7 +179,7 @@ func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> { // CHECK-LABEL: @ldmatrix_x1 func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) -> vector<1x2xf16> { %c0 = arith.constant 0 : index - // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype, layout = #nvvm.mma_layout, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape} : {{.*}} -> i32 + // CHECK: nvvm.ldmatrix {{%.+}} {eltType = #nvvm.ld_st_matrix_elttype, layout = #nvvm.mma_layout, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape} : {{.*}} -> i32 %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16> // CHECK: llvm.bitcast // CHECK: llvm.insertvalue diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 6c0c942a041c6..f13960708f800 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1116,7 +1116,7 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector< llvm.func @wmmald_matrix(%arg0: !llvm.ptr) { // expected-error@+1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr) -> i32 + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr) -> i32 llvm.return } @@ -1124,7 +1124,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr) { llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}} - %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 llvm.return } @@ -1132,7 +1132,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is i32}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32)> + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32)> llvm.return } @@ -1140,7 +1140,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}} - %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> llvm.return } @@ -1148,7 +1148,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 llvm.return } diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 69d791138ec71..c4a15b1d62b4a 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -559,47 +559,47 @@ llvm.func @llvm_nvvm_cp_async_bulk_wait_group() { // CHECK-LABEL: @ld_matrix llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}}) - %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}}) - %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}}) - %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}) - %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}}) - %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}}) - %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) - %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) - %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) - %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,elttype =#nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,eltType =#nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) - %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) - %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) - %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8.p3(ptr addrspace(3) %{{.*}}) - %m16n16_l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %m16n16_l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8.p3(ptr addrspace(3) %{{.*}}) - %m16n16_l2t = nvvm.ldmatrix %arg0{num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,elttype =#nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %m16n16_l2t = nvvm.ldmatrix %arg0{num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,eltType =#nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) - %m16n16_b6_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %m16n16_b6_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) - %m16n16_b6_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %m16n16_b6_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) - %m16n16_b4_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, elttype = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %m16n16_b4_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) - %m16n16_b4_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,elttype =#nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %m16n16_b4_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,eltType =#nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> llvm.return } From 7af2a81556cd3f4e196c1a41a805681929346ebc Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Mon, 21 Jul 2025 17:08:32 +0800 Subject: [PATCH 4/8] Add verifier for ldmatrix --- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 55 +++++++++++++---- mlir/test/Dialect/LLVMIR/invalid.mlir | 69 +++++++++++++++++++--- 2 files changed, 106 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index fbb78ed487448..a4912fc6990b6 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -802,23 +802,58 @@ LogicalResult NVVM::LdMatrixOp::verify() { if (addressSpace != NVVM::kSharedMemorySpace) return emitOpError("expected source pointer in memory space 3"); - if (getNum() != 1 && getNum() != 2 && getNum() != 4) - return emitOpError("expected num attribute to be 1, 2 or 4"); + uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN(); + if (m == 8 && n == 8) { + if (num != 1 && num != 2 && num != 4) { + return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 " + "matrix"); + } + if (getEltType() != LdStMatrixEltType::B16) { + return emitOpError("expected element type to be b16 for 8x8 matrix"); + } + } else if (m == 8 && n == 16) { + if (num != 1 && num != 2 && num != 4) { + return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 " + "matrix"); + } + if (getLayout() != MMALayout::row) { + return emitOpError("expected layout to be row for 8x16 matrix"); + } + if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 && + getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) { + return emitOpError("expected element type to be b8x16.b4x16_p64 or " + "b8x16.b6x16_p32 for 8x16 matrix"); + } + } else if (m == 16 && n == 16) { + if (num != 1 && num != 2) { + return emitOpError("expected num attribute to be 1 or 2 for 16x16 " + "matrix"); + } + if (getLayout() != MMALayout::col) { + return emitOpError("expected layout to be col for 16x16 matrix"); + } + if (getEltType() != LdStMatrixEltType::B8 && + getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 && + getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) { + return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or " + "b8x16.b6x16_p32 for 16x16 matrix"); + } + } else { + return emitOpError("expected shape to be 8x8, 8x16 or 16x16"); + } Type i32 = IntegerType::get(getContext(), 32); - uint32_t num = getNum(); - if (getShape().getM() == 16 && getShape().getN() == 16) { - num *= 2; - } - if (num == 1 && getType() != i32) + uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num); + if (numElements == 1 && getType() != i32) return emitOpError("expected destination type is i32"); - if (num == 2 || num == 4) { + if (numElements == 2 || numElements == 4) { Type dstType = LLVM::LLVMStructType::getLiteral( - getContext(), SmallVector(num, i32)); + getContext(), SmallVector(numElements, i32)); if (getType() != dstType) return emitOpError("expected destination type is a structure of ") - << num << " elements of type i32"; + << numElements << " elements of type i32"; } + return success(); } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index f13960708f800..fc38ed4c6ecfe 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1114,7 +1114,7 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector< // ----- -llvm.func @wmmald_matrix(%arg0: !llvm.ptr) { +llvm.func @ld_matrix(%arg0: !llvm.ptr) { // expected-error@+1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}} %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr) -> i32 llvm.return @@ -1122,15 +1122,15 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr) { // ----- -llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { - // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}} +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x8 matrix}} %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 llvm.return } // ----- -llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is i32}} %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32)> llvm.return @@ -1138,7 +1138,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { // ----- -llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}} %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> llvm.return @@ -1146,12 +1146,65 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { // ----- -llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { - // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected element type to be b16 for 8x8 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32)> + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected layout to be row for 8x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected element type to be b8x16.b4x16_p64 or b8x16.b6x16_p32 for 8x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1 or 2 for 16x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected layout to be col for 16x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> llvm.return } +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected element type to be b8, b8x16.b4x16_p64 or b8x16.b6x16_p32 for 16x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + llvm.return +} + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + llvm.return +} // ----- From c0f544e727955f4803cbe04fc3a5eb19f422b68e Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Wed, 23 Jul 2025 11:18:30 +0800 Subject: [PATCH 5/8] Change "ld_st_matrix_elttype" to "ld_st_matrix_elt_type" --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 2 +- .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 4 +-- mlir/test/Dialect/LLVMIR/invalid.mlir | 24 ++++++------- mlir/test/Target/LLVMIR/nvvmir.mlir | 36 +++++++++---------- 4 files changed, 33 insertions(+), 33 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 0f12829c2b9ad..e7f7d37d6b17d 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -2007,7 +2007,7 @@ def LdStMatrixEltType : I32EnumAttr<"LdStMatrixEltType", "Element type for ldmat let genSpecializedAttr = 0; let cppNamespace = "::mlir::NVVM"; } -def LdStMatrixEltTypeAttr : EnumAttr { +def LdStMatrixEltTypeAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; } diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index bc2204ee63ede..090c8bb42ac17 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -159,7 +159,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec // CHECK-LABEL: @ldmatrix_x4 func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> { %c0 = arith.constant 0 : index - // CHECK: nvvm.ldmatrix {{%.+}} {eltType = #nvvm.ld_st_matrix_elttype, layout = #nvvm.mma_layout, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)> + // CHECK: nvvm.ldmatrix {{%.+}} {eltType = #nvvm.ld_st_matrix_elt_type, layout = #nvvm.mma_layout, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)> %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16> // CHECK: llvm.extractvalue // CHECK: llvm.bitcast @@ -179,7 +179,7 @@ func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> { // CHECK-LABEL: @ldmatrix_x1 func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) -> vector<1x2xf16> { %c0 = arith.constant 0 : index - // CHECK: nvvm.ldmatrix {{%.+}} {eltType = #nvvm.ld_st_matrix_elttype, layout = #nvvm.mma_layout, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape} : {{.*}} -> i32 + // CHECK: nvvm.ldmatrix {{%.+}} {eltType = #nvvm.ld_st_matrix_elt_type, layout = #nvvm.mma_layout, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape} : {{.*}} -> i32 %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16> // CHECK: llvm.bitcast // CHECK: llvm.insertvalue diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index fc38ed4c6ecfe..e0e42e46830bc 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1116,7 +1116,7 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector< llvm.func @ld_matrix(%arg0: !llvm.ptr) { // expected-error@+1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr) -> i32 + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr) -> i32 llvm.return } @@ -1124,7 +1124,7 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr) { llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x8 matrix}} - %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 llvm.return } @@ -1132,7 +1132,7 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is i32}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32)> + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32)> llvm.return } @@ -1140,7 +1140,7 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}} - %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> llvm.return } @@ -1148,7 +1148,7 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected element type to be b16 for 8x8 matrix}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 llvm.return } @@ -1156,7 +1156,7 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x16 matrix}} - %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32)> + %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32)> llvm.return } @@ -1164,7 +1164,7 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected layout to be row for 8x16 matrix}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 llvm.return } @@ -1172,7 +1172,7 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected element type to be b8x16.b4x16_p64 or b8x16.b6x16_p32 for 8x16 matrix}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 llvm.return } @@ -1180,7 +1180,7 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1 or 2 for 16x16 matrix}} - %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> llvm.return } @@ -1188,7 +1188,7 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected layout to be col for 16x16 matrix}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> llvm.return } @@ -1196,13 +1196,13 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected element type to be b8, b8x16.b4x16_p64 or b8x16.b6x16_p32 for 16x16 matrix}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 llvm.return } llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 llvm.return } diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index c4a15b1d62b4a..f502a2f6b5769 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -559,47 +559,47 @@ llvm.func @llvm_nvvm_cp_async_bulk_wait_group() { // CHECK-LABEL: @ld_matrix llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}}) - %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}}) - %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}}) - %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}) - %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}}) - %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}}) - %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) - %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) - %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) - %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,eltType =#nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,eltType =#nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) - %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) - %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) - %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8.p3(ptr addrspace(3) %{{.*}}) - %m16n16_l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %m16n16_l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8.p3(ptr addrspace(3) %{{.*}}) - %m16n16_l2t = nvvm.ldmatrix %arg0{num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,eltType =#nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %m16n16_l2t = nvvm.ldmatrix %arg0{num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,eltType =#nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) - %m16n16_b6_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %m16n16_b6_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) - %m16n16_b6_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %m16n16_b6_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) - %m16n16_b4_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %m16n16_b4_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) - %m16n16_b4_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,eltType =#nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %m16n16_b4_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,eltType =#nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> llvm.return } From dfd19cbc592d70ccdc5f4f9462c84de33f549598 Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Wed, 23 Jul 2025 11:27:21 +0800 Subject: [PATCH 6/8] Simplifier the structure of `getLdMatrixIntrinsicId` --- .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 57 ++++++++----------- 1 file changed, 25 insertions(+), 32 deletions(-) diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index 0904d2b49f184..33b784eb56504 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -138,19 +138,26 @@ static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, NVVM::LdStMatrixShapeAttr shape, NVVM::LdStMatrixEltType eltType) { - if (layout == NVVM::MMALayout::row) { - if (shape.getM() == 8 && shape.getN() == 8 && - eltType == NVVM::LdStMatrixEltType::B16) { - switch (num) { - case 1: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16; - case 2: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16; - case 4: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16; - } - } else if (shape.getM() == 8 && shape.getN() == 16 && - eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { + if (shape.getM() == 8 && shape.getN() == 8) { + switch (num) { + case 1: + return (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16 + : llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16; + case 2: + return (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16 + : llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16; + case 4: + return (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16 + : llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16; + } + } else if (shape.getM() == 8 && shape.getN() == 16) { + if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { switch (num) { case 1: return llvm::Intrinsic:: @@ -162,8 +169,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, return llvm::Intrinsic:: nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32; } - } else if (shape.getM() == 8 && shape.getN() == 16 && - eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { + } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { switch (num) { case 1: return llvm::Intrinsic:: @@ -176,27 +182,15 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64; } } - } else { - if (shape.getM() == 8 && shape.getN() == 8 && - eltType == NVVM::LdStMatrixEltType::B16) { - switch (num) { - case 1: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16; - case 2: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16; - case 4: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16; - } - } else if (shape.getM() == 16 && shape.getN() == 16 && - eltType == NVVM::LdStMatrixEltType::B8) { + } else if (shape.getM() == 16 && shape.getN() == 16) { + if (eltType == NVVM::LdStMatrixEltType::B8) { switch (num) { case 1: return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8; case 2: return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8; } - } else if (shape.getM() == 16 && shape.getN() == 16 && - eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { + } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { switch (num) { case 1: return llvm::Intrinsic:: @@ -205,8 +199,7 @@ getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, return llvm::Intrinsic:: nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32; } - } else if (shape.getM() == 16 && shape.getN() == 16 && - eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { + } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { switch (num) { case 1: return llvm::Intrinsic:: From 40b098b53dd7dc2c6aec1eb39a8eeb0d84b63e45 Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Wed, 23 Jul 2025 11:32:47 +0800 Subject: [PATCH 7/8] Move the negative test to nvvmir-invalid.mlir --- mlir/test/Dialect/LLVMIR/invalid.mlir | 94 --------------------- mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 94 +++++++++++++++++++++ 2 files changed, 94 insertions(+), 94 deletions(-) diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index e0e42e46830bc..c094f37a15e66 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1114,100 +1114,6 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector< // ----- -llvm.func @ld_matrix(%arg0: !llvm.ptr) { - // expected-error@+1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr) -> i32 - llvm.return -} - -// ----- - -llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { - // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x8 matrix}} - %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 - llvm.return -} - -// ----- - -llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { - // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is i32}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32)> - llvm.return -} - -// ----- - -llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { - // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}} - %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> - llvm.return -} - -// ----- - -llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { - // expected-error@+1 {{'nvvm.ldmatrix' op expected element type to be b16 for 8x8 matrix}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 - llvm.return -} - -// ----- - -llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { - // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x16 matrix}} - %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32)> - llvm.return -} - -// ----- - -llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { - // expected-error@+1 {{'nvvm.ldmatrix' op expected layout to be row for 8x16 matrix}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 - llvm.return -} - -// ----- - -llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { - // expected-error@+1 {{'nvvm.ldmatrix' op expected element type to be b8x16.b4x16_p64 or b8x16.b6x16_p32 for 8x16 matrix}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 - llvm.return -} - -// ----- - -llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { - // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1 or 2 for 16x16 matrix}} - %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> - llvm.return -} - -// ----- - -llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { - // expected-error@+1 {{'nvvm.ldmatrix' op expected layout to be col for 16x16 matrix}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> - llvm.return -} - -// ----- - -llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { - // expected-error@+1 {{'nvvm.ldmatrix' op expected element type to be b8, b8x16.b4x16_p64 or b8x16.b6x16_p32 for 16x16 matrix}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 - llvm.return -} - -llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { - // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 - llvm.return -} - -// ----- - llvm.func @caller() { // expected-error @below {{expected function call to produce a value}} llvm.call @callee() : () -> () diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index 8c4f0aafd36a7..0f713b9c918ee 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -312,3 +312,97 @@ llvm.func @nvvm_prefetch_uniform_with_invalid_addr_space(%global_ptr: !llvm.ptr< nvvm.prefetch level = L1 uniform, %global_ptr : !llvm.ptr<1> llvm.return } + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr) -> i32 + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x8 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is i32}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32)> + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}} + %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected element type to be b16 for 8x8 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32)> + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected layout to be row for 8x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected element type to be b8x16.b4x16_p64 or b8x16.b6x16_p32 for 8x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1 or 2 for 16x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected layout to be col for 16x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected element type to be b8, b8x16.b4x16_p64 or b8x16.b6x16_p32 for 16x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 + llvm.return +} + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32 + llvm.return +} \ No newline at end of file From a7d309966ebf0dfa27d1e88849c894365c931561 Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Wed, 23 Jul 2025 11:38:42 +0800 Subject: [PATCH 8/8] Keep the pointer as shared instead of any --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 2 +- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 5 ----- mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 8 -------- 3 files changed, 1 insertion(+), 14 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index e7f7d37d6b17d..1bc7bd0d744ef 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -2042,7 +2042,7 @@ def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">, Results<(outs AnyType:$res)>, - Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr: $num, + Arguments<(ins LLVM_PointerShared: $ptr, I32Attr: $num, MMALayoutAttr: $layout, LdStMatrixShapeAttr: $shape, LdStMatrixEltTypeAttr: $eltType)> { diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index a4912fc6990b6..42389c0008a80 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -797,11 +797,6 @@ LogicalResult NVVM::WMMAMmaOp::verify() { } LogicalResult NVVM::LdMatrixOp::verify() { - unsigned addressSpace = - llvm::cast(getPtr().getType()).getAddressSpace(); - if (addressSpace != NVVM::kSharedMemorySpace) - return emitOpError("expected source pointer in memory space 3"); - uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN(); if (m == 8 && n == 8) { if (num != 1 && num != 2 && num != 4) { diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index 0f713b9c918ee..254373ec3f47c 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -315,14 +315,6 @@ llvm.func @nvvm_prefetch_uniform_with_invalid_addr_space(%global_ptr: !llvm.ptr< // ----- -llvm.func @ld_matrix(%arg0: !llvm.ptr) { - // expected-error@+1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr) -> i32 - llvm.return -} - -// ----- - llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x8 matrix}} %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elt_type} : (!llvm.ptr<3>) -> i32