diff --git a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp index fb034f988..8ca0d21d8 100644 --- a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp +++ b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp @@ -34,7 +34,9 @@ template < typename accscalar_t, typename index_t, int mode, - int vec_size> + int vec_size, + bool per_sample_weights_defined, + bool padding_idx_defined> void embedding_bag( scalar_t* const output, const scalar_t* const weights, @@ -46,13 +48,13 @@ void embedding_bag( const scalar_t* const per_sample_weights, int64_t index_size, int64_t bag_num, - int64_t vec_len, + int64_t feature_dim, index_t padding_idx, bool ignore_offsets, int64_t num_row) { - using vec_t = at::detail::Array; - using vec_acc_t = at::detail::Array; - using vec_idx_t = at::detail::Array; + using vec_t = memory::aligned_vector; + using vec_acc_t = memory::aligned_vector; + using vec_idx_t = memory::aligned_vector; using KernelClass = EmbeddingBagKernelFunctor< scalar_t, accscalar_t, @@ -61,15 +63,20 @@ void embedding_bag( vec_size, vec_t, vec_acc_t, - vec_idx_t>; + vec_idx_t, + per_sample_weights_defined, + padding_idx_defined>; vec_t* o_vec = reinterpret_cast(output); const vec_t* w_vec = reinterpret_cast(weights); vec_idx_t* max_idx_vec = reinterpret_cast(max_index); - vec_len = vec_len / vec_size; - BatchKernelConfig cfg = BatchKernelConfig::make_config( - bag_num, vec_len, 1, bag_num, true, BatchKernelConfig::Policy::pAdaptive); + int vectorized_feature_dim = feature_dim / vec_size; + int64_t work_group_size = syclDeviceMaxWorkGroupSize(); + // TODO: we can set a smaller num_work_group and add for loop in kernel + int64_t num_work_group = ceil_div( + static_cast(bag_num * vectorized_feature_dim), + static_cast(work_group_size)); index_t fixing_bag_size = ignore_offsets ? index_size / bag_num : 0; auto kfn = KernelClass( @@ -81,93 +88,210 @@ void embedding_bag( per_sample_weights, index_size, bag_num, - vec_len, + vectorized_feature_dim, padding_idx, ignore_offsets, o_vec, w_vec, max_idx_vec, - cfg, fixing_bag_size, num_row); sycl_kernel_submit( - cfg.global_size(), cfg.group_size(), getCurrentSYCLQueue(), kfn); + num_work_group * work_group_size, + work_group_size, + getCurrentSYCLQueue(), + kfn); } -#define EMBBAG_KERNEL_ACC( \ - scalar_t, \ - accscalar_t, \ - index_t, \ - mode, \ - vec_size, \ - output, \ - weight, \ - input, \ - offset, \ - offset2bag, \ - bag_size, \ - max_indices, \ - per_sample_weights, \ - index_len, \ - bag_num, \ - vec_len, \ - padding_idx, \ - ignore_offsets, \ - num_row) \ - embedding_bag( \ - output.mutable_data_ptr(), \ - weight.const_data_ptr(), \ - indices.const_data_ptr(), \ - offsets.const_data_ptr(), \ - offset2bag.mutable_data_ptr(), \ - bag_size.mutable_data_ptr(), \ - max_indices.mutable_data_ptr(), \ - per_sample_weights.defined() \ - ? per_sample_weights.const_data_ptr() \ - : nullptr, \ - index_size, \ - bag_num, \ - vec_len, \ - padding_idx, \ - ignore_offsets, \ - num_row) - -#define EMBBAG_KERNEL_NO_ACC( \ - scalar_t, \ - index_t, \ - mode, \ - vec_size, \ - output, \ - weight, \ - input, \ - offset, \ - offset2bag, \ - bag_size, \ - max_indices, \ - per_sample_weights, \ - index_len, \ - bag_num, \ - vec_len, \ - padding_idx, \ - ignore_offsets, \ - num_row) \ - embedding_bag( \ - output.mutable_data_ptr(), \ - weight.const_data_ptr(), \ - indices.const_data_ptr(), \ - offsets.const_data_ptr(), \ - offset2bag.mutable_data_ptr(), \ - bag_size.mutable_data_ptr(), \ - max_indices.mutable_data_ptr(), \ - per_sample_weights.defined() \ - ? per_sample_weights.const_data_ptr() \ - : nullptr, \ - index_size, \ - bag_num, \ - vec_len, \ - padding_idx, \ - ignore_offsets, \ - num_row) +#define EMBBAG_KERNEL_ACC( \ + scalar_t, \ + accscalar_t, \ + index_t, \ + mode, \ + vec_size, \ + output, \ + weight, \ + input, \ + offset, \ + offset2bag, \ + bag_size, \ + max_indices, \ + per_sample_weights, \ + index_len, \ + bag_num, \ + feature_dim, \ + padding_idx, \ + ignore_offsets, \ + num_row) \ + if (per_sample_weights.defined() && padding_idx != -1) \ + embedding_bag( \ + output.mutable_data_ptr(), \ + weight.const_data_ptr(), \ + indices.const_data_ptr(), \ + offsets.const_data_ptr(), \ + offset2bag.mutable_data_ptr(), \ + bag_size.mutable_data_ptr(), \ + max_indices.mutable_data_ptr(), \ + per_sample_weights.const_data_ptr(), \ + index_size, \ + bag_num, \ + feature_dim, \ + padding_idx, \ + ignore_offsets, \ + num_row); \ + else if (!per_sample_weights.defined() && padding_idx != -1) \ + embedding_bag< \ + scalar_t, \ + accscalar_t, \ + index_t, \ + mode, \ + vec_size, \ + false, \ + true>( \ + output.mutable_data_ptr(), \ + weight.const_data_ptr(), \ + indices.const_data_ptr(), \ + offsets.const_data_ptr(), \ + offset2bag.mutable_data_ptr(), \ + bag_size.mutable_data_ptr(), \ + max_indices.mutable_data_ptr(), \ + nullptr, \ + index_size, \ + bag_num, \ + feature_dim, \ + padding_idx, \ + ignore_offsets, \ + num_row); \ + else if (per_sample_weights.defined() && padding_idx == -1) \ + embedding_bag< \ + scalar_t, \ + accscalar_t, \ + index_t, \ + mode, \ + vec_size, \ + true, \ + false>( \ + output.mutable_data_ptr(), \ + weight.const_data_ptr(), \ + indices.const_data_ptr(), \ + offsets.const_data_ptr(), \ + offset2bag.mutable_data_ptr(), \ + bag_size.mutable_data_ptr(), \ + max_indices.mutable_data_ptr(), \ + per_sample_weights.const_data_ptr(), \ + index_size, \ + bag_num, \ + feature_dim, \ + padding_idx, \ + ignore_offsets, \ + num_row); \ + else \ + embedding_bag< \ + scalar_t, \ + accscalar_t, \ + index_t, \ + mode, \ + vec_size, \ + false, \ + false>( \ + output.mutable_data_ptr(), \ + weight.const_data_ptr(), \ + indices.const_data_ptr(), \ + offsets.const_data_ptr(), \ + offset2bag.mutable_data_ptr(), \ + bag_size.mutable_data_ptr(), \ + max_indices.mutable_data_ptr(), \ + nullptr, \ + index_size, \ + bag_num, \ + feature_dim, \ + padding_idx, \ + ignore_offsets, \ + num_row); + +#define EMBBAG_KERNEL_NO_ACC( \ + scalar_t, \ + index_t, \ + mode, \ + vec_size, \ + output, \ + weight, \ + input, \ + offset, \ + offset2bag, \ + bag_size, \ + max_indices, \ + per_sample_weights, \ + index_len, \ + bag_num, \ + feature_dim, \ + padding_idx, \ + ignore_offsets, \ + num_row) \ + if (per_sample_weights.defined() && padding_idx != -1) \ + embedding_bag( \ + output.mutable_data_ptr(), \ + weight.const_data_ptr(), \ + indices.const_data_ptr(), \ + offsets.const_data_ptr(), \ + offset2bag.mutable_data_ptr(), \ + bag_size.mutable_data_ptr(), \ + max_indices.mutable_data_ptr(), \ + per_sample_weights.const_data_ptr(), \ + index_size, \ + bag_num, \ + feature_dim, \ + padding_idx, \ + ignore_offsets, \ + num_row); \ + else if (!per_sample_weights.defined() && padding_idx != -1) \ + embedding_bag( \ + output.mutable_data_ptr(), \ + weight.const_data_ptr(), \ + indices.const_data_ptr(), \ + offsets.const_data_ptr(), \ + offset2bag.mutable_data_ptr(), \ + bag_size.mutable_data_ptr(), \ + max_indices.mutable_data_ptr(), \ + nullptr, \ + index_size, \ + bag_num, \ + feature_dim, \ + padding_idx, \ + ignore_offsets, \ + num_row); \ + else if (per_sample_weights.defined() && padding_idx == -1) \ + embedding_bag( \ + output.mutable_data_ptr(), \ + weight.const_data_ptr(), \ + indices.const_data_ptr(), \ + offsets.const_data_ptr(), \ + offset2bag.mutable_data_ptr(), \ + bag_size.mutable_data_ptr(), \ + max_indices.mutable_data_ptr(), \ + per_sample_weights.const_data_ptr(), \ + index_size, \ + bag_num, \ + feature_dim, \ + padding_idx, \ + ignore_offsets, \ + num_row); \ + else \ + embedding_bag( \ + output.mutable_data_ptr(), \ + weight.const_data_ptr(), \ + indices.const_data_ptr(), \ + offsets.const_data_ptr(), \ + offset2bag.mutable_data_ptr(), \ + bag_size.mutable_data_ptr(), \ + max_indices.mutable_data_ptr(), \ + nullptr, \ + index_size, \ + bag_num, \ + feature_dim, \ + padding_idx, \ + ignore_offsets, \ + num_row); void embedding_bag_sum_template( const Tensor& indices, @@ -180,7 +304,7 @@ void embedding_bag_sum_template( Tensor& max_indices, int64_t index_size, int64_t bag_num, - int64_t vec_len, + int64_t feature_dim, int64_t padding_idx, bool ignore_offsets) { uint64_t num_row = weights.size(0); @@ -201,7 +325,7 @@ void embedding_bag_sum_template( per_sample_weights, \ index_len, \ bag_num, \ - vec_len, \ + feature_dim, \ padding_idx, \ ignore_offsets, \ num_row) @@ -217,7 +341,19 @@ void embedding_bag_sum_template( using accscalar_t = at::acc_type_device; int vec_size = memory::can_vectorize_up_to( (char*)weights.const_data_ptr()); - vec_size = vec_len % vec_size == 0 ? vec_size : 1; + vec_size = feature_dim % vec_size == 0 ? vec_size : 1; + int num_sub_wg = + bag_num * feature_dim / vec_size / syclMaxSubGroupSize(); + int thread_slots = syclGpuEuCount() * syclGpuHWThreadsPerEU(); + for (int v = vec_size; v != 1; + v = v / 2, num_sub_wg = num_sub_wg * 2) { + if (2 * num_sub_wg > thread_slots) { + // peak occurancy = num_sub_wg / thread_slots + // it should > 50% + vec_size = v; + break; + } + } switch (vec_size) { case 8: EXTEND_EMBBAG_SUM_KERNEL_VEC(8); @@ -248,7 +384,7 @@ void embedding_bag_mean_template( Tensor& max_indices, int64_t index_size, int64_t bag_num, - int64_t vec_len, + int64_t feature_dim, int64_t padding_idx, bool ignore_offsets) { uint64_t num_row = weights.size(0); @@ -269,7 +405,7 @@ void embedding_bag_mean_template( per_sample_weights, \ index_len, \ bag_num, \ - vec_len, \ + feature_dim, \ padding_idx, \ ignore_offsets, \ num_row) @@ -285,7 +421,19 @@ void embedding_bag_mean_template( using accscalar_t = at::acc_type_device; int vec_size = memory::can_vectorize_up_to( (char*)weights.const_data_ptr()); - vec_size = vec_len % vec_size == 0 ? vec_size : 1; + vec_size = feature_dim % vec_size == 0 ? vec_size : 1; + int num_sub_wg = + bag_num * feature_dim / vec_size / syclMaxSubGroupSize(); + int thread_slots = syclGpuEuCount() * syclGpuHWThreadsPerEU(); + for (int v = vec_size; v != 1; + v = v / 2, num_sub_wg = num_sub_wg * 2) { + if (2 * num_sub_wg > thread_slots) { + // peak occurancy = num_sub_wg / thread_slots + // it should > 50% + vec_size = v; + break; + } + } switch (vec_size) { case 8: EXTEND_EMBBAG_MEAN_KERNEL_VEC(8); @@ -316,7 +464,7 @@ void embedding_bag_max_template( Tensor& max_indices, int64_t index_size, int64_t bag_num, - int64_t vec_len, + int64_t feature_dim, int64_t padding_idx, bool ignore_offsets) { uint64_t num_row = weights.size(0); @@ -336,7 +484,7 @@ void embedding_bag_max_template( per_sample_weights, \ index_len, \ bag_num, \ - vec_len, \ + feature_dim, \ padding_idx, \ ignore_offsets, \ num_row) @@ -352,7 +500,19 @@ void embedding_bag_max_template( // using accscalar_t = at::acc_type_device; int vec_size = memory::can_vectorize_up_to( (char*)weights.const_data_ptr()); - vec_size = vec_len % vec_size == 0 ? vec_size : 1; + vec_size = feature_dim % vec_size == 0 ? vec_size : 1; + int num_sub_wg = + bag_num * feature_dim / vec_size / syclMaxSubGroupSize(); + int thread_slots = syclGpuEuCount() * syclGpuHWThreadsPerEU(); + for (int v = vec_size; v != 1; + v = v / 2, num_sub_wg = num_sub_wg * 2) { + if (2 * num_sub_wg > thread_slots) { + // peak occurancy = num_sub_wg / thread_slots + // it should > 50% + vec_size = v; + break; + } + } switch (vec_size) { case 8: EXTEND_EMBBAG_MAX_KERNEL_VEC(8); diff --git a/src/ATen/native/xpu/sycl/EmbeddingBag.h b/src/ATen/native/xpu/sycl/EmbeddingBag.h index 8e80182d7..d33084809 100644 --- a/src/ATen/native/xpu/sycl/EmbeddingBag.h +++ b/src/ATen/native/xpu/sycl/EmbeddingBag.h @@ -21,118 +21,129 @@ template < int vec_size, typename vec_t, typename vec_acc_t, - typename vec_idx_t> + typename vec_idx_t, + bool per_sample_weights_defined, + bool padding_idx_defined> struct EmbeddingBagKernelFunctor { - void operator()(sycl::nd_item<2> item) const { - auto desc = cfg_.get_item_desc(item); - index_t start = 0, end = 0; - int64_t off_off = -1; + void operator()(sycl::nd_item<1> item) const { + auto thread_id = item.get_global_linear_id(); + if (thread_id < bag_num_ * vectorized_feature_dim_len_) { + auto current_feature = thread_id % vectorized_feature_dim_len_; + auto current_bag = thread_id / vectorized_feature_dim_len_; + index_t start, end; + bool last_bag = current_bag == bag_num_ - 1; + if (!ignore_offsets_) { + start = offset_[current_bag]; + end = last_bag ? index_size_ : offset_[current_bag + 1]; + } else { + start = current_bag * fixing_bag_size_; + end = start + fixing_bag_size_; + } - do { - if (desc.glb_problem < cfg_.problem_ && - desc.glb_batch < cfg_.problem_batch_) { - bool walk_on_bag = desc.glb_batch != off_off; - if (walk_on_bag) { - off_off = desc.glb_batch; - bool last_bag = off_off == bag_num_ - 1; - if (!ignore_offsets_) { - start = offset_[off_off]; - end = last_bag ? index_size_ : offset_[off_off + 1]; - } else { - start = off_off * fixing_bag_size_; - end = start + fixing_bag_size_; - } - } + vec_acc_t value, value_max; + vec_idx_t index_max; + index_t padding_cnt = 0; - vec_acc_t value, value_max; - vec_idx_t index_max; - index_t padding_cnt = 0; +#pragma unroll + for (int i = 0; i < vec_size; i++) { + value[i] = 0; + } + if constexpr (mode == MODE_MAX) { #pragma unroll for (int i = 0; i < vec_size; i++) { - value[i] = 0; value_max[i] = at::numeric_limits::lower_bound(); index_max[i] = -1; } + } + index_t index_offset, weight_index; + vec_t wei_load; + auto handle_non_padding = [&]() { + wei_load = w_vec_ + [weight_index * vectorized_feature_dim_len_ + current_feature]; - for (index_t off = start; off < end; off++) { - index_t index_off = off; - index_t vec_idx = index_[index_off]; - SYCL_KERNEL_ASSERT(vec_idx < num_row_); - - if (walk_on_bag && desc.glb_problem == 0) { - offset2bag_[index_off] = off_off; - } - - if (padding_idx_ != vec_idx) { - index_t i_off = vec_idx * vec_len_ + desc.glb_problem; - vec_t other = w_vec_[i_off]; - - if constexpr (mode == MODE_SUM) { + if constexpr (mode == MODE_SUM) { #pragma unroll - for (int i = 0; i < vec_size; i++) { - if (per_sample_weights_) { - other[i] *= per_sample_weights_[index_off]; - } - value[i] += other[i]; - } - } else if constexpr (mode == MODE_MEAN) { + for (int i = 0; i < vec_size; i++) { + if constexpr (per_sample_weights_defined) { + wei_load[i] *= per_sample_weights_[index_offset]; + } + value[i] += wei_load[i]; + } + } else if constexpr (mode == MODE_MEAN) { #pragma unroll - for (int i = 0; i < vec_size; i++) { - value[i] += other[i]; - } - } else if constexpr (mode == MODE_MAX) { + for (int i = 0; i < vec_size; i++) { + value[i] += wei_load[i]; + } + } else if constexpr (mode == MODE_MAX) { #pragma unroll - for (int i = 0; i < vec_size; i++) { - if (other[i] > value_max[i]) { - value_max[i] = other[i]; - if (max_index_) { - index_max[i] = vec_idx; - } - } + for (int i = 0; i < vec_size; i++) { + if (wei_load[i] > value_max[i]) { + value_max[i] = wei_load[i]; + if (max_index_) { + index_max[i] = weight_index; } } + } + } + }; + + for (index_t offset_in_bag = start; offset_in_bag < end; + offset_in_bag++) { + index_offset = offset_in_bag; + weight_index = index_[index_offset]; + SYCL_KERNEL_ASSERT(weight_index < num_row_); + + if (current_feature == 0) + offset2bag_[index_offset] = current_bag; + + if constexpr (padding_idx_defined) { + if (padding_idx_ != weight_index) { + handle_non_padding(); } else { padding_cnt++; } + } else { + handle_non_padding(); } + } - int64_t bsize = end - start - padding_cnt; - if (desc.glb_problem == 0) { - bag_size_[off_off] = bsize; - } + int64_t bsize = end - start - padding_cnt; + if (current_feature == 0) { + bag_size_[current_bag] = bsize; + } - index_t o_off = off_off * vec_len_ + desc.glb_problem; - if constexpr (mode == MODE_SUM) { - vec_t o; + index_t o_off = + current_bag * vectorized_feature_dim_len_ + current_feature; + if constexpr (mode == MODE_SUM) { + vec_t o; #pragma unroll - for (int i = 0; i < vec_size; i++) { - o[i] = value[i]; - } - o_vec_[o_off] = o; - } else if constexpr (mode == MODE_MEAN) { - vec_t o; - bsize = bsize == 0 ? 1 : bsize; + for (int i = 0; i < vec_size; i++) { + o[i] = value[i]; + } + o_vec_[o_off] = o; + } else if constexpr (mode == MODE_MEAN) { + vec_t o; + bsize = bsize == 0 ? 1 : bsize; #pragma unroll - for (int i = 0; i < vec_size; i++) { - o[i] = value[i] / bsize; - } - o_vec_[o_off] = o; - } else if constexpr (mode == MODE_MAX) { - vec_t padding; + for (int i = 0; i < vec_size; i++) { + o[i] = value[i] / bsize; + } + o_vec_[o_off] = o; + } else if constexpr (mode == MODE_MAX) { + vec_t padding; #pragma unroll - for (int i = 0; i < vec_size; i++) { - padding[i] = 0; - } - o_vec_[o_off] = - value_max[0] == at::numeric_limits::lower_bound() - ? padding - : value_max; - if (max_index_) { - max_idx_vec_[o_off] = index_max; - } + for (int i = 0; i < vec_size; i++) { + padding[i] = 0; + } + o_vec_[o_off] = + value_max[0] == at::numeric_limits::lower_bound() + ? padding + : value_max; + if (max_index_) { + max_idx_vec_[o_off] = index_max; } } - } while (cfg_.next(item, desc)); + } } EmbeddingBagKernelFunctor( const index_t* const index, @@ -143,13 +154,12 @@ struct EmbeddingBagKernelFunctor { const scalar_t* const per_sample_weights, int64_t index_size, int64_t bag_num, - int64_t vec_len, + int64_t vectorized_feature_dim_len, index_t padding_idx, bool ignore_offsets, vec_t* o_vec, const vec_t* w_vec, vec_idx_t* max_idx_vec, - BatchKernelConfig cfg, index_t fixing_bag_size, index_t num_row) : index_(index), @@ -160,13 +170,12 @@ struct EmbeddingBagKernelFunctor { per_sample_weights_(per_sample_weights), index_size_(index_size), bag_num_(bag_num), - vec_len_(vec_len), + vectorized_feature_dim_len_(vectorized_feature_dim_len), padding_idx_(padding_idx), ignore_offsets_(ignore_offsets), o_vec_(o_vec), w_vec_(w_vec), max_idx_vec_(max_idx_vec), - cfg_(cfg), fixing_bag_size_(fixing_bag_size), num_row_(num_row) {} @@ -179,13 +188,12 @@ struct EmbeddingBagKernelFunctor { const scalar_t* const per_sample_weights_; int64_t index_size_; int64_t bag_num_; - int64_t vec_len_; + int64_t vectorized_feature_dim_len_; index_t padding_idx_; bool ignore_offsets_; vec_t* o_vec_; const vec_t* w_vec_; vec_idx_t* max_idx_vec_; - BatchKernelConfig cfg_; index_t fixing_bag_size_; index_t num_row_; };