1
1
import itertools
2
+ import random
2
3
from typing import Optional
3
4
4
5
import numpy as np
@@ -984,7 +985,9 @@ def get_diag_embed_shape_and_dims():
984
985
985
986
for s in shapes :
986
987
dim_pairs = get_dim1_dim2 (len (s ) + 1 )
987
- result .extend ([(s , dim1 , dim2 ) for dim1 , dim2 in dim_pairs ])
988
+ if dim_pairs :
989
+ dim1 , dim2 = random .choice (dim_pairs )
990
+ result .append ((s , dim1 , dim2 ))
988
991
989
992
return result
990
993
@@ -1019,7 +1022,9 @@ def get_diagonal_backward_shape_and_dims():
1019
1022
1020
1023
for s in shapes :
1021
1024
dim_pairs = get_dim1_dim2 (len (s ))
1022
- result .extend ([(s , dim1 , dim2 ) for dim1 , dim2 in dim_pairs ])
1025
+ if dim_pairs :
1026
+ dim1 , dim2 = random .choice (dim_pairs )
1027
+ result .append ((s , dim1 , dim2 ))
1023
1028
1024
1029
return result
1025
1030
@@ -1030,6 +1035,7 @@ def get_diagonal_backward_shape_and_dims():
1030
1035
@pytest .mark .parametrize ("offset" , [- 1 , 0 , 1 ])
1031
1036
@pytest .mark .parametrize ("dtype" , FLOAT_DTYPES )
1032
1037
def test_accuracy_diagonal_backward (shape , dtype , dim1 , dim2 , offset ):
1038
+ torch .empty (1 , device = "cuda" , requires_grad = True ).backward ()
1033
1039
inp = torch .randn (shape , dtype = dtype , device = flag_gems .device , requires_grad = True )
1034
1040
ref_inp = to_reference (inp )
1035
1041
0 commit comments