Skip to content

[release/2.5][ROCm] Indexing backward kernel improvements from mainline (mutiple commits) #1937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 233 additions & 4 deletions aten/src/ATen/native/cuda/Indexing.cu
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,169 @@ __global__ void indexing_backward_kernel(
}
}

#ifdef USE_ROCM
template <typename scalar_t, bool accumulate>
__global__ void indexing_backward_kernel_rocm(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim) {

// This implementation is adopted from indexing_backward_kernel above.
using opmath_t = at::opmath_type<scalar_t>;
for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z){
int64_t idx = blockIdx.x * blockDim.y + threadIdx.y;
if (idx < numel && (idx == 0 || sorted_indices[idx] != sorted_indices[idx - 1])){
do {
// if not accumulate, we only keep the last duplicate index so skip those before it
if constexpr (!accumulate) {
if ((idx < numel - 1) && sorted_indices[idx] == sorted_indices[idx + 1]) {
idx++;
continue;
}
}
const int64_t weight_row = ((int64_t) sorted_indices[idx]) * stride + z * stride_before;
const int64_t grad_row = ((int64_t) indices[idx]) * stride + z * numel * stride;

opmath_t gradient;
opmath_t weight;

int64_t feature_dim = threadIdx.x + blockIdx.y * blockDim.x;
while (feature_dim < stride) {
gradient = static_cast<opmath_t>(grad_output[grad_row + feature_dim]);
if constexpr (accumulate) {
weight = static_cast<opmath_t>(grad_weight[weight_row + feature_dim]);
}

if constexpr (accumulate) {
weight += gradient;
} else {
weight = gradient;
}

grad_weight[weight_row + feature_dim] = static_cast<scalar_t>(weight);
feature_dim += gridDim.y * blockDim.x;
}

idx++;
} while (idx < numel && sorted_indices[idx] == sorted_indices[idx - 1]);
}
}
}
#endif

#ifdef USE_ROCM
#define SKIP 32
template <typename scalar_t>
__global__ void indexing_backward_kernel_stride_1(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
using opmath_t = at::opmath_type<scalar_t>;

int laneIdx = threadIdx.x % C10_WARP_SIZE;

const opmath_t scale = (opmath_t)1.0;
int64_t grad_row = 0;

extern __shared__ unsigned char smem[];
auto smem_dups_cache = reinterpret_cast<int64_t*>(smem);

// Each warp gets a different section of the share memory allocation:
int smem_offset = threadIdx.y * C10_WARP_SIZE;

// Number of values processed by each thread (grain size)
for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z) {
// Init duplicates every time we compute a new set of entries:
smem_dups_cache[smem_offset + laneIdx] = 0;

int64_t base_idx = blockIdx.x * blockDim.y * C10_WARP_SIZE + threadIdx.y * C10_WARP_SIZE;
int64_t idx = base_idx + laneIdx;

// Each lane calculates the number of duplicates:
if (idx < numel) {
int64_t crnt_sorted_idx = sorted_indices[idx];

if (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1]) {
// Determine the number of duplicates in advance:
int64_t num_duplicates = 1;

// Lookahead in case there is a large number of duplicates. Once that is done, handle the tail.
while ((idx + num_duplicates + SKIP - 1) < numel) {
if (sorted_indices[idx + num_duplicates + SKIP - 1] != crnt_sorted_idx) break;
num_duplicates += SKIP;
}
while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) {
num_duplicates++;
}

if (!accumulate) {
const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before;
grad_row = ((int64_t)indices[idx + num_duplicates - 1]) * stride + z * numel * stride;
grad_weight[weight_row] =
static_cast<scalar_t>(static_cast<opmath_t>(grad_output[grad_row]) * scale);
continue;
}

// Each lane sequentially handles the duplicate elimination:
if (num_duplicates < C10_WARP_SIZE) {
opmath_t gradient = (opmath_t)0.0;
const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before;
for (int64_t i = 0; i < num_duplicates; ++i) {
grad_row = ((int64_t) indices[idx + i]) * stride + z * numel * stride;
gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
}

grad_weight[weight_row] = static_cast<scalar_t>(static_cast<opmath_t>(grad_weight[weight_row]) + gradient);
} else {
// Add duplicate to the cache:
smem_dups_cache[smem_offset + laneIdx] = num_duplicates;
}
}
}

