Skip to content

Commit 664fc4c

Browse files
committed
Store per-step energy, chi, and convergence
1 parent fdcb0fb commit 664fc4c

File tree

1 file changed

+36
-5
lines changed

1 file changed

+36
-5
lines changed

varipeps/optimization/optimizer.py

+36-5
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def random_noise(a):
277277
best_value = jnp.inf
278278
best_tensors = None
279279
best_unitcell = None
280+
best_run = None
280281

281282
random_noise_retries = 0
282283

@@ -307,7 +308,10 @@ def random_noise(a):
307308
varipeps_config.line_search_initial_step_size
308309
)
309310
working_value: Union[float, jnp.ndarray]
310-
max_trunc_error_list = []
311+
max_trunc_error_list = {random_noise_retries: []}
312+
step_energies = {random_noise_retries: []}
313+
step_chi = {random_noise_retries: []}
314+
step_conv = {random_noise_retries: []}
311315

312316
if (
313317
varipeps_config.optimizer_preconverge_with_half_projectors
@@ -417,6 +421,7 @@ def random_noise(a):
417421
descent_dir = [-elem for elem in working_gradient]
418422

419423
conv = jnp.linalg.norm(ravel_pytree(working_gradient)[0])
424+
step_conv[random_noise_retries].append(conv)
420425

421426
try:
422427
(
@@ -460,6 +465,7 @@ def random_noise(a):
460465
best_value = working_value
461466
best_tensors = working_tensors
462467
best_unitcell = working_unitcell
468+
best_run = random_noise_retries
463469

464470
if isinstance(input_tensors, PEPS_Unit_Cell) or (
465471
isinstance(input_tensors, collections.abc.Sequence)
@@ -497,12 +503,22 @@ def random_noise(a):
497503
signal_reset_descent_dir = True
498504
count = -1
499505
random_noise_retries += 1
506+
507+
step_energies[random_noise_retries] = []
508+
step_chi[random_noise_retries] = []
509+
step_conv[random_noise_retries] = []
510+
max_trunc_error_list[random_noise_retries] = []
511+
500512
pbar.reset()
501513
pbar.refresh()
502514
else:
503515
conv = 0
504-
505-
max_trunc_error_list.append(max_trunc_error)
516+
else:
517+
max_trunc_error_list[random_noise_retries].append(max_trunc_error)
518+
step_energies[random_noise_retries].append(working_value)
519+
step_chi[random_noise_retries].append(
520+
working_unitcell.get_unique_tensors()[0].chi
521+
)
506522

507523
if conv < varipeps_config.optimizer_convergence_eps:
508524
working_value, (
@@ -517,7 +533,8 @@ def random_noise(a):
517533
enforce_elementwise_convergence=varipeps_config.ad_use_custom_vjp,
518534
)
519535
varipeps_global_state.ctmrg_projector_method = None
520-
max_trunc_error_list[-1] = max_trunc_error
536+
max_trunc_error_list[random_noise_retries][-1] = max_trunc_error
537+
step_energies[random_noise_retries][-1] = working_value
521538
break
522539

523540
if (
@@ -561,7 +578,16 @@ def random_noise(a):
561578

562579
if count % varipeps_config.optimizer_autosave_step_count == 0:
563580
auxiliary_data = {
564-
"max_trunc_error_list": max_trunc_error_list,
581+
"max_trunc_error_list": tuple(
582+
max_trunc_error_list[k]
583+
for k in sorted(max_trunc_error_list.keys())
584+
),
585+
"step_energies": tuple(
586+
step_energies[k] for k in sorted(step_energies.keys())
587+
),
588+
"step_chi": tuple(step_chi[k] for k in sorted(step_chi.keys())),
589+
"step_conv": tuple(step_conv[k] for k in sorted(step_conv.keys())),
590+
"best_run": best_run if best_run is not None else 0,
565591
}
566592

567593
if spiral_indices is not None:
@@ -589,6 +615,7 @@ def random_noise(a):
589615
best_value = working_value
590616
best_tensors = working_tensors
591617
best_unitcell = working_unitcell
618+
best_run = random_noise_retries
592619

593620
print(f"Best energy result found: {best_value}")
594621

@@ -599,4 +626,8 @@ def random_noise(a):
599626
unitcell=best_unitcell,
600627
nit=count,
601628
max_trunc_error_list=max_trunc_error_list,
629+
step_energies=step_energies,
630+
step_chi=step_chi,
631+
step_conv=step_conv,
632+
best_run=best_run,
602633
)

0 commit comments

Comments
 (0)