@@ -287,6 +287,22 @@ TEST_F(GpuPrimHelpersTest, GpuRadixSort_WithNumBitsZero) {
287
287
test::ExpectTensorEqual<int32>(expected_indices_out, *GetOutput (1 ));
288
288
}
289
289
290
+ TEST_F (GpuPrimHelpersTest, GpuRadixSort_KeysAndIndices_WithNumBitsZero) {
291
+ // Check that num_bits=0 is handled correctly (with indices_in).
292
+ MakeRadixSort (DT_INT32, DT_INT32, /* need_keys_out=*/ true , /* num_bits=*/ 0 );
293
+ AddInputFromArray<int32>(TensorShape ({8 }), {4 , 2 , 6 , 7 , 1 , 3 , 0 , 5 }); // keys
294
+ AddInputFromArray<int32>(TensorShape ({8 }), {7 , 6 , 5 , 4 , 3 , 2 , 1 , 0 }); // inds
295
+ TF_ASSERT_OK (RunOpKernel ());
296
+
297
+ Tensor expected_keys_out (allocator (), DT_INT32, TensorShape ({8 }));
298
+ test::FillValues<int32>(&expected_keys_out, {4 , 2 , 6 , 7 , 1 , 3 , 0 , 5 });
299
+ test::ExpectTensorEqual<int32>(expected_keys_out, *GetOutput (0 ));
300
+
301
+ Tensor expected_indices_out (allocator (), DT_INT32, TensorShape ({8 }));
302
+ test::FillValues<int32>(&expected_indices_out, {7 , 6 , 5 , 4 , 3 , 2 , 1 , 0 });
303
+ test::ExpectTensorEqual<int32>(expected_indices_out, *GetOutput (1 ));
304
+ }
305
+
290
306
TEST_F (GpuPrimHelpersTest, GpuInclusivePrefixSum) {
291
307
MakeInclusivePrefixSum (DT_INT32);
292
308
AddInputFromArray<int32>(TensorShape ({8 }), {4 , 2 , 6 , 7 , 1 , 3 , 0 , 5 });
0 commit comments