@@ -351,13 +351,23 @@ def gen_indices(input_shape, indices_shape, accumulate):
351351def index_put_input_fn (accumulate ):
352352 def inner (shapes , dtype , device ):
353353 input_shape , indices_shape , values_shape = shapes
354- inp = torch .randn (
355- input_shape , dtype = dtype , device = flag_gems .device , requires_grad = False
356- )
354+ if dtype == torch .bool :
355+ inp = torch .randint (
356+ 0 , 2 , input_shape , device = flag_gems .device , requires_grad = False
357+ ).to (dtype )
358+ values = torch .randint (
359+ 0 , 2 , values_shape , device = flag_gems .device , requires_grad = False
360+ ).to (dtype )
361+ else :
362+ inp = torch .randn (
363+ input_shape , dtype = dtype , device = flag_gems .device , requires_grad = False
364+ )
365+ values = torch .randn (
366+ values_shape , dtype = dtype , device = flag_gems .device , requires_grad = False
367+ )
357368 indices = gen_indices (input_shape , indices_shape , accumulate )
358- values = torch .randn (
359- values_shape , dtype = dtype , device = flag_gems .device , requires_grad = False
360- )
369+ if dtype == torch .bool and accumulate :
370+ return
361371 yield inp , indices , values , accumulate
362372
363373 return inner
@@ -422,6 +432,31 @@ def test_index_put__acc_false_perf():
422432 bench .run ()
423433
424434
435+ BOOL_DTYPES = [torch .bool ]
436+
437+
438+ @pytest .mark .index_put
439+ def test_index_put_acc_false_bool_perf ():
440+ bench = IndexPutAccFalseBenchmark (
441+ op_name = "index_put" ,
442+ torch_op = torch .index_put ,
443+ input_fn = index_put_input_fn (False ),
444+ dtypes = BOOL_DTYPES ,
445+ )
446+ bench .run ()
447+
448+
449+ @pytest .mark .index_put_
450+ def test_index_put__acc_false_bool_perf ():
451+ bench = IndexPutAccFalseBenchmark (
452+ op_name = "index_put_" ,
453+ torch_op = torch .index_put_ ,
454+ input_fn = index_put_input_fn (False ),
455+ dtypes = BOOL_DTYPES ,
456+ )
457+ bench .run ()
458+
459+
425460class IndexPutAccTrueBenchmark (GenericBenchmark ):
426461 def set_more_shapes (self ):
427462 INDEX_PUT_SHAPE = (
0 commit comments