Skip to content

Commit a2e6ff3

Browse files
committed
Add split-transfer awareness to optimizer
1 parent fec420d commit a2e6ff3

File tree

5 files changed

+129
-24
lines changed

5 files changed

+129
-24
lines changed

varipeps/ctmrg/routine.py

+27-16
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,6 @@ def _calc_corner_svds(
6969
C1_svd, indices_are_sorted=True, unique_indices=True
7070
)
7171

72-
# debug_print("C1: {}", C1_svd)
73-
7472
C2_svd = jnp.linalg.svd(t.C2, full_matrices=False, compute_uv=False)
7573
step_corner_svd = step_corner_svd.at[ti, 1, : C2_svd.shape[0]].set(
7674
C2_svd, indices_are_sorted=True, unique_indices=True
@@ -382,7 +380,7 @@ def _ctmrg_body_func(carry):
382380
config,
383381
) = carry
384382

385-
if state.ctmrg_split_transfer:
383+
if w_unitcell_last_step.is_split_transfer():
386384
w_unitcell, norm_smallest_S = do_absorption_step_split_transfer(
387385
w_tensors, w_unitcell_last_step, config, state
388386
)
@@ -397,7 +395,7 @@ def elementwise_func(old, new, old_corner, conv_eps, config):
397395
new,
398396
conv_eps,
399397
verbose=config.ctmrg_verbose_output,
400-
split_transfer=state.ctmrg_split_transfer,
398+
split_transfer=w_unitcell.is_split_transfer(),
401399
)
402400
return converged, measure, verbose_data, old_corner
403401

@@ -535,10 +533,6 @@ def calc_ctmrg_env(
535533
norm_smallest_S = jnp.nan
536534
already_tried_chi = {working_unitcell[0, 0][0][0].chi}
537535

538-
varipeps_global_state.ctmrg_split_transfer = isinstance(
539-
unitcell.get_unique_tensors()[0], PEPS_Tensor_Split_Transfer
540-
)
541-
542536
while True:
543537
tmp_count = 0
544538
corner_singular_vals = None
@@ -776,6 +770,7 @@ def _ctmrg_rev_while_body(carry):
776770
bar_fixed_point.get_unique_tensors(),
777771
config.ad_custom_convergence_eps,
778772
verbose=config.ad_custom_verbose_output,
773+
split_transfer=bar_fixed_point.is_split_transfer(),
779774
)
780775

781776
count += 1
@@ -796,15 +791,31 @@ def _ctmrg_rev_while_body(carry):
796791

797792
@jit
798793
def _ctmrg_rev_workhorse(peps_tensors, new_unitcell, new_unitcell_bar, config, state):
799-
_, vjp_peps_tensors = vjp(
800-
lambda t: do_absorption_step(t, new_unitcell, config, state), peps_tensors
801-
)
794+
if new_unitcell.is_split_transfer():
795+
_, vjp_peps_tensors = vjp(
796+
lambda t: do_absorption_step_split_transfer(t, new_unitcell, config, state),
797+
peps_tensors,
798+
)
802799

803-
vjp_env = tree_util.Partial(
804-
vjp(lambda u: do_absorption_step(peps_tensors, u, config, state), new_unitcell)[
805-
1
806-
]
807-
)
800+
vjp_env = tree_util.Partial(
801+
vjp(
802+
lambda u: do_absorption_step_split_transfer(
803+
peps_tensors, u, config, state
804+
),
805+
new_unitcell,
806+
)[1]
807+
)
808+
else:
809+
_, vjp_peps_tensors = vjp(
810+
lambda t: do_absorption_step(t, new_unitcell, config, state), peps_tensors
811+
)
812+
813+
vjp_env = tree_util.Partial(
814+
vjp(
815+
lambda u: do_absorption_step(peps_tensors, u, config, state),
816+
new_unitcell,
817+
)[1]
818+
)
808819

809820
def cond_func(carry):
810821
_, _, _, converged, count, config, state = carry

varipeps/global_state.py

-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ class should not be modified by users.
2121

2222
ctmrg_effective_truncation_eps: Optional[float] = None
2323
ctmrg_projector_method: Optional[Projector_Method] = None
24-
ctmrg_split_transfer: bool = False
2524
basinhopping_disable_half_projector: Optional[bool] = None
2625

2726
def tree_flatten(self) -> Tuple[Tuple[Any, ...], Tuple[Any, ...]]:

