Skip to content

Commit a08ca6a

Browse files
committed
Add bool indices cases to benchmark
1 parent bb9cd23 commit a08ca6a

File tree

1 file changed

+41
-6
lines changed

1 file changed

+41
-6
lines changed

benchmark/test_select_and_slice_perf.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -351,13 +351,23 @@ def gen_indices(input_shape, indices_shape, accumulate):
351351
def index_put_input_fn(accumulate):
352352
def inner(shapes, dtype, device):
353353
input_shape, indices_shape, values_shape = shapes
354-
inp = torch.randn(
355-
input_shape, dtype=dtype, device=flag_gems.device, requires_grad=False
356-
)
354+
if dtype == torch.bool:
355+
inp = torch.randint(
356+
0, 2, input_shape, device=flag_gems.device, requires_grad=False
357+
).to(dtype)
358+
values = torch.randint(
359+
0, 2, values_shape, device=flag_gems.device, requires_grad=False
360+
).to(dtype)
361+
else:
362+
inp = torch.randn(
363+
input_shape, dtype=dtype, device=flag_gems.device, requires_grad=False
364+
)
365+
values = torch.randn(
366+
values_shape, dtype=dtype, device=flag_gems.device, requires_grad=False
367+
)
357368
indices = gen_indices(input_shape, indices_shape, accumulate)
358-
values = torch.randn(
359-
values_shape, dtype=dtype, device=flag_gems.device, requires_grad=False
360-
)
369+
if dtype == torch.bool and accumulate:
370+
return
361371
yield inp, indices, values, accumulate
362372

363373
return inner
@@ -422,6 +432,31 @@ def test_index_put__acc_false_perf():
422432
bench.run()
423433

424434

435+
BOOL_DTYPES = [torch.bool]
436+
437+
438+
@pytest.mark.index_put
439+
def test_index_put_acc_false_bool_perf():
440+
bench = IndexPutAccFalseBenchmark(
441+
op_name="index_put",
442+
torch_op=torch.index_put,
443+
input_fn=index_put_input_fn(False),
444+
dtypes=BOOL_DTYPES,
445+
)
446+
bench.run()
447+
448+
449+
@pytest.mark.index_put_
450+
def test_index_put__acc_false_bool_perf():
451+
bench = IndexPutAccFalseBenchmark(
452+
op_name="index_put_",
453+
torch_op=torch.index_put_,
454+
input_fn=index_put_input_fn(False),
455+
dtypes=BOOL_DTYPES,
456+
)
457+
bench.run()
458+
459+
425460
class IndexPutAccTrueBenchmark(GenericBenchmark):
426461
def set_more_shapes(self):
427462
INDEX_PUT_SHAPE = (

0 commit comments

Comments
 (0)