Skip to content

Commit 22018c1

Browse files
Added limitation for b16 type
1 parent 8078cee commit 22018c1

File tree

2 files changed

+54
-33
lines changed

2 files changed

+54
-33
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,6 +1315,11 @@ class SYCLGen : public SYCLGenBase {
13151315
bool handle_ldmatrix(const InlineAsmInstruction *Inst) override {
13161316
if (Inst->getNumInputOperands() != 1)
13171317
return SYCLGenError();
1318+
1319+
const auto *Type = dyn_cast<InlineAsmBuiltinType>(Inst->getType(0));
1320+
1321+
if (!Type || Type->getKind() != InlineAsmBuiltinType::b16)
1322+
return SYCLGenError();
13181323

13191324
const InlineAsmVectorExpr *VE;
13201325
if (VE = dyn_cast<InlineAsmVectorExpr>(Inst->getOutputOperand())) {

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2056,14 +2056,16 @@ class joint_matrix {
20562056
const size_t num_elements;
20572057
};
20582058

2059-
/// Loads 1 8x8 b16 (128 bytes) matrix from private memory to local memory per
2060-
/// sub-group. Requires the sub-group size of kernel calling this function to
2061-
/// be 32. 'mat' specifies the matrix index to be loaded. The first '(mat + 1) *
2062-
/// 8' work items of sub-group contain the starting address of their respective
2063-
/// matrix row in 'addr'. After distributing addresses to other work items, each
2064-
/// of the 32 work items load 32-bits (2 packed 16-bit data) into 'm' for a
2065-
/// total of 128 bytes. 'trans' specifies to perform a transposed/non-transposed
2066-
/// load by each work item like below
2059+
/// Collectively loads 1 8x8 b16 (128 bytes) matrix from private memory to local
2060+
/// memory per sub-group. Requires the sub-group size of kernel calling this
2061+
/// function to be 32.
2062+
/// 'mat' specifies the matrix index to be loaded. The first '(mat + 1) * 8'
2063+
/// work items of sub-group contain the starting address of their respective
2064+
/// matrix row in 'addr'.
2065+
/// After distributing addresses to other work items, each of the 32 work items
2066+
/// load 32-bits (2 packed 16-bit data) into 'm' for a total of 128 bytes.
2067+
/// 'trans' specifies to perform a transposed/non-transposed load by each work
2068+
/// item like below
20672069
/// Row Major: Each row of the matrix is loaded by a group of 4 work items(wi)
20682070
/// row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
20692071
/// row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
@@ -2076,9 +2078,11 @@ class joint_matrix {
20762078
/// ...
20772079
/// row-6: wi3 wi7 wi11 ... wi31
20782080
/// row-7: wi3 wi7 wi11 ... wi31
2079-
/// \tparam [in] T The type of result variable
2080-
/// \param [in] addr The address of the matrix in local memory
2081-
/// \param [in] m The private memory to store the matrix
2081+
/// \tparam [in] T Type of result variable (currently only supports 16-bit type)
2082+
/// \param [in] addr The starting address of corresponding matrix row for a work
2083+
/// item in local memory
2084+
/// \param [in] m The private memory to store the matrix. It points to 2 b16
2085+
/// type elements.
20822086
/// \param [in] trans Indicates whether the matrix to be loaded transposed
20832087
/// \param [in] mat The matrix index to be loaded
20842088
template <typename T>
@@ -2129,13 +2133,15 @@ void ldmatrix(uintptr_t addr, T *m, bool trans = false, unsigned mat = 0) {
21292133
}
21302134
}
21312135

2132-
/// Loads 2 8x8 b16 (256 bytes) matrix from private memory to local memory per
2133-
/// sub-group. Requires the sub-group size of kernel calling this function to
2134-
/// be 32. The first 16 work items of sub-group contain the starting address of
2135-
/// their respective matrix row in 'addr'. After distributing addresses to other
2136-
/// work items, each of the 32 work items load 64-bits (32-bits per matrix) into
2137-
/// 'm1' & 'm2' for a total of 256 bytes. 'trans' specifies to perform a
2138-
/// transposed/non-transposed load by each work item like below
2136+
/// Collectively loads 2 8x8 b16 (256 bytes) matrix from private memory to local
2137+
/// memory per sub-group. Requires the sub-group size of kernel calling this
2138+
/// function to be 32.
2139+
/// The first 16 work items of sub-group contain the starting address of their
2140+
/// respective matrix row in 'addr'.
2141+
/// After distributing addresses to other work items, each of the 32 work items
2142+
/// load 64-bits (32-bits per matrix) into 'm1' & 'm2' for a total of 256 bytes.
2143+
/// 'trans' specifies to perform a transposed/non-transposed load by each work
2144+
/// item like below
21392145
/// Row Major: Each row of the matrices is loaded by a group of 4 work items(wi)
21402146
/// row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
21412147
/// row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
@@ -2148,10 +2154,13 @@ void ldmatrix(uintptr_t addr, T *m, bool trans = false, unsigned mat = 0) {
21482154
/// ...
21492155
/// row-6: wi3 wi7 wi11 ... wi31
21502156
/// row-7: wi3 wi7 wi11 ... wi31
2151-
/// \tparam [in] T The type of result variable
2152-
/// \param [in] addr The address of the matrix in local memory
2153-
/// \param [in] m1 The private memory to store data of 1st matrix
2154-
/// \param [in] m2 The private memory to store data of 2nd matrix
2157+
/// \tparam [in] T Type of result variable (currently only supports 16-bit type)
2158+
/// \param [in] addr The starting address of corresponding matrix row for a work
2159+
/// item in local memory
2160+
/// \param [in] m1 The private memory to store the data of 1st matrix. It points
2161+
/// to 2 b16 type elements.
2162+
/// \param [in] m2 The private memory to store the data of 2nd matrix. It points
2163+
/// to 2 b16 type elements.
21552164
/// \param [in] trans Indicates whether the matrix to be loaded transposed
21562165
template <typename T>
21572166
void ldmatrix(uintptr_t addr, T *m1, T *m2, bool trans = false) {
@@ -2161,14 +2170,16 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, bool trans = false) {
21612170
ldmatrix(addr, m2, trans, 1);
21622171
}
21632172

2164-
/// Loads 4 8x8 b16 (512 bytes) matrix from private memory to local memory per
2165-
/// sub-group. Requires the sub-group size of kernel calling this function to
2166-
/// be 32. Each work item of sub-group contains the starting address of their
2173+
/// Collectively loads 4 8x8 b16 (512 bytes) matrix from private memory to local
2174+
/// memory per sub-group. Requires the sub-group size of kernel calling this
2175+
/// function to be 32.
2176+
/// Each work item of sub-group contains the starting address of their
21672177
/// respective matrix row in 'addr'.
21682178
/// After distributing addresses to other work items, each of the 32 work items
21692179
/// load 128-bits (32-bits per matrix) into 'm1', 'm2', 'm3' & 'm4' for a total
2170-
/// of 512 bytes. 'trans' specifies to perform a transposed/non-transposed load
2171-
/// by each work item like below
2180+
/// of 512 bytes.
2181+
/// 'trans' specifies to perform a transposed/non-transposed load by each work
2182+
/// item like below
21722183
/// Row Major: Each row of the matrices is loaded by a group of 4 work items(wi)
21732184
/// row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
21742185
/// row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
@@ -2181,12 +2192,17 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, bool trans = false) {
21812192
/// ...
21822193
/// row-6: wi3 wi7 wi11 ... wi31
21832194
/// row-7: wi3 wi7 wi11 ... wi31
2184-
/// \tparam [in] T The type of result variable
2185-
/// \param [in] addr The address of the matrix in local memory
2186-
/// \param [in] m1 The private memory to store data of 1st matrix
2187-
/// \param [in] m2 The private memory to store data of 2nd matrix
2188-
/// \param [in] m3 The private memory to store data of 3rd matrix
2189-
/// \param [in] m4 The private memory to store data of 4th matrix
2195+
/// \tparam [in] T Type of result variable (currently only supports 16-bit type)
2196+
/// \param [in] addr The starting address of corresponding matrix row for a work
2197+
/// item in local memory
2198+
/// \param [in] m1 The private memory to store the data of 1st matrix. It points
2199+
/// to 2 b16 type elements.
2200+
/// \param [in] m2 The private memory to store the data of 2nd matrix. It points
2201+
/// to 2 b16 type elements.
2202+
/// \param [in] m3 The private memory to store the data of 3rd matrix. It points
2203+
/// to 2 b16 type elements.
2204+
/// \param [in] m4 The private memory to store the data of 4th matrix. It points
2205+
/// to 2 b16 type elements.
21902206
/// \param [in] trans Indicates whether the matrix to be loaded transposed
21912207
template <typename T>
21922208
void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) {

0 commit comments

Comments
 (0)