diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index 53d64ace1d85..4ceaae260e33 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -124,6 +124,169 @@ __global__ void indexing_backward_kernel( } } +#ifdef USE_ROCM +template +__global__ void indexing_backward_kernel_rocm( + const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, + int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim) { + + // This implementation is adopted from indexing_backward_kernel above. + using opmath_t = at::opmath_type; + for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z){ + int64_t idx = blockIdx.x * blockDim.y + threadIdx.y; + if (idx < numel && (idx == 0 || sorted_indices[idx] != sorted_indices[idx - 1])){ + do { + // if not accumulate, we only keep the last duplicate index so skip those before it + if constexpr (!accumulate) { + if ((idx < numel - 1) && sorted_indices[idx] == sorted_indices[idx + 1]) { + idx++; + continue; + } + } + const int64_t weight_row = ((int64_t) sorted_indices[idx]) * stride + z * stride_before; + const int64_t grad_row = ((int64_t) indices[idx]) * stride + z * numel * stride; + + opmath_t gradient; + opmath_t weight; + + int64_t feature_dim = threadIdx.x + blockIdx.y * blockDim.x; + while (feature_dim < stride) { + gradient = static_cast(grad_output[grad_row + feature_dim]); + if constexpr (accumulate) { + weight = static_cast(grad_weight[weight_row + feature_dim]); + } + + if constexpr (accumulate) { + weight += gradient; + } else { + weight = gradient; + } + + grad_weight[weight_row + feature_dim] = static_cast(weight); + feature_dim += gridDim.y * blockDim.x; + } + + idx++; + } while (idx < numel && sorted_indices[idx] == sorted_indices[idx - 1]); + } + } +} +#endif + +#ifdef USE_ROCM +#define SKIP 32 +template +__global__ void indexing_backward_kernel_stride_1( + const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, + int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) { + using opmath_t = at::opmath_type; + + int laneIdx = threadIdx.x % C10_WARP_SIZE; + + const opmath_t scale = (opmath_t)1.0; + int64_t grad_row = 0; + + extern __shared__ unsigned char smem[]; + auto smem_dups_cache = reinterpret_cast(smem); + + // Each warp gets a different section of the share memory allocation: + int smem_offset = threadIdx.y * C10_WARP_SIZE; + + // Number of values processed by each thread (grain size) + for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z) { + // Init duplicates every time we compute a new set of entries: + smem_dups_cache[smem_offset + laneIdx] = 0; + + int64_t base_idx = blockIdx.x * blockDim.y * C10_WARP_SIZE + threadIdx.y * C10_WARP_SIZE; + int64_t idx = base_idx + laneIdx; + + // Each lane calculates the number of duplicates: + if (idx < numel) { + int64_t crnt_sorted_idx = sorted_indices[idx]; + + if (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1]) { + // Determine the number of duplicates in advance: + int64_t num_duplicates = 1; + + // Lookahead in case there is a large number of duplicates. Once that is done, handle the tail. + while ((idx + num_duplicates + SKIP - 1) < numel) { + if (sorted_indices[idx + num_duplicates + SKIP - 1] != crnt_sorted_idx) break; + num_duplicates += SKIP; + } + while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) { + num_duplicates++; + } + + if (!accumulate) { + const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before; + grad_row = ((int64_t)indices[idx + num_duplicates - 1]) * stride + z * numel * stride; + grad_weight[weight_row] = + static_cast(static_cast(grad_output[grad_row]) * scale); + continue; + } + + // Each lane sequentially handles the duplicate elimination: + if (num_duplicates < C10_WARP_SIZE) { + opmath_t gradient = (opmath_t)0.0; + const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before; + for (int64_t i = 0; i < num_duplicates; ++i) { + grad_row = ((int64_t) indices[idx + i]) * stride + z * numel * stride; + gradient += static_cast(grad_output[grad_row]) * scale; + } + + grad_weight[weight_row] = static_cast(static_cast(grad_weight[weight_row]) + gradient); + } else { + // Add duplicate to the cache: + smem_dups_cache[smem_offset + laneIdx] = num_duplicates; + } + } + } + + WARP_SYNC(); + + // All lanes in the warp are still active here. Use them all to reduce duplicates when + // large number of duplicates are present: + for (int subwarp = 0; subwarp < C10_WARP_SIZE; subwarp++) { + // All lanes read the shared memory entry for number of duplicates + int64_t new_num_duplicates = smem_dups_cache[smem_offset + subwarp]; + + // Check if the original sub-warp had duplicates to eliminate, if not skip. + if (new_num_duplicates == 0) + continue; + + // There are duplicates that need eliminating: + int64_t new_idx = base_idx + subwarp; + int64_t new_crnt_sorted_idx = sorted_indices[new_idx]; + const int64_t new_weight_row = new_crnt_sorted_idx * stride + z * stride_before; + + // Result of the reduction will be in this variable: + opmath_t gradient = (opmath_t)0.0; + + int64_t num_warp_passes = new_num_duplicates / C10_WARP_SIZE; + // Parallel reduction across the array of duplicates using all the lanes in the warp: + for (int64_t i = 0; i < num_warp_passes; ++i) { + grad_row = ((int64_t) indices[new_idx + i * C10_WARP_SIZE + laneIdx]) * stride + z * numel * stride; + gradient += static_cast(grad_output[grad_row]) * scale; + } + + // Reduce across the lanes of the warp: + WARP_SYNC(); + for (int offset = C10_WARP_SIZE / 2; offset > 0; offset /= 2) { + gradient += WARP_SHFL_DOWN(gradient, offset); + } + + if (laneIdx == 0) { + for (int64_t i = num_warp_passes * C10_WARP_SIZE; i < new_num_duplicates; ++i) { + grad_row = ((int64_t) indices[new_idx + i]) * stride + z * numel * stride; + gradient += static_cast(grad_output[grad_row]) * scale; + } + + grad_weight[new_weight_row] = static_cast(static_cast(grad_weight[new_weight_row]) + gradient); + } + } + } +} +#else template __global__ void indexing_backward_kernel_stride_1( const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, @@ -179,6 +342,7 @@ __global__ void indexing_backward_kernel_stride_1( } } } +#endif template __global__ void indexing_backward_kernel_small_stride( @@ -474,7 +638,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<<>>( + indexing_backward_kernel_stride_1<<>> + ( sorted_indices.const_data_ptr(), orig_indices.const_data_ptr(), expandedValue.const_data_ptr(), @@ -529,6 +708,8 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<<>>( + sorted_indices.const_data_ptr(), + orig_indices.const_data_ptr(), + expandedValue.const_data_ptr(), + src_.mutable_data_ptr(), + num_indices, + sliceSize, + strideBefore, + nElemBefore); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + AT_EXPAND(AT_FLOAT8_TYPES), + kComplexHalf, + kHalf, + kBool, + kBFloat16); + } else { + AT_DISPATCH_V2( + expandedValue.scalar_type(), + "indexing_backward", + AT_WRAP([&] { + indexing_backward_kernel_rocm<<>>( + sorted_indices.const_data_ptr(), + orig_indices.const_data_ptr(), + expandedValue.const_data_ptr(), + src_.mutable_data_ptr(), + num_indices, + sliceSize, + strideBefore, + nElemBefore); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + AT_EXPAND(AT_FLOAT8_TYPES), + kComplexHalf, + kHalf, + kBool, + kBFloat16); + } +#endif } else { AT_DISPATCH_V2( expandedValue.scalar_type(), @@ -576,8 +805,8 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List