Skip to content

[MLIR][NVVM] Support generating all the ldmatrix intrinsics from NVVM ops #148783

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1990,6 +1990,27 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
let hasVerifier = 1;
}

def LdStMatrixShapeAttr : NVVM_Attr<"LdStMatrixShape", "ld_st_matrix_shape"> {
let summary = "Matrix shape for ldmatrix and stmatrix";
let parameters = (ins "int":$m, "int":$n);
let assemblyFormat = "`<` struct(params) `>`";
}

def LdStMatrixEltTypeB16 : I32EnumAttrCase<"B16", 0, "b16">;
def LdStMatrixEltTypeB8 : I32EnumAttrCase<"B8", 1, "b8">;
def LdStMatrixEltTypeB8X16_B6X16_P32 : I32EnumAttrCase<"B8X16_B6X16_P32", 2, "b8x16.b6x16_p32">;
def LdStMatrixEltTypeB8X16_B4X16_P64 : I32EnumAttrCase<"B8X16_B4X16_P64", 3, "b8x16.b4x16_p64">;

def LdStMatrixEltType : I32EnumAttr<"LdStMatrixEltType", "Element type for ldmatrix and stmatrix",
[LdStMatrixEltTypeB16, LdStMatrixEltTypeB8,
LdStMatrixEltTypeB8X16_B6X16_P32, LdStMatrixEltTypeB8X16_B4X16_P64]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
def LdStMatrixEltTypeAttr : EnumAttr<NVVM_Dialect, LdStMatrixEltType, "ld_st_matrix_elttype"> {
let assemblyFormat = "`<` $value `>`";
}

def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
Arguments<(ins LLVM_PointerShared:$ptr,
Variadic<I32>:$sources,
Expand Down Expand Up @@ -2021,13 +2042,16 @@ def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,

def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
Results<(outs AnyType:$res)>,
Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> {
Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr: $num,
MMALayoutAttr: $layout,
LdStMatrixShapeAttr: $shape,
LdStMatrixEltTypeAttr: $elttype)> {

let summary = "cooperative matrix load";

string llvmBuilder = [{
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
auto intId = getLdMatrixIntrinsicId($layout, $num);
auto intId = getLdMatrixIntrinsicId($layout, $num, $shape, $elttype);
$res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()});
}];

Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
Value srcPtr =
getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
adaptor.getSrcMemref(), adaptor.getIndices());
auto shape = NVVM::LdStMatrixShapeAttr::get(rewriter.getContext(), 8, 8);
Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
ldMatrixResultType, srcPtr,
/*num=*/op.getNumTiles(),
/*layout=*/op.getTranspose() ? NVVM::MMALayout::col
: NVVM::MMALayout::row);
: NVVM::MMALayout::row,
/*shape=*/shape, /*elttype=*/NVVM::LdStMatrixEltType::B16);

// The ldmatrix operation returns either a single i32 value or a struct of
// i32 values. Here we unpack those values and cast them back to their
Expand Down
12 changes: 8 additions & 4 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -806,14 +806,18 @@ LogicalResult NVVM::LdMatrixOp::verify() {
return emitOpError("expected num attribute to be 1, 2 or 4");

Type i32 = IntegerType::get(getContext(), 32);
if (getNum() == 1 && getType() != i32)
uint32_t num = getNum();
if (getShape().getM() == 16 && getShape().getN() == 16) {
num *= 2;
}
if (num == 1 && getType() != i32)
return emitOpError("expected destination type is i32");
if (getNum() == 2 || getNum() == 4) {
if (num == 2 || num == 4) {
Type dstType = LLVM::LLVMStructType::getLiteral(
getContext(), SmallVector<Type>(getNum(), i32));
getContext(), SmallVector<Type>(num, i32));
if (getType() != dstType)
return emitOpError("expected destination type is a structure of ")
<< getNum() << " elements of type i32";
<< num << " elements of type i32";
}
return success();
}
Expand Down
101 changes: 79 additions & 22 deletions mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,33 +134,90 @@ static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
llvm_unreachable("unsupported vote kind");
}

