@@ -2061,13 +2061,12 @@ class joint_matrix {
20612061// / \tparam [in] T The type of result variable
20622062// / \param [in] addr The address of the matrix in local memory
20632063// / \param [in] m The private memory to store the matrix
2064- // / \param [in] item The sycl::nd_item index space class
20652064// / \param [in] trans Indicates whether the matrix to be loaded transposed
20662065// / \param [in] mat The matrix index to be loaded
2067- template <typename T, typename ItemT >
2068- void ldmatrix (uintptr_t addr, T *m, const ItemT &item, bool trans = false ,
2069- unsigned mat = 0 ) {
2070- int lane = item. get_sub_group () .get_local_linear_id ();
2066+ template <typename T>
2067+ void ldmatrix (uintptr_t addr, T *m, bool trans = false , unsigned mat = 0 ) {
2068+ auto sg = sycl::ext::oneapi::this_work_item::get_sub_group ();
2069+ int lane = sg .get_local_linear_id ();
20712070
20722071 int lane_group8_row = lane / 8 ;
20732072 int lane_group8_col = lane % 8 ;
@@ -2079,8 +2078,8 @@ void ldmatrix(uintptr_t addr, T *m, const ItemT &item, bool trans = false,
20792078 src_lane += 1 ;
20802079
20812080 // Broadcast the address from the source lane
2082- auto recv_addr_uintp = dpct::select_from_sub_group (
2083- item. get_sub_group () , addr, mat * 8 + src_lane);
2081+ auto recv_addr_uintp =
2082+ dpct::select_from_sub_group (sg , addr, mat * 8 + src_lane);
20842083
20852084 // Cast the received address from uintptr_t to the type of 'm'
20862085 auto recv_addr = reinterpret_cast <T *>(recv_addr_uintp);
@@ -2092,10 +2091,10 @@ void ldmatrix(uintptr_t addr, T *m, const ItemT &item, bool trans = false,
20922091 int src_lane = (lane % 4 ) * 2 ;
20932092
20942093 // Broadcast the address from the source lane
2095- auto recv_addr_uintp_1 = dpct::select_from_sub_group (
2096- item. get_sub_group () , addr, mat * 8 + src_lane);
2097- auto recv_addr_uintp_2 = dpct::select_from_sub_group (
2098- item. get_sub_group () , addr, mat * 8 + src_lane + 1 );
2094+ auto recv_addr_uintp_1 =
2095+ dpct::select_from_sub_group (sg , addr, mat * 8 + src_lane);
2096+ auto recv_addr_uintp_2 =
2097+ dpct::select_from_sub_group (sg , addr, mat * 8 + src_lane + 1 );
20992098
21002099 // Cast the received address from uintptr_t to 'half *'
21012100 auto recv_addr_1 = reinterpret_cast <sycl::half *>(recv_addr_uintp_1);
@@ -2118,15 +2117,13 @@ void ldmatrix(uintptr_t addr, T *m, const ItemT &item, bool trans = false,
21182117// / \param [in] addr The address of the matrix in local memory
21192118// / \param [in] m1 The private memory to store data of 1st matrix
21202119// / \param [in] m2 The private memory to store data of 2nd matrix
2121- // / \param [in] item The sycl::nd_item index space class
21222120// / \param [in] trans Indicates whether the matrix to be loaded transposed
2123- template <typename T, typename ItemT>
2124- void ldmatrix (uintptr_t addr, T *m1, T *m2, const ItemT &item,
2125- bool trans = false ) {
2121+ template <typename T>
2122+ void ldmatrix (uintptr_t addr, T *m1, T *m2, bool trans = false ) {
21262123 // Load 1st matrix
2127- ldmatrix (addr, m1, item, trans, 0 );
2124+ ldmatrix (addr, m1, trans, 0 );
21282125 // Load 2nd matrix
2129- ldmatrix (addr, m2, item, trans, 1 );
2126+ ldmatrix (addr, m2, trans, 1 );
21302127}
21312128
21322129// / Loads 4 8x8 b16 matrix from local memory to private memory (32-bits per wi)
@@ -2137,19 +2134,17 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, const ItemT &item,
21372134// / \param [in] m2 The private memory to store data of 2nd matrix
21382135// / \param [in] m3 The private memory to store data of 3rd matrix
21392136// / \param [in] m4 The private memory to store data of 4th matrix
2140- // / \param [in] item The sycl::nd_item index space class
21412137// / \param [in] trans Indicates whether the matrix to be loaded transposed
2142- template <typename T, typename ItemT>
2143- void ldmatrix (uintptr_t addr, T *m1, T *m2, T *m3, T *m4, const ItemT &item,
2144- bool trans = false ) {
2138+ template <typename T>
2139+ void ldmatrix (uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false ) {
21452140 // Load 1st matrix
2146- ldmatrix (addr, m1, item, trans, 0 );
2141+ ldmatrix (addr, m1, trans, 0 );
21472142 // Load 2nd matrix
2148- ldmatrix (addr, m2, item, trans, 1 );
2143+ ldmatrix (addr, m2, trans, 1 );
21492144 // Load 3rd matrix
2150- ldmatrix (addr, m3, item, trans, 2 );
2145+ ldmatrix (addr, m3, trans, 2 );
21512146 // Load 4th matrix
2152- ldmatrix (addr, m4, item, trans, 3 );
2147+ ldmatrix (addr, m4, trans, 3 );
21532148}
21542149
21552150} // namespace matrix
0 commit comments