-
Notifications
You must be signed in to change notification settings - Fork 40
Group Norm Backward Optimization with vectorization and parallel reduction #1652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Please show performance impact |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds a vectorized functor version for the Group Norm Backward kernel to improve performance on systems supporting vectorized operations. Key changes include:
- Addition of ComputeInternalGradientsVectorizedFunctor with vectorized reduction logic.
- Conditional kernel launch based on vectorization capability.
- Updated work-group size computation to accommodate the vectorized implementation.
sum1_vec[v] = static_cast<T_ACC>(vec_dY_[iv] * vec_X_[iv]); | ||
sum2_vec[v] = static_cast<T_ACC>(vec_dY_[iv]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It appears that inside the inner loop the value of sum1_vec[v] is overwritten in each iteration rather than accumulated. Consider using '+=' to aggregate results across iterations if that was the intended behavior.
sum1_vec[v] = static_cast<T_ACC>(vec_dY_[iv] * vec_X_[iv]); | |
sum2_vec[v] = static_cast<T_ACC>(vec_dY_[iv]); | |
sum1_vec[v] += static_cast<T_ACC>(vec_dY_[iv] * vec_X_[iv]); | |
sum2_vec[v] += static_cast<T_ACC>(vec_dY_[iv]); |
Copilot uses AI. Check for mistakes.
sum1_vec[v] = static_cast<T_ACC>(vec_dY_[iv] * vec_X_[iv]); | ||
sum2_vec[v] = static_cast<T_ACC>(vec_dY_[iv]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the sum1_vec update, sum2_vec[v] is overwritten on each iteration of the inner loop instead of accumulating the results. If accumulation is intended, replace '=' with '+='.
sum1_vec[v] = static_cast<T_ACC>(vec_dY_[iv] * vec_X_[iv]); | |
sum2_vec[v] = static_cast<T_ACC>(vec_dY_[iv]); | |
sum1_vec[v] += static_cast<T_ACC>(vec_dY_[iv] * vec_X_[iv]); | |
sum2_vec[v] += static_cast<T_ACC>(vec_dY_[iv]); |
Copilot uses AI. Check for mistakes.
Pls. update the PR description to elaborate on why the changes can improve the performance and the detailed performance data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Informative PR description and comments are required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, the optimization looks good to me. However, pls. address two common issues.
- Pls. avoid using non-common abbreviations
- Update the PR description by elaborating on the detailed optimization ideas and detailed performance improvements
using vec_t = memory::aligned_vector<T, VEC_SIZE>; | ||
using vec_td = memory::aligned_vector<T_ACC, VEC_SIZE>; | ||
|
||
[[intel::reqd_sub_group_size(SIMD)]] void operator()( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xytintel , @fengyuan14 , @gujinghui , could you help check the behavior of [[intel::reqd_sub_group_size(SIMD)]]
on the latest XE?
using T_ACC = acc_type_device<T, kXPU>; | ||
using vec_t = memory::aligned_vector<T, VEC_SIZE>; | ||
using vec_td = memory::aligned_vector<T_ACC, VEC_SIZE>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the rule to use UPPER
and lower
to define the namespace using
using T_ACC = acc_type_device<T, kXPU>; | ||
using vec_t = memory::aligned_vector<T, VEC_SIZE>; | ||
using vec_td = memory::aligned_vector<T_ACC, VEC_SIZE>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the meanings of _t
and _td
accordingly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use acc_vec_t instead to align with the overall code.
Vec_t and acc_vec_t represent vectors created with the corresponding datatype.
sycl::nd_item<1> item) const { | ||
vec_td sum1_vec = {}; | ||
vec_td sum2_vec = {}; | ||
auto g_start = item.get_group(0) * VEC_SIZE; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the meaning of g_
? group
or global
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It means group, use group_start instead.
|
||
#pragma unroll | ||
for (int v = 0; v < VEC_SIZE; ++v) { | ||
const int64_t nc = g_start + v; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
v
is a variable, why is nc
a constant variable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the abbreviation of nc
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nc is not an abbreviation, it means n*c in NCHW, and cuda also uses this variable name in the context.
Although v is a variable, it remains unchanged in a single loop, so nc is constant.
|
||
#pragma unroll | ||
for (int v = 0; v < VEC_SIZE; ++v) { | ||
sum1_vec[v] = GroupReduceSumWithoutBroadcast<T_ACC, SIMD>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GroupReduceSumWithoutBroadcast
represents a sum reduction within a subgroup, right? Hence, why has the function been defined as GroupXXX
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GroupReduceSumWithoutBroadcast represents a sum reduction within a group, and SubgroupReduceSumWithoutBroadcast represents a sum reduction within a subgroup.
All the requested changes have been updated.
Uh oh!
There was an error while loading. Please reload this page.