@@ -69,8 +69,6 @@ def _calc_corner_svds(
69
69
C1_svd , indices_are_sorted = True , unique_indices = True
70
70
)
71
71
72
- # debug_print("C1: {}", C1_svd)
73
-
74
72
C2_svd = jnp .linalg .svd (t .C2 , full_matrices = False , compute_uv = False )
75
73
step_corner_svd = step_corner_svd .at [ti , 1 , : C2_svd .shape [0 ]].set (
76
74
C2_svd , indices_are_sorted = True , unique_indices = True
@@ -382,7 +380,7 @@ def _ctmrg_body_func(carry):
382
380
config ,
383
381
) = carry
384
382
385
- if state . ctmrg_split_transfer :
383
+ if w_unitcell_last_step . is_split_transfer () :
386
384
w_unitcell , norm_smallest_S = do_absorption_step_split_transfer (
387
385
w_tensors , w_unitcell_last_step , config , state
388
386
)
@@ -397,7 +395,7 @@ def elementwise_func(old, new, old_corner, conv_eps, config):
397
395
new ,
398
396
conv_eps ,
399
397
verbose = config .ctmrg_verbose_output ,
400
- split_transfer = state . ctmrg_split_transfer ,
398
+ split_transfer = w_unitcell . is_split_transfer () ,
401
399
)
402
400
return converged , measure , verbose_data , old_corner
403
401
@@ -535,10 +533,6 @@ def calc_ctmrg_env(
535
533
norm_smallest_S = jnp .nan
536
534
already_tried_chi = {working_unitcell [0 , 0 ][0 ][0 ].chi }
537
535
538
- varipeps_global_state .ctmrg_split_transfer = isinstance (
539
- unitcell .get_unique_tensors ()[0 ], PEPS_Tensor_Split_Transfer
540
- )
541
-
542
536
while True :
543
537
tmp_count = 0
544
538
corner_singular_vals = None
@@ -776,6 +770,7 @@ def _ctmrg_rev_while_body(carry):
776
770
bar_fixed_point .get_unique_tensors (),
777
771
config .ad_custom_convergence_eps ,
778
772
verbose = config .ad_custom_verbose_output ,
773
+ split_transfer = bar_fixed_point .is_split_transfer (),
779
774
)
780
775
781
776
count += 1
@@ -796,15 +791,31 @@ def _ctmrg_rev_while_body(carry):
796
791
797
792
@jit
798
793
def _ctmrg_rev_workhorse (peps_tensors , new_unitcell , new_unitcell_bar , config , state ):
799
- _ , vjp_peps_tensors = vjp (
800
- lambda t : do_absorption_step (t , new_unitcell , config , state ), peps_tensors
801
- )
794
+ if new_unitcell .is_split_transfer ():
795
+ _ , vjp_peps_tensors = vjp (
796
+ lambda t : do_absorption_step_split_transfer (t , new_unitcell , config , state ),
797
+ peps_tensors ,
798
+ )
802
799
803
- vjp_env = tree_util .Partial (
804
- vjp (lambda u : do_absorption_step (peps_tensors , u , config , state ), new_unitcell )[
805
- 1
806
- ]
807
- )
800
+ vjp_env = tree_util .Partial (
801
+ vjp (
802
+ lambda u : do_absorption_step_split_transfer (
803
+ peps_tensors , u , config , state
804
+ ),
805
+ new_unitcell ,
806
+ )[1 ]
807
+ )
808
+ else :
809
+ _ , vjp_peps_tensors = vjp (
810
+ lambda t : do_absorption_step (t , new_unitcell , config , state ), peps_tensors
811
+ )
812
+
813
+ vjp_env = tree_util .Partial (
814
+ vjp (
815
+ lambda u : do_absorption_step (peps_tensors , u , config , state ),
816
+ new_unitcell ,
817
+ )[1 ]
818
+ )
808
819
809
820
def cond_func (carry ):
810
821
_ , _ , _ , converged , count , config , state = carry
0 commit comments