Skip to content

Commit 832e290

Browse files
committed
Make gradient of Fishman sqrt happy
1 parent 2862b88 commit 832e290

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

varipeps/ctmrg/projectors.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,11 @@ def _split_transfer_fishman(first_tensor, second_tensor, truncation_eps):
293293
(first_ketbra_S / first_ketbra_S[0]) >= truncation_eps, first_ketbra_S, 0
294294
)
295295
first_ketbra_S /= jnp.sum(first_ketbra_S)
296-
first_ketbra_S = jnp.sqrt(first_ketbra_S)
296+
first_ketbra_S = jnp.where(
297+
first_ketbra_S == 0,
298+
0,
299+
jnp.sqrt(jnp.where(first_ketbra_S == 0, 1, first_ketbra_S)),
300+
)
297301
if first_tensor.ndim == 5:
298302
first_ketbra_U = first_ketbra_U.reshape(
299303
first_tensor.shape[0], first_tensor.shape[1], first_ketbra_U.shape[-1]
@@ -332,7 +336,11 @@ def _split_transfer_fishman(first_tensor, second_tensor, truncation_eps):
332336
(second_ketbra_S / second_ketbra_S[0]) >= truncation_eps, second_ketbra_S, 0
333337
)
334338
second_ketbra_S /= jnp.sum(second_ketbra_S)
335-
second_ketbra_S = jnp.sqrt(second_ketbra_S)
339+
second_ketbra_S = jnp.where(
340+
second_ketbra_S == 0,
341+
0,
342+
jnp.sqrt(jnp.where(second_ketbra_S == 0, 1, second_ketbra_S)),
343+
)
336344
if second_tensor.ndim == 5:
337345
second_ketbra_U = second_ketbra_U.reshape(
338346
second_tensor.shape[0],

varipeps/peps/unitcell.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -495,9 +495,7 @@ def convert_to_split_transfer(
495495
sanity_checks=False,
496496
)
497497

498-
def convert_to_full_transfer(
499-
self: T_PEPS_Unit_Cell
500-
) -> T_PEPS_Unit_Cell:
498+
def convert_to_full_transfer(self: T_PEPS_Unit_Cell) -> T_PEPS_Unit_Cell:
501499
"""
502500
Convert the list of unique tensors to the full transfer ansatz.
503501

varipeps/utils/svd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def _svd_jvp_rule(primals, tangents):
5555
ds = jnp.real(jnp.diagonal(dS, 0, -2, -1))
5656

5757
s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim))
58+
# s_diffs = jnp.where(s_diffs / (s[0] ** 2) >= 1e-12, s_diffs, 0)
5859
s_diffs_zeros = jnp.ones((), dtype=A.dtype) * (
5960
s_diffs == 0.0
6061
) # is 1. where s_diffs is 0. and is 0. everywhere else

0 commit comments

Comments
 (0)