Skip to content

Commit 5e655c8

Browse files
committed
Add tests for bool type
1 parent 8a72bed commit 5e655c8

File tree

1 file changed

+79
-42
lines changed

1 file changed

+79
-42
lines changed

tests/test_reduction_ops.py

Lines changed: 79 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,21 +1486,25 @@ def test_accuracy_depthwise2d(
14861486

14871487

14881488
INDEX_PUT_SHAPE_ACC_FALSE = (
1489-
((2**28,), ((2**16,),), (2**16,)),
1490-
((32, 32), ((8,), (8,)), (8,)),
1491-
((32, 32), ((8,), (2, 8)), (8,)),
1492-
((32, 32), ((2, 8),), (32,)),
1493-
((512, 512, 512), ((128,), (128,), (128,)), (128,)),
1494-
((512, 512, 512), ((2, 128), (128,), (128,)), (128,)),
1495-
((512, 512, 512), ((2, 128),), (512,)),
1489+
((2**28,), ((2**16,),), (2**16,), False),
1490+
((32, 32), ((8,), (8,)), (8,), False),
1491+
((32, 32), ((8,), (2, 8)), (8,), False),
1492+
((32, 32), ((2, 8),), (32,), False),
1493+
((512, 512, 512), ((128,), (128,), (128,)), (128,), False),
1494+
((512, 512, 512), ((2, 128), (128,), (128,)), (128,), False),
1495+
((512, 512, 512), ((2, 128),), (512,), False),
14961496
(
14971497
(64, 64, 64),
14981498
(
14991499
(2, 8),
15001500
(2, 8),
15011501
),
15021502
(2, 8, 64),
1503+
False,
15031504
),
1505+
((100,), ((100,),), (100,), True),
1506+
((32, 32), ((32, 32),), (32, 32), True),
1507+
((16, 16, 4), ((16, 16, 4),), (16, 16, 4), True),
15041508
)
15051509

15061510
INDEX_ACC_SHAPE = (
@@ -1521,30 +1525,44 @@ def test_accuracy_depthwise2d(
15211525
)
15221526

15231527

1524-
def gen_indices(input_shape, indices_shape, accumulate):
1528+
def gen_indices(input_shape, indices_shape, accumulate, is_bool):
15251529
indices = []
1526-
for i, shape in enumerate(indices_shape):
1527-
index = np.random.choice(
1528-
np.arange(input_shape[i]), size=shape, replace=accumulate
1529-
)
1530-
indices.append(torch.tensor(index, device=flag_gems.device))
1531-
return indices
1530+
1531+
if is_bool:
1532+
mask_shape = indices_shape[0]
1533+
1534+
mask = torch.randint(0, 2, size=mask_shape, dtype=torch.bool, device=flag_gems.device)
1535+
return [mask]
1536+
1537+
else:
1538+
for i, shape in enumerate(indices_shape):
1539+
index = np.random.choice(
1540+
np.arange(input_shape[i]), size=shape, replace=accumulate
1541+
)
1542+
indices.append(torch.tensor(index, device=flag_gems.device))
1543+
return indices
15321544

15331545

15341546
@pytest.mark.index_put
15351547
@pytest.mark.parametrize(
1536-
"input_shape, indices_shape, values_shape", INDEX_PUT_SHAPE_ACC_FALSE
1548+
"input_shape, indices_shape, values_shape, is_bool", INDEX_PUT_SHAPE_ACC_FALSE
15371549
)
15381550
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
1539-
def test_index_put_acc_false(input_shape, indices_shape, values_shape, dtype):
1551+
def test_index_put_acc_false(input_shape, indices_shape, values_shape, is_bool, dtype):
15401552
accumulate = False
15411553
inp = torch.randn(
15421554
input_shape, dtype=dtype, device=flag_gems.device, requires_grad=False
15431555
)
1544-
indices = gen_indices(input_shape, indices_shape, accumulate)
1545-
values = torch.randn(
1546-
values_shape, dtype=dtype, device=flag_gems.device, requires_grad=False
1547-
)
1556+
1557+
indices = gen_indices(input_shape, indices_shape, accumulate, is_bool)
1558+
1559+
if is_bool:
1560+
K = indices[0].sum().item()
1561+
values = torch.randn((K,), dtype=dtype, device=flag_gems.device, requires_grad=False)
1562+
else:
1563+
values = torch.randn(
1564+
values_shape, dtype=dtype, device=flag_gems.device, requires_grad=False
1565+
)
15481566

15491567
ref_inp = to_reference(inp)
15501568
ref_indices = [to_reference(index) for index in indices]
@@ -1555,19 +1573,20 @@ def test_index_put_acc_false(input_shape, indices_shape, values_shape, dtype):
15551573

15561574

15571575
INDEX_PUT_SHAPE_ACC_TRUE = (
1558-
((2**28,), ((2**16,),), (2**16,)),
1559-
((32, 32), ((8,), (8,)), (8,)),
1560-
((512, 512, 512), ((128,), (128,), (128,)), (128,)),
1561-
((64, 64, 64), ((2, 8), (2, 8), (2, 8)), (2, 8)),
1576+
((2**28,), ((2**16,),), (2**16,), False),
1577+
((32, 32), ((8,), (8,)), (8,), False),
1578+
((512, 512, 512), ((128,), (128,), (128,)), (128,), False),
1579+
((64, 64, 64), ((2, 8), (2, 8), (2, 8)), (2, 8), False),
1580+
((32, 32), ((32, 32),), (32 * 32,), True),
15621581
)
15631582

15641583

15651584
@pytest.mark.index_put
15661585
@pytest.mark.parametrize(
1567-
"input_shape, indices_shape, values_shape", INDEX_PUT_SHAPE_ACC_TRUE
1586+
"input_shape, indices_shape, values_shape, is_bool", INDEX_PUT_SHAPE_ACC_TRUE
15681587
)
15691588
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
1570-
def test_index_put_acc_true(input_shape, indices_shape, values_shape, dtype):
1589+
def test_index_put_acc_true(input_shape, indices_shape, values_shape, is_bool, dtype):
15711590
init_seed(0)
15721591
if flag_gems.vendor_name == "cambricon":
15731592
torch.manual_seed(24)
@@ -1576,10 +1595,16 @@ def test_index_put_acc_true(input_shape, indices_shape, values_shape, dtype):
15761595
inp = torch.randn(
15771596
input_shape, dtype=dtype, device=flag_gems.device, requires_grad=False
15781597
)
1579-
indices = gen_indices(input_shape, indices_shape, accumulate)
1580-
values = torch.randn(
1581-
values_shape, dtype=dtype, device=flag_gems.device, requires_grad=False
1582-
)
1598+
1599+
indices = gen_indices(input_shape, indices_shape, accumulate, is_bool)
1600+
1601+
if is_bool:
1602+
K = indices[0].sum().item()
1603+
values = torch.randn((K,), dtype=dtype, device=flag_gems.device, requires_grad=False)
1604+
else:
1605+
values = torch.randn(
1606+
values_shape, dtype=dtype, device=flag_gems.device, requires_grad=False
1607+
)
15831608

15841609
ref_inp = to_reference(inp, upcast=True)
15851610
ref_indices = [to_reference(index) for index in indices]
@@ -1591,18 +1616,24 @@ def test_index_put_acc_true(input_shape, indices_shape, values_shape, dtype):
15911616

15921617
@pytest.mark.index_put_
15931618
@pytest.mark.parametrize(
1594-
"input_shape, indices_shape, values_shape", INDEX_PUT_SHAPE_ACC_FALSE
1619+
"input_shape, indices_shape, values_shape, is_bool", INDEX_PUT_SHAPE_ACC_FALSE
15951620
)
15961621
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
1597-
def test_index_put__acc_false(input_shape, indices_shape, values_shape, dtype):
1622+
def test_index_put__acc_false(input_shape, indices_shape, values_shape, is_bool, dtype):
15981623
accumulate = False
15991624
inp = torch.randn(
16001625
input_shape, dtype=dtype, device=flag_gems.device, requires_grad=False
16011626
)
1602-
indices = gen_indices(input_shape, indices_shape, accumulate)
1603-
values = torch.randn(
1604-
values_shape, dtype=dtype, device=flag_gems.device, requires_grad=False
1605-
)
1627+
1628+
indices = gen_indices(input_shape, indices_shape, accumulate, is_bool)
1629+
1630+
if is_bool:
1631+
K = indices[0].sum().item()
1632+
values = torch.randn((K,), dtype=dtype, device=flag_gems.device, requires_grad=False)
1633+
else:
1634+
values = torch.randn(
1635+
values_shape, dtype=dtype, device=flag_gems.device, requires_grad=False
1636+
)
16061637

16071638
ref_inp = to_reference(inp)
16081639
ref_indices = [to_reference(index) for index in indices]
@@ -1614,10 +1645,10 @@ def test_index_put__acc_false(input_shape, indices_shape, values_shape, dtype):
16141645

16151646
@pytest.mark.index_put_
16161647
@pytest.mark.parametrize(
1617-
"input_shape, indices_shape, values_shape", INDEX_PUT_SHAPE_ACC_TRUE
1648+
"input_shape, indices_shape, values_shape, is_bool", INDEX_PUT_SHAPE_ACC_TRUE
16181649
)
16191650
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
1620-
def test_index_put__acc_true(input_shape, indices_shape, values_shape, dtype):
1651+
def test_index_put__acc_true(input_shape, indices_shape, values_shape, is_bool, dtype):
16211652
if flag_gems.vendor_name == "kunlunxin":
16221653
torch.manual_seed(0)
16231654
torch.cuda.manual_seed_all(0)
@@ -1631,10 +1662,16 @@ def test_index_put__acc_true(input_shape, indices_shape, values_shape, dtype):
16311662
inp = torch.randn(
16321663
input_shape, dtype=dtype, device=flag_gems.device, requires_grad=False
16331664
)
1634-
indices = gen_indices(input_shape, indices_shape, accumulate)
1635-
values = torch.randn(
1636-
values_shape, dtype=dtype, device=flag_gems.device, requires_grad=False
1637-
)
1665+
1666+
indices = gen_indices(input_shape, indices_shape, accumulate, is_bool)
1667+
1668+
if is_bool:
1669+
K = indices[0].sum().item()
1670+
values = torch.randn((K,), dtype=dtype, device=flag_gems.device, requires_grad=False)
1671+
else:
1672+
values = torch.randn(
1673+
values_shape, dtype=dtype, device=flag_gems.device, requires_grad=False
1674+
)
16381675

16391676
ref_inp = to_reference(inp, upcast=True)
16401677
ref_indices = [to_reference(index) for index in indices]

0 commit comments

Comments
 (0)