From 8448168b8c9cd6c2b8f1bc2db23b2b533875748b Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Mon, 24 Feb 2025 21:19:06 +0000 Subject: [PATCH] [ROCm] Optimize the stride one indexing backwards kernel (#146420) This patch makes several changes to the stride 1 backwards indexing kernel as follows: - enables the computation across the `sorted_indices` array to happen in parallel by all the lanes in the warp, this means that the accesses to `sorted_indices` are now fully coalesced. - the duplicate counting now happens in parallel: each lane in the warp counts the duplicates of a different `idx`. - enable skipping during duplicate count: this optimization ensures that for large number of duplicates we can skip 32 values at time to speed up the count. - for low number of duplicates i.e. we have less than `warp-size` duplicates then just perform the tail reduction which avoid the wasteful parallel reduction across the warp for this case (it would only add zero values). - for high number of duplicates i.e. when we have more than `warp-size` duplicates then we still use the full warp of lanes to compute the reduced value with as much parallelism as possible. This is done by making sure that all lanes stick around and cooperatively execute the reduction in case there is a single `idx` which has a large number of duplicates (i.e. a duplicate spike). For this to happen we use shared memory to pass the duplicate count computed in parallel in the first part of the kernel to the cooperative reduction part of the kernel. Benefits on examples extracted from workloads show a 3.6x to 10x speed-up. co-author: Hashem Hashemi Pull Request resolved: https://github.com/pytorch/pytorch/pull/146420 Approved by: https://github.com/pruthvistony, https://github.com/jeffdaily --- aten/src/ATen/native/cuda/Indexing.cu | 130 +++++++++++++++++++++++++- 1 file changed, 129 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index ef24bc3628e7..17df19a6f198 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -182,6 +182,120 @@ __global__ void indexing_backward_kernel_rocm( } #endif +#ifdef USE_ROCM +#define SKIP 32 +template +__global__ void indexing_backward_kernel_stride_1( + const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, + int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) { + using opmath_t = at::opmath_type; + + int laneIdx = threadIdx.x % C10_WARP_SIZE; + + const opmath_t scale = (opmath_t)1.0; + int64_t grad_row = 0; + + extern __shared__ unsigned char smem[]; + auto smem_dups_cache = reinterpret_cast(smem); + + // Each warp gets a different section of the share memory allocation: + int smem_offset = threadIdx.y * C10_WARP_SIZE; + + // Number of values processed by each thread (grain size) + for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z) { + // Init duplicates every time we compute a new set of entries: + smem_dups_cache[smem_offset + laneIdx] = 0; + + int64_t base_idx = blockIdx.x * blockDim.y * C10_WARP_SIZE + threadIdx.y * C10_WARP_SIZE; + int64_t idx = base_idx + laneIdx; + + // Each lane calculates the number of duplicates: + if (idx < numel) { + int64_t crnt_sorted_idx = sorted_indices[idx]; + + if (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1]) { + // Determine the number of duplicates in advance: + int64_t num_duplicates = 1; + + // Lookahead in case there is a large number of duplicates. Once that is done, handle the tail. + while ((idx + num_duplicates + SKIP - 1) < numel) { + if (sorted_indices[idx + num_duplicates + SKIP - 1] != crnt_sorted_idx) break; + num_duplicates += SKIP; + } + while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) { + num_duplicates++; + } + + if (!accumulate) { + const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before; + grad_row = ((int64_t)indices[idx + num_duplicates - 1]) * stride + z * numel * stride; + grad_weight[weight_row] = + static_cast(static_cast(grad_output[grad_row]) * scale); + continue; + } + + // Each lane sequentially handles the duplicate elimination: + if (num_duplicates < C10_WARP_SIZE) { + opmath_t gradient = (opmath_t)0.0; + const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before; + for (int64_t i = 0; i < num_duplicates; ++i) { + grad_row = ((int64_t) indices[idx + i]) * stride + z * numel * stride; + gradient += static_cast(grad_output[grad_row]) * scale; + } + + grad_weight[weight_row] = static_cast(static_cast(grad_weight[weight_row]) + gradient); + } else { + // Add duplicate to the cache: + smem_dups_cache[smem_offset + laneIdx] = num_duplicates; + } + } + } + + WARP_SYNC(); + + // All lanes in the warp are still active here. Use them all to reduce duplicates when + // large number of duplicates are present: + for (int subwarp = 0; subwarp < C10_WARP_SIZE; subwarp++) { + // All lanes read the shared memory entry for number of duplicates + int64_t new_num_duplicates = smem_dups_cache[smem_offset + subwarp]; + + // Check if the original sub-warp had duplicates to eliminate, if not skip. + if (new_num_duplicates == 0) + continue; + + // There are duplicates that need eliminating: + int64_t new_idx = base_idx + subwarp; + int64_t new_crnt_sorted_idx = sorted_indices[new_idx]; + const int64_t new_weight_row = new_crnt_sorted_idx * stride + z * stride_before; + + // Result of the reduction will be in this variable: + opmath_t gradient = (opmath_t)0.0; + + int64_t num_warp_passes = new_num_duplicates / C10_WARP_SIZE; + // Parallel reduction across the array of duplicates using all the lanes in the warp: + for (int64_t i = 0; i < num_warp_passes; ++i) { + grad_row = ((int64_t) indices[new_idx + i * C10_WARP_SIZE + laneIdx]) * stride + z * numel * stride; + gradient += static_cast(grad_output[grad_row]) * scale; + } + + // Reduce across the lanes of the warp: + WARP_SYNC(); + for (int offset = C10_WARP_SIZE / 2; offset > 0; offset /= 2) { + gradient += WARP_SHFL_DOWN(gradient, offset); + } + + if (laneIdx == 0) { + for (int64_t i = num_warp_passes * C10_WARP_SIZE; i < new_num_duplicates; ++i) { + grad_row = ((int64_t) indices[new_idx + i]) * stride + z * numel * stride; + gradient += static_cast(grad_output[grad_row]) * scale; + } + + grad_weight[new_weight_row] = static_cast(static_cast(grad_weight[new_weight_row]) + gradient); + } + } + } +} +#else template __global__ void indexing_backward_kernel_stride_1( const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, @@ -237,6 +351,7 @@ __global__ void indexing_backward_kernel_stride_1( } } } +#endif template __global__ void indexing_backward_kernel_small_stride( @@ -567,13 +682,24 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<<>>( + indexing_backward_kernel_stride_1<<>> + ( sorted_indices.const_data_ptr(), orig_indices.const_data_ptr(), expandedValue.const_data_ptr(), @@ -591,6 +717,8 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List