@@ -2056,14 +2056,16 @@ class joint_matrix {
20562056 const size_t num_elements;
20572057};
20582058
2059- // / Loads 1 8x8 b16 (128 bytes) matrix from private memory to local memory per
2060- // / sub-group. Requires the sub-group size of kernel calling this function to
2061- // / be 32. 'mat' specifies the matrix index to be loaded. The first '(mat + 1) *
2062- // / 8' work items of sub-group contain the starting address of their respective
2063- // / matrix row in 'addr'. After distributing addresses to other work items, each
2064- // / of the 32 work items load 32-bits (2 packed 16-bit data) into 'm' for a
2065- // / total of 128 bytes. 'trans' specifies to perform a transposed/non-transposed
2066- // / load by each work item like below
2059+ // / Collectively loads 1 8x8 b16 (128 bytes) matrix from private memory to local
2060+ // / memory per sub-group. Requires the sub-group size of kernel calling this
2061+ // / function to be 32.
2062+ // / 'mat' specifies the matrix index to be loaded. The first '(mat + 1) * 8'
2063+ // / work items of sub-group contain the starting address of their respective
2064+ // / matrix row in 'addr'.
2065+ // / After distributing addresses to other work items, each of the 32 work items
2066+ // / load 32-bits (2 packed 16-bit data) into 'm' for a total of 128 bytes.
2067+ // / 'trans' specifies to perform a transposed/non-transposed load by each work
2068+ // / item like below
20672069// / Row Major: Each row of the matrix is loaded by a group of 4 work items(wi)
20682070// / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
20692071// / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
@@ -2076,9 +2078,11 @@ class joint_matrix {
20762078// / ...
20772079// / row-6: wi3 wi7 wi11 ... wi31
20782080// / row-7: wi3 wi7 wi11 ... wi31
2079- // / \tparam [in] T The type of result variable
2080- // / \param [in] addr The address of the matrix in local memory
2081- // / \param [in] m The private memory to store the matrix
2081+ // / \tparam [in] T Type of result variable (currently only supports 16-bit type)
2082+ // / \param [in] addr The starting address of corresponding matrix row for a work
2083+ // / item in local memory
2084+ // / \param [in] m The private memory to store the matrix. It points to 2 b16
2085+ // / type elements.
20822086// / \param [in] trans Indicates whether the matrix to be loaded transposed
20832087// / \param [in] mat The matrix index to be loaded
20842088template <typename T>
@@ -2129,13 +2133,15 @@ void ldmatrix(uintptr_t addr, T *m, bool trans = false, unsigned mat = 0) {
21292133 }
21302134}
21312135
2132- // / Loads 2 8x8 b16 (256 bytes) matrix from private memory to local memory per
2133- // / sub-group. Requires the sub-group size of kernel calling this function to
2134- // / be 32. The first 16 work items of sub-group contain the starting address of
2135- // / their respective matrix row in 'addr'. After distributing addresses to other
2136- // / work items, each of the 32 work items load 64-bits (32-bits per matrix) into
2137- // / 'm1' & 'm2' for a total of 256 bytes. 'trans' specifies to perform a
2138- // / transposed/non-transposed load by each work item like below
2136+ // / Collectively loads 2 8x8 b16 (256 bytes) matrix from private memory to local
2137+ // / memory per sub-group. Requires the sub-group size of kernel calling this
2138+ // / function to be 32.
2139+ // / The first 16 work items of sub-group contain the starting address of their
2140+ // / respective matrix row in 'addr'.
2141+ // / After distributing addresses to other work items, each of the 32 work items
2142+ // / load 64-bits (32-bits per matrix) into 'm1' & 'm2' for a total of 256 bytes.
2143+ // / 'trans' specifies to perform a transposed/non-transposed load by each work
2144+ // / item like below
21392145// / Row Major: Each row of the matrices is loaded by a group of 4 work items(wi)
21402146// / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
21412147// / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
@@ -2148,10 +2154,13 @@ void ldmatrix(uintptr_t addr, T *m, bool trans = false, unsigned mat = 0) {
21482154// / ...
21492155// / row-6: wi3 wi7 wi11 ... wi31
21502156// / row-7: wi3 wi7 wi11 ... wi31
2151- // / \tparam [in] T The type of result variable
2152- // / \param [in] addr The address of the matrix in local memory
2153- // / \param [in] m1 The private memory to store data of 1st matrix
2154- // / \param [in] m2 The private memory to store data of 2nd matrix
2157+ // / \tparam [in] T Type of result variable (currently only supports 16-bit type)
2158+ // / \param [in] addr The starting address of corresponding matrix row for a work
2159+ // / item in local memory
2160+ // / \param [in] m1 The private memory to store the data of 1st matrix. It points
2161+ // / to 2 b16 type elements.
2162+ // / \param [in] m2 The private memory to store the data of 2nd matrix. It points
2163+ // / to 2 b16 type elements.
21552164// / \param [in] trans Indicates whether the matrix to be loaded transposed
21562165template <typename T>
21572166void ldmatrix (uintptr_t addr, T *m1, T *m2, bool trans = false ) {
@@ -2161,14 +2170,16 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, bool trans = false) {
21612170 ldmatrix (addr, m2, trans, 1 );
21622171}
21632172
2164- // / Loads 4 8x8 b16 (512 bytes) matrix from private memory to local memory per
2165- // / sub-group. Requires the sub-group size of kernel calling this function to
2166- // / be 32. Each work item of sub-group contains the starting address of their
2173+ // / Collectively loads 4 8x8 b16 (512 bytes) matrix from private memory to local
2174+ // / memory per sub-group. Requires the sub-group size of kernel calling this
2175+ // / function to be 32.
2176+ // / Each work item of sub-group contains the starting address of their
21672177// / respective matrix row in 'addr'.
21682178// / After distributing addresses to other work items, each of the 32 work items
21692179// / load 128-bits (32-bits per matrix) into 'm1', 'm2', 'm3' & 'm4' for a total
2170- // / of 512 bytes. 'trans' specifies to perform a transposed/non-transposed load
2171- // / by each work item like below
2180+ // / of 512 bytes.
2181+ // / 'trans' specifies to perform a transposed/non-transposed load by each work
2182+ // / item like below
21722183// / Row Major: Each row of the matrices is loaded by a group of 4 work items(wi)
21732184// / row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
21742185// / row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
@@ -2181,12 +2192,17 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, bool trans = false) {
21812192// / ...
21822193// / row-6: wi3 wi7 wi11 ... wi31
21832194// / row-7: wi3 wi7 wi11 ... wi31
2184- // / \tparam [in] T The type of result variable
2185- // / \param [in] addr The address of the matrix in local memory
2186- // / \param [in] m1 The private memory to store data of 1st matrix
2187- // / \param [in] m2 The private memory to store data of 2nd matrix
2188- // / \param [in] m3 The private memory to store data of 3rd matrix
2189- // / \param [in] m4 The private memory to store data of 4th matrix
2195+ // / \tparam [in] T Type of result variable (currently only supports 16-bit type)
2196+ // / \param [in] addr The starting address of corresponding matrix row for a work
2197+ // / item in local memory
2198+ // / \param [in] m1 The private memory to store the data of 1st matrix. It points
2199+ // / to 2 b16 type elements.
2200+ // / \param [in] m2 The private memory to store the data of 2nd matrix. It points
2201+ // / to 2 b16 type elements.
2202+ // / \param [in] m3 The private memory to store the data of 3rd matrix. It points
2203+ // / to 2 b16 type elements.
2204+ // / \param [in] m4 The private memory to store the data of 4th matrix. It points
2205+ // / to 2 b16 type elements.
21902206// / \param [in] trans Indicates whether the matrix to be loaded transposed
21912207template <typename T>
21922208void ldmatrix (uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false ) {
0 commit comments