Skip to content

Commit 4f748a6

Browse files
committedFeb 27, 2025
Fix formating of source files
1 parent c19b3df commit 4f748a6

File tree

6 files changed

+63
-40
lines changed

6 files changed

+63
-40
lines changed
 

‎varipeps/ctmrg/absorption.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ class definition for details.
274274
right_proj, smallest_S = calc_right_projectors(
275275
*_get_ctmrg_2x2_structure(peps_tensors, view, "top-right"),
276276
config,
277-
state
277+
state,
278278
)
279279
right_projectors[(x, y)] = right_proj
280280
smallest_S_list.append(smallest_S)
@@ -474,7 +474,7 @@ class definition for details.
474474
bottom_proj, smallest_S = calc_bottom_projectors(
475475
*_get_ctmrg_2x2_structure(peps_tensors, view, "bottom-left"),
476476
config,
477-
state
477+
state,
478478
)
479479
bottom_projectors[(x, y)] = bottom_proj
480480
smallest_S_list.append(smallest_S)
@@ -894,7 +894,7 @@ def do_right_absorption_split_transfer(
894894
) = calc_right_projectors_split_transfer(
895895
*_get_ctmrg_2x2_structure(peps_tensors, view, "top-right"),
896896
config,
897-
state
897+
state,
898898
)
899899
right_projectors[(x, y)] = right_proj
900900
smallest_S_list.append(smallest_S_ket)
@@ -1226,7 +1226,7 @@ def do_bottom_absorption_split_transfer(
12261226
) = calc_bottom_projectors_split_transfer(
12271227
*_get_ctmrg_2x2_structure(peps_tensors, view, "bottom-left"),
12281228
config,
1229-
state
1229+
state,
12301230
)
12311231
bottom_projectors[(x, y)] = bottom_proj
12321232
smallest_S_list.append(smallest_S_ket)

‎varipeps/ctmrg/routine.py

+17-15
Original file line numberDiff line numberDiff line change
@@ -579,22 +579,24 @@ def calc_ctmrg_env(
579579
)
580580

