diff --git a/src/ATen/native/xpu/sycl/GroupNormKernels.cpp b/src/ATen/native/xpu/sycl/GroupNormKernels.cpp index 935ab99f7..4038007e6 100644 --- a/src/ATen/native/xpu/sycl/GroupNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/GroupNormKernels.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -1408,7 +1409,7 @@ void group_norm_backward_kernel_impl( Tensor c3 = at::empty({N, G}, X.options().dtype(kAccType)); T_ACC* c2_data = c2.mutable_data_ptr(); T_ACC* c3_data = c3.mutable_data_ptr(); - + Tensor dummy_gamma = at::ones({1, G, D}, X.options().dtype(kAccType)); if (gamma.defined()) { auto iter = TensorIteratorConfig() .check_all_same_dtype(std::is_same::value) @@ -1417,6 +1418,14 @@ void group_norm_backward_kernel_impl( .add_owned_const_input(gamma.view({1, G, D})) .build(); gpu_kernel(iter, GroupNormBackwardC1Functor()); + } else { + auto iter = TensorIteratorConfig() + .check_all_same_dtype(std::is_same::value) + .add_output(c1) + .add_owned_const_input(rstd.view({N, G, 1})) + .add_owned_const_input(dummy_gamma.view({1, G, D})) + .build(); + gpu_kernel(iter, GroupNormBackwardC1Functor()); } wg_size = (C / G) < get_group_reduce_group_size(simd) @@ -1440,31 +1449,17 @@ void group_norm_backward_kernel_impl( c2_data, c3_data); - if (gamma.defined()) { - auto iter = TensorIteratorConfig() - .check_all_same_dtype(std::is_same::value) - .resize_outputs(false) - .add_owned_output(dX.view({N * G, D, HxW})) - .add_owned_const_input(dY.view({N * G, D, HxW})) - .add_owned_const_input(X.view({N * G, D, HxW})) - .add_owned_const_input(c1.view({N * G, D, 1})) - .add_owned_const_input(c2.view({N * G, 1, 1})) - .add_owned_const_input(c3.view({N * G, 1, 1})) - .build(); - gpu_kernel(iter, GroupNormBackwardDXFunctor()); - } else { - auto iter = TensorIteratorConfig() - .check_all_same_dtype(std::is_same::value) - .resize_outputs(false) - .add_owned_output(dX.view({N * G, D * HxW})) - .add_owned_const_input(dY.view({N * G, D * HxW})) - .add_owned_const_input(X.view({N * G, D * HxW})) - .add_owned_const_input(rstd.view({N * G, 1})) - .add_owned_const_input(c2.view({N * G, 1})) - .add_owned_const_input(c3.view({N * G, 1})) - .build(); - gpu_kernel(iter, GroupNormBackwardDXFunctor()); - } + auto iter = TensorIteratorConfig() + .check_all_same_dtype(std::is_same::value) + .resize_outputs(false) + .add_owned_output(dX.view({N * G, D, HxW})) + .add_owned_const_input(dY.view({N * G, D, HxW})) + .add_owned_const_input(X.view({N * G, D, HxW})) + .add_owned_const_input(c1.view({N * G, D, 1})) + .add_owned_const_input(c2.view({N * G, 1, 1})) + .add_owned_const_input(c3.view({N * G, 1, 1})) + .build(); + gpu_kernel(iter, GroupNormBackwardDXFunctor()); } if (dgamma.defined() || dbeta.defined()) {