Skip to content

Commit 69ffadc

Browse files
committed
Implement basinhopping for the PEPS optimizer
1 parent fcf57fd commit 69ffadc

File tree

5 files changed

+196
-0
lines changed

5 files changed

+196
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
.. _peps_ad_optimization_basinhopping:
2+
3+
.. currentmodule:: peps_ad.optimization.basinhopping
4+
5+
Implementation of the basinhopping optimizer for the PEPS model (:mod:`peps_ad.optimization.basinhopping`)
6+
==========================================================================================================
7+
8+
.. automodule:: peps_ad.optimization.basinhopping
9+
:members:
10+
:undoc-members:
11+
:show-inheritance:

docs/source/api/optimization/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Variational optimizer for the PEPS network (:mod:`peps_ad.optimization`)
1010
.. toctree::
1111
:maxdepth: 2
1212

13+
basinhopping
1314
inner_function
1415
line_search
1516
optimizer

peps_ad/config.py

+14
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,15 @@ class PEPS_AD_Config:
125125
line_search_use_last_step_size (:obj:`bool`):
126126
Flag if the line search should start from the step size of the
127127
previous optimizer step.
128+
basinhopping_niter (:obj:`int`):
129+
Value for parameter `niter` of :obj:`scipy.optimize.basinhopping`.
130+
See this function for details.
131+
basinhopping_T (:obj:`int`):
132+
Value for parameter `T` of :obj:`scipy.optimize.basinhopping`.
133+
See this function for details.
134+
basinhopping_niter_success (:obj:`int`):
135+
Value for parameter `niterniter_success` of
136+
:obj:`scipy.optimize.basinhopping`. See this function for details.
128137
"""
129138

130139
# AD config
@@ -172,6 +181,11 @@ class PEPS_AD_Config:
172181
line_search_wolfe_const: float = 0.9
173182
line_search_use_last_step_size: bool = False
174183

184+
# Basinhopping
185+
basinhopping_niter: int = 20
186+
basinhopping_T: float = 0.01
187+
basinhopping_niter_success: int = 5
188+
175189
def tree_flatten(self) -> Tuple[Tuple[Any, ...], Tuple[Any, ...]]:
176190
aux_data = (
177191
{name: getattr(self, name) for name in self.__dataclass_fields__.keys()},

peps_ad/optimization/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
from . import inner_function
22
from . import line_search
33
from . import optimizer
4+
from . import basinhopping
5+
6+
from .optimizer import optimize_peps_network

peps_ad/optimization/basinhopping.py

+167
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
from dataclasses import dataclass
2+
from os import PathLike
3+
4+
import numpy
5+
import scipy.optimize as spo
6+
7+
import jax.numpy as jnp
8+
from jax.flatten_util import ravel_pytree
9+
10+
from peps_ad import peps_ad_config
11+
from peps_ad.peps import PEPS_Unit_Cell
12+
from peps_ad.expectation import Expectation_Model
13+
from peps_ad.mapping import Map_To_PEPS_Model
14+
from peps_ad.ctmrg import CTMRGNotConvergedError, CTMRGGradientNotConvergedError
15+
16+
from .line_search import NoSuitableStepSizeError
17+
from .optimizer import optimize_peps_network, autosave_function
18+
19+
from typing import List, Union, Tuple, cast, Sequence, Callable, Optional
20+
21+
22+
@dataclass
23+
class PEPS_AD_Basinhopping:
24+
"""
25+
Class to wrap the basinhopping algorithm for the variational update
26+
of PEPS or mapped structures.
27+
28+
The parameters of the class initialization are the same as for
29+
:obj:`~peps_ad.optimization.optimize_peps_network`.
30+
31+
Args:
32+
initial_guess (:obj:`~peps_ad.peps.PEPS_Unit_Cell` or :term:`sequence` of :obj:`jax.numpy.ndarray`):
33+
The PEPS unitcell to work on or the tensors which should be mapped by
34+
`convert_to_unitcell_func` to a PEPS unitcell.
35+
expectation_func (:obj:`~peps_ad.expectation.Expectation_Model`):
36+
Callable to calculate one expectation value which is used as loss
37+
loss function of the model. Likely the function to calculate the energy.
38+
convert_to_unitcell_func (:obj:`~peps_ad.mapping.Map_To_PEPS_Model`):
39+
Function to convert the `input_tensors` to a PEPS unitcell. If ommited,
40+
it is assumed that a PEPS unitcell is the first input parameter.
41+
autosave_filename (:obj:`os.PathLike`):
42+
Filename where intermediate results are automatically saved.
43+
autosave_func (:term:`callable`):
44+
Function which is called to autosave the intermediate results.
45+
The function has to accept the arguments `(filename, tensors, unitcell)`.data (:obj:`Unit_Cell_Data`):
46+
Instance of unit cell data class
47+
"""
48+
49+
initial_guess: Union[PEPS_Unit_Cell, Sequence[jnp.ndarray]]
50+
expectation_func: Expectation_Model
51+
convert_to_unitcell_func: Optional[Map_To_PEPS_Model] = None
52+
autosave_filename: PathLike = "data/autosave.hdf5"
53+
autosave_func: Callable[
54+
[PathLike, Sequence[jnp.ndarray], PEPS_Unit_Cell], None
55+
] = autosave_function
56+
57+
def __post_init__(self):
58+
if isinstance(self.initial_guess, PEPS_Unit_Cell):
59+
initial_guess_tensors = [
60+
i.tensor for i in self.initial_guess.get_unique_tensors()
61+
]
62+
else:
63+
initial_guess_tensors = list(self.initial_guess)
64+
65+
initial_guess_flatten_tensors, self._map_pytree_func = ravel_pytree(
66+
initial_guess_tensors
67+
)
68+
69+
initial_guess_tensors_numpy = numpy.asarray(initial_guess_flatten_tensors)
70+
71+
if numpy.iscomplexobj(initial_guess_tensors_numpy):
72+
self._initial_guess_tensors_numpy = numpy.concatenate(
73+
(
74+
numpy.real(initial_guess_tensors_numpy),
75+
numpy.imag(initial_guess_tensors_numpy),
76+
)
77+
)
78+
self._iscomplex = True
79+
self._initial_guess_complex_length = initial_guess_flatten_tensors.size
80+
else:
81+
self._initial_guess_tensors_numpy = initial_guess_tensors_numpy
82+
self._iscomplex = False
83+
84+
def _wrapper_own_optimizer(
85+
self,
86+
fun,
87+
x0,
88+
*args,
89+
**kwargs,
90+
):
91+
if self._iscomplex:
92+
x0_jax = jnp.asarray(
93+
x0[: self._initial_guess_complex_length]
94+
+ 1j * x0[self._initial_guess_complex_length :]
95+
)
96+
else:
97+
x0_jax = jnp.asarray(x0)
98+
x0_jax = self._map_pytree_func(x0_jax)
99+
100+
if isinstance(self.initial_guess, PEPS_Unit_Cell):
101+
input_obj = PEPS_Unit_Cell.from_tensor_list(
102+
x0_jax, self.initial_guess.data.structure
103+
)
104+
else:
105+
input_obj = x0_jax
106+
107+
opt_result = optimize_peps_network(
108+
input_obj,
109+
self.expectation_func,
110+
self.convert_to_unitcell_func,
111+
self.autosave_filename,
112+
self.autosave_func,
113+
)
114+
115+
result_tensors, _ = ravel_pytree(opt_result.x)
116+
result_tensors_numpy = numpy.asarray(result_tensors)
117+
if self._iscomplex:
118+
result_tensors_numpy = numpy.concatenate(
119+
(numpy.real(result_tensors_numpy), numpy.imag(result_tensors_numpy))
120+
)
121+
122+
opt_result["x"] = result_tensors_numpy
123+
opt_result["fun"] = numpy.asarray(opt_result.fun)
124+
125+
return opt_result
126+
127+
@staticmethod
128+
def _dummy_func(x, *args, **kwargs):
129+
return x
130+
131+
def run(self) -> spo.OptimizeResult:
132+
"""
133+
Run the basinhopping algorithm for the setup initialized in the class
134+
object.
135+
136+
For details see :obj:`scipy.optimize.basinhopping`.
137+
138+
Returns:
139+
:obj:`scipy.optimize.OptimizeResult`:
140+
Result from the basinhopping algorithm with additional fields
141+
``unitcell`` and ``result_tensors`` for the result tensors and
142+
unitcell in the normal format of this library.
143+
"""
144+
result = spo.basinhopping(
145+
self._dummy_func,
146+
self._initial_guess_tensors_numpy,
147+
niter=peps_ad_config.basinhopping_niter,
148+
T=peps_ad_config.basinhopping_T,
149+
niter_success=peps_ad_config.basinhopping_niter_success,
150+
disp=True,
151+
minimizer_kwargs={"method": self._wrapper_own_optimizer},
152+
)
153+
154+
result["unitcell"] = result.lowest_optimization_result.unitcell
155+
156+
if self._iscomplex:
157+
x_jax = jnp.asarray(
158+
result.x[: self._initial_guess_complex_length]
159+
+ 1j * result.x[self._initial_guess_complex_length :]
160+
)
161+
else:
162+
x_jax = jnp.asarray(result.x)
163+
x_jax = self._map_pytree_func(x_jax)
164+
165+
result["result_tensors"] = x_jax
166+
167+
return result

0 commit comments

Comments
 (0)