Skip to content

Commit bea0116

Browse files
author
intel
committed
save
1 parent 57ec998 commit bea0116

File tree

1 file changed

+40
-9
lines changed

1 file changed

+40
-9
lines changed

src/ATen/native/xpu/sycl/EmbeddingBag.cpp

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ void embedding_bag(
7272
vec_idx_t* max_idx_vec = reinterpret_cast<vec_idx_t*>(max_index);
7373

7474
int vectorized_feature_dim = feature_dim / vec_size;
75-
int64_t work_group_size = syclMaxWorkGroupSize<KernelClass>();
75+
int64_t work_group_size = syclDeviceMaxWorkGroupSize();
7676
// TODO: we can set a smaller num_work_group and add for loop in kernel
7777
int64_t num_work_group = ceil_div(
7878
static_cast<int64_t>(bag_num * vectorized_feature_dim),
@@ -342,14 +342,19 @@ void embedding_bag_sum_template(
342342
int vec_size = memory::can_vectorize_up_to<scalar_t>(
343343
(char*)weights.const_data_ptr());
344344
vec_size = feature_dim % vec_size == 0 ? vec_size : 1;
345-
// for (int v : {8, 4, 2, 1}) {
346-
// // We only load one weight[i] in one subgroup, otherwise memory
347-
// // load cannot be coalesce
348-
// if (feature_dim % v == 0 && (feature_dim / v) % 32 == 0) {
349-
// vec_size = v;
350-
// break;
351-
// }
352-
// }
345+
auto num_wg = bag_num * feature_dim / vec_size /
346+
syclDeviceMaxWorkGroupSize();
347+
auto num_xe = syclGpuEuCount() / syclGpuEUCountPerSubslice();
348+
if (num_wg < num_xe) {
349+
for (int v = vec_size; v != 1; v = v / 2) {
350+
// We only load one weight[i] in one subgroup, otherwise
351+
// memory load cannot be coalesce
352+
if (feature_dim % v == 0 && (feature_dim / v) % 32 == 0) {
353+
vec_size = v;
354+
break;
355+
}
356+
}
357+
}
353358
switch (vec_size) {
354359
case 8:
355360
EXTEND_EMBBAG_SUM_KERNEL_VEC(8);
@@ -418,6 +423,19 @@ void embedding_bag_mean_template(
418423
int vec_size = memory::can_vectorize_up_to<scalar_t>(
419424
(char*)weights.const_data_ptr());
420425
vec_size = feature_dim % vec_size == 0 ? vec_size : 1;
426+
auto num_wg = bag_num * feature_dim / vec_size /
427+
syclDeviceMaxWorkGroupSize();
428+
auto num_xe = syclGpuEuCount() / syclGpuEUCountPerSubslice();
429+
if (num_wg < num_xe) {
430+
for (int v = vec_size; v != 1; v = v / 2) {
431+
// We only load one weight[i] in one subgroup, otherwise
432+
// memory load cannot be coalesce
433+
if (feature_dim % v == 0 && (feature_dim / v) % 32 == 0) {
434+
vec_size = v;
435+
break;
436+
}
437+
}
438+
}
421439
switch (vec_size) {
422440
case 8:
423441
EXTEND_EMBBAG_MEAN_KERNEL_VEC(8);
@@ -485,6 +503,19 @@ void embedding_bag_max_template(
485503
int vec_size = memory::can_vectorize_up_to<scalar_t>(
486504
(char*)weights.const_data_ptr());
487505
vec_size = feature_dim % vec_size == 0 ? vec_size : 1;
506+
auto num_wg = bag_num * feature_dim / vec_size /
507+
syclDeviceMaxWorkGroupSize();
508+
auto num_xe = syclGpuEuCount() / syclGpuEUCountPerSubslice();
509+
if (num_wg < num_xe) {
510+
for (int v = vec_size; v != 1; v = v / 2) {
511+
// We only load one weight[i] in one subgroup, otherwise
512+
// memory load cannot be coalesce
513+
if (feature_dim % v == 0 && (feature_dim / v) % 32 == 0) {
514+
vec_size = v;
515+
break;
516+
}
517+
}
518+
}
488519
switch (vec_size) {
489520
case 8:
490521
EXTEND_EMBBAG_MAX_KERNEL_VEC(8);

0 commit comments

Comments
 (0)