From 9d1a19dd2d9faaacfa638bca16392e7da42b757a Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Sat, 7 Jun 2025 08:47:38 +0000 Subject: [PATCH 01/11] save --- src/ATen/native/xpu/sycl/EmbeddingBag.cpp | 271 ++++++++++++++++------ src/ATen/native/xpu/sycl/EmbeddingBag.h | 75 +++--- 2 files changed, 238 insertions(+), 108 deletions(-) diff --git a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp index fb034f9887..e6196fdbcc 100644 --- a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp +++ b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp @@ -34,7 +34,9 @@ template < typename accscalar_t, typename index_t, int mode, - int vec_size> + int vec_size, + bool per_sample_weights_defined, + bool padding_idx_defined> void embedding_bag( scalar_t* const output, const scalar_t* const weights, @@ -61,7 +63,9 @@ void embedding_bag( vec_size, vec_t, vec_acc_t, - vec_idx_t>; + vec_idx_t, + per_sample_weights_defined, + padding_idx_defined>; vec_t* o_vec = reinterpret_cast(output); const vec_t* w_vec = reinterpret_cast(weights); @@ -94,80 +98,195 @@ void embedding_bag( cfg.global_size(), cfg.group_size(), getCurrentSYCLQueue(), kfn); } -#define EMBBAG_KERNEL_ACC( \ - scalar_t, \ - accscalar_t, \ - index_t, \ - mode, \ - vec_size, \ - output, \ - weight, \ - input, \ - offset, \ - offset2bag, \ - bag_size, \ - max_indices, \ - per_sample_weights, \ - index_len, \ - bag_num, \ - vec_len, \ - padding_idx, \ - ignore_offsets, \ - num_row) \ - embedding_bag( \ - output.mutable_data_ptr(), \ - weight.const_data_ptr(), \ - indices.const_data_ptr(), \ - offsets.const_data_ptr(), \ - offset2bag.mutable_data_ptr(), \ - bag_size.mutable_data_ptr(), \ - max_indices.mutable_data_ptr(), \ - per_sample_weights.defined() \ - ? per_sample_weights.const_data_ptr() \ - : nullptr, \ - index_size, \ - bag_num, \ - vec_len, \ - padding_idx, \ - ignore_offsets, \ - num_row) - -#define EMBBAG_KERNEL_NO_ACC( \ - scalar_t, \ - index_t, \ - mode, \ - vec_size, \ - output, \ - weight, \ - input, \ - offset, \ - offset2bag, \ - bag_size, \ - max_indices, \ - per_sample_weights, \ - index_len, \ - bag_num, \ - vec_len, \ - padding_idx, \ - ignore_offsets, \ - num_row) \ - embedding_bag( \ - output.mutable_data_ptr(), \ - weight.const_data_ptr(), \ - indices.const_data_ptr(), \ - offsets.const_data_ptr(), \ - offset2bag.mutable_data_ptr(), \ - bag_size.mutable_data_ptr(), \ - max_indices.mutable_data_ptr(), \ - per_sample_weights.defined() \ - ? per_sample_weights.const_data_ptr() \ - : nullptr, \ - index_size, \ - bag_num, \ - vec_len, \ - padding_idx, \ - ignore_offsets, \ - num_row) +#define EMBBAG_KERNEL_ACC( \ + scalar_t, \ + accscalar_t, \ + index_t, \ + mode, \ + vec_size, \ + output, \ + weight, \ + input, \ + offset, \ + offset2bag, \ + bag_size, \ + max_indices, \ + per_sample_weights, \ + index_len, \ + bag_num, \ + vec_len, \ + padding_idx, \ + ignore_offsets, \ + num_row) \ + if (per_sample_weights.defined() && padding_idx != -1) \ + embedding_bag( \ + output.mutable_data_ptr(), \ + weight.const_data_ptr(), \ + indices.const_data_ptr(), \ + offsets.const_data_ptr(), \ + offset2bag.mutable_data_ptr(), \ + bag_size.mutable_data_ptr(), \ + max_indices.mutable_data_ptr(), \ + per_sample_weights.const_data_ptr(), \ + index_size, \ + bag_num, \ + vec_len, \ + padding_idx, \ + ignore_offsets, \ + num_row); \ + else if (!per_sample_weights.defined() && padding_idx != -1) \ + embedding_bag< \ + scalar_t, \ + accscalar_t, \ + index_t, \ + mode, \ + vec_size, \ + false, \ + true>( \ + output.mutable_data_ptr(), \ + weight.const_data_ptr(), \ + indices.const_data_ptr(), \ + offsets.const_data_ptr(), \ + offset2bag.mutable_data_ptr(), \ + bag_size.mutable_data_ptr(), \ + max_indices.mutable_data_ptr(), \ + nullptr, \ + index_size, \ + bag_num, \ + vec_len, \ + padding_idx, \ + ignore_offsets, \ + num_row); \ + else if (per_sample_weights.defined() && padding_idx == -1) \ + embedding_bag< \ + scalar_t, \ + accscalar_t, \ + index_t, \ + mode, \ + vec_size, \ + true, \ + false>( \ + output.mutable_data_ptr(), \ + weight.const_data_ptr(), \ + indices.const_data_ptr(), \ + offsets.const_data_ptr(), \ + offset2bag.mutable_data_ptr(), \ + bag_size.mutable_data_ptr(), \ + max_indices.mutable_data_ptr(), \ + per_sample_weights.const_data_ptr(), \ + index_size, \ + bag_num, \ + vec_len, \ + padding_idx, \ + ignore_offsets, \ + num_row); \ + else \ + embedding_bag< \ + scalar_t, \ + accscalar_t, \ + index_t, \ + mode, \ + vec_size, \ + false, \ + false>( \ + output.mutable_data_ptr(), \ + weight.const_data_ptr(), \ + indices.const_data_ptr(), \ + offsets.const_data_ptr(), \ + offset2bag.mutable_data_ptr(), \ + bag_size.mutable_data_ptr(), \ + max_indices.mutable_data_ptr(), \ + nullptr, \ + index_size, \ + bag_num, \ + vec_len, \ + padding_idx, \ + ignore_offsets, \ + num_row); + +#define EMBBAG_KERNEL_NO_ACC( \ + scalar_t, \ + index_t, \ + mode, \ + vec_size, \ + output, \ + weight, \ + input, \ + offset, \ + offset2bag, \ + bag_size, \ + max_indices, \ + per_sample_weights, \ + index_len, \ + bag_num, \ + vec_len, \ + padding_idx, \ + ignore_offsets, \ + num_row) \ + if (per_sample_weights.defined() && padding_idx != -1) \ + embedding_bag( \ + output.mutable_data_ptr(), \ + weight.const_data_ptr(), \ + indices.const_data_ptr(), \ + offsets.const_data_ptr(), \ + offset2bag.mutable_data_ptr(), \ + bag_size.mutable_data_ptr(), \ + max_indices.mutable_data_ptr(), \ + per_sample_weights.const_data_ptr(), \ + index_size, \ + bag_num, \ + vec_len, \ + padding_idx, \ + ignore_offsets, \ + num_row); \ + else if (!per_sample_weights.defined() && padding_idx != -1) \ + embedding_bag( \ + output.mutable_data_ptr(), \ + weight.const_data_ptr(), \ + indices.const_data_ptr(), \ + offsets.const_data_ptr(), \ + offset2bag.mutable_data_ptr(), \ + bag_size.mutable_data_ptr(), \ + max_indices.mutable_data_ptr(), \ + nullptr, \ + index_size, \ + bag_num, \ + vec_len, \ + padding_idx, \ + ignore_offsets, \ + num_row); \ + else if (per_sample_weights.defined() && padding_idx == -1) \ + embedding_bag( \ + output.mutable_data_ptr(), \ + weight.const_data_ptr(), \ + indices.const_data_ptr(), \ + offsets.const_data_ptr(), \ + offset2bag.mutable_data_ptr(), \ + bag_size.mutable_data_ptr(), \ + max_indices.mutable_data_ptr(), \ + per_sample_weights.const_data_ptr(), \ + index_size, \ + bag_num, \ + vec_len, \ + padding_idx, \ + ignore_offsets, \ + num_row); \ + else \ + embedding_bag( \ + output.mutable_data_ptr(), \ + weight.const_data_ptr(), \ + indices.const_data_ptr(), \ + offsets.const_data_ptr(), \ + offset2bag.mutable_data_ptr(), \ + bag_size.mutable_data_ptr(), \ + max_indices.mutable_data_ptr(), \ + nullptr, \ + index_size, \ + bag_num, \ + vec_len, \ + padding_idx, \ + ignore_offsets, \ + num_row); void embedding_bag_sum_template( const Tensor& indices, diff --git a/src/ATen/native/xpu/sycl/EmbeddingBag.h b/src/ATen/native/xpu/sycl/EmbeddingBag.h index 8e80182d73..904bc615d2 100644 --- a/src/ATen/native/xpu/sycl/EmbeddingBag.h +++ b/src/ATen/native/xpu/sycl/EmbeddingBag.h @@ -21,7 +21,9 @@ template < int vec_size, typename vec_t, typename vec_acc_t, - typename vec_idx_t> + typename vec_idx_t, + bool per_sample_weights_defined, + bool padding_idx_defined> struct EmbeddingBagKernelFunctor { void operator()(sycl::nd_item<2> item) const { auto desc = cfg_.get_item_desc(item); @@ -53,46 +55,55 @@ struct EmbeddingBagKernelFunctor { value_max[i] = at::numeric_limits::lower_bound(); index_max[i] = -1; } + index_t index_off, vec_idx, i_off; + vec_t other; + auto handle_non_padding = [&]() { + i_off = vec_idx * vec_len_ + desc.glb_problem; + other = w_vec_[i_off]; - for (index_t off = start; off < end; off++) { - index_t index_off = off; - index_t vec_idx = index_[index_off]; - SYCL_KERNEL_ASSERT(vec_idx < num_row_); - - if (walk_on_bag && desc.glb_problem == 0) { - offset2bag_[index_off] = off_off; - } - - if (padding_idx_ != vec_idx) { - index_t i_off = vec_idx * vec_len_ + desc.glb_problem; - vec_t other = w_vec_[i_off]; - - if constexpr (mode == MODE_SUM) { + if constexpr (mode == MODE_SUM) { #pragma unroll - for (int i = 0; i < vec_size; i++) { - if (per_sample_weights_) { - other[i] *= per_sample_weights_[index_off]; - } - value[i] += other[i]; + for (int i = 0; i < vec_size; i++) { + if constexpr (per_sample_weights_defined) { + other[i] *= per_sample_weights_[index_off]; } - } else if constexpr (mode == MODE_MEAN) { + value[i] += other[i]; + } + } else if constexpr (mode == MODE_MEAN) { #pragma unroll - for (int i = 0; i < vec_size; i++) { - value[i] += other[i]; - } - } else if constexpr (mode == MODE_MAX) { + for (int i = 0; i < vec_size; i++) { + value[i] += other[i]; + } + } else if constexpr (mode == MODE_MAX) { #pragma unroll - for (int i = 0; i < vec_size; i++) { - if (other[i] > value_max[i]) { - value_max[i] = other[i]; - if (max_index_) { - index_max[i] = vec_idx; - } + for (int i = 0; i < vec_size; i++) { + if (other[i] > value_max[i]) { + value_max[i] = other[i]; + if (max_index_) { + index_max[i] = vec_idx; } } } + } + }; + + for (index_t off = start; off < end; off++) { + index_off = off; + vec_idx = index_[index_off]; + // SYCL_KERNEL_ASSERT(vec_idx < num_row_); + + if (walk_on_bag && desc.glb_problem == 0) { + offset2bag_[index_off] = off_off; + } + + if constexpr (padding_idx_defined) { + if (padding_idx_ != vec_idx) { + handle_non_padding(); + } else { + padding_cnt++; + } } else { - padding_cnt++; + handle_non_padding(); } } From 82abd39e72a0d91b94d7ddc12cb80494bb9f0bd3 Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Tue, 10 Jun 2025 13:58:15 +0000 Subject: [PATCH 02/11] fix vectorization --- src/ATen/native/xpu/sycl/EmbeddingBag.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp index e6196fdbcc..097b6e7f86 100644 --- a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp +++ b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp @@ -52,9 +52,9 @@ void embedding_bag( index_t padding_idx, bool ignore_offsets, int64_t num_row) { - using vec_t = at::detail::Array; - using vec_acc_t = at::detail::Array; - using vec_idx_t = at::detail::Array; + using vec_t = memory::aligned_vector; + using vec_acc_t = memory::aligned_vector; + using vec_idx_t = memory::aligned_vector; using KernelClass = EmbeddingBagKernelFunctor< scalar_t, accscalar_t, From 536e8a3a700c232ad62224e78514d633000ef35a Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Tue, 17 Jun 2025 07:46:32 +0000 Subject: [PATCH 03/11] test --- src/ATen/native/xpu/sycl/EmbeddingBag.cpp | 8 +++++++- src/ATen/native/xpu/sycl/EmbeddingBag.h | 9 ++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp index 097b6e7f86..be5edafeaa 100644 --- a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp +++ b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp @@ -72,8 +72,13 @@ void embedding_bag( vec_idx_t* max_idx_vec = reinterpret_cast(max_index); vec_len = vec_len / vec_size; + int tile = 1; + if (32 % vec_len == 0) { + tile = 32 / vec_len; + } + int batch = (bag_num + tile - 1) / tile; BatchKernelConfig cfg = BatchKernelConfig::make_config( - bag_num, vec_len, 1, bag_num, true, BatchKernelConfig::Policy::pAdaptive); + batch, vec_len * tile, 1, batch, true, BatchKernelConfig::Policy::pAdaptive); index_t fixing_bag_size = ignore_offsets ? index_size / bag_num : 0; auto kfn = KernelClass( @@ -86,6 +91,7 @@ void embedding_bag( index_size, bag_num, vec_len, + tile, padding_idx, ignore_offsets, o_vec, diff --git a/src/ATen/native/xpu/sycl/EmbeddingBag.h b/src/ATen/native/xpu/sycl/EmbeddingBag.h index 904bc615d2..f1c85b40e6 100644 --- a/src/ATen/native/xpu/sycl/EmbeddingBag.h +++ b/src/ATen/native/xpu/sycl/EmbeddingBag.h @@ -35,7 +35,7 @@ struct EmbeddingBagKernelFunctor { desc.glb_batch < cfg_.problem_batch_) { bool walk_on_bag = desc.glb_batch != off_off; if (walk_on_bag) { - off_off = desc.glb_batch; + off_off = desc.glb_batch * tile_ + desc.glb_problem / (cfg_.problem_ / tile_); bool last_bag = off_off == bag_num_ - 1; if (!ignore_offsets_) { start = offset_[off_off]; @@ -58,7 +58,7 @@ struct EmbeddingBagKernelFunctor { index_t index_off, vec_idx, i_off; vec_t other; auto handle_non_padding = [&]() { - i_off = vec_idx * vec_len_ + desc.glb_problem; + i_off = vec_idx * vec_len_ + desc.glb_problem % (cfg_.problem_ / tile_); other = w_vec_[i_off]; if constexpr (mode == MODE_SUM) { @@ -112,7 +112,7 @@ struct EmbeddingBagKernelFunctor { bag_size_[off_off] = bsize; } - index_t o_off = off_off * vec_len_ + desc.glb_problem; + index_t o_off = off_off * vec_len_ + desc.glb_problem % (cfg_.problem_ / tile_); if constexpr (mode == MODE_SUM) { vec_t o; #pragma unroll @@ -155,6 +155,7 @@ struct EmbeddingBagKernelFunctor { int64_t index_size, int64_t bag_num, int64_t vec_len, + int64_t tile, index_t padding_idx, bool ignore_offsets, vec_t* o_vec, @@ -172,6 +173,7 @@ struct EmbeddingBagKernelFunctor { index_size_(index_size), bag_num_(bag_num), vec_len_(vec_len), + tile_(tile), padding_idx_(padding_idx), ignore_offsets_(ignore_offsets), o_vec_(o_vec), @@ -191,6 +193,7 @@ struct EmbeddingBagKernelFunctor { int64_t index_size_; int64_t bag_num_; int64_t vec_len_; + int64_t tile_; index_t padding_idx_; bool ignore_offsets_; vec_t* o_vec_; From 1321c5e6cce4041be33b86b2004704af9f98404a Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Thu, 19 Jun 2025 05:05:00 +0000 Subject: [PATCH 04/11] save --- src/ATen/native/xpu/sycl/EmbeddingBag.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp index be5edafeaa..308d2dc14c 100644 --- a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp +++ b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp @@ -73,8 +73,9 @@ void embedding_bag( vec_len = vec_len / vec_size; int tile = 1; - if (32 % vec_len == 0) { - tile = 32 / vec_len; + int sub_group_size = syclMaxSubGroupSize(); + if (sub_group_size % vec_len == 0) { + tile = sub_group_size / vec_len; } int batch = (bag_num + tile - 1) / tile; BatchKernelConfig cfg = BatchKernelConfig::make_config( From ef76d4d3aea74e434959acb0e2390dc56a0c5429 Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Thu, 19 Jun 2025 05:07:22 +0000 Subject: [PATCH 05/11] save --- src/ATen/native/xpu/sycl/EmbeddingBag.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ATen/native/xpu/sycl/EmbeddingBag.h b/src/ATen/native/xpu/sycl/EmbeddingBag.h index f1c85b40e6..ec9dbad6cd 100644 --- a/src/ATen/native/xpu/sycl/EmbeddingBag.h +++ b/src/ATen/native/xpu/sycl/EmbeddingBag.h @@ -90,7 +90,7 @@ struct EmbeddingBagKernelFunctor { for (index_t off = start; off < end; off++) { index_off = off; vec_idx = index_[index_off]; - // SYCL_KERNEL_ASSERT(vec_idx < num_row_); + SYCL_KERNEL_ASSERT(vec_idx < num_row_); if (walk_on_bag && desc.glb_problem == 0) { offset2bag_[index_off] = off_off; From cd17f6a675619f6a35c25d80b1e449761caa170a Mon Sep 17 00:00:00 2001 From: intel Date: Fri, 11 Jul 2025 14:00:23 +0000 Subject: [PATCH 06/11] save --- src/ATen/native/xpu/sycl/EmbeddingBag.cpp | 72 ++++---- src/ATen/native/xpu/sycl/EmbeddingBag.h | 195 +++++++++++----------- 2 files changed, 134 insertions(+), 133 deletions(-) diff --git a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp index 308d2dc14c..1f047f745e 100644 --- a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp +++ b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp @@ -48,7 +48,7 @@ void embedding_bag( const scalar_t* const per_sample_weights, int64_t index_size, int64_t bag_num, - int64_t vec_len, + int64_t feature_dim, index_t padding_idx, bool ignore_offsets, int64_t num_row) { @@ -71,15 +71,12 @@ void embedding_bag( const vec_t* w_vec = reinterpret_cast(weights); vec_idx_t* max_idx_vec = reinterpret_cast(max_index); - vec_len = vec_len / vec_size; - int tile = 1; - int sub_group_size = syclMaxSubGroupSize(); - if (sub_group_size % vec_len == 0) { - tile = sub_group_size / vec_len; - } - int batch = (bag_num + tile - 1) / tile; - BatchKernelConfig cfg = BatchKernelConfig::make_config( - batch, vec_len * tile, 1, batch, true, BatchKernelConfig::Policy::pAdaptive); + int vectorized_feature_dim = feature_dim / vec_size; + int64_t work_group_size = syclMaxWorkGroupSize(); + // TODO: we can set num_work_group = 1024 and add for loop in kernel + int64_t num_work_group = ceil_div( + static_cast(bag_num * vectorized_feature_dim), + static_cast(work_group_size)); index_t fixing_bag_size = ignore_offsets ? index_size / bag_num : 0; auto kfn = KernelClass( @@ -91,18 +88,19 @@ void embedding_bag( per_sample_weights, index_size, bag_num, - vec_len, - tile, + vectorized_feature_dim, padding_idx, ignore_offsets, o_vec, w_vec, max_idx_vec, - cfg, fixing_bag_size, num_row); sycl_kernel_submit( - cfg.global_size(), cfg.group_size(), getCurrentSYCLQueue(), kfn); + num_work_group * work_group_size, + work_group_size, + getCurrentSYCLQueue(), + kfn); } #define EMBBAG_KERNEL_ACC( \ @@ -121,7 +119,7 @@ void embedding_bag( per_sample_weights, \ index_len, \ bag_num, \ - vec_len, \ + feature_dim, \ padding_idx, \ ignore_offsets, \ num_row) \ @@ -137,7 +135,7 @@ void embedding_bag( per_sample_weights.const_data_ptr(), \ index_size, \ bag_num, \ - vec_len, \ + feature_dim, \ padding_idx, \ ignore_offsets, \ num_row); \ @@ -160,7 +158,7 @@ void embedding_bag( nullptr, \ index_size, \ bag_num, \ - vec_len, \ + feature_dim, \ padding_idx, \ ignore_offsets, \ num_row); \ @@ -183,7 +181,7 @@ void embedding_bag( per_sample_weights.const_data_ptr(), \ index_size, \ bag_num, \ - vec_len, \ + feature_dim, \ padding_idx, \ ignore_offsets, \ num_row); \ @@ -206,7 +204,7 @@ void embedding_bag( nullptr, \ index_size, \ bag_num, \ - vec_len, \ + feature_dim, \ padding_idx, \ ignore_offsets, \ num_row); @@ -226,7 +224,7 @@ void embedding_bag( per_sample_weights, \ index_len, \ bag_num, \ - vec_len, \ + feature_dim, \ padding_idx, \ ignore_offsets, \ num_row) \ @@ -242,7 +240,7 @@ void embedding_bag( per_sample_weights.const_data_ptr(), \ index_size, \ bag_num, \ - vec_len, \ + feature_dim, \ padding_idx, \ ignore_offsets, \ num_row); \ @@ -258,7 +256,7 @@ void embedding_bag( nullptr, \ index_size, \ bag_num, \ - vec_len, \ + feature_dim, \ padding_idx, \ ignore_offsets, \ num_row); \ @@ -274,7 +272,7 @@ void embedding_bag( per_sample_weights.const_data_ptr(), \ index_size, \ bag_num, \ - vec_len, \ + feature_dim, \ padding_idx, \ ignore_offsets, \ num_row); \ @@ -290,7 +288,7 @@ void embedding_bag( nullptr, \ index_size, \ bag_num, \ - vec_len, \ + feature_dim, \ padding_idx, \ ignore_offsets, \ num_row); @@ -306,7 +304,7 @@ void embedding_bag_sum_template( Tensor& max_indices, int64_t index_size, int64_t bag_num, - int64_t vec_len, + int64_t feature_dim, int64_t padding_idx, bool ignore_offsets) { uint64_t num_row = weights.size(0); @@ -327,7 +325,7 @@ void embedding_bag_sum_template( per_sample_weights, \ index_len, \ bag_num, \ - vec_len, \ + feature_dim, \ padding_idx, \ ignore_offsets, \ num_row) @@ -343,7 +341,15 @@ void embedding_bag_sum_template( using accscalar_t = at::acc_type_device; int vec_size = memory::can_vectorize_up_to( (char*)weights.const_data_ptr()); - vec_size = vec_len % vec_size == 0 ? vec_size : 1; + vec_size = feature_dim % vec_size == 0 ? vec_size : 1; + for (int v : {8, 4, 2, 1}) { + // We only load one weight[i] in one subgroup, otherwise memory + // load cannot be coalesce + if (feature_dim % v == 0 && (feature_dim / v) % 32 == 0) { + vec_size = v; + break; + } + } switch (vec_size) { case 8: EXTEND_EMBBAG_SUM_KERNEL_VEC(8); @@ -374,7 +380,7 @@ void embedding_bag_mean_template( Tensor& max_indices, int64_t index_size, int64_t bag_num, - int64_t vec_len, + int64_t feature_dim, int64_t padding_idx, bool ignore_offsets) { uint64_t num_row = weights.size(0); @@ -395,7 +401,7 @@ void embedding_bag_mean_template( per_sample_weights, \ index_len, \ bag_num, \ - vec_len, \ + feature_dim, \ padding_idx, \ ignore_offsets, \ num_row) @@ -411,7 +417,7 @@ void embedding_bag_mean_template( using accscalar_t = at::acc_type_device; int vec_size = memory::can_vectorize_up_to( (char*)weights.const_data_ptr()); - vec_size = vec_len % vec_size == 0 ? vec_size : 1; + vec_size = feature_dim % vec_size == 0 ? vec_size : 1; switch (vec_size) { case 8: EXTEND_EMBBAG_MEAN_KERNEL_VEC(8); @@ -442,7 +448,7 @@ void embedding_bag_max_template( Tensor& max_indices, int64_t index_size, int64_t bag_num, - int64_t vec_len, + int64_t feature_dim, int64_t padding_idx, bool ignore_offsets) { uint64_t num_row = weights.size(0); @@ -462,7 +468,7 @@ void embedding_bag_max_template( per_sample_weights, \ index_len, \ bag_num, \ - vec_len, \ + feature_dim, \ padding_idx, \ ignore_offsets, \ num_row) @@ -478,7 +484,7 @@ void embedding_bag_max_template( // using accscalar_t = at::acc_type_device; int vec_size = memory::can_vectorize_up_to( (char*)weights.const_data_ptr()); - vec_size = vec_len % vec_size == 0 ? vec_size : 1; + vec_size = feature_dim % vec_size == 0 ? vec_size : 1; switch (vec_size) { case 8: EXTEND_EMBBAG_MAX_KERNEL_VEC(8); diff --git a/src/ATen/native/xpu/sycl/EmbeddingBag.h b/src/ATen/native/xpu/sycl/EmbeddingBag.h index ec9dbad6cd..4429fd4591 100644 --- a/src/ATen/native/xpu/sycl/EmbeddingBag.h +++ b/src/ATen/native/xpu/sycl/EmbeddingBag.h @@ -25,125 +25,126 @@ template < bool per_sample_weights_defined, bool padding_idx_defined> struct EmbeddingBagKernelFunctor { - void operator()(sycl::nd_item<2> item) const { - auto desc = cfg_.get_item_desc(item); - index_t start = 0, end = 0; - int64_t off_off = -1; + void operator()(sycl::nd_item<1> item) const { + auto thread_id = item.get_global_linear_id(); + if (thread_id < bag_num_ * vectorized_feature_dim_len_) { + auto current_feature = thread_id % vectorized_feature_dim_len_; + auto current_bag = thread_id / vectorized_feature_dim_len_; + index_t start, end; + bool last_bag = current_bag == bag_num_ - 1; + // TODO: add template + if (!ignore_offsets_) { + start = offset_[current_bag]; + end = last_bag ? index_size_ : offset_[current_bag + 1]; + } else { + start = current_bag * fixing_bag_size_; + end = start + fixing_bag_size_; + } - do { - if (desc.glb_problem < cfg_.problem_ && - desc.glb_batch < cfg_.problem_batch_) { - bool walk_on_bag = desc.glb_batch != off_off; - if (walk_on_bag) { - off_off = desc.glb_batch * tile_ + desc.glb_problem / (cfg_.problem_ / tile_); - bool last_bag = off_off == bag_num_ - 1; - if (!ignore_offsets_) { - start = offset_[off_off]; - end = last_bag ? index_size_ : offset_[off_off + 1]; - } else { - start = off_off * fixing_bag_size_; - end = start + fixing_bag_size_; - } - } + vec_acc_t value, value_max; + vec_idx_t index_max; + index_t padding_cnt = 0; - vec_acc_t value, value_max; - vec_idx_t index_max; - index_t padding_cnt = 0; +#pragma unroll + for (int i = 0; i < vec_size; i++) { + value[i] = 0; + } + if constexpr (mode == MODE_MAX) { #pragma unroll for (int i = 0; i < vec_size; i++) { - value[i] = 0; value_max[i] = at::numeric_limits::lower_bound(); index_max[i] = -1; } - index_t index_off, vec_idx, i_off; - vec_t other; - auto handle_non_padding = [&]() { - i_off = vec_idx * vec_len_ + desc.glb_problem % (cfg_.problem_ / tile_); - other = w_vec_[i_off]; + } + index_t index_offset, weight_index; + vec_t wei_load; + auto handle_non_padding = [&]() { + wei_load = w_vec_ + [weight_index * vectorized_feature_dim_len_ + current_feature]; - if constexpr (mode == MODE_SUM) { + if constexpr (mode == MODE_SUM) { #pragma unroll - for (int i = 0; i < vec_size; i++) { - if constexpr (per_sample_weights_defined) { - other[i] *= per_sample_weights_[index_off]; - } - value[i] += other[i]; + for (int i = 0; i < vec_size; i++) { + if constexpr (per_sample_weights_defined) { + wei_load[i] *= per_sample_weights_[index_offset]; } - } else if constexpr (mode == MODE_MEAN) { + value[i] += wei_load[i]; + } + } else if constexpr (mode == MODE_MEAN) { #pragma unroll - for (int i = 0; i < vec_size; i++) { - value[i] += other[i]; - } - } else if constexpr (mode == MODE_MAX) { + for (int i = 0; i < vec_size; i++) { + value[i] += wei_load[i]; + } + } else if constexpr (mode == MODE_MAX) { #pragma unroll - for (int i = 0; i < vec_size; i++) { - if (other[i] > value_max[i]) { - value_max[i] = other[i]; - if (max_index_) { - index_max[i] = vec_idx; - } + for (int i = 0; i < vec_size; i++) { + if (wei_load[i] > value_max[i]) { + value_max[i] = wei_load[i]; + if (max_index_) { + index_max[i] = weight_index; } } } - }; + } + }; - for (index_t off = start; off < end; off++) { - index_off = off; - vec_idx = index_[index_off]; - SYCL_KERNEL_ASSERT(vec_idx < num_row_); + for (index_t offset_in_bag = start; offset_in_bag < end; + offset_in_bag++) { + index_offset = offset_in_bag; + weight_index = index_[index_offset]; + SYCL_KERNEL_ASSERT(weight_index < num_row_); - if (walk_on_bag && desc.glb_problem == 0) { - offset2bag_[index_off] = off_off; - } + if (current_feature == 0) + offset2bag_[index_offset] = current_bag; - if constexpr (padding_idx_defined) { - if (padding_idx_ != vec_idx) { - handle_non_padding(); - } else { - padding_cnt++; - } - } else { + if constexpr (padding_idx_defined) { + if (padding_idx_ != weight_index) { handle_non_padding(); + } else { + padding_cnt++; } + } else { + handle_non_padding(); } + } - int64_t bsize = end - start - padding_cnt; - if (desc.glb_problem == 0) { - bag_size_[off_off] = bsize; - } + int64_t bsize = end - start - padding_cnt; + if (current_feature == 0) { + bag_size_[current_bag] = bsize; + } - index_t o_off = off_off * vec_len_ + desc.glb_problem % (cfg_.problem_ / tile_); - if constexpr (mode == MODE_SUM) { - vec_t o; + index_t o_off = + current_bag * vectorized_feature_dim_len_ + current_feature; + if constexpr (mode == MODE_SUM) { + vec_t o; #pragma unroll - for (int i = 0; i < vec_size; i++) { - o[i] = value[i]; - } - o_vec_[o_off] = o; - } else if constexpr (mode == MODE_MEAN) { - vec_t o; - bsize = bsize == 0 ? 1 : bsize; + for (int i = 0; i < vec_size; i++) { + o[i] = value[i]; + } + o_vec_[o_off] = o; + } else if constexpr (mode == MODE_MEAN) { + vec_t o; + bsize = bsize == 0 ? 1 : bsize; #pragma unroll - for (int i = 0; i < vec_size; i++) { - o[i] = value[i] / bsize; - } - o_vec_[o_off] = o; - } else if constexpr (mode == MODE_MAX) { - vec_t padding; + for (int i = 0; i < vec_size; i++) { + o[i] = value[i] / bsize; + } + o_vec_[o_off] = o; + } else if constexpr (mode == MODE_MAX) { + vec_t padding; #pragma unroll - for (int i = 0; i < vec_size; i++) { - padding[i] = 0; - } - o_vec_[o_off] = - value_max[0] == at::numeric_limits::lower_bound() - ? padding - : value_max; - if (max_index_) { - max_idx_vec_[o_off] = index_max; - } + for (int i = 0; i < vec_size; i++) { + padding[i] = 0; + } + o_vec_[o_off] = + value_max[0] == at::numeric_limits::lower_bound() + ? padding + : value_max; + if (max_index_) { + max_idx_vec_[o_off] = index_max; } } - } while (cfg_.next(item, desc)); + } } EmbeddingBagKernelFunctor( const index_t* const index, @@ -154,14 +155,12 @@ struct EmbeddingBagKernelFunctor { const scalar_t* const per_sample_weights, int64_t index_size, int64_t bag_num, - int64_t vec_len, - int64_t tile, + int64_t vectorized_feature_dim_len, index_t padding_idx, bool ignore_offsets, vec_t* o_vec, const vec_t* w_vec, vec_idx_t* max_idx_vec, - BatchKernelConfig cfg, index_t fixing_bag_size, index_t num_row) : index_(index), @@ -172,14 +171,12 @@ struct EmbeddingBagKernelFunctor { per_sample_weights_(per_sample_weights), index_size_(index_size), bag_num_(bag_num), - vec_len_(vec_len), - tile_(tile), + vectorized_feature_dim_len_(vectorized_feature_dim_len), padding_idx_(padding_idx), ignore_offsets_(ignore_offsets), o_vec_(o_vec), w_vec_(w_vec), max_idx_vec_(max_idx_vec), - cfg_(cfg), fixing_bag_size_(fixing_bag_size), num_row_(num_row) {} @@ -192,14 +189,12 @@ struct EmbeddingBagKernelFunctor { const scalar_t* const per_sample_weights_; int64_t index_size_; int64_t bag_num_; - int64_t vec_len_; - int64_t tile_; + int64_t vectorized_feature_dim_len_; index_t padding_idx_; bool ignore_offsets_; vec_t* o_vec_; const vec_t* w_vec_; vec_idx_t* max_idx_vec_; - BatchKernelConfig cfg_; index_t fixing_bag_size_; index_t num_row_; }; From 9ad79be338243fb07b457db9b671fa3582c17cca Mon Sep 17 00:00:00 2001 From: intel Date: Fri, 11 Jul 2025 14:35:07 +0000 Subject: [PATCH 07/11] save --- src/ATen/native/xpu/sycl/EmbeddingBag.cpp | 16 ++++++++++++++++ src/ATen/native/xpu/sycl/EmbeddingBag.h | 1 - 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp index 1f047f745e..e9c4718eb3 100644 --- a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp +++ b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp @@ -418,6 +418,14 @@ void embedding_bag_mean_template( int vec_size = memory::can_vectorize_up_to( (char*)weights.const_data_ptr()); vec_size = feature_dim % vec_size == 0 ? vec_size : 1; + for (int v : {8, 4, 2, 1}) { + // We only load one weight[i] in one subgroup, otherwise memory + // load cannot be coalesce + if (feature_dim % v == 0 && (feature_dim / v) % 32 == 0) { + vec_size = v; + break; + } + } switch (vec_size) { case 8: EXTEND_EMBBAG_MEAN_KERNEL_VEC(8); @@ -485,6 +493,14 @@ void embedding_bag_max_template( int vec_size = memory::can_vectorize_up_to( (char*)weights.const_data_ptr()); vec_size = feature_dim % vec_size == 0 ? vec_size : 1; + for (int v : {8, 4, 2, 1}) { + // We only load one weight[i] in one subgroup, otherwise memory + // load cannot be coalesce + if (feature_dim % v == 0 && (feature_dim / v) % 32 == 0) { + vec_size = v; + break; + } + } switch (vec_size) { case 8: EXTEND_EMBBAG_MAX_KERNEL_VEC(8); diff --git a/src/ATen/native/xpu/sycl/EmbeddingBag.h b/src/ATen/native/xpu/sycl/EmbeddingBag.h index 4429fd4591..d33084809c 100644 --- a/src/ATen/native/xpu/sycl/EmbeddingBag.h +++ b/src/ATen/native/xpu/sycl/EmbeddingBag.h @@ -32,7 +32,6 @@ struct EmbeddingBagKernelFunctor { auto current_bag = thread_id / vectorized_feature_dim_len_; index_t start, end; bool last_bag = current_bag == bag_num_ - 1; - // TODO: add template if (!ignore_offsets_) { start = offset_[current_bag]; end = last_bag ? index_size_ : offset_[current_bag + 1]; From 4bee8fab7c561e0de093856198ba6eb181c5c158 Mon Sep 17 00:00:00 2001 From: intel Date: Fri, 11 Jul 2025 14:55:06 +0000 Subject: [PATCH 08/11] save --- src/ATen/native/xpu/sycl/EmbeddingBag.cpp | 32 ++++++----------------- 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp index e9c4718eb3..70dfa9e07c 100644 --- a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp +++ b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp @@ -342,14 +342,14 @@ void embedding_bag_sum_template( int vec_size = memory::can_vectorize_up_to( (char*)weights.const_data_ptr()); vec_size = feature_dim % vec_size == 0 ? vec_size : 1; - for (int v : {8, 4, 2, 1}) { - // We only load one weight[i] in one subgroup, otherwise memory - // load cannot be coalesce - if (feature_dim % v == 0 && (feature_dim / v) % 32 == 0) { - vec_size = v; - break; - } - } + // for (int v : {8, 4, 2, 1}) { + // // We only load one weight[i] in one subgroup, otherwise memory + // // load cannot be coalesce + // if (feature_dim % v == 0 && (feature_dim / v) % 32 == 0) { + // vec_size = v; + // break; + // } + // } switch (vec_size) { case 8: EXTEND_EMBBAG_SUM_KERNEL_VEC(8); @@ -418,14 +418,6 @@ void embedding_bag_mean_template( int vec_size = memory::can_vectorize_up_to( (char*)weights.const_data_ptr()); vec_size = feature_dim % vec_size == 0 ? vec_size : 1; - for (int v : {8, 4, 2, 1}) { - // We only load one weight[i] in one subgroup, otherwise memory - // load cannot be coalesce - if (feature_dim % v == 0 && (feature_dim / v) % 32 == 0) { - vec_size = v; - break; - } - } switch (vec_size) { case 8: EXTEND_EMBBAG_MEAN_KERNEL_VEC(8); @@ -493,14 +485,6 @@ void embedding_bag_max_template( int vec_size = memory::can_vectorize_up_to( (char*)weights.const_data_ptr()); vec_size = feature_dim % vec_size == 0 ? vec_size : 1; - for (int v : {8, 4, 2, 1}) { - // We only load one weight[i] in one subgroup, otherwise memory - // load cannot be coalesce - if (feature_dim % v == 0 && (feature_dim / v) % 32 == 0) { - vec_size = v; - break; - } - } switch (vec_size) { case 8: EXTEND_EMBBAG_MAX_KERNEL_VEC(8); From 207c5e7c12ee920c8ca91c065a6ef0fe6be1d4c5 Mon Sep 17 00:00:00 2001 From: intel Date: Fri, 11 Jul 2025 15:02:43 +0000 Subject: [PATCH 09/11] save --- src/ATen/native/xpu/sycl/EmbeddingBag.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp index 70dfa9e07c..c03395b424 100644 --- a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp +++ b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp @@ -73,7 +73,7 @@ void embedding_bag( int vectorized_feature_dim = feature_dim / vec_size; int64_t work_group_size = syclMaxWorkGroupSize(); - // TODO: we can set num_work_group = 1024 and add for loop in kernel + // TODO: we can set a smaller num_work_group and add for loop in kernel int64_t num_work_group = ceil_div( static_cast(bag_num * vectorized_feature_dim), static_cast(work_group_size)); From a53da4efd8bcb014864ebd80aae954beae0a62e1 Mon Sep 17 00:00:00 2001 From: intel Date: Sat, 12 Jul 2025 14:36:12 +0000 Subject: [PATCH 10/11] save --- src/ATen/native/xpu/sycl/EmbeddingBag.cpp | 49 ++++++++++++++++++----- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp index c03395b424..b895ff4177 100644 --- a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp +++ b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp @@ -72,7 +72,7 @@ void embedding_bag( vec_idx_t* max_idx_vec = reinterpret_cast(max_index); int vectorized_feature_dim = feature_dim / vec_size; - int64_t work_group_size = syclMaxWorkGroupSize(); + int64_t work_group_size = syclDeviceMaxWorkGroupSize(); // TODO: we can set a smaller num_work_group and add for loop in kernel int64_t num_work_group = ceil_div( static_cast(bag_num * vectorized_feature_dim), @@ -342,14 +342,19 @@ void embedding_bag_sum_template( int vec_size = memory::can_vectorize_up_to( (char*)weights.const_data_ptr()); vec_size = feature_dim % vec_size == 0 ? vec_size : 1; - // for (int v : {8, 4, 2, 1}) { - // // We only load one weight[i] in one subgroup, otherwise memory - // // load cannot be coalesce - // if (feature_dim % v == 0 && (feature_dim / v) % 32 == 0) { - // vec_size = v; - // break; - // } - // } + auto num_wg = bag_num * feature_dim / vec_size / + syclDeviceMaxWorkGroupSize(); + auto num_xe = syclGpuEuCount() / syclGpuEUCountPerSubslice(); + if (num_wg < num_xe) { + for (int v = vec_size; v != 1; v = v / 2) { + // We only load one weight[i] in one subgroup, otherwise + // memory load cannot be coalesce + if (feature_dim % v == 0 && (feature_dim / v) % 32 == 0) { + vec_size = v; + break; + } + } + } switch (vec_size) { case 8: EXTEND_EMBBAG_SUM_KERNEL_VEC(8); @@ -418,6 +423,19 @@ void embedding_bag_mean_template( int vec_size = memory::can_vectorize_up_to( (char*)weights.const_data_ptr()); vec_size = feature_dim % vec_size == 0 ? vec_size : 1; + auto num_wg = bag_num * feature_dim / vec_size / + syclDeviceMaxWorkGroupSize(); + auto num_xe = syclGpuEuCount() / syclGpuEUCountPerSubslice(); + if (num_wg < num_xe) { + for (int v = vec_size; v != 1; v = v / 2) { + // We only load one weight[i] in one subgroup, otherwise + // memory load cannot be coalesce + if (feature_dim % v == 0 && (feature_dim / v) % 32 == 0) { + vec_size = v; + break; + } + } + } switch (vec_size) { case 8: EXTEND_EMBBAG_MEAN_KERNEL_VEC(8); @@ -485,6 +503,19 @@ void embedding_bag_max_template( int vec_size = memory::can_vectorize_up_to( (char*)weights.const_data_ptr()); vec_size = feature_dim % vec_size == 0 ? vec_size : 1; + auto num_wg = bag_num * feature_dim / vec_size / + syclDeviceMaxWorkGroupSize(); + auto num_xe = syclGpuEuCount() / syclGpuEUCountPerSubslice(); + if (num_wg < num_xe) { + for (int v = vec_size; v != 1; v = v / 2) { + // We only load one weight[i] in one subgroup, otherwise + // memory load cannot be coalesce + if (feature_dim % v == 0 && (feature_dim / v) % 32 == 0) { + vec_size = v; + break; + } + } + } switch (vec_size) { case 8: EXTEND_EMBBAG_MAX_KERNEL_VEC(8); From 9138a677f200c6e910720a898c2d579f90a0a5ad Mon Sep 17 00:00:00 2001 From: intel Date: Tue, 15 Jul 2025 03:02:05 +0000 Subject: [PATCH 11/11] save --- src/ATen/native/xpu/sycl/EmbeddingBag.cpp | 63 +++++++++++------------ 1 file changed, 30 insertions(+), 33 deletions(-) diff --git a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp index b895ff4177..8ca0d21d8b 100644 --- a/src/ATen/native/xpu/sycl/EmbeddingBag.cpp +++ b/src/ATen/native/xpu/sycl/EmbeddingBag.cpp @@ -342,17 +342,16 @@ void embedding_bag_sum_template( int vec_size = memory::can_vectorize_up_to( (char*)weights.const_data_ptr()); vec_size = feature_dim % vec_size == 0 ? vec_size : 1; - auto num_wg = bag_num * feature_dim / vec_size / - syclDeviceMaxWorkGroupSize(); - auto num_xe = syclGpuEuCount() / syclGpuEUCountPerSubslice(); - if (num_wg < num_xe) { - for (int v = vec_size; v != 1; v = v / 2) { - // We only load one weight[i] in one subgroup, otherwise - // memory load cannot be coalesce - if (feature_dim % v == 0 && (feature_dim / v) % 32 == 0) { - vec_size = v; - break; - } + int num_sub_wg = + bag_num * feature_dim / vec_size / syclMaxSubGroupSize(); + int thread_slots = syclGpuEuCount() * syclGpuHWThreadsPerEU(); + for (int v = vec_size; v != 1; + v = v / 2, num_sub_wg = num_sub_wg * 2) { + if (2 * num_sub_wg > thread_slots) { + // peak occurancy = num_sub_wg / thread_slots + // it should > 50% + vec_size = v; + break; } } switch (vec_size) { @@ -423,17 +422,16 @@ void embedding_bag_mean_template( int vec_size = memory::can_vectorize_up_to( (char*)weights.const_data_ptr()); vec_size = feature_dim % vec_size == 0 ? vec_size : 1; - auto num_wg = bag_num * feature_dim / vec_size / - syclDeviceMaxWorkGroupSize(); - auto num_xe = syclGpuEuCount() / syclGpuEUCountPerSubslice(); - if (num_wg < num_xe) { - for (int v = vec_size; v != 1; v = v / 2) { - // We only load one weight[i] in one subgroup, otherwise - // memory load cannot be coalesce - if (feature_dim % v == 0 && (feature_dim / v) % 32 == 0) { - vec_size = v; - break; - } + int num_sub_wg = + bag_num * feature_dim / vec_size / syclMaxSubGroupSize(); + int thread_slots = syclGpuEuCount() * syclGpuHWThreadsPerEU(); + for (int v = vec_size; v != 1; + v = v / 2, num_sub_wg = num_sub_wg * 2) { + if (2 * num_sub_wg > thread_slots) { + // peak occurancy = num_sub_wg / thread_slots + // it should > 50% + vec_size = v; + break; } } switch (vec_size) { @@ -503,17 +501,16 @@ void embedding_bag_max_template( int vec_size = memory::can_vectorize_up_to( (char*)weights.const_data_ptr()); vec_size = feature_dim % vec_size == 0 ? vec_size : 1; - auto num_wg = bag_num * feature_dim / vec_size / - syclDeviceMaxWorkGroupSize(); - auto num_xe = syclGpuEuCount() / syclGpuEUCountPerSubslice(); - if (num_wg < num_xe) { - for (int v = vec_size; v != 1; v = v / 2) { - // We only load one weight[i] in one subgroup, otherwise - // memory load cannot be coalesce - if (feature_dim % v == 0 && (feature_dim / v) % 32 == 0) { - vec_size = v; - break; - } + int num_sub_wg = + bag_num * feature_dim / vec_size / syclMaxSubGroupSize(); + int thread_slots = syclGpuEuCount() * syclGpuHWThreadsPerEU(); + for (int v = vec_size; v != 1; + v = v / 2, num_sub_wg = num_sub_wg * 2) { + if (2 * num_sub_wg > thread_slots) { + // peak occurancy = num_sub_wg / thread_slots + // it should > 50% + vec_size = v; + break; } } switch (vec_size) {