diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index ef24bc3628e7..17df19a6f198 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -182,6 +182,120 @@ __global__ void indexing_backward_kernel_rocm( } #endif +#ifdef USE_ROCM +#define SKIP 32 +template +__global__ void indexing_backward_kernel_stride_1( + const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, + int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) { + using opmath_t = at::opmath_type; + + int laneIdx = threadIdx.x % C10_WARP_SIZE; + + const opmath_t scale = (opmath_t)1.0; + int64_t grad_row = 0; + + extern __shared__ unsigned char smem[]; + auto smem_dups_cache = reinterpret_cast(smem); + + // Each warp gets a different section of the share memory allocation: + int smem_offset = threadIdx.y * C10_WARP_SIZE; + + // Number of values processed by each thread (grain size) + for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z) { + // Init duplicates every time we compute a new set of entries: + smem_dups_cache[smem_offset + laneIdx] = 0; + + int64_t base_idx = blockIdx.x * blockDim.y * C10_WARP_SIZE + threadIdx.y * C10_WARP_SIZE; + int64_t idx = base_idx + laneIdx; + + // Each lane calculates the number of duplicates: + if (idx < numel) { + int64_t crnt_sorted_idx = sorted_indices[idx]; + + if (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1]) { + // Determine the number of duplicates in advance: + int64_t num_duplicates = 1; + + // Lookahead in case there is a large number of duplicates. Once that is done, handle the tail. + while ((idx + num_duplicates + SKIP - 1) < numel) { + if (sorted_indices[idx + num_duplicates + SKIP - 1] != crnt_sorted_idx) break; + num_duplicates += SKIP; + } + while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) { + num_duplicates++; + } + + if (!accumulate) { + const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before; + grad_row = ((int64_t)indices[idx + num_duplicates - 1]) * stride + z * numel * stride; + grad_weight[weight_row] = + static_cast(static_cast(grad_output[grad_row]) * scale); + continue; + } + + // Each lane sequentially handles the duplicate elimination: + if (num_duplicates < C10_WARP_SIZE) { + opmath_t gradient = (opmath_t)0.0; + const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before; + for (int64_t i = 0; i < num_duplicates; ++i) { + grad_row = ((int64_t) indices[idx + i]) * stride + z * numel * stride; + gradient += static_cast(grad_output[grad_row]) * scale; + } + + grad_weight[weight_row] = static_cast(static_cast(grad_weight[weight_row]) + gradient); + } else { + // Add duplicate to the cache: + smem_dups_cache[smem_offset + laneIdx] = num_duplicates; + } + } + } + + WARP_SYNC(); + + // All lanes in the warp are still active here. Use them all to reduce duplicates when + // large number of duplicates are present: + for (int subwarp = 0; subwarp < C10_WARP_SIZE; subwarp++) { + // All lanes read the shared memory entry for number of duplicates + int64_t new_num_duplicates = smem_dups_cache[smem_offset + subwarp]; + + // Check if the original sub-warp had duplicates to eliminate, if not skip. + if (new_num_duplicates == 0) + continue; + + // There are duplicates that need eliminating: + int64_t new_idx = base_idx + subwarp; + int64_t new_crnt_sorted_idx = sorted_indices[new_idx]; + const int64_t new_weight_row = new_crnt_sorted_idx * stride + z * stride_before; + + // Result of the reduction will be in this variable: + opmath_t gradient = (opmath_t)0.0; + + int64_t num_warp_passes = new_num_duplicates / C10_WARP_SIZE; + // Parallel reduction across the array of duplicates using all the lanes in the warp: + for (int64_t i = 0; i < num_warp_passes; ++i) { + grad_row = ((int64_t) indices[new_idx + i * C10_WARP_SIZE + laneIdx]) * stride + z * numel * stride; + gradient += static_cast(grad_output[grad_row]) * scale; + } + + // Reduce across the lanes of the warp: + WARP_SYNC(); + for (int offset = C10_WARP_SIZE / 2; offset > 0; offset /= 2) { + gradient += WARP_SHFL_DOWN(gradient, offset); + } + + if (laneIdx == 0) { + for (int64_t i = num_warp_passes * C10_WARP_SIZE; i < new_num_duplicates; ++i) { + grad_row = ((int64_t) indices[new_idx + i]) * stride + z * numel * stride; + gradient += static_cast(grad_output[grad_row]) * scale; + } + + grad_weight[new_weight_row] = static_cast(static_cast(grad_weight[new_weight_row]) + gradient); + } + } + } +} +#else template __global__ void indexing_backward_kernel_stride_1( const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, @@ -237,6 +351,7 @@ __global__ void indexing_backward_kernel_stride_1( } } } +#endif template __global__ void indexing_backward_kernel_small_stride( @@ -567,13 +682,24 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<<>>( + indexing_backward_kernel_stride_1<<>> + ( sorted_indices.const_data_ptr(), orig_indices.const_data_ptr(), expandedValue.const_data_ptr(), @@ -591,6 +717,8 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List