Skip to content

Optimize roi_align on BMG #1698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 50 additions & 26 deletions src/ATen/native/xpu/sycl/RoiAlignKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,33 @@ T bilinear_interpolate(
template <typename T>
struct RoiAlignForwardKernel {
void operator()(sycl::nd_item<1> item) const {
XPU_KERNEL_LOOP(item, index, nthreads_) {
// (n, c, ph, pw) is an element in the pooled output
auto wg = item.get_group(0);
Copy link
Preview

Copilot AI May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure that using the workgroup id divided by wgs_per_roi_ to compute the ROI index accurately reflects the intended work distribution; a clarifying comment here would be helpful.

Suggested change
auto wg = item.get_group(0);
auto wg = item.get_group(0);
// Compute the ROI index (n) by dividing the workgroup ID (wg) by the number of workgroups per ROI (wgs_per_roi_).
// This ensures that each ROI is processed by the correct set of workgroups.

Copilot uses AI. Check for mistakes.

auto idx = item.get_local_id(0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I delete this variable

int n = wg / wg_per_roi_;
int index = (wg - n * wg_per_roi_) * item.get_local_range(0);
if (index < item_per_roi_) {
int pw = index % pooled_width_;
int ph = (index / pooled_width_) % pooled_height_;
int c = (index / pooled_width_ / pooled_height_) % channels_;
int n = index / pooled_width_ / pooled_height_ / channels_;

const T* offset_rois = rois_ + n * 5;
int roi_batch_ind = offset_rois[0];
if (idx == 0) {
cache_roi_[0] = offset_rois[0];

// Do not using rounding; this implementation detail is critical
T offset = aligned_ ? (T)0.5 : (T)0.0;
cache_roi_[1] = offset_rois[1] * spatial_scale_ - offset;
cache_roi_[2] = offset_rois[2] * spatial_scale_ - offset;
cache_roi_[3] = offset_rois[3] * spatial_scale_ - offset;
cache_roi_[4] = offset_rois[4] * spatial_scale_ - offset;
}
item.barrier(sycl_local_fence);
Copy link
Contributor

@EikanWang EikanWang May 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The barrier may be bypassed for some work items. If so, it will trigger hw hang. Pls. refine the logic of line 76 if (index < item_per_roi_) to ensure the barrier is not bypassed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated


// Do not using rounding; this implementation detail is critical
T offset = aligned_ ? (T)0.5 : (T)0.0;
T roi_start_w = offset_rois[1] * spatial_scale_ - offset;
T roi_start_h = offset_rois[2] * spatial_scale_ - offset;
T roi_end_w = offset_rois[3] * spatial_scale_ - offset;
T roi_end_h = offset_rois[4] * spatial_scale_ - offset;
int roi_batch_ind = cache_roi_[0];
T roi_start_w = cache_roi_[1];
T roi_start_h = cache_roi_[2];
T roi_end_w = cache_roi_[3];
T roi_end_h = cache_roi_[4];

T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
Expand Down Expand Up @@ -136,9 +147,10 @@ struct RoiAlignForwardKernel {
}
}
RoiAlignForwardKernel(
int nthreads,
const T* input,
const T spatial_scale,
int item_per_rois,
int wg_per_roi,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the variable names be items_per_roi and wgs_per_roi accordingly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The RoiAlignForwardKernel does not inherit from __SYCL_KER_CONFIG_CONVENTION__. When will the sycl_ker_config_convention be invoked? @xytintel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified variable names and add SYCL_KER_CONFIG_CONVENTION inherit

int channels,
int height,
int width,
Expand All @@ -148,9 +160,10 @@ struct RoiAlignForwardKernel {
bool aligned,
const T* rois,
T* output)
: nthreads_(nthreads),
input_(input),
: input_(input),
spatial_scale_(spatial_scale),
item_per_roi_(item_per_rois),
wg_per_roi_(wg_per_roi),
channels_(channels),
height_(height),
width_(width),
Expand All @@ -160,20 +173,25 @@ struct RoiAlignForwardKernel {
aligned_(aligned),
rois_(rois),
output_(output) {}
void sycl_ker_config_convention(sycl::handler& cgh) {
cache_roi_ = sycl_local_acc_t<T>(5, cgh);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pls. define a variable for the magic value - 5. Meanwhile, please add informative comments to elaborate on why the value should be 5 rather than other values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added comment

}

private:
int nthreads_;
const T* input_;
const T spatial_scale_;
int channels_;
int height_;
int width_;
int pooled_height_;
int pooled_width_;
int sampling_ratio_;
bool aligned_;
const int item_per_roi_;
const int wg_per_roi_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments are required. Please add informative description for each variable.

const int channels_;
const int height_;
const int width_;
const int pooled_height_;
const int pooled_width_;
const int sampling_ratio_;
const bool aligned_;
const T* rois_;
T* output_;
sycl_local_acc_t<T> cache_roi_;
};

template <typename T>
Expand Down Expand Up @@ -415,11 +433,7 @@ Tensor roi_align_kernel(

at::Tensor output = at::zeros(
{num_rois, channels, pooled_height, pooled_width}, input.options());

auto output_size = num_rois * pooled_height * pooled_width * channels;
int64_t global_range =
ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512));
int64_t local_range = 512;

if (output.numel() == 0) {
return output;
Expand All @@ -433,10 +447,20 @@ Tensor roi_align_kernel(
input.scalar_type(),
"roi_align_forward_kernel_xpu",
[&] {
int64_t local_range =
syclMaxWorkGroupSize<RoiAlignForwardKernel<scalar_t>>();
int item_per_roi = pooled_height * pooled_width * channels;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

item_per_roi -> items_per_roi

if (item_per_roi < local_range) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The local_range is the maximum number of work items that can be in a single work group. Please assert that local_range should always be a multiplier of 32. Otherwise, the local_range may be adjusted to exceed the maximum number of work items. Or has the SYCL spec defined that the value of syclMaxWorkGroupSize always can be divided by 32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated, now local range will not larger than max group size

local_range = (item_per_roi + 32 - 1) / 32 *
32; // wg can be smaller but it better to be a mutiple of 32
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Frankly speaking, I cannot quite understand what's the motivation of 32. If it represents the SIMD len, pls. define a constant variable. Please. pls. elaborate on why it is better to be a multiple of 32.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's SIMD len, @xytintel can our block size to be a random number?

int wg_per_roi = (item_per_roi + local_range - 1) / local_range;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wg_per_roi -> wgs_per_roi.

int64_t global_range = wg_per_roi * num_rois;
auto kfn = RoiAlignForwardKernel<scalar_t>(
output_size,
input_.data_ptr<scalar_t>(),
spatial_scale,
item_per_roi,
wg_per_roi,
channels,
height,
width,
Expand Down
18 changes: 11 additions & 7 deletions src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only format changes, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, only format changes. auto lint changes this during build. Why our lint ci is green?

Original file line number Diff line number Diff line change
Expand Up @@ -425,9 +425,9 @@ struct UpsampleBilinear2dBackwardNotAlignKernelFunctor {
// scale is 1 if on boundary
distance_w =
distance_w + is_boundary_w * (output_width_ * 2 - distance_w);
bool is_boundary_h =
!((point_h >= output_height_) &&
(point_h <= output_height_ * input_height_ * 2 - output_height_));
bool is_boundary_h = !(
(point_h >= output_height_) &&
(point_h <= output_height_ * input_height_ * 2 - output_height_));
distance_h =
distance_h + is_boundary_h * (output_height_ * 2 - distance_h);
accscalar_t scale =
Expand Down Expand Up @@ -606,8 +606,10 @@ void launch_upsample_bilinear2d_backward_kernel(
// TODO: when input 3x3, scale is 1.5, output is 4x4,
// pytorch prefer use 1/1.5, but my implementation treat it as 3/4...
// I also have to skip double because of rounding issues, it will not pass ut
can_optimize = can_optimize && (align_corners || (input_width == (rwidth * output_width) &&
input_height == (rheight * output_height))) &&
can_optimize = can_optimize &&
(align_corners ||
(input_width == (rwidth * output_width) &&
input_height == (rheight * output_height))) &&
!std::is_same<scalar_t, double>::value;
if (can_optimize) {
if (align_corners) {
Expand Down Expand Up @@ -790,8 +792,10 @@ void launch_upsample_bilinear2d_backward_nhwc_kernel(
// TODO: when input 3x3, scale is 1.5, output is 4x4,
// pytorch prefer use 1/1.5, but my implementation treat it as 3/4...
// I also have to skip double because of rounding issues, it will not pass ut
can_optimize = can_optimize && (align_corners || (input_width == (rwidth * output_width) &&
input_height == (rheight * output_height))) &&
can_optimize = can_optimize &&
(align_corners ||
(input_width == (rwidth * output_width) &&
input_height == (rheight * output_height))) &&
!std::is_same<scalar_t, double>::value;
if (can_optimize) {
if (align_corners) {
Expand Down
Loading