varipeps/optimization/inner_function.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def calc_ctmrg_expectation(
9494
:obj:`tuple`\ (:obj:`jax.numpy.ndarray`, :obj:`~varipeps.peps.PEPS_Unit_Cell`):
9595
Tuple consisting of the calculated expectation value and the new unitcell.
9696
"""
97+
state_split_transfer = unitcell.is_split_transfer()
98+
9799
spiral_vectors = additional_input.get("spiral_vectors")
98100
if expectation_func.is_spiral_peps and spiral_vectors is None:
99101
peps_tensors, unitcell, spiral_vectors = _map_tensors(
@@ -122,20 +124,25 @@ def calc_ctmrg_expectation(
122124
input_tensors, unitcell, convert_to_unitcell_func, False
123125
)
124126

127+
if state_split_transfer != unitcell.is_split_transfer():
128+
raise ValueError("Map function is not split transfer aware. Please fix that!")
129+
125130
new_unitcell, max_trunc_error = calc_ctmrg_env(
126131
peps_tensors,
127132
unitcell,
128133
enforce_elementwise_convergence=enforce_elementwise_convergence,
129134
)
130135

136+
exp_unitcell = new_unitcell.convert_to_full_transfer()
137+
131138
if expectation_func.is_spiral_peps:
132139
return cast(
133-
jnp.ndarray, expectation_func(peps_tensors, new_unitcell, spiral_vectors)
140+
jnp.ndarray, expectation_func(peps_tensors, exp_unitcell, spiral_vectors)
134141
), (
135142
new_unitcell,
136143
max_trunc_error,
137144
)
138-
return cast(jnp.ndarray, expectation_func(peps_tensors, new_unitcell)), (
145+
return cast(jnp.ndarray, expectation_func(peps_tensors, exp_unitcell)), (
139146
new_unitcell,
140147
max_trunc_error,
141148
)
@@ -188,6 +195,8 @@ def calc_preconverged_ctmrg_value_and_grad(
188195
unitcell.
189196
2. The calculated gradient.
190197
"""
198+
state_split_transfer = unitcell.is_split_transfer()
199+
191200
spiral_vectors = additional_input.get("spiral_vectors")
192201
if expectation_func.is_spiral_peps and spiral_vectors is None:
193202
peps_tensors, unitcell, spiral_vectors = _map_tensors(
@@ -216,6 +225,9 @@ def calc_preconverged_ctmrg_value_and_grad(
216225
input_tensors, unitcell, convert_to_unitcell_func, False
217226
)
218227

228+
if state_split_transfer != unitcell.is_split_transfer():
229+
raise ValueError("Map function is not split transfer aware. Please fix that!")
230+
219231
if calc_preconverged:
220232
preconverged_unitcell, _ = calc_ctmrg_env(
221233
peps_tensors,
@@ -265,6 +277,8 @@ def calc_ctmrg_expectation_custom(
265277
:obj:`tuple`\ (:obj:`jax.numpy.ndarray`, :obj:`~varipeps.peps.PEPS_Unit_Cell`):
266278
Tuple consisting of the calculated expectation value and the new unitcell.
267279
"""
280+
state_split_transfer = unitcell.is_split_transfer()
281+
268282
spiral_vectors = additional_input.get("spiral_vectors")
269283
if expectation_func.is_spiral_peps and spiral_vectors is None:
270284
peps_tensors, unitcell, spiral_vectors = _map_tensors(
@@ -293,16 +307,21 @@ def calc_ctmrg_expectation_custom(
293307
input_tensors, unitcell, convert_to_unitcell_func, False
294308
)
295309

310+
if state_split_transfer != unitcell.is_split_transfer():
311+
raise ValueError("Map function is not split transfer aware. Please fix that!")
312+
296313
new_unitcell, max_trunc_error = calc_ctmrg_env_custom_rule(peps_tensors, unitcell)
297314

315+
exp_unitcell = new_unitcell.convert_to_full_transfer()
316+
298317
if expectation_func.is_spiral_peps:
299318
return cast(
300-
jnp.ndarray, expectation_func(peps_tensors, new_unitcell, spiral_vectors)
319+
jnp.ndarray, expectation_func(peps_tensors, exp_unitcell, spiral_vectors)
301320
), (
302321
new_unitcell,
303322
max_trunc_error,
304323
)
305-
return cast(jnp.ndarray, expectation_func(peps_tensors, new_unitcell)), (
324+
return cast(jnp.ndarray, expectation_func(peps_tensors, exp_unitcell)), (
306325
new_unitcell,
307326
max_trunc_error,
308327
)

varipeps/peps/tensor.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -1105,7 +1105,16 @@ def load_from_group(cls: Type[T_PEPS_Tensor], grp: h5py.Group) -> T_PEPS_Tensor:
11051105
max_chi=max_chi,
11061106
)
11071107

1108-
def convert_to_split_transfer(self: T_PEPS_Tensor) -> T_PEPS_Tensor_Split_Transfer:
1108+
@property
1109+
def is_split_transfer(self: T_PEPS_Tensor) -> bool:
1110+
return False
1111+
1112+
def convert_to_split_transfer(
1113+
self: T_PEPS_Tensor, interlayer_chi: Optional[int] = None
1114+
) -> T_PEPS_Tensor_Split_Transfer:
1115+
if interlayer_chi is None:
1116+
interlayer_chi = self.chi
1117+
11091118
return PEPS_Tensor_Split_Transfer(
11101119
tensor=self.tensor,
11111120
C1=self.C1,
@@ -1120,9 +1129,12 @@ def convert_to_split_transfer(self: T_PEPS_Tensor) -> T_PEPS_Tensor_Split_Transf
11201129
D=self.D,
11211130
chi=self.chi,
11221131
max_chi=self.max_chi,
1123-
interlayer_chi=self.chi,
1132+
interlayer_chi=interlayer_chi,
11241133
)
11251134

1135+
def convert_to_full_transfer(self: T_PEPS_Tensor) -> T_PEPS_Tensor:
1136+
return self
1137+
11261138
def tree_flatten(self) -> Tuple[Tuple[Any, ...], Tuple[Any, ...]]:
11271139
data = (
11281140
self.tensor,
@@ -2467,7 +2479,7 @@ def __add__(
24672479
"Both PEPS tensors must have the same tensor, d, D and chi values."
24682480
)
24692481

2470-
return PEPS_Tensor(
2482+
return type(self)(
24712483
tensor=self.tensor,
24722484
C1=self.C1 + other.C1,
24732485
C2=self.C2 + other.C2,
@@ -2621,6 +2633,10 @@ def load_from_group(
26212633
interlayer_chi=interlayer_chi,
26222634
)
26232635

2636+
@property
2637+
def is_split_transfer(self: T_PEPS_Tensor_Split_Transfer) -> bool:
2638+
return True
2639+
26242640
def convert_to_split_transfer(
26252641
self: T_PEPS_Tensor_Split_Transfer,
26262642
) -> T_PEPS_Tensor_Split_Transfer:

varipeps/peps/unitcell.py

+60
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,66 @@ def replace_unique_tensors(
450450
sanity_checks=False,
451451
)
452452

453+
def is_split_transfer(self: T_PEPS_Unit_Cell) -> bool:
454+
return all(t.is_split_transfer for t in self.data.peps_tensors)
455+
456+
def convert_to_split_transfer(
457+
self: T_PEPS_Unit_Cell, interlayer_chi: Optional[int] = None
458+
) -> T_PEPS_Unit_Cell:
459+
"""
460+
Convert the list of unique tensors to the split transfer ansatz.
461+
462+
Args:
463+
interlayer_chi (:obj:`int`, optional):
464+
Bond dimension for the interlayer index in the split transfer
465+
ansatz. If set to None, the same value as for the enviroment
466+
bond dimension is used.
467+
Returns:
468+
PEPS_Unit_Cell:
469+
New instance of PEPS unit cell with the new unique tensor list.
470+
"""
471+
if self.is_split_transfer():
472+
return self
473+
474+
new_unique_tensors = type(self.data.peps_tensors)(
475+
t.convert_to_split_transfer(interlayer_chi) for t in self.data.peps_tensors
476+
)
477+
478+
new_data = self.data.replace_peps_tensors(new_unique_tensors)
479+
480+
return type(self)(
481+
data=new_data,
482+
real_ix=self.real_ix,
483+
real_iy=self.real_iy,
484+
sanity_checks=False,
485+
)
486+
487+
def convert_to_full_transfer(
488+
self: T_PEPS_Unit_Cell, interlayer_chi: Optional[int] = None
489+
) -> T_PEPS_Unit_Cell:
490+
"""
491+
Convert the list of unique tensors to the full transfer ansatz.
492+
493+
Returns:
494+
PEPS_Unit_Cell:
495+
New instance of PEPS unit cell with the new unique tensor list.
496+
"""
497+
if not self.is_split_transfer():
498+
return self
499+
500+
new_unique_tensors = type(self.data.peps_tensors)(
501+
t.convert_to_full_transfer() for t in self.data.peps_tensors
502+
)
503+
504+
new_data = self.data.replace_peps_tensors(new_unique_tensors)
505+
506+
return type(self)(
507+
data=new_data,
508+
real_ix=self.real_ix,
509+
real_iy=self.real_iy,
510+
sanity_checks=False,
511+
)
512+
453513
def change_chi(
454514
self: T_PEPS_Unit_Cell,
455515
new_chi: int,

0 commit comments

Comments
 (0)