@@ -277,6 +277,7 @@ def random_noise(a):
277
277
best_value = jnp .inf
278
278
best_tensors = None
279
279
best_unitcell = None
280
+ best_run = None
280
281
281
282
random_noise_retries = 0
282
283
@@ -307,7 +308,10 @@ def random_noise(a):
307
308
varipeps_config .line_search_initial_step_size
308
309
)
309
310
working_value : Union [float , jnp .ndarray ]
310
- max_trunc_error_list = []
311
+ max_trunc_error_list = {random_noise_retries : []}
312
+ step_energies = {random_noise_retries : []}
313
+ step_chi = {random_noise_retries : []}
314
+ step_conv = {random_noise_retries : []}
311
315
312
316
if (
313
317
varipeps_config .optimizer_preconverge_with_half_projectors
@@ -417,6 +421,7 @@ def random_noise(a):
417
421
descent_dir = [- elem for elem in working_gradient ]
418
422
419
423
conv = jnp .linalg .norm (ravel_pytree (working_gradient )[0 ])
424
+ step_conv [random_noise_retries ].append (conv )
420
425
421
426
try :
422
427
(
@@ -460,6 +465,7 @@ def random_noise(a):
460
465
best_value = working_value
461
466
best_tensors = working_tensors
462
467
best_unitcell = working_unitcell
468
+ best_run = random_noise_retries
463
469
464
470
if isinstance (input_tensors , PEPS_Unit_Cell ) or (
465
471
isinstance (input_tensors , collections .abc .Sequence )
@@ -497,12 +503,22 @@ def random_noise(a):
497
503
signal_reset_descent_dir = True
498
504
count = - 1
499
505
random_noise_retries += 1
506
+
507
+ step_energies [random_noise_retries ] = []
508
+ step_chi [random_noise_retries ] = []
509
+ step_conv [random_noise_retries ] = []
510
+ max_trunc_error_list [random_noise_retries ] = []
511
+
500
512
pbar .reset ()
501
513
pbar .refresh ()
502
514
else :
503
515
conv = 0
504
-
505
- max_trunc_error_list .append (max_trunc_error )
516
+ else :
517
+ max_trunc_error_list [random_noise_retries ].append (max_trunc_error )
518
+ step_energies [random_noise_retries ].append (working_value )
519
+ step_chi [random_noise_retries ].append (
520
+ working_unitcell .get_unique_tensors ()[0 ].chi
521
+ )
506
522
507
523
if conv < varipeps_config .optimizer_convergence_eps :
508
524
working_value , (
@@ -517,7 +533,8 @@ def random_noise(a):
517
533
enforce_elementwise_convergence = varipeps_config .ad_use_custom_vjp ,
518
534
)
519
535
varipeps_global_state .ctmrg_projector_method = None
520
- max_trunc_error_list [- 1 ] = max_trunc_error
536
+ max_trunc_error_list [random_noise_retries ][- 1 ] = max_trunc_error
537
+ step_energies [random_noise_retries ][- 1 ] = working_value
521
538
break
522
539
523
540
if (
@@ -561,7 +578,16 @@ def random_noise(a):
561
578
562
579
if count % varipeps_config .optimizer_autosave_step_count == 0 :
563
580
auxiliary_data = {
564
- "max_trunc_error_list" : max_trunc_error_list ,
581
+ "max_trunc_error_list" : tuple (
582
+ max_trunc_error_list [k ]
583
+ for k in sorted (max_trunc_error_list .keys ())
584
+ ),
585
+ "step_energies" : tuple (
586
+ step_energies [k ] for k in sorted (step_energies .keys ())
587
+ ),
588
+ "step_chi" : tuple (step_chi [k ] for k in sorted (step_chi .keys ())),
589
+ "step_conv" : tuple (step_conv [k ] for k in sorted (step_conv .keys ())),
590
+ "best_run" : best_run if best_run is not None else 0 ,
565
591
}
566
592
567
593
if spiral_indices is not None :
@@ -589,6 +615,7 @@ def random_noise(a):
589
615
best_value = working_value
590
616
best_tensors = working_tensors
591
617
best_unitcell = working_unitcell
618
+ best_run = random_noise_retries
592
619
593
620
print (f"Best energy result found: { best_value } " )
594
621
@@ -599,4 +626,8 @@ def random_noise(a):
599
626
unitcell = best_unitcell ,
600
627
nit = count ,
601
628
max_trunc_error_list = max_trunc_error_list ,
629
+ step_energies = step_energies ,
630
+ step_chi = step_chi ,
631
+ step_conv = step_conv ,
632
+ best_run = best_run ,
602
633
)
0 commit comments