Skip to content

[AUTOGENERATED] [release/2.6] [release/2.5][ROCm] Indexing backward kernel improvements from mainline #1942

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 11, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 129 additions & 1 deletion aten/src/ATen/native/cuda/Indexing.cu
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,120 @@ __global__ void indexing_backward_kernel_rocm(
}
#endif

#ifdef USE_ROCM
#define SKIP 32
template <typename scalar_t>
__global__ void indexing_backward_kernel_stride_1(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
using opmath_t = at::opmath_type<scalar_t>;

int laneIdx = threadIdx.x % C10_WARP_SIZE;

const opmath_t scale = (opmath_t)1.0;
int64_t grad_row = 0;

extern __shared__ unsigned char smem[];
auto smem_dups_cache = reinterpret_cast<int64_t*>(smem);

// Each warp gets a different section of the share memory allocation:
int smem_offset = threadIdx.y * C10_WARP_SIZE;

// Number of values processed by each thread (grain size)
for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z) {
// Init duplicates every time we compute a new set of entries:
smem_dups_cache[smem_offset + laneIdx] = 0;

int64_t base_idx = blockIdx.x * blockDim.y * C10_WARP_SIZE + threadIdx.y * C10_WARP_SIZE;
int64_t idx = base_idx + laneIdx;

// Each lane calculates the number of duplicates:
if (idx < numel) {
int64_t crnt_sorted_idx = sorted_indices[idx];

if (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1]) {
// Determine the number of duplicates in advance:
int64_t num_duplicates = 1;

// Lookahead in case there is a large number of duplicates. Once that is done, handle the tail.
while ((idx + num_duplicates + SKIP - 1) < numel) {
if (sorted_indices[idx + num_duplicates + SKIP - 1] != crnt_sorted_idx) break;
num_duplicates += SKIP;
}
while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) {
num_duplicates++;
}

if (!accumulate) {
const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before;
grad_row = ((int64_t)indices[idx + num_duplicates - 1]) * stride + z * numel * stride;
grad_weight[weight_row] =
static_cast<scalar_t>(static_cast<opmath_t>(grad_output[grad_row]) * scale);
continue;
}

// Each lane sequentially handles the duplicate elimination:
if (num_duplicates < C10_WARP_SIZE) {
opmath_t gradient = (opmath_t)0.0;
const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before;
for (int64_t i = 0; i < num_duplicates; ++i) {
grad_row = ((int64_t) indices[idx + i]) * stride + z * numel * stride;
gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
}

grad_weight[weight_row] = static_cast<scalar_t>(static_cast<opmath_t>(grad_weight[weight_row]) + gradient);
} else {
// Add duplicate to the cache:
smem_dups_cache[smem_offset + laneIdx] = num_duplicates;
}
}
}

WARP_SYNC();

// All lanes in the warp are still active here. Use them all to reduce duplicates when
// large number of duplicates are present:
for (int subwarp = 0; subwarp < C10_WARP_SIZE; subwarp++) {
// All lanes read the shared memory entry for number of duplicates
int64_t new_num_duplicates = smem_dups_cache[smem_offset + subwarp];

// Check if the original sub-warp had duplicates to eliminate, if not skip.
if (new_num_duplicates == 0)
continue;

// There are duplicates that need eliminating:
int64_t new_idx = base_idx + subwarp;
int64_t new_crnt_sorted_idx = sorted_indices[new_idx];
const int64_t new_weight_row = new_crnt_sorted_idx * stride + z * stride_before;

// Result of the reduction will be in this variable:
opmath_t gradient = (opmath_t)0.0;

int64_t num_warp_passes = new_num_duplicates / C10_WARP_SIZE;
// Parallel reduction across the array of duplicates using all the lanes in the warp:
for (int64_t i = 0; i < num_warp_passes; ++i) {
grad_row = ((int64_t) indices[new_idx + i * C10_WARP_SIZE + laneIdx]) * stride + z * numel * stride;
gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
}

// Reduce across the lanes of the warp:
WARP_SYNC();
for (int offset = C10_WARP_SIZE / 2; offset > 0; offset /= 2) {
gradient += WARP_SHFL_DOWN(gradient, offset);
}

if (laneIdx == 0) {
for (int64_t i = num_warp_passes * C10_WARP_SIZE; i < new_num_duplicates; ++i) {
grad_row = ((int64_t) indices[new_idx + i]) * stride + z * numel * stride;
gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
}

grad_weight[new_weight_row] = static_cast<scalar_t>(static_cast<opmath_t>(grad_weight[new_weight_row]) + gradient);
}
}
}
}
#else
template <typename scalar_t>
__global__ void indexing_backward_kernel_stride_1(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
Expand Down Expand Up @@ -237,6 +351,7 @@ __global__ void indexing_backward_kernel_stride_1(
}
}
}
#endif

template <typename scalar_t>
__global__ void indexing_backward_kernel_small_stride(
Expand Down Expand Up @@ -567,13 +682,24 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten


if (sliceSize == 1) {
#ifdef USE_ROCM
// Adapt grid size to smaller virtual warp size:
dim3 new_grid(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)), grid.y, grid.z);
size_t smem_dups_size = indices_per_block * warp_size * sizeof(int64_t);
#define KERNEL_GRID new_grid
#define KERNEL_SMEM smem_dups_size
#else
#define KERNEL_GRID grid
#define KERNEL_SMEM 0
#endif
// This implementation is faster with high amounts of duplicates but could overflow
// if FP16 / BF16 is used
AT_DISPATCH_V2(
expandedValue.scalar_type(),
"indexing_backward_kernel_stride_1",
AT_WRAP([&] {
indexing_backward_kernel_stride_1<scalar_t><<<grid, block, 0, stream>>>(
indexing_backward_kernel_stride_1<scalar_t><<<KERNEL_GRID, block, KERNEL_SMEM, stream>>>
(
sorted_indices.const_data_ptr<int64_t>(),
orig_indices.const_data_ptr<int64_t>(),
expandedValue.const_data_ptr<scalar_t>(),
Expand All @@ -591,6 +717,8 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
kHalf,
kBool,
kBFloat16);
#undef KERNEL_GRID
#undef KERNEL_SMEM
} else {
if (sliceSize <= warp_size) {
AT_DISPATCH_V2(
Expand Down