Skip to content

Optimize roi_align on BMG #1698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 78 additions & 45 deletions src/ATen/native/xpu/sycl/RoiAlignKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,25 @@ T bilinear_interpolate(
return 0;
}

if (y <= 0)
y = 0;
if (x <= 0)
x = 0;
y = std::max(T(0), y);
x = std::max(T(0), x);

int y_low = (int)y;
int x_low = (int)x;
int y_high;
int x_high;

if (y_low >= height - 1) {
y_high = y_low = height - 1;
y_low = std::min(height - 1, y_low);
x_low = std::min(width - 1, x_low);
y_high = std::min(y_low + 1, height - 1);
x_high = std::min(x_low + 1, width - 1);

if (y_low == height - 1) {
y = (T)y_low;
} else {
y_high = y_low + 1;
}

if (x_low >= width - 1) {
x_high = x_low = width - 1;
if (x_low == width - 1) {
x = (T)x_low;
} else {
x_high = x_low + 1;
}

T ly = y - y_low;
Expand All @@ -67,24 +64,39 @@ T bilinear_interpolate(
return val;
}
template <typename T>
struct RoiAlignForwardKernel {
struct RoiAlignForwardKernel : public __SYCL_KER_CONFIG_CONVENTION__ {
void operator()(sycl::nd_item<1> item) const {
XPU_KERNEL_LOOP(item, index, nthreads_) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width_;
int ph = (index / pooled_width_) % pooled_height_;
int c = (index / pooled_width_ / pooled_height_) % channels_;
int n = index / pooled_width_ / pooled_height_ / channels_;

const T* offset_rois = rois_ + n * 5;
int roi_batch_ind = offset_rois[0];
// each roi will have 5 values, batch_idx,x1,y1,x2,y2
constexpr int roi_size = 5;
auto wg = item.get_group(0);
Copy link
Preview

Copilot AI May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure that using the workgroup id divided by wgs_per_roi_ to compute the ROI index accurately reflects the intended work distribution; a clarifying comment here would be helpful.

Suggested change
auto wg = item.get_group(0);
auto wg = item.get_group(0);
// Compute the ROI index (n) by dividing the workgroup ID (wg) by the number of workgroups per ROI (wgs_per_roi_).
// This ensures that each ROI is processed by the correct set of workgroups.

Copilot uses AI. Check for mistakes.

int n = wg / wgs_per_roi_;
int output_index_on_batch_n =
(wg - n * wgs_per_roi_) * item.get_local_range(0) +
item.get_local_id(0);
const T* current_roi = rois_ + n * roi_size;
if (item.get_local_id(0) == 0) {
cached_roi_[0] = current_roi[0];

// Do not using rounding; this implementation detail is critical
T offset = aligned_ ? (T)0.5 : (T)0.0;
T roi_start_w = offset_rois[1] * spatial_scale_ - offset;
T roi_start_h = offset_rois[2] * spatial_scale_ - offset;
T roi_end_w = offset_rois[3] * spatial_scale_ - offset;
T roi_end_h = offset_rois[4] * spatial_scale_ - offset;
cached_roi_[1] = current_roi[1] * spatial_scale_ - offset;
cached_roi_[2] = current_roi[2] * spatial_scale_ - offset;
cached_roi_[3] = current_roi[3] * spatial_scale_ - offset;
cached_roi_[4] = current_roi[4] * spatial_scale_ - offset;
}
item.barrier(sycl_local_fence);

if (output_index_on_batch_n < items_per_roi_) {
int pw = output_index_on_batch_n % pooled_width_;
int ph = (output_index_on_batch_n / pooled_width_) % pooled_height_;
int c = (output_index_on_batch_n / pooled_width_ / pooled_height_) %
channels_;

int roi_batch_ind = cached_roi_[0];
T roi_start_w = cached_roi_[1];
T roi_start_h = cached_roi_[2];
T roi_end_w = cached_roi_[3];
T roi_end_h = cached_roi_[4];

T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
Expand Down Expand Up @@ -125,20 +137,26 @@ struct RoiAlignForwardKernel {
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);

T val =
bilinear_interpolate(offset_input, height_, width_, y, x, index);
T val = bilinear_interpolate(
offset_input,
height_,
width_,
y,
x,
output_index_on_batch_n + n * items_per_roi_);
output_val += val;
}
}
output_val /= count;

output_[index] = output_val;
output_[output_index_on_batch_n + n * items_per_roi_] = output_val;
}
}
RoiAlignForwardKernel(
int nthreads,
const T* input,
const T spatial_scale,
int items_per_rois,
int wgs_per_roi,
int channels,
int height,
int width,
Expand All @@ -148,9 +166,10 @@ struct RoiAlignForwardKernel {
bool aligned,
const T* rois,
T* output)
: nthreads_(nthreads),
input_(input),
: input_(input),
spatial_scale_(spatial_scale),
items_per_roi_(items_per_rois),
wgs_per_roi_(wgs_per_roi),
channels_(channels),
height_(height),
width_(width),
Expand All @@ -160,20 +179,26 @@ struct RoiAlignForwardKernel {
aligned_(aligned),
rois_(rois),
output_(output) {}
void sycl_ker_config_convention(sycl::handler& cgh) {
// each roi will have 5 values, batch_idx,x1,y1,x2,y2
cached_roi_ = sycl_local_acc_t<T>(5, cgh);
}

private:
int nthreads_;
const T* input_;
const T spatial_scale_;
int channels_;
int height_;
int width_;
int pooled_height_;
int pooled_width_;
int sampling_ratio_;
bool aligned_;
const int items_per_roi_;
const int wgs_per_roi_;
const int channels_;
const int height_;
const int width_;
const int pooled_height_;
const int pooled_width_;
const int sampling_ratio_;
const bool aligned_;
const T* rois_;
T* output_;
sycl_local_acc_t<T> cached_roi_;
};

template <typename T>
Expand Down Expand Up @@ -415,11 +440,7 @@ Tensor roi_align_kernel(

at::Tensor output = at::zeros(
{num_rois, channels, pooled_height, pooled_width}, input.options());

auto output_size = num_rois * pooled_height * pooled_width * channels;
int64_t global_range =
ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512));
int64_t local_range = 512;

if (output.numel() == 0) {
return output;
Expand All @@ -433,10 +454,22 @@ Tensor roi_align_kernel(
input.scalar_type(),
"roi_align_forward_kernel_xpu",
[&] {
int64_t local_range =
syclMaxWorkGroupSize<RoiAlignForwardKernel<scalar_t>>();
int items_per_roi = pooled_height * pooled_width * channels;
if (items_per_roi < local_range) {
constexpr int simd_len = 32;
local_range = std::min(
local_range,
int64_t(items_per_roi + simd_len - 1) / simd_len * simd_len);
}
int wgs_per_roi = (items_per_roi + local_range - 1) / local_range;
int64_t global_range = wgs_per_roi * num_rois;
auto kfn = RoiAlignForwardKernel<scalar_t>(
output_size,
input_.data_ptr<scalar_t>(),
spatial_scale,
items_per_roi,
wgs_per_roi,
channels,
height,
width,
Expand Down
18 changes: 11 additions & 7 deletions src/ATen/native/xpu/sycl/UpSampleBilinear2dKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,9 +425,9 @@ struct UpsampleBilinear2dBackwardNotAlignKernelFunctor {
// scale is 1 if on boundary
distance_w =
distance_w + is_boundary_w * (output_width_ * 2 - distance_w);
bool is_boundary_h =
!((point_h >= output_height_) &&
(point_h <= output_height_ * input_height_ * 2 - output_height_));
bool is_boundary_h = !(
(point_h >= output_height_) &&
(point_h <= output_height_ * input_height_ * 2 - output_height_));
distance_h =
distance_h + is_boundary_h * (output_height_ * 2 - distance_h);
accscalar_t scale =
Expand Down Expand Up @@ -606,8 +606,10 @@ void launch_upsample_bilinear2d_backward_kernel(
// TODO: when input 3x3, scale is 1.5, output is 4x4,
// pytorch prefer use 1/1.5, but my implementation treat it as 3/4...
// I also have to skip double because of rounding issues, it will not pass ut
can_optimize = can_optimize && (align_corners || (input_width == (rwidth * output_width) &&
input_height == (rheight * output_height))) &&
can_optimize = can_optimize &&
(align_corners ||
(input_width == (rwidth * output_width) &&
input_height == (rheight * output_height))) &&
!std::is_same<scalar_t, double>::value;
if (can_optimize) {
if (align_corners) {
Expand Down Expand Up @@ -790,8 +792,10 @@ void launch_upsample_bilinear2d_backward_nhwc_kernel(
// TODO: when input 3x3, scale is 1.5, output is 4x4,
// pytorch prefer use 1/1.5, but my implementation treat it as 3/4...
// I also have to skip double because of rounding issues, it will not pass ut
can_optimize = can_optimize && (align_corners || (input_width == (rwidth * output_width) &&
input_height == (rheight * output_height))) &&
can_optimize = can_optimize &&
(align_corners ||
(input_width == (rwidth * output_width) &&
input_height == (rheight * output_height))) &&
!std::is_same<scalar_t, double>::value;
if (can_optimize) {
if (align_corners) {
Expand Down