WARP_SYNC();

// All lanes in the warp are still active here. Use them all to reduce duplicates when
// large number of duplicates are present:
for (int subwarp = 0; subwarp < C10_WARP_SIZE; subwarp++) {
// All lanes read the shared memory entry for number of duplicates
int64_t new_num_duplicates = smem_dups_cache[smem_offset + subwarp];

// Check if the original sub-warp had duplicates to eliminate, if not skip.
if (new_num_duplicates == 0)
continue;

// There are duplicates that need eliminating:
int64_t new_idx = base_idx + subwarp;
int64_t new_crnt_sorted_idx = sorted_indices[new_idx];
const int64_t new_weight_row = new_crnt_sorted_idx * stride + z * stride_before;

// Result of the reduction will be in this variable:
opmath_t gradient = (opmath_t)0.0;

int64_t num_warp_passes = new_num_duplicates / C10_WARP_SIZE;
// Parallel reduction across the array of duplicates using all the lanes in the warp:
for (int64_t i = 0; i < num_warp_passes; ++i) {
grad_row = ((int64_t) indices[new_idx + i * C10_WARP_SIZE + laneIdx]) * stride + z * numel * stride;
gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
}

// Reduce across the lanes of the warp:
WARP_SYNC();
for (int offset = C10_WARP_SIZE / 2; offset > 0; offset /= 2) {
gradient += WARP_SHFL_DOWN(gradient, offset);
}

if (laneIdx == 0) {
for (int64_t i = num_warp_passes * C10_WARP_SIZE; i < new_num_duplicates; ++i) {
grad_row = ((int64_t) indices[new_idx + i]) * stride + z * numel * stride;
gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
}

grad_weight[new_weight_row] = static_cast<scalar_t>(static_cast<opmath_t>(grad_weight[new_weight_row]) + gradient);
}
}
}
}
#else
template <typename scalar_t>
__global__ void indexing_backward_kernel_stride_1(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
Expand Down Expand Up @@ -179,6 +342,7 @@ __global__ void indexing_backward_kernel_stride_1(
}
}
}
#endif

