Skip to content

Commit 0eb24e2

Browse files
committed
[Bugfix] initialize cuda context properly and reduce test cases
1 parent a3e5971 commit 0eb24e2

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tests/test_special_ops.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import itertools
2+
import random
23
from typing import Optional
34

45
import numpy as np
@@ -984,7 +985,9 @@ def get_diag_embed_shape_and_dims():
984985

985986
for s in shapes:
986987
dim_pairs = get_dim1_dim2(len(s) + 1)
987-
result.extend([(s, dim1, dim2) for dim1, dim2 in dim_pairs])
988+
if dim_pairs:
989+
dim1, dim2 = random.choice(dim_pairs)
990+
result.append((s, dim1, dim2))
988991

989992
return result
990993

@@ -1019,7 +1022,9 @@ def get_diagonal_backward_shape_and_dims():
10191022

10201023
for s in shapes:
10211024
dim_pairs = get_dim1_dim2(len(s))
1022-
result.extend([(s, dim1, dim2) for dim1, dim2 in dim_pairs])
1025+
if dim_pairs:
1026+
dim1, dim2 = random.choice(dim_pairs)
1027+
result.append((s, dim1, dim2))
10231028

10241029
return result
10251030

@@ -1030,6 +1035,7 @@ def get_diagonal_backward_shape_and_dims():
10301035
@pytest.mark.parametrize("offset", [-1, 0, 1])
10311036
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
10321037
def test_accuracy_diagonal_backward(shape, dtype, dim1, dim2, offset):
1038+
torch.empty(1, device="cuda", requires_grad=True).backward()
10331039
inp = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True)
10341040
ref_inp = to_reference(inp)
10351041

0 commit comments

Comments
 (0)