581581
if tmp_count < varipeps_config.ctmrg_max_steps:
582-
working_unitcell, converged, end_count, norm_smallest_S = _ctmrg_while_wrapper(
583-
(
584-
peps_tensors,
585-
working_unitcell,
586-
False,
582+
working_unitcell, converged, end_count, norm_smallest_S = (
583+
_ctmrg_while_wrapper(
587584
(
588-
corner_singular_vals
589-
if corner_singular_vals is not None
590-
else init_corner_singular_vals
591-
),
592-
eps,
593-
tmp_count,
594-
enforce_elementwise_convergence,
595-
jnp.inf,
596-
varipeps_global_state,
597-
varipeps_config,
585+
peps_tensors,
586+
working_unitcell,
587+
False,
588+
(
589+
corner_singular_vals
590+
if corner_singular_vals is not None
591+
else init_corner_singular_vals
592+
),
593+
eps,
594+
tmp_count,
595+
enforce_elementwise_convergence,
596+
jnp.inf,
597+
varipeps_global_state,
598+
varipeps_config,
599+
)
598600
)
599601
)
600602
else:

‎varipeps/ctmrg/structure_factor_absorption.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ class definition for details.
313313
right_proj, smallest_S = calc_right_projectors(
314314
*_get_ctmrg_2x2_structure(peps_tensors, view, "top-right"),
315315
config,
316-
state
316+
state,
317317
)
318318
right_projectors[(x, y)] = right_proj
319319
smallest_S_list.append(smallest_S)
@@ -683,7 +683,7 @@ class definition for details.
683683
bottom_proj, smallest_S = calc_bottom_projectors(
684684
*_get_ctmrg_2x2_structure(peps_tensors, view, "bottom-left"),
685685
config,
686-
state
686+
state,
687687
)
688688
bottom_projectors[(x, y)] = bottom_proj
689689
smallest_S_list.append(smallest_S)

‎varipeps/mapping/triangular.py

+34-17
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,15 @@ class Triangular_Expectation_Value(Expectation_Model):
6464
\\
6565
6666
Args:
67-
nearest_neighbor_gates (:term:`sequence` of :obj:`jax.numpy.ndarray`):
68-
Sequence with the gates that should be applied to each nearest
67+
horizontal_gates (:term:`sequence` of :obj:`jax.numpy.ndarray`):
68+
Sequence with the gates that should be applied to each nearest horizontal
69+
neighbor.
70+
vertical_gates (:term:`sequence` of :obj:`jax.numpy.ndarray`):
71+
Sequence with the gates that should be applied to each nearest vertical
72+
neighbor.
73+
diagonal_gates (:term:`sequence` of :obj:`jax.numpy.ndarray`):
74+
Sequence with the gates that should be applied to each nearest diagonal
6975
neighbor.
70-
downward_triangle_gates (:term:`sequence` of :obj:`jax.numpy.ndarray`):
71-
Sequence with the gates that should be applied to the downward
72-
triangles.
7376
normalization_factor (:obj:`int`):
7477
Factor which should be used to normalize the calculated values.
7578
If for example three sites are mapped into one PEPS site this
@@ -81,24 +84,38 @@ class Triangular_Expectation_Value(Expectation_Model):
8184
if spiral iPEPS ansatz is used.
8285
"""
8386

84-
nearest_neighbor_gates: Sequence[jnp.ndarray]
87+
horizontal_gates: Sequence[jnp.ndarray]
88+
vertical_gates: Sequence[jnp.ndarray]
89+
diagonal_gates: Sequence[jnp.ndarray]
8590
real_d: int
8691
normalization_factor: int = 1
8792

8893
is_spiral_peps: bool = False
8994
spiral_unitary_operator: Optional[jnp.ndarray] = None
9095

9196
def __post_init__(self) -> None:
92-
if isinstance(self.nearest_neighbor_gates, jnp.ndarray):
93-
self.nearest_neighbor_gates = (self.nearest_neighbor_gates,)
97+
if isinstance(self.horizontal_gates, jnp.ndarray):
98+
self.horizontal_gates = (self.horizontal_gates,)
99+
else:
100+
self.horizontal_gates = tuple(self.horizontal_gates)
101+
102+
if isinstance(self.vertical_gates, jnp.ndarray):
103+
self.vertical_gates = (self.vertical_gates,)
94104
else:
95-
self.nearest_neighbor_gates = tuple(self.nearest_neighbor_gates)
105+
self.vertical_gates = tuple(self.vertical_gates)
106+
107+
if isinstance(self.diagonal_gates, jnp.ndarray):
108+
self.diagonal_gates = (self.diagonal_gates,)
109+
else:
110+
self.diagonal_gates = tuple(self.diagonal_gates)
96111

97112
self._result_type = (
98113
jnp.float64
99114
if all(
100115
jnp.allclose(g, g.T.conj())
101-
for g in self.nearest_neighbor_gates
116+
for g in self.horizontal_gates
117+
+ self.vertical_gates
118+
+ self.diagonal_gates
102119
)
103120
else jnp.complex128
104121
)
@@ -120,7 +137,7 @@ def __call__(
120137
) -> Union[jnp.ndarray, List[jnp.ndarray]]:
121138
result = [
122139
jnp.array(0, dtype=self._result_type)
123-
for _ in range(len(self.nearest_neighbor_gates))
140+
for _ in range(len(self.horizontal_gates))
124141
]
125142

126143
if self.is_spiral_peps:
@@ -145,7 +162,7 @@ def __call__(
145162
(1,),
146163
varipeps_config.spiral_wavevector_type,
147164
)
148-
for h in self.nearest_neighbor_gates
165+
for h in self.horizontal_gates
149166
)
150167
working_v_gates = tuple(
151168
apply_unitary(
@@ -159,7 +176,7 @@ def __call__(
159176
(1,),
160177
varipeps_config.spiral_wavevector_type,
161178
)
162-
for v in self.nearest_neighbor_gates
179+
for v in self.vertical_gates
163180
)
164181
working_d_gates = tuple(
165182
apply_unitary(
@@ -173,12 +190,12 @@ def __call__(
173190
(1,),
174191
varipeps_config.spiral_wavevector_type,
175192
)
176-
for d in self.nearest_neighbor_gates
193+
for d in self.diagonal_gates
177194
)
178195
else:
179-
working_h_gates = self.nearest_neighbor_gates
180-
working_v_gates = self.nearest_neighbor_gates
181-
working_d_gates = self.nearest_neighbor_gates
196+
working_h_gates = self.horizontal_gates
197+
working_v_gates = self.vertical_gates
198+
working_d_gates = self.diagonal_gates
182199

183200
for x, iter_rows in unitcell.iter_all_rows(only_unique=only_unique):
184201
for y, view in iter_rows:

‎varipeps/optimization/optimizer.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,11 @@ def random_noise(a):
755755
"Energy": f"{working_value:0.10f}",
756756
"Retries": random_noise_retries,
757757
"Convergence": f"{conv:0.8f}",
758-
"Line search step": f"{linesearch_step:0.8f}" if linesearch_step is not None else "0",
758+
"Line search step": (
759+
f"{linesearch_step:0.8f}"
760+
if linesearch_step is not None
761+
else "0"
762+
),
759763
"Max. trunc. err.": f"{max_trunc_error:0.8g}",
760764
}
761765
)

‎varipeps/peps/unitcell.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def __post_init__(self):
211211

212212
@staticmethod
213213
def _check_structure(
214-
structure: Tuple[Tuple[int, ...], ...]
214+
structure: Tuple[Tuple[int, ...], ...],
215215
) -> Tuple[Tuple[Tuple[int, ...], ...], jnp.ndarray]:
216216
structure = np.array(structure)
217217

0 commit comments

Comments
 (0)
Please sign in to comment.