-
Notifications
You must be signed in to change notification settings - Fork 40
Optimize roi_align on BMG #1698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
a56ddb6
3866176
57fbe22
60a466d
378a035
2b07b17
8ea16f8
834be4c
5153f32
a764ebe
1775046
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -69,22 +69,33 @@ T bilinear_interpolate( | |
template <typename T> | ||
struct RoiAlignForwardKernel { | ||
void operator()(sycl::nd_item<1> item) const { | ||
XPU_KERNEL_LOOP(item, index, nthreads_) { | ||
// (n, c, ph, pw) is an element in the pooled output | ||
auto wg = item.get_group(0); | ||
auto idx = item.get_local_id(0); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pls. rename this variable or the variable name of line 75 - https://github.com/intel/torch-xpu-ops/pull/1698/files#diff-5d6dc19a588e273ebfc8bf9dcc23fdc67ff9e961075e5d8e1385c7e896ef3ce9R75. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I delete this variable |
||
int n = wg / wg_per_roi_; | ||
int index = (wg - n * wg_per_roi_) * item.get_local_range(0); | ||
if (index < item_per_roi_) { | ||
int pw = index % pooled_width_; | ||
int ph = (index / pooled_width_) % pooled_height_; | ||
int c = (index / pooled_width_ / pooled_height_) % channels_; | ||
int n = index / pooled_width_ / pooled_height_ / channels_; | ||
|
||
const T* offset_rois = rois_ + n * 5; | ||
int roi_batch_ind = offset_rois[0]; | ||
if (idx == 0) { | ||
cache_roi_[0] = offset_rois[0]; | ||
|
||
// Do not using rounding; this implementation detail is critical | ||
T offset = aligned_ ? (T)0.5 : (T)0.0; | ||
cache_roi_[1] = offset_rois[1] * spatial_scale_ - offset; | ||
cache_roi_[2] = offset_rois[2] * spatial_scale_ - offset; | ||
cache_roi_[3] = offset_rois[3] * spatial_scale_ - offset; | ||
cache_roi_[4] = offset_rois[4] * spatial_scale_ - offset; | ||
} | ||
item.barrier(sycl_local_fence); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The barrier may be bypassed for some work items. If so, it will trigger hw hang. Pls. refine the logic of line 76 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated |
||
|
||
// Do not using rounding; this implementation detail is critical | ||
T offset = aligned_ ? (T)0.5 : (T)0.0; | ||
T roi_start_w = offset_rois[1] * spatial_scale_ - offset; | ||
T roi_start_h = offset_rois[2] * spatial_scale_ - offset; | ||
T roi_end_w = offset_rois[3] * spatial_scale_ - offset; | ||
T roi_end_h = offset_rois[4] * spatial_scale_ - offset; | ||
int roi_batch_ind = cache_roi_[0]; | ||
T roi_start_w = cache_roi_[1]; | ||
T roi_start_h = cache_roi_[2]; | ||
T roi_end_w = cache_roi_[3]; | ||
T roi_end_h = cache_roi_[4]; | ||
|
||
T roi_width = roi_end_w - roi_start_w; | ||
T roi_height = roi_end_h - roi_start_h; | ||
|
@@ -136,9 +147,10 @@ struct RoiAlignForwardKernel { | |
} | ||
} | ||
RoiAlignForwardKernel( | ||
int nthreads, | ||
const T* input, | ||
const T spatial_scale, | ||
int item_per_rois, | ||
int wg_per_roi, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should the variable names be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modified variable names and add SYCL_KER_CONFIG_CONVENTION inherit |
||
int channels, | ||
int height, | ||
int width, | ||
|
@@ -148,9 +160,10 @@ struct RoiAlignForwardKernel { | |
bool aligned, | ||
const T* rois, | ||
T* output) | ||
: nthreads_(nthreads), | ||
input_(input), | ||
: input_(input), | ||
spatial_scale_(spatial_scale), | ||
item_per_roi_(item_per_rois), | ||
wg_per_roi_(wg_per_roi), | ||
channels_(channels), | ||
height_(height), | ||
width_(width), | ||
|
@@ -160,20 +173,25 @@ struct RoiAlignForwardKernel { | |
aligned_(aligned), | ||
rois_(rois), | ||
output_(output) {} | ||
void sycl_ker_config_convention(sycl::handler& cgh) { | ||
cache_roi_ = sycl_local_acc_t<T>(5, cgh); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pls. define a variable for the magic value - There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added comment |
||
} | ||
|
||
private: | ||
int nthreads_; | ||
const T* input_; | ||
const T spatial_scale_; | ||
int channels_; | ||
int height_; | ||
int width_; | ||
int pooled_height_; | ||
int pooled_width_; | ||
int sampling_ratio_; | ||
bool aligned_; | ||
const int item_per_roi_; | ||
const int wg_per_roi_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comments are required. Please add informative description for each variable. |
||
const int channels_; | ||
const int height_; | ||
const int width_; | ||
const int pooled_height_; | ||
const int pooled_width_; | ||
const int sampling_ratio_; | ||
const bool aligned_; | ||
const T* rois_; | ||
T* output_; | ||
sycl_local_acc_t<T> cache_roi_; | ||
}; | ||
|
||
template <typename T> | ||
|
@@ -415,11 +433,7 @@ Tensor roi_align_kernel( | |
|
||
at::Tensor output = at::zeros( | ||
{num_rois, channels, pooled_height, pooled_width}, input.options()); | ||
|
||
auto output_size = num_rois * pooled_height * pooled_width * channels; | ||
int64_t global_range = | ||
ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)); | ||
int64_t local_range = 512; | ||
|
||
if (output.numel() == 0) { | ||
return output; | ||
|
@@ -433,10 +447,20 @@ Tensor roi_align_kernel( | |
input.scalar_type(), | ||
"roi_align_forward_kernel_xpu", | ||
[&] { | ||
int64_t local_range = | ||
syclMaxWorkGroupSize<RoiAlignForwardKernel<scalar_t>>(); | ||
int item_per_roi = pooled_height * pooled_width * channels; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
if (item_per_roi < local_range) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated, now local range will not larger than max group size |
||
local_range = (item_per_roi + 32 - 1) / 32 * | ||
32; // wg can be smaller but it better to be a mutiple of 32 | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Frankly speaking, I cannot quite understand what's the motivation of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's SIMD len, @xytintel can our block size to be a random number? |
||
int wg_per_roi = (item_per_roi + local_range - 1) / local_range; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
int64_t global_range = wg_per_roi * num_rois; | ||
auto kfn = RoiAlignForwardKernel<scalar_t>( | ||
output_size, | ||
input_.data_ptr<scalar_t>(), | ||
spatial_scale, | ||
item_per_roi, | ||
wg_per_roi, | ||
channels, | ||
height, | ||
width, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only format changes, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, only format changes. auto lint changes this during build. Why our lint ci is green? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure that using the workgroup id divided by wgs_per_roi_ to compute the ROI index accurately reflects the intended work distribution; a clarifying comment here would be helpful.
Copilot uses AI. Check for mistakes.