Skip to content

Commit 2630785

Browse files
committed
Implement separate bra and ket absorption
1 parent dbc05f7 commit 2630785

File tree

8 files changed

+3817
-85
lines changed

8 files changed

+3817
-85
lines changed

varipeps/contractions/definitions.py

+848
Large diffs are not rendered by default.

varipeps/ctmrg/absorption.py

+717
Large diffs are not rendered by default.

varipeps/ctmrg/projectors.py

+1,101-3
Large diffs are not rendered by default.

varipeps/ctmrg/routine.py

+242-74
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
import jax.debug as jdebug
99

1010
from varipeps import varipeps_config, varipeps_global_state
11-
from varipeps.peps import PEPS_Tensor, PEPS_Unit_Cell
11+
from varipeps.peps import PEPS_Tensor, PEPS_Tensor_Split_Transfer, PEPS_Unit_Cell
1212
from varipeps.utils.debug_print import debug_print
13-
from .absorption import do_absorption_step
13+
from .absorption import do_absorption_step, do_absorption_step_split_transfer
1414

1515
from typing import Sequence, Tuple, List, Optional
1616

@@ -25,6 +25,14 @@ class CTM_Enum(enum.IntEnum):
2525
T2 = enum.auto()
2626
T3 = enum.auto()
2727
T4 = enum.auto()
28+
T1_ket = enum.auto()
29+
T1_bra = enum.auto()
30+
T2_ket = enum.auto()
31+
T2_bra = enum.auto()
32+
T3_ket = enum.auto()
33+
T3_bra = enum.auto()
34+
T4_ket = enum.auto()
35+
T4_bra = enum.auto()
2836

2937

3038
class CTMRGNotConvergedError(Exception):
@@ -61,6 +69,8 @@ def _calc_corner_svds(
6169
C1_svd, indices_are_sorted=True, unique_indices=True
6270
)
6371

