Skip to content

Commit a3e5971

Browse files
committed
[Bugfix] diagonal_backward
1 parent b8ec29a commit a3e5971

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tests/test_special_ops.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1034,7 +1034,8 @@ def test_accuracy_diagonal_backward(shape, dtype, dim1, dim2, offset):
10341034
ref_inp = to_reference(inp)
10351035

10361036
ref_out = torch.diagonal(ref_inp, offset, dim1, dim2)
1037-
res_out = torch.diagonal(inp, offset, dim1, dim2)
1037+
with flag_gems.use_gems():
1038+
res_out = torch.diagonal(inp, offset, dim1, dim2)
10381039

10391040
out_grad = torch.randn_like(res_out)
10401041
ref_grad = to_reference(out_grad)

0 commit comments

Comments
 (0)