Skip to content

Commit 8cca9d1

Browse files
Merge pull request tensorflow#51094 from benbarsdell:gpu-war-radix-sort-sparse-segment-reduce-grad
PiperOrigin-RevId: 390177302 Change-Id: Idd8e881b54b6a3693de0471df9cee106498162b0
2 parents b68fe60 + e116bc9 commit 8cca9d1

5 files changed

+109
-19
lines changed

tensorflow/core/kernels/gpu_prim_helpers.h

+31
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
#include "tensorflow/core/kernels/gpu_prim.h"
2525
#include "tensorflow/core/lib/core/status.h"
2626
#include "tensorflow/core/util/gpu_kernel_helper.h"
27+
#include "tensorflow/stream_executor/stream.h"
2728

2829
namespace tensorflow {
2930

@@ -57,6 +58,36 @@ Status GpuRadixSort(OpKernelContext* context, int size, const Tkey* keys_in,
5758
const Tindex* indices_in, // Optional
5859
Tindex* indices_out, int num_bits = sizeof(Tkey) * 8) {
5960
if (size == 0) return Status::OK();
61+
if (num_bits == 0) {
62+
// Workaround for CUB failing when begin_bit = end_bit = 0 (e.g., when all
63+
// keys are 0, so no sorting is needed).
64+
se::Stream* stream = context->op_device_context()->stream();
65+
if (keys_out) {
66+
// Copy keys_in to keys_out.
67+
size_t num_bytes = size * sizeof(Tkey);
68+
se::DeviceMemoryBase src(const_cast<Tkey*>(keys_in), num_bytes);
69+
se::DeviceMemoryBase dst(keys_out, num_bytes);
70+
if (!stream->ThenMemcpy(&dst, src, num_bytes).ok()) {
71+
return errors::Internal("Failed to copy keys_in to keys_out");
72+
}
73+
}
74+
if (indices_in) {
75+
// Copy indices_in to indices_out.
76+
size_t num_bytes = size * sizeof(Tindex);
77+
se::DeviceMemoryBase src(const_cast<Tindex*>(indices_in), num_bytes);
78+
se::DeviceMemoryBase dst(indices_out, num_bytes);
79+
if (!stream->ThenMemcpy(&dst, src, num_bytes).ok()) {
80+
return errors::Internal("Failed to copy indices_in to indices_out");
81+
}
82+
} else {
83+
// Set output indices to range.
84+
const Eigen::GpuDevice& device =
85+
context->eigen_device<Eigen::GpuDevice>();
86+
TF_RETURN_IF_ERROR(detail::RangeInit(device, Tindex(0), Tindex(1),
87+
Tindex(size), indices_out));
88+
}
89+
return Status::OK();
90+
}
6091
// Allocate temporary inputs/outputs if necessary.
6192
Tensor tmp_indices_in;
6293
if (!indices_in) {

tensorflow/core/kernels/gpu_prim_helpers_test.cu.cc

+32
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,38 @@ TEST_F(GpuPrimHelpersTest, GpuRadixSort_WithNumBits) {
271271
test::ExpectTensorEqual<int32>(expected_indices_out, *GetOutput(1));
272272
}
273273

274+
TEST_F(GpuPrimHelpersTest, GpuRadixSort_WithNumBitsZero) {
275+
// Check that num_bits=0 is handled correctly.
276+
MakeRadixSort(DT_INT32, DT_INT32, /*need_keys_out=*/true, /*num_bits=*/0);
277+
AddInputFromArray<int32>(TensorShape({8}), {4, 2, 6, 7, 1, 3, 0, 5}); // keys
278+
AddInputFromArray<int32>(TensorShape({0}), {}); // inds
279+
TF_ASSERT_OK(RunOpKernel());
280+
281+
Tensor expected_keys_out(allocator(), DT_INT32, TensorShape({8}));
282+
test::FillValues<int32>(&expected_keys_out, {4, 2, 6, 7, 1, 3, 0, 5});
283+
test::ExpectTensorEqual<int32>(expected_keys_out, *GetOutput(0));
284+
285+
Tensor expected_indices_out(allocator(), DT_INT32, TensorShape({8}));
286+
test::FillValues<int32>(&expected_indices_out, {0, 1, 2, 3, 4, 5, 6, 7});
287+
test::ExpectTensorEqual<int32>(expected_indices_out, *GetOutput(1));
288+
}
289+
290+
TEST_F(GpuPrimHelpersTest, GpuRadixSort_KeysAndIndices_WithNumBitsZero) {
291+
// Check that num_bits=0 is handled correctly (with indices_in).
292+
MakeRadixSort(DT_INT32, DT_INT32, /*need_keys_out=*/true, /*num_bits=*/0);
293+
AddInputFromArray<int32>(TensorShape({8}), {4, 2, 6, 7, 1, 3, 0, 5}); // keys
294+
AddInputFromArray<int32>(TensorShape({8}), {7, 6, 5, 4, 3, 2, 1, 0}); // inds
295+
TF_ASSERT_OK(RunOpKernel());
296+
297+
Tensor expected_keys_out(allocator(), DT_INT32, TensorShape({8}));
298+
test::FillValues<int32>(&expected_keys_out, {4, 2, 6, 7, 1, 3, 0, 5});
299+
test::ExpectTensorEqual<int32>(expected_keys_out, *GetOutput(0));
300+
301+
Tensor expected_indices_out(allocator(), DT_INT32, TensorShape({8}));
302+
test::FillValues<int32>(&expected_indices_out, {7, 6, 5, 4, 3, 2, 1, 0});
303+
test::ExpectTensorEqual<int32>(expected_indices_out, *GetOutput(1));
304+
}
305+
274306
TEST_F(GpuPrimHelpersTest, GpuInclusivePrefixSum) {
275307
MakeInclusivePrefixSum(DT_INT32);
276308
AddInputFromArray<int32>(TensorShape({8}), {4, 2, 6, 7, 1, 3, 0, 5});

tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h

+25-17
Original file line numberDiff line numberDiff line change
@@ -790,23 +790,31 @@ struct SparseSegmentGradFunctor<GPUDevice, T, Index, SegmentId> {
790790
segment_offsets_ptr, weights_ptr));
791791
}
792792

793-
// Sort indices and permute segments.
794-
Tensor sorted_indices;
795-
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<Index>::value,
796-
TensorShape({nouter}),
797-
&sorted_indices));
798-
Index* sorted_indices_ptr = sorted_indices.flat<Index>().data();
799-
Tensor sorted_segment;
800-
OP_REQUIRES_OK(context, context->allocate_temp(
801-
DataTypeToEnum<SegmentId>::value,
802-
TensorShape({nouter}), &sorted_segment));
803-
SegmentId* sorted_segment_ptr = sorted_segment.flat<SegmentId>().data();
804-
OP_REQUIRES_OK(context, GpuRadixSort(context, nouter,
805-
/*keys_in=*/indices_vec.data(),
806-
/*keys_out=*/sorted_indices_ptr,
807-
/*indices_in=*/segment_vec.data(),
808-
/*indices_out=*/sorted_segment_ptr,
809-
/*num_bits=*/Log2Ceiling64(noutput)));
793+
const Index* sorted_indices_ptr = indices_vec.data();
794+
const SegmentId* sorted_segment_ptr = segment_vec.data();
795+
Tensor tmp_sorted_indices;
796+
Tensor tmp_sorted_segment;
797+
if (noutput > 1) {
798+
// Sort indices and permute segments.
799+
OP_REQUIRES_OK(context, context->allocate_temp(
800+
DataTypeToEnum<Index>::value,
801+
TensorShape({nouter}), &tmp_sorted_indices));
802+
Index* tmp_sorted_indices_ptr = tmp_sorted_indices.flat<Index>().data();
803+
OP_REQUIRES_OK(context, context->allocate_temp(
804+
DataTypeToEnum<SegmentId>::value,
805+
TensorShape({nouter}), &tmp_sorted_segment));
806+
SegmentId* tmp_sorted_segment_ptr =
807+
tmp_sorted_segment.flat<SegmentId>().data();
808+
OP_REQUIRES_OK(context,
809+
GpuRadixSort(context, nouter,
810+
/*keys_in=*/indices_vec.data(),
811+
/*keys_out=*/tmp_sorted_indices_ptr,
812+
/*indices_in=*/segment_vec.data(),
813+
/*indices_out=*/tmp_sorted_segment_ptr,
814+
/*num_bits=*/Log2Ceiling64(noutput)));
815+
sorted_indices_ptr = tmp_sorted_indices_ptr;
816+
sorted_segment_ptr = tmp_sorted_segment_ptr;
817+
}
810818

