@@ -271,6 +271,38 @@ TEST_F(GpuPrimHelpersTest, GpuRadixSort_WithNumBits) {
271
271
test::ExpectTensorEqual<int32>(expected_indices_out, *GetOutput (1 ));
272
272
}
273
273
274
+ TEST_F (GpuPrimHelpersTest, GpuRadixSort_WithNumBitsZero) {
275
+ // Check that num_bits=0 is handled correctly.
276
+ MakeRadixSort (DT_INT32, DT_INT32, /* need_keys_out=*/ true , /* num_bits=*/ 0 );
277
+ AddInputFromArray<int32>(TensorShape ({8 }), {4 , 2 , 6 , 7 , 1 , 3 , 0 , 5 }); // keys
278
+ AddInputFromArray<int32>(TensorShape ({0 }), {}); // inds
279
+ TF_ASSERT_OK (RunOpKernel ());
280
+
281
+ Tensor expected_keys_out (allocator (), DT_INT32, TensorShape ({8 }));
282
+ test::FillValues<int32>(&expected_keys_out, {4 , 2 , 6 , 7 , 1 , 3 , 0 , 5 });
283
+ test::ExpectTensorEqual<int32>(expected_keys_out, *GetOutput (0 ));
284
+
285
+ Tensor expected_indices_out (allocator (), DT_INT32, TensorShape ({8 }));
286
+ test::FillValues<int32>(&expected_indices_out, {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 });
287
+ test::ExpectTensorEqual<int32>(expected_indices_out, *GetOutput (1 ));
288
+ }
289
+
290
+ TEST_F (GpuPrimHelpersTest, GpuRadixSort_KeysAndIndices_WithNumBitsZero) {
291
+ // Check that num_bits=0 is handled correctly (with indices_in).
292
+ MakeRadixSort (DT_INT32, DT_INT32, /* need_keys_out=*/ true , /* num_bits=*/ 0 );
293
+ AddInputFromArray<int32>(TensorShape ({8 }), {4 , 2 , 6 , 7 , 1 , 3 , 0 , 5 }); // keys
294
+ AddInputFromArray<int32>(TensorShape ({8 }), {7 , 6 , 5 , 4 , 3 , 2 , 1 , 0 }); // inds
295
+ TF_ASSERT_OK (RunOpKernel ());
296
+
297
+ Tensor expected_keys_out (allocator (), DT_INT32, TensorShape ({8 }));
298
+ test::FillValues<int32>(&expected_keys_out, {4 , 2 , 6 , 7 , 1 , 3 , 0 , 5 });
299
+ test::ExpectTensorEqual<int32>(expected_keys_out, *GetOutput (0 ));
300
+
301
+ Tensor expected_indices_out (allocator (), DT_INT32, TensorShape ({8 }));
302
+ test::FillValues<int32>(&expected_indices_out, {7 , 6 , 5 , 4 , 3 , 2 , 1 , 0 });
303
+ test::ExpectTensorEqual<int32>(expected_indices_out, *GetOutput (1 ));
304
+ }
305
+
274
306
TEST_F (GpuPrimHelpersTest, GpuInclusivePrefixSum) {
275
307
MakeInclusivePrefixSum (DT_INT32);
276
308
AddInputFromArray<int32>(TensorShape ({8 }), {4 , 2 , 6 , 7 , 1 , 3 , 0 , 5 });
0 commit comments