Skip to content

Commit c040e57

Browse files
doru1004pruthvistony
authored andcommitted
[ROCm] Optimize the stride one indexing backwards kernel (pytorch#146420)
This patch makes several changes to the stride 1 backwards indexing kernel as follows: - enables the computation across the `sorted_indices` array to happen in parallel by all the lanes in the warp, this means that the accesses to `sorted_indices` are now fully coalesced. - the duplicate counting now happens in parallel: each lane in the warp counts the duplicates of a different `idx`. - enable skipping during duplicate count: this optimization ensures that for large number of duplicates we can skip 32 values at time to speed up the count. - for low number of duplicates i.e. we have less than `warp-size` duplicates then just perform the tail reduction which avoid the wasteful parallel reduction across the warp for this case (it would only add zero values). - for high number of duplicates i.e. when we have more than `warp-size` duplicates then we still use the full warp of lanes to compute the reduced value with as much parallelism as possible. This is done by making sure that all lanes stick around and cooperatively execute the reduction in case there is a single `idx` which has a large number of duplicates (i.e. a duplicate spike). For this to happen we use shared memory to pass the duplicate count computed in parallel in the first part of the kernel to the cooperative reduction part of the kernel. Benefits on examples extracted from workloads show a 3.6x to 10x speed-up. co-author: Hashem Hashemi <[email protected]> Pull Request resolved: pytorch#146420 Approved by: https://github.com/pruthvistony, https://github.com/jeffdaily
1 parent 0e782cb commit c040e57

File tree

1 file changed

+129
-1
lines changed

1 file changed

+129
-1
lines changed

aten/src/ATen/native/cuda/Indexing.cu

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,120 @@ __global__ void indexing_backward_kernel_rocm(
173173
}
174174
#endif
175175

176+
#ifdef USE_ROCM
177+
#define SKIP 32
178+
template <typename scalar_t>
179+
__global__ void indexing_backward_kernel_stride_1(
180+
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
181+
int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
182+
using opmath_t = at::opmath_type<scalar_t>;
183+
184+
int laneIdx = threadIdx.x % C10_WARP_SIZE;
185+
186+
const opmath_t scale = (opmath_t)1.0;
187+
int64_t grad_row = 0;
188+
189+
extern __shared__ unsigned char smem[];
190+
auto smem_dups_cache = reinterpret_cast<int64_t*>(smem);
191+
192+
// Each warp gets a different section of the share memory allocation:
193+
int smem_offset = threadIdx.y * C10_WARP_SIZE;
194+
195+
// Number of values processed by each thread (grain size)
196+
for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z) {
197+
// Init duplicates every time we compute a new set of entries:
198+
smem_dups_cache[smem_offset + laneIdx] = 0;
199+
200+
int64_t base_idx = blockIdx.x * blockDim.y * C10_WARP_SIZE + threadIdx.y * C10_WARP_SIZE;
201+
int64_t idx = base_idx + laneIdx;
202+
203+
// Each lane calculates the number of duplicates:
204+
if (idx < numel) {
205+
int64_t crnt_sorted_idx = sorted_indices[idx];
206+
207+
if (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1]) {
208+
// Determine the number of duplicates in advance:
209+
int64_t num_duplicates = 1;
210+
211+
// Lookahead in case there is a large number of duplicates. Once that is done, handle the tail.
212+
while ((idx + num_duplicates + SKIP - 1) < numel) {
213+
if (sorted_indices[idx + num_duplicates + SKIP - 1] != crnt_sorted_idx) break;
214+
num_duplicates += SKIP;
215+
}
216+
while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) {
217+
num_duplicates++;
218+
}
219+
220+
if (!accumulate) {
221+
const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before;
222+
grad_row = ((int64_t)indices[idx + num_duplicates - 1]) * stride + z * numel * stride;
223+
grad_weight[weight_row] =
224+
static_cast<scalar_t>(static_cast<opmath_t>(grad_output[grad_row]) * scale);
225+
continue;
226+
}
227+
228+
// Each lane sequentially handles the duplicate elimination:
229+
if (num_duplicates < C10_WARP_SIZE) {
230+
opmath_t gradient = (opmath_t)0.0;
231+
const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before;
232+
for (int64_t i = 0; i < num_duplicates; ++i) {
233+
grad_row = ((int64_t) indices[idx + i]) * stride + z * numel * stride;
234+
gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
235+
}
236+
237+
grad_weight[weight_row] = static_cast<scalar_t>(static_cast<opmath_t>(grad_weight[weight_row]) + gradient);
238+
} else {
239+
// Add duplicate to the cache:
240+
smem_dups_cache[smem_offset + laneIdx] = num_duplicates;
241+
}
242+
}
243+
}
244+
245+
WARP_SYNC();
246+
247+
// All lanes in the warp are still active here. Use them all to reduce duplicates when
248+
// large number of duplicates are present:
249+
for (int subwarp = 0; subwarp < C10_WARP_SIZE; subwarp++) {
250+
// All lanes read the shared memory entry for number of duplicates
251+
int64_t new_num_duplicates = smem_dups_cache[smem_offset + subwarp];
252+
253+
// Check if the original sub-warp had duplicates to eliminate, if not skip.
254+
if (new_num_duplicates == 0)
255+
continue;
256+
257+
// There are duplicates that need eliminating:
258+
int64_t new_idx = base_idx + subwarp;
259+
int64_t new_crnt_sorted_idx = sorted_indices[new_idx];
260+
const int64_t new_weight_row = new_crnt_sorted_idx * stride + z * stride_before;
261+
262+
// Result of the reduction will be in this variable:
263+
opmath_t gradient = (opmath_t)0.0;
264+
265+
int64_t num_warp_passes = new_num_duplicates / C10_WARP_SIZE;
266+
// Parallel reduction across the array of duplicates using all the lanes in the warp:
267+
for (int64_t i = 0; i < num_warp_passes; ++i) {
268+
grad_row = ((int64_t) indices[new_idx + i * C10_WARP_SIZE + laneIdx]) * stride + z * numel * stride;
269+
gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
270+
}
271+
272+
// Reduce across the lanes of the warp:
273+
WARP_SYNC();
274+
for (int offset = C10_WARP_SIZE / 2; offset > 0; offset /= 2) {
275+
gradient += WARP_SHFL_DOWN(gradient, offset);
276+
}
277+
278+
if (laneIdx == 0) {
279+
for (int64_t i = num_warp_passes * C10_WARP_SIZE; i < new_num_duplicates; ++i) {
280+
grad_row = ((int64_t) indices[new_idx + i]) * stride + z * numel * stride;
281+
gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
282+
}
283+
284+
grad_weight[new_weight_row] = static_cast<scalar_t>(static_cast<opmath_t>(grad_weight[new_weight_row]) + gradient);
285+
}
286+
}
287+
}
288+
}
289+
#else
176290
template <typename scalar_t>
177291
__global__ void indexing_backward_kernel_stride_1(
178292
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@@ -228,6 +342,7 @@ __global__ void indexing_backward_kernel_stride_1(
228342
}
229343
}
230344
}
345+
#endif
231346

