Skip to content

Commit 475ad73

Browse files
committedFeb 27, 2025
Fallback to QR based SVD if GESDD is not converging
1 parent 34b692a commit 475ad73

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed
 

‎varipeps/utils/svd.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,30 @@ def _H(x):
2424
def svd_wrapper(a):
2525
check_arraylike("jnp.linalg.svd", a)
2626
(a,) = promote_dtypes_inexact(jnp.asarray(a))
27-
return lax_svd(a, full_matrices=False, compute_uv=True)
27+
28+
result = lax_svd(a, full_matrices=False, compute_uv=True)
29+
30+
result = lax.cond(
31+
jnp.isnan(jnp.sum(result[1])),
32+
lambda matrix, _: lax_svd(
33+
matrix,
34+
full_matrices=False,
35+
compute_uv=True,
36+
algorithm=lax.linalg.SvdAlgorithm.QR,
37+
),
38+
lambda _, res: res,
39+
a,
40+
result,
41+
)
42+
43+
return result
2844

2945

3046
@svd_wrapper.defjvp
3147
def _svd_jvp_rule(primals, tangents):
3248
(A,) = primals
3349
(dA,) = tangents
34-
U, s, Vt = lax_svd(A, full_matrices=False, compute_uv=True)
50+
U, s, Vt = svd_wrapper(A)
3551

3652
Ut, V = _H(U), _H(Vt)
3753
s_dim = s[..., None, :]

0 commit comments

Comments
 (0)
Please sign in to comment.