@@ -400,15 +400,70 @@ def random_noise(a):
400
400
except (CTMRGNotConvergedError , CTMRGGradientNotConvergedError ) as e :
401
401
varipeps_global_state .ctmrg_projector_method = None
402
402
403
- return OptimizeResult (
404
- success = False ,
405
- message = str (type (e )),
406
- x = working_tensors ,
407
- fun = working_value ,
408
- unitcell = working_unitcell ,
409
- nit = count ,
410
- max_trunc_error_list = max_trunc_error_list ,
411
- )
403
+ if random_noise_retries == 0 :
404
+ return OptimizeResult (
405
+ success = False ,
406
+ message = str (type (e )),
407
+ x = working_tensors ,
408
+ fun = working_value ,
409
+ unitcell = working_unitcell ,
410
+ nit = count ,
411
+ max_trunc_error_list = max_trunc_error_list ,
412
+ step_energies = step_energies ,
413
+ step_chi = step_chi ,
414
+ step_conv = step_conv ,
415
+ best_run = 0 ,
416
+ )
417
+ elif (
418
+ random_noise_retries
419
+ >= varipeps_config .optimizer_random_noise_max_retries
420
+ ):
421
+ working_value = jnp .inf
422
+ break
423
+ else :
424
+ if isinstance (input_tensors , PEPS_Unit_Cell ) or (
425
+ isinstance (input_tensors , collections .abc .Sequence )
426
+ and isinstance (input_tensors [0 ], PEPS_Unit_Cell )
427
+ ):
428
+ working_tensors = (
429
+ cast (
430
+ List [jnp .ndarray ],
431
+ [i .tensor for i in best_unitcell .get_unique_tensors ()],
432
+ )
433
+ + best_tensors [best_unitcell .get_len_unique_tensors () :]
434
+ )
435
+
436
+ working_tensors = [random_noise (i ) for i in working_tensors ]
437
+
438
+ working_tensors_obj = [
439
+ e .replace_tensor (working_tensors [i ])
440
+ for i , e in enumerate (best_unitcell .get_unique_tensors ())
441
+ ]
442
+
443
+ working_unitcell = best_unitcell .replace_unique_tensors (
444
+ working_tensors_obj
445
+ )
446
+ else :
447
+ working_tensors = [random_noise (i ) for i in best_tensors ]
448
+ working_unitcell = None
449
+
450
+ descent_dir = None
451
+ working_gradient = None
452
+ signal_reset_descent_dir = True
453
+ count = 0
454
+ random_noise_retries += 1
455
+ old_descent_dir = descent_dir
456
+ old_gradient = working_gradient
457
+
458
+ step_energies [random_noise_retries ] = []
459
+ step_chi [random_noise_retries ] = []
460
+ step_conv [random_noise_retries ] = []
461
+ max_trunc_error_list [random_noise_retries ] = []
462
+
463
+ pbar .reset ()
464
+ pbar .refresh ()
465
+
466
+ continue
412
467
413
468
working_gradient = [elem .conj () for elem in working_gradient_seq ]
414
469
@@ -567,9 +622,10 @@ def random_noise(a):
567
622
descent_dir = None
568
623
working_gradient = None
569
624
signal_reset_descent_dir = True
570
- count = - 1
625
+ count = 0
571
626
random_noise_retries += 1
572
- conv = jnp .inf
627
+ old_descent_dir = descent_dir
628
+ old_gradient = working_gradient
573
629
574
630
step_energies [random_noise_retries ] = []
575
631
step_chi [random_noise_retries ] = []
@@ -578,6 +634,8 @@ def random_noise(a):
578
634
579
635
pbar .reset ()
580
636
pbar .refresh ()
637
+
638
+ continue
581
639
else :
582
640
conv = 0
583
641
else :
0 commit comments