Skip to content

Commit 5a4cece

Browse files
committed
Fix handling of retries if the CTMRG did not converge
1 parent b81dbe3 commit 5a4cece

File tree

1 file changed

+69
-11
lines changed

1 file changed

+69
-11
lines changed

varipeps/optimization/optimizer.py

+69-11
Original file line numberDiff line numberDiff line change
@@ -400,15 +400,70 @@ def random_noise(a):
400400
except (CTMRGNotConvergedError, CTMRGGradientNotConvergedError) as e:
401401
varipeps_global_state.ctmrg_projector_method = None
402402

403-
return OptimizeResult(
404-
success=False,
405-
message=str(type(e)),
406-
x=working_tensors,
407-
fun=working_value,
408-
unitcell=working_unitcell,
409-
nit=count,
410-
max_trunc_error_list=max_trunc_error_list,
411-
)
403+
if random_noise_retries == 0:
404+
return OptimizeResult(
405+
success=False,
406+
message=str(type(e)),
407+
x=working_tensors,
408+
fun=working_value,
409+
unitcell=working_unitcell,
410+
nit=count,
411+
max_trunc_error_list=max_trunc_error_list,
412+
step_energies=step_energies,
413+
step_chi=step_chi,
414+
step_conv=step_conv,
415+
best_run=0,
416+
)
417+
elif (
418+
random_noise_retries
419+
>= varipeps_config.optimizer_random_noise_max_retries
420+
):
421+
working_value = jnp.inf
422+
break
423+
else:
424+
if isinstance(input_tensors, PEPS_Unit_Cell) or (
425+
isinstance(input_tensors, collections.abc.Sequence)
426+
and isinstance(input_tensors[0], PEPS_Unit_Cell)
427+
):
428+
working_tensors = (
429+
cast(
430+
List[jnp.ndarray],
431+
[i.tensor for i in best_unitcell.get_unique_tensors()],
432+
)
433+
+ best_tensors[best_unitcell.get_len_unique_tensors() :]
434+
)
435+
436+
working_tensors = [random_noise(i) for i in working_tensors]
437+
438+
working_tensors_obj = [
439+
e.replace_tensor(working_tensors[i])
440+
for i, e in enumerate(best_unitcell.get_unique_tensors())
441+
]
442+
443+
working_unitcell = best_unitcell.replace_unique_tensors(
444+
working_tensors_obj
445+
)
446+
else:
447+
working_tensors = [random_noise(i) for i in best_tensors]
448+
working_unitcell = None
449+
450+
descent_dir = None
451+
working_gradient = None
452+
signal_reset_descent_dir = True
453+
count = 0
454+
random_noise_retries += 1
455+
old_descent_dir = descent_dir
456+
old_gradient = working_gradient
457+
458+
step_energies[random_noise_retries] = []
459+
step_chi[random_noise_retries] = []
460+
step_conv[random_noise_retries] = []
461+
max_trunc_error_list[random_noise_retries] = []
462+
463+
pbar.reset()
464+
pbar.refresh()
465+
466+
continue
412467

413468
working_gradient = [elem.conj() for elem in working_gradient_seq]
414469

@@ -567,9 +622,10 @@ def random_noise(a):
567622
descent_dir = None
568623
working_gradient = None
569624
signal_reset_descent_dir = True
570-
count = -1
625+
count = 0
571626
random_noise_retries += 1
572-
conv = jnp.inf
627+
old_descent_dir = descent_dir
628+
old_gradient = working_gradient
573629

574630
step_energies[random_noise_retries] = []
575631
step_chi[random_noise_retries] = []
@@ -578,6 +634,8 @@ def random_noise(a):
578634

579635
pbar.reset()
580636
pbar.refresh()
637+
638+
continue
581639
else:
582640
conv = 0
583641
else:

0 commit comments

Comments
 (0)