template <typename scalar_t>
__global__ void indexing_backward_kernel_small_stride(
Expand Down Expand Up @@ -474,7 +638,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
// cub on CUDA <= 11.2 have a bug that for small sizes
// cub's sort can be much slower than thrust's merge sort
// this bug is fixed in CUDA 11.3
#if (defined(CUDA_VERSION) && CUDA_VERSION < 11030) || defined(USE_ROCM)
#if (defined(CUDA_VERSION) && CUDA_VERSION < 11030) && !defined(USE_ROCM)
if (num_indices < 50000) {
index_put_with_sort_kernel_thrust_helper(linearIndex, orig_indices, sorted_indices, num_indices);
} else
Expand All @@ -495,7 +659,11 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
linearIndex.numel()*sliceSize*nElemBefore == expandedValue.numel(),
"number of flattened indices did not match number of elements in the value tensor: ",
linearIndex.numel()*sliceSize*nElemBefore, " vs ", expandedValue.numel());
#ifdef USE_ROCM
const int UNROLL = 1;
#else
const int UNROLL = 4;
#endif
const int indices_per_block = 4;
const int warp_size = at::cuda::warp_size();
dim3 grid(ceil_div(num_indices, (int64_t) indices_per_block),
Expand All @@ -505,13 +673,24 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten


if (sliceSize == 1) {
#ifdef USE_ROCM
// Adapt grid size to smaller virtual warp size:
dim3 new_grid(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)), grid.y, grid.z);
size_t smem_dups_size = indices_per_block * warp_size * sizeof(int64_t);
#define KERNEL_GRID new_grid
#define KERNEL_SMEM smem_dups_size
#else
#define KERNEL_GRID grid
#define KERNEL_SMEM 0
#endif
// This implementation is faster with high amounts of duplicates but could overflow
// if FP16 / BF16 is used
AT_DISPATCH_V2(
expandedValue.scalar_type(),
"indexing_backward_kernel_stride_1",
AT_WRAP([&] {
indexing_backward_kernel_stride_1<scalar_t><<<grid, block, 0, stream>>>(
indexing_backward_kernel_stride_1<scalar_t><<<KERNEL_GRID, block, KERNEL_SMEM, stream>>>
(
sorted_indices.const_data_ptr<int64_t>(),
orig_indices.const_data_ptr<int64_t>(),
expandedValue.const_data_ptr<scalar_t>(),
Expand All @@ -529,6 +708,8 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
kHalf,
kBool,
kBFloat16);
#undef KERNEL_GRID
#undef KERNEL_SMEM
} else {
if (sliceSize <= warp_size) {
AT_DISPATCH_V2(
Expand All @@ -553,6 +734,54 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
kHalf,
kBool,
kBFloat16);
#ifdef USE_ROCM
} else if (UNROLL == 1) {
if (accumulate) {
AT_DISPATCH_V2(
expandedValue.scalar_type(),
"indexing_backward",
AT_WRAP([&] {
indexing_backward_kernel_rocm<scalar_t, true><<<grid, block, 0, stream>>>(
sorted_indices.const_data_ptr<int64_t>(),
orig_indices.const_data_ptr<int64_t>(),
expandedValue.const_data_ptr<scalar_t>(),
src_.mutable_data_ptr<scalar_t>(),
num_indices,
sliceSize,
strideBefore,
nElemBefore);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
AT_EXPAND(AT_FLOAT8_TYPES),
kComplexHalf,
kHalf,
kBool,
kBFloat16);
} else {
AT_DISPATCH_V2(
expandedValue.scalar_type(),
"indexing_backward",
AT_WRAP([&] {
indexing_backward_kernel_rocm<scalar_t, false><<<grid, block, 0, stream>>>(
sorted_indices.const_data_ptr<int64_t>(),
orig_indices.const_data_ptr<int64_t>(),
expandedValue.const_data_ptr<scalar_t>(),
src_.mutable_data_ptr<scalar_t>(),
num_indices,
sliceSize,
strideBefore,
nElemBefore);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
AT_EXPAND(AT_FLOAT8_TYPES),
kComplexHalf,
kHalf,
kBool,
kBFloat16);
}
#endif
} else {
AT_DISPATCH_V2(
expandedValue.scalar_type(),
Expand All @@ -576,8 +805,8 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
kHalf,
kBool,
kBFloat16);
}
}
}

if (permuted) {
self.copy_(src_.permute(inversePerm));
Expand Down Expand Up @@ -628,7 +857,7 @@ void index_put_with_sort_quantized(Tensor & self, const c10::List<std::optional<
// cub on CUDA <= 11.2 have a bug that for small sizes
// cub's sort can be much slower than thrust's merge sort
// this bug is fixed in CUDA 11.3
#if (defined(CUDA_VERSION) && CUDA_VERSION < 11030) || defined(USE_ROCM)
#if (defined(CUDA_VERSION) && CUDA_VERSION < 11030) && !defined(USE_ROCM)
if (num_indices < 50000) {
index_put_with_sort_kernel_thrust_helper(linearIndex, orig_indices, sorted_indices, num_indices);
} else
Expand Down
31 changes: 31 additions & 0 deletions test/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,37 @@ def test_index_put_accumulate_expanded_values(self, device):
out_cpu = t.index_put_(indices, values2d, accumulate=True)
self.assertEqual(out_cuda.cpu(), out_cpu)

@onlyCUDA
def test_index_put_large_indices(self, device):
def generate_indices(num_indices: int, index_range: int):
indices = []
for _ in range(num_indices):
x = random.randint(0, index_range - 1)
indices.append(x)
return torch.tensor(indices)

num_indices = 401988
max_index_range = 2000
results = []
target_index_range = [16, 256, 2000]
for generated_index_range in target_index_range:
# create CPU tensors
a_tensor_size = (max_index_range, 256)
a = torch.randn(a_tensor_size, dtype=torch.bfloat16)
b = generate_indices(
num_indices=num_indices, index_range=generated_index_range
)
c_tensor_size = (num_indices, 256)
c = torch.randn(c_tensor_size, dtype=torch.bfloat16)
# create GPU copies
a_dev = a.to(device)
b_dev = b.to(device)
c_dev = c.to(device)
# run
a.index_put_(indices=[b], values=c, accumulate=True)
a_dev.index_put_(indices=[b_dev], values=c_dev, accumulate=True)
self.assertEqual(a_dev.cpu(), a)

@onlyCUDA
def test_index_put_accumulate_non_contiguous(self, device):
t = torch.zeros((5, 2, 2))
Expand Down