Skip to content

Commit d9873fb

Browse files
committed
Re-enable and optimize SparseSegmentMeanGrad GPU
- Re-enables this kernel now that the CUB issue has been worked around. - Optimizes the kernel to skip the sort call when output_dim0 = 1. - Adds a test case for when output_dim0 = 1.
1 parent 2cf9810 commit d9873fb

File tree

3 files changed

+46
-19
lines changed

3 files changed

+46
-19
lines changed

tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h

+25-17
Original file line numberDiff line numberDiff line change
@@ -790,23 +790,31 @@ struct SparseSegmentGradFunctor<GPUDevice, T, Index, SegmentId> {
790790
segment_offsets_ptr, weights_ptr));
791791
}
792792

793-
// Sort indices and permute segments.
794-
Tensor sorted_indices;
795-
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<Index>::value,
796-
TensorShape({nouter}),
797-
&sorted_indices));
798-
Index* sorted_indices_ptr = sorted_indices.flat<Index>().data();
799-
Tensor sorted_segment;
800-
OP_REQUIRES_OK(context, context->allocate_temp(
801-
DataTypeToEnum<SegmentId>::value,
802-
TensorShape({nouter}), &sorted_segment));
803-
SegmentId* sorted_segment_ptr = sorted_segment.flat<SegmentId>().data();
804-
OP_REQUIRES_OK(context, GpuRadixSort(context, nouter,
805-
/*keys_in=*/indices_vec.data(),
806-
/*keys_out=*/sorted_indices_ptr,
807-
/*indices_in=*/segment_vec.data(),
808-
/*indices_out=*/sorted_segment_ptr,
809-
/*num_bits=*/Log2Ceiling64(noutput)));
793+
const Index* sorted_indices_ptr = indices_vec.data();
794+
const SegmentId* sorted_segment_ptr = segment_vec.data();
795+
Tensor tmp_sorted_indices;
796+
Tensor tmp_sorted_segment;
797+
if (noutput > 1) {
798+
// Sort indices and permute segments.
799+
OP_REQUIRES_OK(context, context->allocate_temp(
800+
DataTypeToEnum<Index>::value,
801+
TensorShape({nouter}), &tmp_sorted_indices));
802+
Index* tmp_sorted_indices_ptr = tmp_sorted_indices.flat<Index>().data();
803+
OP_REQUIRES_OK(context, context->allocate_temp(
804+
DataTypeToEnum<SegmentId>::value,
805+
TensorShape({nouter}), &tmp_sorted_segment));
806+
SegmentId* tmp_sorted_segment_ptr =
807+
tmp_sorted_segment.flat<SegmentId>().data();
808+
OP_REQUIRES_OK(context,
809+
GpuRadixSort(context, nouter,
810+
/*keys_in=*/indices_vec.data(),
811+
/*keys_out=*/tmp_sorted_indices_ptr,
812+
/*indices_in=*/segment_vec.data(),
813+
/*indices_out=*/tmp_sorted_segment_ptr,
814+
/*num_bits=*/Log2Ceiling64(noutput)));
815+
sorted_indices_ptr = tmp_sorted_indices_ptr;
816+
sorted_segment_ptr = tmp_sorted_segment_ptr;
817+
}
810818

811819
// Compute the gradient using a weighted SegmentReduceGPU with the segment
812820
// IDs and indices swapped.

tensorflow/core/kernels/segment_reduction_ops_impl_5.cc

-2
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@ TF_CALL_FLOAT_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
217217
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
218218
#undef REGISTER_GPU_SPARSE_KERNELS
219219

220-
#if 0 // TODO(b/192086735): Enable once bug is fixed.
221220
#define REGISTER_GPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
222221
REGISTER_KERNEL_BUILDER( \
223222
Name("SparseSegmentMeanGrad") \
@@ -229,7 +228,6 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
229228
SparseSegmentMeanGradOp<GPUDevice, type, index_type, segment_ids_type>);
230229
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
231230
#undef REGISTER_GPU_SPARSE_KERNELS
232-
#endif
233231

234232
#define REGISTER_GPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
235233
REGISTER_KERNEL_BUILDER( \

tensorflow/python/kernel_tests/segment_reduction_ops_test.py

+21
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,27 @@ def testGradientExplicit(self):
902902
tf_xgrad = tf_op(tf_ygrad, indices, segment_ids, output_dim0)
903903
self.assertAllClose(tf_xgrad, np_xgrad)
904904

905+
def testGradientExplicitSingleOutput(self):
906+
# The GPU implem has a special case when there is a single output.
907+
for inner_size in (1, 2, 3, 32):
908+
with self.session():
909+
tf_ygrad, np_ygrad = self._input([3, inner_size],
910+
dtype=dtypes_lib.float32)
911+
segment_ids = [0, 1, 2, 2, 2]
912+
indices = [0, 0, 0, 0, 0]
913+
output_dim0 = 1
914+
ops_list = [
915+
(math_ops.sparse_segment_sum_grad, "sum"),
916+
(math_ops.sparse_segment_mean_grad, "mean"),
917+
(math_ops.sparse_segment_sqrt_n_grad, "sqrtn"),
918+
]
919+
for tf_op, mode in ops_list:
920+
np_xgrad = self._sparseSegmentReduceGrad(np_ygrad, indices,
921+
segment_ids, output_dim0,
922+
mode)
923+
tf_xgrad = tf_op(tf_ygrad, indices, segment_ids, output_dim0)
924+
self.assertAllClose(tf_xgrad, np_xgrad)
925+
905926
def testGradientValid(self):
906927
# Baseline for the testGradient*Invalid* methods below.
907928
tf_x, _ = self._input([3, 4], dtype=dtypes_lib.float32)

0 commit comments

Comments
 (0)