@@ -72,7 +72,7 @@ void embedding_bag(
72
72
vec_idx_t * max_idx_vec = reinterpret_cast <vec_idx_t *>(max_index);
73
73
74
74
int vectorized_feature_dim = feature_dim / vec_size;
75
- int64_t work_group_size = syclMaxWorkGroupSize<KernelClass> ();
75
+ int64_t work_group_size = syclDeviceMaxWorkGroupSize ();
76
76
// TODO: we can set a smaller num_work_group and add for loop in kernel
77
77
int64_t num_work_group = ceil_div (
78
78
static_cast <int64_t >(bag_num * vectorized_feature_dim),
@@ -342,14 +342,19 @@ void embedding_bag_sum_template(
342
342
int vec_size = memory::can_vectorize_up_to<scalar_t >(
343
343
(char *)weights.const_data_ptr ());
344
344
vec_size = feature_dim % vec_size == 0 ? vec_size : 1 ;
345
- // for (int v : {8, 4, 2, 1}) {
346
- // // We only load one weight[i] in one subgroup, otherwise memory
347
- // // load cannot be coalesce
348
- // if (feature_dim % v == 0 && (feature_dim / v) % 32 == 0) {
349
- // vec_size = v;
350
- // break;
351
- // }
352
- // }
345
+ auto num_wg = bag_num * feature_dim / vec_size /
346
+ syclDeviceMaxWorkGroupSize ();
347
+ auto num_xe = syclGpuEuCount () / syclGpuEUCountPerSubslice ();
348
+ if (num_wg < num_xe) {
349
+ for (int v = vec_size; v != 1 ; v = v / 2 ) {
350
+ // We only load one weight[i] in one subgroup, otherwise
351
+ // memory load cannot be coalesce
352
+ if (feature_dim % v == 0 && (feature_dim / v) % 32 == 0 ) {
353
+ vec_size = v;
354
+ break ;
355
+ }
356
+ }
357
+ }
353
358
switch (vec_size) {
354
359
case 8 :
355
360
EXTEND_EMBBAG_SUM_KERNEL_VEC (8 );
@@ -418,6 +423,19 @@ void embedding_bag_mean_template(
418
423
int vec_size = memory::can_vectorize_up_to<scalar_t >(
419
424
(char *)weights.const_data_ptr ());
420
425
vec_size = feature_dim % vec_size == 0 ? vec_size : 1 ;
426
+ auto num_wg = bag_num * feature_dim / vec_size /
427
+ syclDeviceMaxWorkGroupSize ();
428
+ auto num_xe = syclGpuEuCount () / syclGpuEUCountPerSubslice ();
429
+ if (num_wg < num_xe) {
430
+ for (int v = vec_size; v != 1 ; v = v / 2 ) {
431
+ // We only load one weight[i] in one subgroup, otherwise
432
+ // memory load cannot be coalesce
433
+ if (feature_dim % v == 0 && (feature_dim / v) % 32 == 0 ) {
434
+ vec_size = v;
435
+ break ;
436
+ }
437
+ }
438
+ }
421
439
switch (vec_size) {
422
440
case 8 :
423
441
EXTEND_EMBBAG_MEAN_KERNEL_VEC (8 );
@@ -485,6 +503,19 @@ void embedding_bag_max_template(
485
503
int vec_size = memory::can_vectorize_up_to<scalar_t >(
486
504
(char *)weights.const_data_ptr ());
487
505
vec_size = feature_dim % vec_size == 0 ? vec_size : 1 ;
506
+ auto num_wg = bag_num * feature_dim / vec_size /
507
+ syclDeviceMaxWorkGroupSize ();
508
+ auto num_xe = syclGpuEuCount () / syclGpuEUCountPerSubslice ();
509
+ if (num_wg < num_xe) {
510
+ for (int v = vec_size; v != 1 ; v = v / 2 ) {
511
+ // We only load one weight[i] in one subgroup, otherwise
512
+ // memory load cannot be coalesce
513
+ if (feature_dim % v == 0 && (feature_dim / v) % 32 == 0 ) {
514
+ vec_size = v;
515
+ break ;
516
+ }
517
+ }
518
+ }
488
519
switch (vec_size) {
489
520
case 8 :
490
521
EXTEND_EMBBAG_MAX_KERNEL_VEC (8 );
0 commit comments