8
8
import jax .debug as jdebug
9
9
10
10
from varipeps import varipeps_config , varipeps_global_state
11
- from varipeps .peps import PEPS_Tensor , PEPS_Unit_Cell
11
+ from varipeps .peps import PEPS_Tensor , PEPS_Tensor_Split_Transfer , PEPS_Unit_Cell
12
12
from varipeps .utils .debug_print import debug_print
13
- from .absorption import do_absorption_step
13
+ from .absorption import do_absorption_step , do_absorption_step_split_transfer
14
14
15
15
from typing import Sequence , Tuple , List , Optional
16
16
@@ -25,6 +25,14 @@ class CTM_Enum(enum.IntEnum):
25
25
T2 = enum .auto ()
26
26
T3 = enum .auto ()
27
27
T4 = enum .auto ()
28
+ T1_ket = enum .auto ()
29
+ T1_bra = enum .auto ()
30
+ T2_ket = enum .auto ()
31
+ T2_bra = enum .auto ()
32
+ T3_ket = enum .auto ()
33
+ T3_bra = enum .auto ()
34
+ T4_ket = enum .auto ()
35
+ T4_bra = enum .auto ()
28
36
29
37
30
38
class CTMRGNotConvergedError (Exception ):
@@ -61,6 +69,8 @@ def _calc_corner_svds(
61
69
C1_svd , indices_are_sorted = True , unique_indices = True
62
70
)
63
71
72
+ # debug_print("C1: {}", C1_svd)
73
+
64
74
C2_svd = jnp .linalg .svd (t .C2 , full_matrices = False , compute_uv = False )
65
75
step_corner_svd = step_corner_svd .at [ti , 1 , : C2_svd .shape [0 ]].set (
66
76
C2_svd , indices_are_sorted = True , unique_indices = True
@@ -79,15 +89,20 @@ def _calc_corner_svds(
79
89
return step_corner_svd
80
90
81
91
82
- @partial (jit , static_argnums = (3 ,), inline = True )
92
+ @partial (jit , static_argnums = (3 , 4 ), inline = True )
83
93
def _is_element_wise_converged (
84
94
old_peps_tensors : List [PEPS_Tensor ],
85
95
new_peps_tensors : List [PEPS_Tensor ],
86
96
eps : float ,
87
97
verbose : bool = False ,
98
+ split_transfer : bool = False ,
88
99
) -> Tuple [bool , float , Optional [List [Tuple [int , CTM_Enum , float ]]]]:
89
100
result = 0
90
- measure = jnp .zeros ((len (old_peps_tensors ), 8 ), dtype = jnp .float64 )
101
+
102
+ if split_transfer :
103
+ measure = jnp .zeros ((len (old_peps_tensors ), 12 ), dtype = jnp .float64 )
104
+ else :
105
+ measure = jnp .zeros ((len (old_peps_tensors ), 8 ), dtype = jnp .float64 )
91
106
92
107
verbose_data = [] if verbose else None
93
108
@@ -144,73 +159,210 @@ def _is_element_wise_converged(
144
159
if verbose :
145
160
verbose_data .append ((ti , CTM_Enum .C4 , jnp .amax (diff )))
146
161
147
- old_shape = old_peps_tensors [ti ].T1 .shape
148
- new_shape = new_peps_tensors [ti ].T1 .shape
149
- diff = jnp .abs (
150
- new_peps_tensors [ti ].T1 [
151
- : old_shape [0 ], : old_shape [1 ], : old_shape [2 ], : old_shape [3 ]
152
- ]
153
- - old_peps_tensors [ti ].T1 [
154
- : new_shape [0 ], : new_shape [1 ], : new_shape [2 ], : new_shape [3 ]
155
- ]
156
- )
157
- result += jnp .sum (diff > eps )
158
- measure = measure .at [ti , 4 ].set (
159
- jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
160
- )
161
- if verbose :
162
- verbose_data .append ((ti , CTM_Enum .T1 , jnp .amax (diff )))
163
-
164
- old_shape = old_peps_tensors [ti ].T2 .shape
165
- new_shape = new_peps_tensors [ti ].T2 .shape
166
- diff = jnp .abs (
167
- new_peps_tensors [ti ].T2 [
168
- : old_shape [0 ], : old_shape [1 ], : old_shape [2 ], : old_shape [3 ]
169
- ]
170
- - old_peps_tensors [ti ].T2 [
171
- : new_shape [0 ], : new_shape [1 ], : new_shape [2 ], : new_shape [3 ]
172
- ]
173
- )
174
- result += jnp .sum (diff > eps )
175
- measure = measure .at [ti , 5 ].set (
176
- jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
177
- )
178
- if verbose :
179
- verbose_data .append ((ti , CTM_Enum .T2 , jnp .amax (diff )))
180
-
181
- old_shape = old_peps_tensors [ti ].T3 .shape
182
- new_shape = new_peps_tensors [ti ].T3 .shape
183
- diff = jnp .abs (
184
- new_peps_tensors [ti ].T3 [
185
- : old_shape [0 ], : old_shape [1 ], : old_shape [2 ], : old_shape [3 ]
186
- ]
187
- - old_peps_tensors [ti ].T3 [
188
- : new_shape [0 ], : new_shape [1 ], : new_shape [2 ], : new_shape [3 ]
189
- ]
190
- )
191
- result += jnp .sum (diff > eps )
192
- measure = measure .at [ti , 6 ].set (
193
- jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
194
- )
195
- if verbose :
196
- verbose_data .append ((ti , CTM_Enum .T3 , jnp .amax (diff )))
197
-
198
- old_shape = old_peps_tensors [ti ].T4 .shape
199
- new_shape = new_peps_tensors [ti ].T4 .shape
200
- diff = jnp .abs (
201
- new_peps_tensors [ti ].T4 [
202
- : old_shape [0 ], : old_shape [1 ], : old_shape [2 ], : old_shape [3 ]
203
- ]
204
- - old_peps_tensors [ti ].T4 [
205
- : new_shape [0 ], : new_shape [1 ], : new_shape [2 ], : new_shape [3 ]
206
- ]
207
- )
208
- result += jnp .sum (diff > eps )
209
- measure = measure .at [ti , 7 ].set (
210
- jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
211
- )
212
- if verbose :
213
- verbose_data .append ((ti , CTM_Enum .T4 , jnp .amax (diff )))
162
+ if split_transfer :
163
+ old_shape = old_peps_tensors [ti ].T1_ket .shape
164
+ new_shape = new_peps_tensors [ti ].T1_ket .shape
165
+ diff = jnp .abs (
166
+ new_peps_tensors [ti ].T1_ket [
167
+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ]
168
+ ]
169
+ - old_peps_tensors [ti ].T1_ket [
170
+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ]
171
+ ]
172
+ )
173
+ result += jnp .sum (diff > eps )
174
+ measure = measure .at [ti , 4 ].set (
175
+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
176
+ )
177
+ if verbose :
178
+ verbose_data .append ((ti , CTM_Enum .T1_ket , jnp .amax (diff )))
179
+
180
+ old_shape = old_peps_tensors [ti ].T1_bra .shape
181
+ new_shape = new_peps_tensors [ti ].T1_bra .shape
182
+ diff = jnp .abs (
183
+ new_peps_tensors [ti ].T1_bra [
184
+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ]
185
+ ]
186
+ - old_peps_tensors [ti ].T1_bra [
187
+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ]
188
+ ]
189
+ )
190
+ result += jnp .sum (diff > eps )
191
+ measure = measure .at [ti , 5 ].set (
192
+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
193
+ )
194
+ if verbose :
195
+ verbose_data .append ((ti , CTM_Enum .T1_bra , jnp .amax (diff )))
196
+
197
+ old_shape = old_peps_tensors [ti ].T2_ket .shape
198
+ new_shape = new_peps_tensors [ti ].T2_ket .shape
199
+ diff = jnp .abs (
200
+ new_peps_tensors [ti ].T2_ket [
201
+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ]
202
+ ]
203
+ - old_peps_tensors [ti ].T2_ket [
204
+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ]
205
+ ]
206
+ )
207
+ result += jnp .sum (diff > eps )
208
+ measure = measure .at [ti , 6 ].set (
209
+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
210
+ )
211
+ if verbose :
212
+ verbose_data .append ((ti , CTM_Enum .T2_ket , jnp .amax (diff )))
213
+
214
+ old_shape = old_peps_tensors [ti ].T2_bra .shape
215
+ new_shape = new_peps_tensors [ti ].T2_bra .shape
216
+ diff = jnp .abs (
217
+ new_peps_tensors [ti ].T2_bra [
218
+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ]
219
+ ]
220
+ - old_peps_tensors [ti ].T2_bra [
221
+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ]
222
+ ]
223
+ )
224
+ result += jnp .sum (diff > eps )
225
+ measure = measure .at [ti , 7 ].set (
226
+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
227
+ )
228
+ if verbose :
229
+ verbose_data .append ((ti , CTM_Enum .T2_bra , jnp .amax (diff )))
230
+
231
+ old_shape = old_peps_tensors [ti ].T3_ket .shape
232
+ new_shape = new_peps_tensors [ti ].T3_ket .shape
233
+ diff = jnp .abs (
234
+ new_peps_tensors [ti ].T3_ket [
235
+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ]
236
+ ]
237
+ - old_peps_tensors [ti ].T3_ket [
238
+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ]
239
+ ]
240
+ )
241
+ result += jnp .sum (diff > eps )
242
+ measure = measure .at [ti , 8 ].set (
243
+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
244
+ )
245
+ if verbose :
246
+ verbose_data .append ((ti , CTM_Enum .T3_ket , jnp .amax (diff )))
247
+
248
+ old_shape = old_peps_tensors [ti ].T3_bra .shape
249
+ new_shape = new_peps_tensors [ti ].T3_bra .shape
250
+ diff = jnp .abs (
251
+ new_peps_tensors [ti ].T3_bra [
252
+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ]
253
+ ]
254
+ - old_peps_tensors [ti ].T3_bra [
255
+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ]
256
+ ]
257
+ )
258
+ result += jnp .sum (diff > eps )
259
+ measure = measure .at [ti , 9 ].set (
260
+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
261
+ )
262
+ if verbose :
263
+ verbose_data .append ((ti , CTM_Enum .T3_bra , jnp .amax (diff )))
264
+
265
+ old_shape = old_peps_tensors [ti ].T4_ket .shape
266
+ new_shape = new_peps_tensors [ti ].T4_ket .shape
267
+ diff = jnp .abs (
268
+ new_peps_tensors [ti ].T4_ket [
269
+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ]
270
+ ]
271
+ - old_peps_tensors [ti ].T4_ket [
272
+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ]
273
+ ]
274
+ )
275
+ result += jnp .sum (diff > eps )
276
+ measure = measure .at [ti , 10 ].set (
277
+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
278
+ )
279
+ if verbose :
280
+ verbose_data .append ((ti , CTM_Enum .T4_ket , jnp .amax (diff )))
281
+
282
+ old_shape = old_peps_tensors [ti ].T4_bra .shape
283
+ new_shape = new_peps_tensors [ti ].T4_bra .shape
284
+ diff = jnp .abs (
285
+ new_peps_tensors [ti ].T4_bra [
286
+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ]
287
+ ]
288
+ - old_peps_tensors [ti ].T4_bra [
289
+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ]
290
+ ]
291
+ )
292
+ result += jnp .sum (diff > eps )
293
+ measure = measure .at [ti , 11 ].set (
294
+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
295
+ )
296
+ if verbose :
297
+ verbose_data .append ((ti , CTM_Enum .T4_bra , jnp .amax (diff )))
298
+ else :
299
+ old_shape = old_peps_tensors [ti ].T1 .shape
300
+ new_shape = new_peps_tensors [ti ].T1 .shape
301
+ diff = jnp .abs (
302
+ new_peps_tensors [ti ].T1 [
303
+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ], : old_shape [3 ]
304
+ ]
305
+ - old_peps_tensors [ti ].T1 [
306
+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ], : new_shape [3 ]
307
+ ]
308
+ )
309
+ result += jnp .sum (diff > eps )
310
+ measure = measure .at [ti , 4 ].set (
311
+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
312
+ )
313
+ if verbose :
314
+ verbose_data .append ((ti , CTM_Enum .T1 , jnp .amax (diff )))
315
+
316
+ old_shape = old_peps_tensors [ti ].T2 .shape
317
+ new_shape = new_peps_tensors [ti ].T2 .shape
318
+ diff = jnp .abs (
319
+ new_peps_tensors [ti ].T2 [
320
+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ], : old_shape [3 ]
321
+ ]
322
+ - old_peps_tensors [ti ].T2 [
323
+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ], : new_shape [3 ]
324
+ ]
325
+ )
326
+ result += jnp .sum (diff > eps )
327
+ measure = measure .at [ti , 5 ].set (
328
+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
329
+ )
330
+ if verbose :
331
+ verbose_data .append ((ti , CTM_Enum .T2 , jnp .amax (diff )))
332
+
333
+ old_shape = old_peps_tensors [ti ].T3 .shape
334
+ new_shape = new_peps_tensors [ti ].T3 .shape
335
+ diff = jnp .abs (
336
+ new_peps_tensors [ti ].T3 [
337
+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ], : old_shape [3 ]
338
+ ]
339
+ - old_peps_tensors [ti ].T3 [
340
+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ], : new_shape [3 ]
341
+ ]
342
+ )
343
+ result += jnp .sum (diff > eps )
344
+ measure = measure .at [ti , 6 ].set (
345
+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
346
+ )
347
+ if verbose :
348
+ verbose_data .append ((ti , CTM_Enum .T3 , jnp .amax (diff )))
349
+
350
+ old_shape = old_peps_tensors [ti ].T4 .shape
351
+ new_shape = new_peps_tensors [ti ].T4 .shape
352
+ diff = jnp .abs (
353
+ new_peps_tensors [ti ].T4 [
354
+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ], : old_shape [3 ]
355
+ ]
356
+ - old_peps_tensors [ti ].T4 [
357
+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ], : new_shape [3 ]
358
+ ]
359
+ )
360
+ result += jnp .sum (diff > eps )
361
+ measure = measure .at [ti , 7 ].set (
362
+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
363
+ )
364
+ if verbose :
365
+ verbose_data .append ((ti , CTM_Enum .T4 , jnp .amax (diff )))
214
366
215
367
return result == 0 , jnp .linalg .norm (measure ), verbose_data
216
368
@@ -230,16 +382,22 @@ def _ctmrg_body_func(carry):
230
382
config ,
231
383
) = carry
232
384
233
- w_unitcell , norm_smallest_S = do_absorption_step (
234
- w_tensors , w_unitcell_last_step , config , state
235
- )
385
+ if state .ctmrg_split_transfer :
386
+ w_unitcell , norm_smallest_S = do_absorption_step_split_transfer (
387
+ w_tensors , w_unitcell_last_step , config , state
388
+ )
389
+ else :
390
+ w_unitcell , norm_smallest_S = do_absorption_step (
391
+ w_tensors , w_unitcell_last_step , config , state
392
+ )
236
393
237
394
def elementwise_func (old , new , old_corner , conv_eps , config ):
238
395
converged , measure , verbose_data = _is_element_wise_converged (
239
396
old ,
240
397
new ,
241
398
conv_eps ,
242
399
verbose = config .ctmrg_verbose_output ,
400
+ split_transfer = state .ctmrg_split_transfer ,
243
401
)
244
402
return converged , measure , verbose_data , old_corner
245
403
@@ -377,12 +535,22 @@ def calc_ctmrg_env(
377
535
norm_smallest_S = jnp .nan
378
536
already_tried_chi = {working_unitcell [0 , 0 ][0 ][0 ].chi }
379
537
538
+ varipeps_global_state .ctmrg_split_transfer = isinstance (
539
+ unitcell .get_unique_tensors ()[0 ], PEPS_Tensor_Split_Transfer
540
+ )
541
+
380
542
while True :
381
543
tmp_count = 0
382
544
corner_singular_vals = None
383
545
384
546
while any (
385
547
i .C1 .shape [0 ] != i .chi for i in working_unitcell .get_unique_tensors ()
548
+ ) or (
549
+ hasattr (working_unitcell .get_unique_tensors ()[0 ], "T4_ket" )
550
+ and any (
551
+ i .T4_ket .shape [0 ] != i .interlayer_chi
552
+ for i in working_unitcell .get_unique_tensors ()
553
+ )
386
554
):
387
555
(
388
556
_ ,
0 commit comments