-
Notifications
You must be signed in to change notification settings - Fork 40
Optimize roi_align on BMG #1698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
XPU_KERNEL_LOOP(item, index, nthreads_) { | ||
// (n, c, ph, pw) is an element in the pooled output | ||
auto wg = item.get_group(0); | ||
auto idx = item.get_local_id(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pls. rename this variable or the variable name of line 75 - https://github.com/intel/torch-xpu-ops/pull/1698/files#diff-5d6dc19a588e273ebfc8bf9dcc23fdc67ff9e961075e5d8e1385c7e896ef3ce9R75.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I delete this variable
int item_per_rois, | ||
int wg_per_roi, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the variable names be items_per_roi
and wgs_per_roi
accordingly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The RoiAlignForwardKernel
does not inherit from __SYCL_KER_CONFIG_CONVENTION__
. When will the sycl_ker_config_convention
be invoked? @xytintel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified variable names and add SYCL_KER_CONFIG_CONVENTION inherit
@@ -160,20 +173,25 @@ struct RoiAlignForwardKernel { | |||
aligned_(aligned), | |||
rois_(rois), | |||
output_(output) {} | |||
void sycl_ker_config_convention(sycl::handler& cgh) { | |||
cache_roi_ = sycl_local_acc_t<T>(5, cgh); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pls. define a variable for the magic value - 5
. Meanwhile, please add informative comments to elaborate on why the value should be 5
rather than other values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added comment
local_range = (item_per_roi + 32 - 1) / 32 * | ||
32; // wg can be smaller but it better to be a mutiple of 32 | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Frankly speaking, I cannot quite understand what's the motivation of 32
. If it represents the SIMD len, pls. define a constant variable. Please. pls. elaborate on why it is better to be a multiple of 32
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's SIMD len, @xytintel can our block size to be a random number?
cache_roi_[3] = offset_rois[3] * spatial_scale_ - offset; | ||
cache_roi_[4] = offset_rois[4] * spatial_scale_ - offset; | ||
} | ||
item.barrier(sycl_local_fence); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The barrier may be bypassed for some work items. If so, it will trigger hw hang. Pls. refine the logic of line 76 if (index < item_per_roi_)
to ensure the barrier
is not bypassed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
const int item_per_roi_; | ||
const int wg_per_roi_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments are required. Please add informative description for each variable.
int64_t local_range = | ||
syclMaxWorkGroupSize<RoiAlignForwardKernel<scalar_t>>(); | ||
int item_per_roi = pooled_height * pooled_width * channels; | ||
if (item_per_roi < local_range) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The local_range
is the maximum number of work items that can be in a single work group. Please assert that local_range
should always be a multiplier of 32. Otherwise, the local_range
may be adjusted to exceed the maximum number of work items. Or has the SYCL spec defined that the value of syclMaxWorkGroupSize
always can be divided by 32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated, now local range will not larger than max group size
@@ -433,10 +447,20 @@ Tensor roi_align_kernel( | |||
input.scalar_type(), | |||
"roi_align_forward_kernel_xpu", | |||
[&] { | |||
int64_t local_range = | |||
syclMaxWorkGroupSize<RoiAlignForwardKernel<scalar_t>>(); | |||
int item_per_roi = pooled_height * pooled_width * channels; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
item_per_roi
-> items_per_roi
local_range = (item_per_roi + 32 - 1) / 32 * | ||
32; // wg can be smaller but it better to be a mutiple of 32 | ||
} | ||
int wg_per_roi = (item_per_roi + local_range - 1) / local_range; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wg_per_roi
-> wgs_per_roi
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only format changes, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, only format changes. auto lint changes this during build. Why our lint ci is green?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR aims to optimize the roi_align performance on BMG by reducing repeated LLC memory accesses and streamlining conditional execution. Key changes include refactoring boundaries and conditional checks in the upsample kernels, and enhancing workgroup-based caching and indexing in the roi_align implementation.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
File | Description |
---|---|
src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp | Refined boundary condition handling and restructured the can_optimize condition |
src/ATen/native/xpu/sycl/RoiAlignKernels.cpp | Updated bilinear interpolation clamping and improved ROI workgroup indexing with shared memory caching |
Comments suppressed due to low confidence (1)
src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp:608
- Consider refactoring this compound conditional for 'can_optimize' to improve readability and maintainability, perhaps by extracting it into a helper function if it is reused.
can_optimize = can_optimize && (align_corners || (input_width == (rwidth * output_width) &&
int roi_batch_ind = offset_rois[0]; | ||
// each roi will have 5 values, batch_idx,x1,y1,x2,y2 | ||
constexpr int roi_size = 5; | ||
auto wg = item.get_group(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure that using the workgroup id divided by wgs_per_roi_ to compute the ROI index accurately reflects the intended work distribution; a clarifying comment here would be helpful.
auto wg = item.get_group(0); | |
auto wg = item.get_group(0); | |
// Compute the ROI index (n) by dividing the workgroup ID (wg) by the number of workgroups per ROI (wgs_per_roi_). | |
// This ensures that each ROI is processed by the correct set of workgroups. |
Copilot uses AI. Check for mistakes.
For input [1, 2048, 50, 75], rois [1000,5], roi align takes 4.7 ms on PVC but 75 ms on BMG. Each roi will have 2048xoutput_hxoutput_w work items reading the same value from LLC, and it's very slow on BMG. After put them into shared local memory, PVC takes 4.0ms, BMG reaches 7.5ms. I also removed some if else branching by min/max. I also fix a code style issue.