diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 45a8904375e2b..0f12829c2b9ad 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1990,6 +1990,27 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">, let hasVerifier = 1; } +def LdStMatrixShapeAttr : NVVM_Attr<"LdStMatrixShape", "ld_st_matrix_shape"> { + let summary = "Matrix shape for ldmatrix and stmatrix"; + let parameters = (ins "int":$m, "int":$n); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def LdStMatrixEltTypeB16 : I32EnumAttrCase<"B16", 0, "b16">; +def LdStMatrixEltTypeB8 : I32EnumAttrCase<"B8", 1, "b8">; +def LdStMatrixEltTypeB8X16_B6X16_P32 : I32EnumAttrCase<"B8X16_B6X16_P32", 2, "b8x16.b6x16_p32">; +def LdStMatrixEltTypeB8X16_B4X16_P64 : I32EnumAttrCase<"B8X16_B4X16_P64", 3, "b8x16.b4x16_p64">; + +def LdStMatrixEltType : I32EnumAttr<"LdStMatrixEltType", "Element type for ldmatrix and stmatrix", + [LdStMatrixEltTypeB16, LdStMatrixEltTypeB8, + LdStMatrixEltTypeB8X16_B6X16_P32, LdStMatrixEltTypeB8X16_B4X16_P64]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def LdStMatrixEltTypeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, Arguments<(ins LLVM_PointerShared:$ptr, Variadic:$sources, @@ -2021,13 +2042,16 @@ def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">, Results<(outs AnyType:$res)>, - Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> { + Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr: $num, + MMALayoutAttr: $layout, + LdStMatrixShapeAttr: $shape, + LdStMatrixEltTypeAttr: $eltType)> { let summary = "cooperative matrix load"; string llvmBuilder = [{ auto operands = moduleTranslation.lookupValues(opInst.getOperands()); - auto intId = getLdMatrixIntrinsicId($layout, $num); + auto intId = getLdMatrixIntrinsicId($layout, $num, $shape, $eltType); $res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()}); }]; diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 80b3d85488495..b1ed889194ab0 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -285,11 +285,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern { Value srcPtr = getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType, adaptor.getSrcMemref(), adaptor.getIndices()); + auto shape = NVVM::LdStMatrixShapeAttr::get(rewriter.getContext(), 8, 8); Value ldMatrixResult = b.create( ldMatrixResultType, srcPtr, /*num=*/op.getNumTiles(), /*layout=*/op.getTranspose() ? NVVM::MMALayout::col - : NVVM::MMALayout::row); + : NVVM::MMALayout::row, + /*shape=*/shape, /*eltType=*/NVVM::LdStMatrixEltType::B16); // The ldmatrix operation returns either a single i32 value or a struct of // i32 values. Here we unpack those values and cast them back to their diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 6e29b129e8835..a4912fc6990b6 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -802,19 +802,58 @@ LogicalResult NVVM::LdMatrixOp::verify() { if (addressSpace != NVVM::kSharedMemorySpace) return emitOpError("expected source pointer in memory space 3"); - if (getNum() != 1 && getNum() != 2 && getNum() != 4) - return emitOpError("expected num attribute to be 1, 2 or 4"); + uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN(); + if (m == 8 && n == 8) { + if (num != 1 && num != 2 && num != 4) { + return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 " + "matrix"); + } + if (getEltType() != LdStMatrixEltType::B16) { + return emitOpError("expected element type to be b16 for 8x8 matrix"); + } + } else if (m == 8 && n == 16) { + if (num != 1 && num != 2 && num != 4) { + return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 " + "matrix"); + } + if (getLayout() != MMALayout::row) { + return emitOpError("expected layout to be row for 8x16 matrix"); + } + if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 && + getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) { + return emitOpError("expected element type to be b8x16.b4x16_p64 or " + "b8x16.b6x16_p32 for 8x16 matrix"); + } + } else if (m == 16 && n == 16) { + if (num != 1 && num != 2) { + return emitOpError("expected num attribute to be 1 or 2 for 16x16 " + "matrix"); + } + if (getLayout() != MMALayout::col) { + return emitOpError("expected layout to be col for 16x16 matrix"); + } + if (getEltType() != LdStMatrixEltType::B8 && + getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 && + getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) { + return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or " + "b8x16.b6x16_p32 for 16x16 matrix"); + } + } else { + return emitOpError("expected shape to be 8x8, 8x16 or 16x16"); + } Type i32 = IntegerType::get(getContext(), 32); - if (getNum() == 1 && getType() != i32) + uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num); + if (numElements == 1 && getType() != i32) return emitOpError("expected destination type is i32"); - if (getNum() == 2 || getNum() == 4) { + if (numElements == 2 || numElements == 4) { Type dstType = LLVM::LLVMStructType::getLiteral( - getContext(), SmallVector(getNum(), i32)); + getContext(), SmallVector(numElements, i32)); if (getType() != dstType) return emitOpError("expected destination type is a structure of ") - << getNum() << " elements of type i32"; + << numElements << " elements of type i32"; } + return success(); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index eecca64c4bf81..0904d2b49f184 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -134,33 +134,90 @@ static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) { llvm_unreachable("unsupported vote kind"); } -/// Return the intrinsic ID associated with ldmatrix for the given paramters. -static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, - int32_t num) { +static llvm::Intrinsic::ID +getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, + NVVM::LdStMatrixShapeAttr shape, + NVVM::LdStMatrixEltType eltType) { if (layout == NVVM::MMALayout::row) { - switch (num) { - case 1: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16; - case 2: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16; - case 4: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16; - default: - llvm_unreachable("unsupported number of matrix"); + if (shape.getM() == 8 && shape.getN() == 8 && + eltType == NVVM::LdStMatrixEltType::B16) { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16; + case 2: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16; + case 4: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16; + } + } else if (shape.getM() == 8 && shape.getN() == 16 && + eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { + switch (num) { + case 1: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32; + case 2: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32; + case 4: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32; + } + } else if (shape.getM() == 8 && shape.getN() == 16 && + eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { + switch (num) { + case 1: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64; + case 2: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64; + case 4: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64; + } } - } else { - switch (num) { - case 1: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16; - case 2: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16; - case 4: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16; - default: - llvm_unreachable("unsupported number of matrix"); + if (shape.getM() == 8 && shape.getN() == 8 && + eltType == NVVM::LdStMatrixEltType::B16) { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16; + case 2: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16; + case 4: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16; + } + } else if (shape.getM() == 16 && shape.getN() == 16 && + eltType == NVVM::LdStMatrixEltType::B8) { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8; + case 2: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8; + } + } else if (shape.getM() == 16 && shape.getN() == 16 && + eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { + switch (num) { + case 1: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32; + case 2: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32; + } + } else if (shape.getM() == 16 && shape.getN() == 16 && + eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { + switch (num) { + case 1: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64; + case 2: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64; + } } } + llvm_unreachable("unknown ldmatrix kind"); } /// Return the intrinsic ID associated with st.bulk for the given address type. diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index d0bc806e0aa8c..bc2204ee63ede 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -159,7 +159,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec // CHECK-LABEL: @ldmatrix_x4 func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> { %c0 = arith.constant 0 : index - // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout, num = 4 : i32} {{.*}} -> !llvm.struct<(i32, i32, i32, i32) + // CHECK: nvvm.ldmatrix {{%.+}} {eltType = #nvvm.ld_st_matrix_elttype, layout = #nvvm.mma_layout, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)> %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16> // CHECK: llvm.extractvalue // CHECK: llvm.bitcast @@ -179,7 +179,7 @@ func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> { // CHECK-LABEL: @ldmatrix_x1 func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) -> vector<1x2xf16> { %c0 = arith.constant 0 : index - // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout, num = 1 : i32} {{.*}} -> i32 + // CHECK: nvvm.ldmatrix {{%.+}} {eltType = #nvvm.ld_st_matrix_elttype, layout = #nvvm.mma_layout, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape} : {{.*}} -> i32 %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16> // CHECK: llvm.bitcast // CHECK: llvm.insertvalue diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index bd1106e304c60..fc38ed4c6ecfe 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1114,33 +1114,95 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector< // ----- -llvm.func @wmmald_matrix(%arg0: !llvm.ptr) { +llvm.func @ld_matrix(%arg0: !llvm.ptr) { // expected-error@+1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr) -> i32 + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr) -> i32 llvm.return } // ----- -llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { - // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}} - %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> i32 +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x8 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 llvm.return } // ----- -llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is i32}} - %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32)> + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32)> llvm.return } // ----- -llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) { +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}} - %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected element type to be b16 for 8x8 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4 for 8x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32)> + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected layout to be row for 8x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected element type to be b8x16.b4x16_p64 or b8x16.b6x16_p32 for 8x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1 or 2 for 16x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected layout to be col for 16x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + llvm.return +} + +// ----- + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected element type to be b8, b8x16.b4x16_p64 or b8x16.b6x16_p32 for 16x16 matrix}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + llvm.return +} + +llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { + // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}} + %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 llvm.return } diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index c7fa41c98ac92..6a4edd0d22a08 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -385,17 +385,6 @@ llvm.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) { llvm.return } -// CHECK-LABEL: llvm.func @ld_matrix -llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { - // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout, num = 1 : i32} : (!llvm.ptr<3>) -> i32 - %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> i32 - // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout, num = 2 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> - %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> - // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> - %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> - llvm.return -} - // CHECK-LABEL: llvm.func @redux_sync llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 { // CHECK: nvvm.redux.sync add %{{.*}} diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index f86a04186f512..c4a15b1d62b4a 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -559,17 +559,47 @@ llvm.func @llvm_nvvm_cp_async_bulk_wait_group() { // CHECK-LABEL: @ld_matrix llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) { // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}}) - %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> i32 + %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}}) - %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}}) - %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> - // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}) - %l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> i32 + %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + + // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}) + %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}}) - %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}}) - %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + + // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) + %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) + %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) + %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,eltType =#nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + + // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) + %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> i32 + // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) + %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) + %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + + // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8.p3(ptr addrspace(3) %{{.*}}) + %m16n16_l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8.p3(ptr addrspace(3) %{{.*}}) + %m16n16_l2t = nvvm.ldmatrix %arg0{num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,eltType =#nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + + // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) + %m16n16_b6_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}}) + %m16n16_b6_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + + // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) + %m16n16_b4_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout, shape =#nvvm.ld_st_matrix_shape, eltType = #nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}}) + %m16n16_b4_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout, shape = #nvvm.ld_st_matrix_shape,eltType =#nvvm.ld_st_matrix_elttype} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> llvm.return }