Skip to content

Commit 481f325

Browse files
committed
Fix missing function part in other optimizer inner functions
1 parent 03785dd commit 481f325

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

varipeps/optimization/inner_function.py

+36
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,24 @@ def calc_ctmrg_expectation(
9999
peps_tensors, unitcell, spiral_vectors = _map_tensors(
100100
input_tensors, unitcell, convert_to_unitcell_func, True
101101
)
102+
103+
if any(i.size == 1 for i in spiral_vectors):
104+
spiral_vectors_x = additional_input.get("spiral_vectors_x")
105+
spiral_vectors_y = additional_input.get("spiral_vectors_y")
106+
if spiral_vectors_x is not None:
107+
if isinstance(spiral_vectors_x, jnp.ndarray):
108+
spiral_vectors_x = (spiral_vectors_x,)
109+
spiral_vectors = tuple(
110+
jnp.array((sx, sy))
111+
for sx, sy in safe_zip(spiral_vectors_x, spiral_vectors)
112+
)
113+
elif spiral_vectors_y is not None:
114+
if isinstance(spiral_vectors_y, jnp.ndarray):
115+
spiral_vectors_y = (spiral_vectors_y,)
116+
spiral_vectors = tuple(
117+
jnp.array((sx, sy))
118+
for sx, sy in safe_zip(spiral_vectors, spiral_vectors_y)
119+
)
102120
else:
103121
peps_tensors, unitcell = _map_tensors(
104122
input_tensors, unitcell, convert_to_unitcell_func, False
@@ -175,6 +193,24 @@ def calc_preconverged_ctmrg_value_and_grad(
175193
peps_tensors, unitcell, spiral_vectors = _map_tensors(
176194
input_tensors, unitcell, convert_to_unitcell_func, True
177195
)
196+
197+
if any(i.size == 1 for i in spiral_vectors):
198+
spiral_vectors_x = additional_input.get("spiral_vectors_x")
199+
spiral_vectors_y = additional_input.get("spiral_vectors_y")
200+
if spiral_vectors_x is not None:
201+
if isinstance(spiral_vectors_x, jnp.ndarray):
202+
spiral_vectors_x = (spiral_vectors_x,)
203+
spiral_vectors = tuple(
204+
jnp.array((sx, sy))
205+
for sx, sy in safe_zip(spiral_vectors_x, spiral_vectors)
206+
)
207+
elif spiral_vectors_y is not None:
208+
if isinstance(spiral_vectors_y, jnp.ndarray):
209+
spiral_vectors_y = (spiral_vectors_y,)
210+
spiral_vectors = tuple(
211+
jnp.array((sx, sy))
212+
for sx, sy in safe_zip(spiral_vectors, spiral_vectors_y)
213+
)
178214
else:
179215
peps_tensors, unitcell = _map_tensors(
180216
input_tensors, unitcell, convert_to_unitcell_func, False

0 commit comments

Comments
 (0)