72+
# debug_print("C1: {}", C1_svd)
73+
6474
C2_svd = jnp.linalg.svd(t.C2, full_matrices=False, compute_uv=False)
6575
step_corner_svd = step_corner_svd.at[ti, 1, : C2_svd.shape[0]].set(
6676
C2_svd, indices_are_sorted=True, unique_indices=True
@@ -79,15 +89,20 @@ def _calc_corner_svds(
7989
return step_corner_svd
8090

8191

82-
@partial(jit, static_argnums=(3,), inline=True)
92+
@partial(jit, static_argnums=(3, 4), inline=True)
8393
def _is_element_wise_converged(
8494
old_peps_tensors: List[PEPS_Tensor],
8595
new_peps_tensors: List[PEPS_Tensor],
8696
eps: float,
8797
verbose: bool = False,
98+
split_transfer: bool = False,
8899
) -> Tuple[bool, float, Optional[List[Tuple[int, CTM_Enum, float]]]]:
89100
result = 0
90-
measure = jnp.zeros((len(old_peps_tensors), 8), dtype=jnp.float64)
101+
102+
if split_transfer:
103+
measure = jnp.zeros((len(old_peps_tensors), 12), dtype=jnp.float64)
104+
else:
105+
measure = jnp.zeros((len(old_peps_tensors), 8), dtype=jnp.float64)
91106

92107
verbose_data = [] if verbose else None
93108

@@ -144,73 +159,210 @@ def _is_element_wise_converged(
144159
if verbose:
145160
verbose_data.append((ti, CTM_Enum.C4, jnp.amax(diff)))
146161

147-
old_shape = old_peps_tensors[ti].T1.shape
148-
new_shape = new_peps_tensors[ti].T1.shape
149-
diff = jnp.abs(
150-
new_peps_tensors[ti].T1[
151-
: old_shape[0], : old_shape[1], : old_shape[2], : old_shape[3]
152-
]
153-
- old_peps_tensors[ti].T1[
154-
: new_shape[0], : new_shape[1], : new_shape[2], : new_shape[3]
155-
]
156-
)
157-
result += jnp.sum(diff > eps)
158-
measure = measure.at[ti, 4].set(
159-
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
160-
)
161-
if verbose:
162-
verbose_data.append((ti, CTM_Enum.T1, jnp.amax(diff)))
163-
164-
old_shape = old_peps_tensors[ti].T2.shape
165-
new_shape = new_peps_tensors[ti].T2.shape
166-
diff = jnp.abs(
167-
new_peps_tensors[ti].T2[
168-
: old_shape[0], : old_shape[1], : old_shape[2], : old_shape[3]
169-
]
170-
- old_peps_tensors[ti].T2[
171-
: new_shape[0], : new_shape[1], : new_shape[2], : new_shape[3]
172-
]
173-
)
174-
result += jnp.sum(diff > eps)
175-
measure = measure.at[ti, 5].set(
176-
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
177-
)
178-
if verbose:
179-
verbose_data.append((ti, CTM_Enum.T2, jnp.amax(diff)))
180-
181-
old_shape = old_peps_tensors[ti].T3.shape
182-
new_shape = new_peps_tensors[ti].T3.shape
183-
diff = jnp.abs(
184-
new_peps_tensors[ti].T3[
185-
: old_shape[0], : old_shape[1], : old_shape[2], : old_shape[3]
186-
]
187-
- old_peps_tensors[ti].T3[
188-
: new_shape[0], : new_shape[1], : new_shape[2], : new_shape[3]
189-
]
190-
)
191-
result += jnp.sum(diff > eps)
192-
measure = measure.at[ti, 6].set(
193-
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
194-
)
195-
if verbose:
196-
verbose_data.append((ti, CTM_Enum.T3, jnp.amax(diff)))
197-
198-
old_shape = old_peps_tensors[ti].T4.shape
199-
new_shape = new_peps_tensors[ti].T4.shape
200-
diff = jnp.abs(
201-
new_peps_tensors[ti].T4[
202-
: old_shape[0], : old_shape[1], : old_shape[2], : old_shape[3]
203-
]
204-
- old_peps_tensors[ti].T4[
205-
: new_shape[0], : new_shape[1], : new_shape[2], : new_shape[3]
206-
]
207-
)
208-
result += jnp.sum(diff > eps)
209-
measure = measure.at[ti, 7].set(
210-
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
211-
)
212-
if verbose:
213-
verbose_data.append((ti, CTM_Enum.T4, jnp.amax(diff)))
162+
if split_transfer:
163+
old_shape = old_peps_tensors[ti].T1_ket.shape
164+
new_shape = new_peps_tensors[ti].T1_ket.shape
165+
diff = jnp.abs(
166+
new_peps_tensors[ti].T1_ket[
167+
: old_shape[0], : old_shape[1], : old_shape[2]
168+
]
169+
- old_peps_tensors[ti].T1_ket[
170+
: new_shape[0], : new_shape[1], : new_shape[2]
171+
]
172+
)
173+
result += jnp.sum(diff > eps)
174+
measure = measure.at[ti, 4].set(
175+
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
176+
)
177+
if verbose:
178+
verbose_data.append((ti, CTM_Enum.T1_ket, jnp.amax(diff)))
179+
180+
old_shape = old_peps_tensors[ti].T1_bra.shape
181+
new_shape = new_peps_tensors[ti].T1_bra.shape
182+
diff = jnp.abs(
183+
new_peps_tensors[ti].T1_bra[
184+
: old_shape[0], : old_shape[1], : old_shape[2]
185+
]
186+
- old_peps_tensors[ti].T1_bra[
187+
: new_shape[0], : new_shape[1], : new_shape[2]
188+
]
189+
)
190+
result += jnp.sum(diff > eps)
191+
measure = measure.at[ti, 5].set(
192+
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
193+
)
194+
if verbose:
195+
verbose_data.append((ti, CTM_Enum.T1_bra, jnp.amax(diff)))
196+
197+
old_shape = old_peps_tensors[ti].T2_ket.shape
198+
new_shape = new_peps_tensors[ti].T2_ket.shape
199+
diff = jnp.abs(
200+
new_peps_tensors[ti].T2_ket[
201+
: old_shape[0], : old_shape[1], : old_shape[2]
202+
]
203+
- old_peps_tensors[ti].T2_ket[
204+
: new_shape[0], : new_shape[1], : new_shape[2]
205+
]
206+
)
207+
result += jnp.sum(diff > eps)
208+
measure = measure.at[ti, 6].set(
209+
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
210+
)
211+
if verbose:
212+
verbose_data.append((ti, CTM_Enum.T2_ket, jnp.amax(diff)))
213+
214+
old_shape = old_peps_tensors[ti].T2_bra.shape
215+
new_shape = new_peps_tensors[ti].T2_bra.shape
216+
diff = jnp.abs(
217+
new_peps_tensors[ti].T2_bra[
218+
: old_shape[0], : old_shape[1], : old_shape[2]
219+
]
220+
- old_peps_tensors[ti].T2_bra[
221+
: new_shape[0], : new_shape[1], : new_shape[2]
222+
]
223+
)
224+
result += jnp.sum(diff > eps)
225+
measure = measure.at[ti, 7].set(
226+
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
227+
)
228+
if verbose:
229+
verbose_data.append((ti, CTM_Enum.T2_bra, jnp.amax(diff)))
230+
231+
old_shape = old_peps_tensors[ti].T3_ket.shape
232+
new_shape = new_peps_tensors[ti].T3_ket.shape
233+
diff = jnp.abs(
234+
new_peps_tensors[ti].T3_ket[
235+
: old_shape[0], : old_shape[1], : old_shape[2]
236+
]
237+
- old_peps_tensors[ti].T3_ket[
238+
: new_shape[0], : new_shape[1], : new_shape[2]
239+
]
240+
)
241+
result += jnp.sum(diff > eps)
242+
measure = measure.at[ti, 8].set(
243+
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
244+
)
245+
if verbose:
246+
verbose_data.append((ti, CTM_Enum.T3_ket, jnp.amax(diff)))
247+
248+
old_shape = old_peps_tensors[ti].T3_bra.shape
249+
new_shape = new_peps_tensors[ti].T3_bra.shape
250+
diff = jnp.abs(
251+
new_peps_tensors[ti].T3_bra[
252+
: old_shape[0], : old_shape[1], : old_shape[2]
253+
]
254+
- old_peps_tensors[ti].T3_bra[
255+
: new_shape[0], : new_shape[1], : new_shape[2]
256+
]
257+
)
258+
result += jnp.sum(diff > eps)
259+
measure = measure.at[ti, 9].set(
260+
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
261+
)
262+
if verbose:
263+
verbose_data.append((ti, CTM_Enum.T3_bra, jnp.amax(diff)))
264+
265+
old_shape = old_peps_tensors[ti].T4_ket.shape
266+
new_shape = new_peps_tensors[ti].T4_ket.shape
267+
diff = jnp.abs(
268+
new_peps_tensors[ti].T4_ket[
269+
: old_shape[0], : old_shape[1], : old_shape[2]
270+
]
271+
- old_peps_tensors[ti].T4_ket[
272+
: new_shape[0], : new_shape[1], : new_shape[2]
273+
]
274+
)
275+
result += jnp.sum(diff > eps)
276+
measure = measure.at[ti, 10].set(
277+
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
278+
)
279+
if verbose:
280+
verbose_data.append((ti, CTM_Enum.T4_ket, jnp.amax(diff)))
281+
282+
old_shape = old_peps_tensors[ti].T4_bra.shape
283+
new_shape = new_peps_tensors[ti].T4_bra.shape
284+
diff = jnp.abs(
285+
new_peps_tensors[ti].T4_bra[
286+
: old_shape[0], : old_shape[1], : old_shape[2]
287+
]
288+
- old_peps_tensors[ti].T4_bra[
289+
: new_shape[0], : new_shape[1], : new_shape[2]
290+
]
291+
)
292+
result += jnp.sum(diff > eps)
293+
measure = measure.at[ti, 11].set(
294+
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
295+
)
296+
if verbose:
297+
verbose_data.append((ti, CTM_Enum.T4_bra, jnp.amax(diff)))
298+
else:
299+
old_shape = old_peps_tensors[ti].T1.shape
300+
new_shape = new_peps_tensors[ti].T1.shape
301+
diff = jnp.abs(
302+
new_peps_tensors[ti].T1[
303+
: old_shape[0], : old_shape[1], : old_shape[2], : old_shape[3]
304+
]
305+
- old_peps_tensors[ti].T1[
306+
: new_shape[0], : new_shape[1], : new_shape[2], : new_shape[3]
307+
]
308+
)
309+
result += jnp.sum(diff > eps)
310+
measure = measure.at[ti, 4].set(
311+
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
312+
)
313+
if verbose:
314+
verbose_data.append((ti, CTM_Enum.T1, jnp.amax(diff)))
315+
316+
old_shape = old_peps_tensors[ti].T2.shape
317+
new_shape = new_peps_tensors[ti].T2.shape
318+
diff = jnp.abs(
319+
new_peps_tensors[ti].T2[
320+
: old_shape[0], : old_shape[1], : old_shape[2], : old_shape[3]
321+
]
322+
- old_peps_tensors[ti].T2[
323+
: new_shape[0], : new_shape[1], : new_shape[2], : new_shape[3]
324+
]
325+
)
326+
result += jnp.sum(diff > eps)
327+
measure = measure.at[ti, 5].set(
328+
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
329+
)
330+
if verbose:
331+
verbose_data.append((ti, CTM_Enum.T2, jnp.amax(diff)))
332+
333+
old_shape = old_peps_tensors[ti].T3.shape
334+
new_shape = new_peps_tensors[ti].T3.shape
335+
diff = jnp.abs(
336+
new_peps_tensors[ti].T3[
337+
: old_shape[0], : old_shape[1], : old_shape[2], : old_shape[3]
338+
]
339+
- old_peps_tensors[ti].T3[
340+
: new_shape[0], : new_shape[1], : new_shape[2], : new_shape[3]
341+
]
342+
)
343+
result += jnp.sum(diff > eps)
344+
measure = measure.at[ti, 6].set(
345+
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
346+
)
347+
if verbose:
348+
verbose_data.append((ti, CTM_Enum.T3, jnp.amax(diff)))
349+
350+
old_shape = old_peps_tensors[ti].T4.shape
351+
new_shape = new_peps_tensors[ti].T4.shape
352+
diff = jnp.abs(
353+
new_peps_tensors[ti].T4[
354+
: old_shape[0], : old_shape[1], : old_shape[2], : old_shape[3]
355+
]
356+
- old_peps_tensors[ti].T4[
357+
: new_shape[0], : new_shape[1], : new_shape[2], : new_shape[3]
358+
]
359+
)
360+
result += jnp.sum(diff > eps)
361+
measure = measure.at[ti, 7].set(
362+
jnp.linalg.norm(diff), indices_are_sorted=True, unique_indices=True
363+
)
364+
if verbose:
365+
verbose_data.append((ti, CTM_Enum.T4, jnp.amax(diff)))
214366

215367
return result == 0, jnp.linalg.norm(measure), verbose_data
216368

@@ -230,16 +382,22 @@ def _ctmrg_body_func(carry):
230382
config,
231383
) = carry
232384

233-
w_unitcell, norm_smallest_S = do_absorption_step(
234-
w_tensors, w_unitcell_last_step, config, state
235-
)
385+
if state.ctmrg_split_transfer:
386+
w_unitcell, norm_smallest_S = do_absorption_step_split_transfer(
387+
w_tensors, w_unitcell_last_step, config, state
388+
)
389+
else:
390+
w_unitcell, norm_smallest_S = do_absorption_step(
391+
w_tensors, w_unitcell_last_step, config, state
392+
)
236393

237394
def elementwise_func(old, new, old_corner, conv_eps, config):
238395
converged, measure, verbose_data = _is_element_wise_converged(
239396
old,
240397
new,
241398
conv_eps,
242399
verbose=config.ctmrg_verbose_output,
400+
split_transfer=state.ctmrg_split_transfer,
243401
)
244402
return converged, measure, verbose_data, old_corner
245403

@@ -377,12 +535,22 @@ def calc_ctmrg_env(
377535
norm_smallest_S = jnp.nan
378536
already_tried_chi = {working_unitcell[0, 0][0][0].chi}
379537

538+
varipeps_global_state.ctmrg_split_transfer = isinstance(
539+
unitcell.get_unique_tensors()[0], PEPS_Tensor_Split_Transfer
540+
)
541+
380542
while True:
381543
tmp_count = 0
382544
corner_singular_vals = None
383545

384546
while any(
385547
i.C1.shape[0] != i.chi for i in working_unitcell.get_unique_tensors()
548+
) or (
549+
hasattr(working_unitcell.get_unique_tensors()[0], "T4_ket")
550+
and any(
551+
i.T4_ket.shape[0] != i.interlayer_chi
552+
for i in working_unitcell.get_unique_tensors()
553+
)
386554
):
387555
(
388556
_,

varipeps/global_state.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class should not be modified by users.
2121

2222
ctmrg_effective_truncation_eps: Optional[float] = None
2323
ctmrg_projector_method: Optional[Projector_Method] = None
24+
ctmrg_split_transfer: bool = False
2425
basinhopping_disable_half_projector: Optional[bool] = None
2526

2627
def tree_flatten(self) -> Tuple[Tuple[Any, ...], Tuple[Any, ...]]:

0 commit comments

Comments
 (0)