/// Return the intrinsic ID associated with ldmatrix for the given paramters.
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
int32_t num) {
static llvm::Intrinsic::ID
getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
NVVM::LdStMatrixShapeAttr shape,
NVVM::LdStMatrixEltType elttype) {
if (layout == NVVM::MMALayout::row) {
switch (num) {
case 1:
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
case 2:
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
case 4:
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
default:
llvm_unreachable("unsupported number of matrix");
if (shape.getM() == 8 && shape.getN() == 8 &&
elttype == NVVM::LdStMatrixEltType::B16) {
switch (num) {
case 1:
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
case 2:
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
case 4:
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
}
} else if (shape.getM() == 8 && shape.getN() == 16 &&
elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
switch (num) {
case 1:
return llvm::Intrinsic::
nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32;
case 2:
return llvm::Intrinsic::
nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32;
case 4:
return llvm::Intrinsic::
nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
}
} else if (shape.getM() == 8 && shape.getN() == 16 &&
elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
switch (num) {
case 1:
return llvm::Intrinsic::
nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64;
case 2:
return llvm::Intrinsic::
nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64;
case 4:
return llvm::Intrinsic::
nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64;
}
}

} else {
switch (num) {
case 1:
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
case 2:
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
case 4:
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
default:
llvm_unreachable("unsupported number of matrix");
if (shape.getM() == 8 && shape.getN() == 8 &&
elttype == NVVM::LdStMatrixEltType::B16) {
switch (num) {
case 1:
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
case 2:
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
case 4:
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
}
} else if (shape.getM() == 16 && shape.getN() == 16 &&
elttype == NVVM::LdStMatrixEltType::B8) {
switch (num) {
case 1:
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8;
case 2:
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
}
} else if (shape.getM() == 16 && shape.getN() == 16 &&
elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
switch (num) {
case 1:
return llvm::Intrinsic::
nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32;
case 2:
return llvm::Intrinsic::
nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
}
} else if (shape.getM() == 16 && shape.getN() == 16 &&
elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
switch (num) {
case 1:
return llvm::Intrinsic::
nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64;
case 2:
return llvm::Intrinsic::
nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64;
}
}
}
llvm_unreachable("unknown ldmatrix kind");
}

/// Return the intrinsic ID associated with st.bulk for the given address type.
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec
// CHECK-LABEL: @ldmatrix_x4
func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> {
%c0 = arith.constant 0 : index
// CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} {{.*}} -> !llvm.struct<(i32, i32, i32, i32)
// CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)>
%a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16>
// CHECK: llvm.extractvalue
// CHECK: llvm.bitcast
Expand All @@ -179,7 +179,7 @@ func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> {
// CHECK-LABEL: @ldmatrix_x1
func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) -> vector<1x2xf16> {
%c0 = arith.constant 0 : index
// CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} {{.*}} -> i32
// CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>} : {{.*}} -> i32
%a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16>
// CHECK: llvm.bitcast
// CHECK: llvm.insertvalue
Expand Down
17 changes: 13 additions & 4 deletions mlir/test/Dialect/LLVMIR/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1116,34 +1116,43 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<

llvm.func @wmmald_matrix(%arg0: !llvm.ptr) {
// expected-error@+1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}}
%l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr) -> i32
%l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr) -> i32
llvm.return
}

// -----

llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
// expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}}
%l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
%l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
llvm.return
}

// -----

llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
// expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is i32}}
%l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
%l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
llvm.return
}

// -----

llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
// expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}}
%l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
%l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
llvm.return
}

// -----

llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
// expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}}
%l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> i32
llvm.return
}


// -----

llvm.func @caller() {
Expand Down
11 changes: 0 additions & 11 deletions mlir/test/Dialect/LLVMIR/nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -385,17 +385,6 @@ llvm.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
llvm.return
}

// CHECK-LABEL: llvm.func @ld_matrix
llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} : (!llvm.ptr<3>) -> i32
%l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 2 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
%l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
%l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
llvm.return
}

// CHECK-LABEL: llvm.func @redux_sync
llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 {
// CHECK: nvvm.redux.sync add %{{.*}}
Expand Down
44 changes: 37 additions & 7 deletions mlir/test/Target/LLVMIR/nvvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -559,17 +559,47 @@ llvm.func @llvm_nvvm_cp_async_bulk_wait_group() {
// CHECK-LABEL: @ld_matrix
llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
// CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})
%l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
%l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
// CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}})
%l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
%l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}})
%l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}})
%l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<3>) -> i32
%l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>

// CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}})
%l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
// CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}})
%l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
%l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}})
%l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
%l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>

// CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
%m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> i32
// CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
%m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
%m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>,elttype =#nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>

// CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
%m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m = 8, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> i32
// CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
%m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
%m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m = 8, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>

// CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8.p3(ptr addrspace(3) %{{.*}})
%m16n16_l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8.p3(ptr addrspace(3) %{{.*}})
%m16n16_l2t = nvvm.ldmatrix %arg0{num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>,elttype =#nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>

// CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
%m16n16_b6_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
%m16n16_b6_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>

// CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
%m16n16_b4_l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m = 16, n = 16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
%m16n16_b4_l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>,elttype =#nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
llvm.return
}

Expand Down