@@ -790,23 +790,31 @@ struct SparseSegmentGradFunctor<GPUDevice, T, Index, SegmentId> {
790
790
segment_offsets_ptr, weights_ptr));
791
791
}
792
792
793
- // Sort indices and permute segments.
794
- Tensor sorted_indices;
795
- OP_REQUIRES_OK (context, context->allocate_temp (DataTypeToEnum<Index>::value,
796
- TensorShape ({nouter}),
797
- &sorted_indices));
798
- Index* sorted_indices_ptr = sorted_indices.flat <Index>().data ();
799
- Tensor sorted_segment;
800
- OP_REQUIRES_OK (context, context->allocate_temp (
801
- DataTypeToEnum<SegmentId>::value,
802
- TensorShape ({nouter}), &sorted_segment));
803
- SegmentId* sorted_segment_ptr = sorted_segment.flat <SegmentId>().data ();
804
- OP_REQUIRES_OK (context, GpuRadixSort (context, nouter,
805
- /* keys_in=*/ indices_vec.data (),
806
- /* keys_out=*/ sorted_indices_ptr,
807
- /* indices_in=*/ segment_vec.data (),
808
- /* indices_out=*/ sorted_segment_ptr,
809
- /* num_bits=*/ Log2Ceiling64 (noutput)));
793
+ const Index* sorted_indices_ptr = indices_vec.data ();
794
+ const SegmentId* sorted_segment_ptr = segment_vec.data ();
795
+ Tensor tmp_sorted_indices;
796
+ Tensor tmp_sorted_segment;
797
+ if (noutput > 1 ) {
798
+ // Sort indices and permute segments.
799
+ OP_REQUIRES_OK (context, context->allocate_temp (
800
+ DataTypeToEnum<Index>::value,
801
+ TensorShape ({nouter}), &tmp_sorted_indices));
802
+ Index* tmp_sorted_indices_ptr = tmp_sorted_indices.flat <Index>().data ();
803
+ OP_REQUIRES_OK (context, context->allocate_temp (
804
+ DataTypeToEnum<SegmentId>::value,
805
+ TensorShape ({nouter}), &tmp_sorted_segment));
806
+ SegmentId* tmp_sorted_segment_ptr =
807
+ tmp_sorted_segment.flat <SegmentId>().data ();
808
+ OP_REQUIRES_OK (context,
809
+ GpuRadixSort (context, nouter,
810
+ /* keys_in=*/ indices_vec.data (),
811
+ /* keys_out=*/ tmp_sorted_indices_ptr,
812
+ /* indices_in=*/ segment_vec.data (),
813
+ /* indices_out=*/ tmp_sorted_segment_ptr,
814
+ /* num_bits=*/ Log2Ceiling64 (noutput)));
815
+ sorted_indices_ptr = tmp_sorted_indices_ptr;
816
+ sorted_segment_ptr = tmp_sorted_segment_ptr;
817
+ }
810
818
811
819
// Compute the gradient using a weighted SegmentReduceGPU with the segment
812
820
// IDs and indices swapped.
0 commit comments