@@ -1486,21 +1486,25 @@ def test_accuracy_depthwise2d(
14861486
14871487
14881488INDEX_PUT_SHAPE_ACC_FALSE = (
1489- ((2 ** 28 ,), ((2 ** 16 ,),), (2 ** 16 ,)),
1490- ((32 , 32 ), ((8 ,), (8 ,)), (8 ,)),
1491- ((32 , 32 ), ((8 ,), (2 , 8 )), (8 ,)),
1492- ((32 , 32 ), ((2 , 8 ),), (32 ,)),
1493- ((512 , 512 , 512 ), ((128 ,), (128 ,), (128 ,)), (128 ,)),
1494- ((512 , 512 , 512 ), ((2 , 128 ), (128 ,), (128 ,)), (128 ,)),
1495- ((512 , 512 , 512 ), ((2 , 128 ),), (512 ,)),
1489+ ((2 ** 28 ,), ((2 ** 16 ,),), (2 ** 16 ,), False ),
1490+ ((32 , 32 ), ((8 ,), (8 ,)), (8 ,), False ),
1491+ ((32 , 32 ), ((8 ,), (2 , 8 )), (8 ,), False ),
1492+ ((32 , 32 ), ((2 , 8 ),), (32 ,), False ),
1493+ ((512 , 512 , 512 ), ((128 ,), (128 ,), (128 ,)), (128 ,), False ),
1494+ ((512 , 512 , 512 ), ((2 , 128 ), (128 ,), (128 ,)), (128 ,), False ),
1495+ ((512 , 512 , 512 ), ((2 , 128 ),), (512 ,), False ),
14961496 (
14971497 (64 , 64 , 64 ),
14981498 (
14991499 (2 , 8 ),
15001500 (2 , 8 ),
15011501 ),
15021502 (2 , 8 , 64 ),
1503+ False ,
15031504 ),
1505+ ((100 ,), ((100 ,),), (100 ,), True ),
1506+ ((32 , 32 ), ((32 , 32 ),), (32 , 32 ), True ),
1507+ ((16 , 16 , 4 ), ((16 , 16 , 4 ),), (16 , 16 , 4 ), True ),
15041508)
15051509
15061510INDEX_ACC_SHAPE = (
@@ -1521,30 +1525,44 @@ def test_accuracy_depthwise2d(
15211525)
15221526
15231527
1524- def gen_indices (input_shape , indices_shape , accumulate ):
1528+ def gen_indices (input_shape , indices_shape , accumulate , is_bool ):
15251529 indices = []
1526- for i , shape in enumerate (indices_shape ):
1527- index = np .random .choice (
1528- np .arange (input_shape [i ]), size = shape , replace = accumulate
1529- )
1530- indices .append (torch .tensor (index , device = flag_gems .device ))
1531- return indices
1530+
1531+ if is_bool :
1532+ mask_shape = indices_shape [0 ]
1533+
1534+ mask = torch .randint (0 , 2 , size = mask_shape , dtype = torch .bool , device = flag_gems .device )
1535+ return [mask ]
1536+
1537+ else :
1538+ for i , shape in enumerate (indices_shape ):
1539+ index = np .random .choice (
1540+ np .arange (input_shape [i ]), size = shape , replace = accumulate
1541+ )
1542+ indices .append (torch .tensor (index , device = flag_gems .device ))
1543+ return indices
15321544
15331545
15341546@pytest .mark .index_put
15351547@pytest .mark .parametrize (
1536- "input_shape, indices_shape, values_shape" , INDEX_PUT_SHAPE_ACC_FALSE
1548+ "input_shape, indices_shape, values_shape, is_bool " , INDEX_PUT_SHAPE_ACC_FALSE
15371549)
15381550@pytest .mark .parametrize ("dtype" , FLOAT_DTYPES )
1539- def test_index_put_acc_false (input_shape , indices_shape , values_shape , dtype ):
1551+ def test_index_put_acc_false (input_shape , indices_shape , values_shape , is_bool , dtype ):
15401552 accumulate = False
15411553 inp = torch .randn (
15421554 input_shape , dtype = dtype , device = flag_gems .device , requires_grad = False
15431555 )
1544- indices = gen_indices (input_shape , indices_shape , accumulate )
1545- values = torch .randn (
1546- values_shape , dtype = dtype , device = flag_gems .device , requires_grad = False
1547- )
1556+
1557+ indices = gen_indices (input_shape , indices_shape , accumulate , is_bool )
1558+
1559+ if is_bool :
1560+ K = indices [0 ].sum ().item ()
1561+ values = torch .randn ((K ,), dtype = dtype , device = flag_gems .device , requires_grad = False )
1562+ else :
1563+ values = torch .randn (
1564+ values_shape , dtype = dtype , device = flag_gems .device , requires_grad = False
1565+ )
15481566
15491567 ref_inp = to_reference (inp )
15501568 ref_indices = [to_reference (index ) for index in indices ]
@@ -1555,19 +1573,20 @@ def test_index_put_acc_false(input_shape, indices_shape, values_shape, dtype):
15551573
15561574
15571575INDEX_PUT_SHAPE_ACC_TRUE = (
1558- ((2 ** 28 ,), ((2 ** 16 ,),), (2 ** 16 ,)),
1559- ((32 , 32 ), ((8 ,), (8 ,)), (8 ,)),
1560- ((512 , 512 , 512 ), ((128 ,), (128 ,), (128 ,)), (128 ,)),
1561- ((64 , 64 , 64 ), ((2 , 8 ), (2 , 8 ), (2 , 8 )), (2 , 8 )),
1576+ ((2 ** 28 ,), ((2 ** 16 ,),), (2 ** 16 ,), False ),
1577+ ((32 , 32 ), ((8 ,), (8 ,)), (8 ,), False ),
1578+ ((512 , 512 , 512 ), ((128 ,), (128 ,), (128 ,)), (128 ,), False ),
1579+ ((64 , 64 , 64 ), ((2 , 8 ), (2 , 8 ), (2 , 8 )), (2 , 8 ), False ),
1580+ ((32 , 32 ), ((32 , 32 ),), (32 * 32 ,), True ),
15621581)
15631582
15641583
15651584@pytest .mark .index_put
15661585@pytest .mark .parametrize (
1567- "input_shape, indices_shape, values_shape" , INDEX_PUT_SHAPE_ACC_TRUE
1586+ "input_shape, indices_shape, values_shape, is_bool " , INDEX_PUT_SHAPE_ACC_TRUE
15681587)
15691588@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .float32 ])
1570- def test_index_put_acc_true (input_shape , indices_shape , values_shape , dtype ):
1589+ def test_index_put_acc_true (input_shape , indices_shape , values_shape , is_bool , dtype ):
15711590 init_seed (0 )
15721591 if flag_gems .vendor_name == "cambricon" :
15731592 torch .manual_seed (24 )
@@ -1576,10 +1595,16 @@ def test_index_put_acc_true(input_shape, indices_shape, values_shape, dtype):
15761595 inp = torch .randn (
15771596 input_shape , dtype = dtype , device = flag_gems .device , requires_grad = False
15781597 )
1579- indices = gen_indices (input_shape , indices_shape , accumulate )
1580- values = torch .randn (
1581- values_shape , dtype = dtype , device = flag_gems .device , requires_grad = False
1582- )
1598+
1599+ indices = gen_indices (input_shape , indices_shape , accumulate , is_bool )
1600+
1601+ if is_bool :
1602+ K = indices [0 ].sum ().item ()
1603+ values = torch .randn ((K ,), dtype = dtype , device = flag_gems .device , requires_grad = False )
1604+ else :
1605+ values = torch .randn (
1606+ values_shape , dtype = dtype , device = flag_gems .device , requires_grad = False
1607+ )
15831608
15841609 ref_inp = to_reference (inp , upcast = True )
15851610 ref_indices = [to_reference (index ) for index in indices ]
@@ -1591,18 +1616,24 @@ def test_index_put_acc_true(input_shape, indices_shape, values_shape, dtype):
15911616
15921617@pytest .mark .index_put_
15931618@pytest .mark .parametrize (
1594- "input_shape, indices_shape, values_shape" , INDEX_PUT_SHAPE_ACC_FALSE
1619+ "input_shape, indices_shape, values_shape, is_bool " , INDEX_PUT_SHAPE_ACC_FALSE
15951620)
15961621@pytest .mark .parametrize ("dtype" , FLOAT_DTYPES )
1597- def test_index_put__acc_false (input_shape , indices_shape , values_shape , dtype ):
1622+ def test_index_put__acc_false (input_shape , indices_shape , values_shape , is_bool , dtype ):
15981623 accumulate = False
15991624 inp = torch .randn (
16001625 input_shape , dtype = dtype , device = flag_gems .device , requires_grad = False
16011626 )
1602- indices = gen_indices (input_shape , indices_shape , accumulate )
1603- values = torch .randn (
1604- values_shape , dtype = dtype , device = flag_gems .device , requires_grad = False
1605- )
1627+
1628+ indices = gen_indices (input_shape , indices_shape , accumulate , is_bool )
1629+
1630+ if is_bool :
1631+ K = indices [0 ].sum ().item ()
1632+ values = torch .randn ((K ,), dtype = dtype , device = flag_gems .device , requires_grad = False )
1633+ else :
1634+ values = torch .randn (
1635+ values_shape , dtype = dtype , device = flag_gems .device , requires_grad = False
1636+ )
16061637
16071638 ref_inp = to_reference (inp )
16081639 ref_indices = [to_reference (index ) for index in indices ]
@@ -1614,10 +1645,10 @@ def test_index_put__acc_false(input_shape, indices_shape, values_shape, dtype):
16141645
16151646@pytest .mark .index_put_
16161647@pytest .mark .parametrize (
1617- "input_shape, indices_shape, values_shape" , INDEX_PUT_SHAPE_ACC_TRUE
1648+ "input_shape, indices_shape, values_shape, is_bool " , INDEX_PUT_SHAPE_ACC_TRUE
16181649)
16191650@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .float32 ])
1620- def test_index_put__acc_true (input_shape , indices_shape , values_shape , dtype ):
1651+ def test_index_put__acc_true (input_shape , indices_shape , values_shape , is_bool , dtype ):
16211652 if flag_gems .vendor_name == "kunlunxin" :
16221653 torch .manual_seed (0 )
16231654 torch .cuda .manual_seed_all (0 )
@@ -1631,10 +1662,16 @@ def test_index_put__acc_true(input_shape, indices_shape, values_shape, dtype):
16311662 inp = torch .randn (
16321663 input_shape , dtype = dtype , device = flag_gems .device , requires_grad = False
16331664 )
1634- indices = gen_indices (input_shape , indices_shape , accumulate )
1635- values = torch .randn (
1636- values_shape , dtype = dtype , device = flag_gems .device , requires_grad = False
1637- )
1665+
1666+ indices = gen_indices (input_shape , indices_shape , accumulate , is_bool )
1667+
1668+ if is_bool :
1669+ K = indices [0 ].sum ().item ()
1670+ values = torch .randn ((K ,), dtype = dtype , device = flag_gems .device , requires_grad = False )
1671+ else :
1672+ values = torch .randn (
1673+ values_shape , dtype = dtype , device = flag_gems .device , requires_grad = False
1674+ )
16381675
16391676 ref_inp = to_reference (inp , upcast = True )
16401677 ref_indices = [to_reference (index ) for index in indices ]
0 commit comments