811819
// Compute the gradient using a weighted SegmentReduceGPU with the segment
812820
// IDs and indices swapped.

tensorflow/core/kernels/segment_reduction_ops_impl_5.cc

-2
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@ TF_CALL_FLOAT_TYPES(REGISTER_CPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
217217
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
218218
#undef REGISTER_GPU_SPARSE_KERNELS
219219

220-
#if 0 // TODO(b/192086735): Enable once bug is fixed.
221220
#define REGISTER_GPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
222221
REGISTER_KERNEL_BUILDER( \
223222
Name("SparseSegmentMeanGrad") \
@@ -229,7 +228,6 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
229228
SparseSegmentMeanGradOp<GPUDevice, type, index_type, segment_ids_type>);
230229
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SPARSE_KERNELS_FOR_EACH_INDEX_TYPE);
231230
#undef REGISTER_GPU_SPARSE_KERNELS
232-
#endif
233231

234232
#define REGISTER_GPU_SPARSE_KERNELS(type, index_type, segment_ids_type) \
235233
REGISTER_KERNEL_BUILDER( \

tensorflow/python/kernel_tests/segment_reduction_ops_test.py

+21
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,27 @@ def testGradientExplicit(self):
902902
tf_xgrad = tf_op(tf_ygrad, indices, segment_ids, output_dim0)
903903
self.assertAllClose(tf_xgrad, np_xgrad)
904904

905+
def testGradientExplicitSingleOutput(self):
906+
# The GPU implem has a special case when there is a single output.
907+
for inner_size in (1, 2, 3, 32):
908+
with self.session():
909+
tf_ygrad, np_ygrad = self._input([3, inner_size],
910+
dtype=dtypes_lib.float32)
911+
segment_ids = [0, 1, 2, 2, 2]
912+
indices = [0, 0, 0, 0, 0]
913+
output_dim0 = 1
914+
ops_list = [
915+
(math_ops.sparse_segment_sum_grad, "sum"),
916+
(math_ops.sparse_segment_mean_grad, "mean"),
917+
(math_ops.sparse_segment_sqrt_n_grad, "sqrtn"),
918+
]
919+
for tf_op, mode in ops_list:
920+
np_xgrad = self._sparseSegmentReduceGrad(np_ygrad, indices,
921+
segment_ids, output_dim0,
922+
mode)
923+
tf_xgrad = tf_op(tf_ygrad, indices, segment_ids, output_dim0)
924+
self.assertAllClose(tf_xgrad, np_xgrad)
925+
905926
def testGradientValid(self):
906927
# Baseline for the testGradient*Invalid* methods below.
907928
tf_x, _ = self._input([3, 4], dtype=dtypes_lib.float32)

0 commit comments

Comments
 (0)