232347
template <typename scalar_t>
233348
__global__ void indexing_backward_kernel_small_stride(
@@ -558,13 +673,24 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
558673

559674

560675
if (sliceSize == 1) {
676+
#ifdef USE_ROCM
677+
// Adapt grid size to smaller virtual warp size:
678+
dim3 new_grid(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)), grid.y, grid.z);
679+
size_t smem_dups_size = indices_per_block * warp_size * sizeof(int64_t);
680+
#define KERNEL_GRID new_grid
681+
#define KERNEL_SMEM smem_dups_size
682+
#else
683+
#define KERNEL_GRID grid
684+
#define KERNEL_SMEM 0
685+
#endif
561686
// This implementation is faster with high amounts of duplicates but could overflow
562687
// if FP16 / BF16 is used
563688
AT_DISPATCH_V2(
564689
expandedValue.scalar_type(),
565690
"indexing_backward_kernel_stride_1",
566691
AT_WRAP([&] {
567-
indexing_backward_kernel_stride_1<scalar_t><<<grid, block, 0, stream>>>(
692+
indexing_backward_kernel_stride_1<scalar_t><<<KERNEL_GRID, block, KERNEL_SMEM, stream>>>
693+
(
568694
sorted_indices.const_data_ptr<int64_t>(),
569695
orig_indices.const_data_ptr<int64_t>(),
570696
expandedValue.const_data_ptr<scalar_t>(),
@@ -582,6 +708,8 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
582708
kHalf,
583709
kBool,
584710
kBFloat16);
711+
#undef KERNEL_GRID
712+
#undef KERNEL_SMEM
585713
} else {
586714
if (sliceSize <= warp_size) {
587715
AT_DISPATCH_V2(

0 commit comments

Comments
 (0)