|
| 1 | +from dataclasses import dataclass |
| 2 | +from os import PathLike |
| 3 | + |
| 4 | +import numpy |
| 5 | +import scipy.optimize as spo |
| 6 | + |
| 7 | +import jax.numpy as jnp |
| 8 | +from jax.flatten_util import ravel_pytree |
| 9 | + |
| 10 | +from peps_ad import peps_ad_config |
| 11 | +from peps_ad.peps import PEPS_Unit_Cell |
| 12 | +from peps_ad.expectation import Expectation_Model |
| 13 | +from peps_ad.mapping import Map_To_PEPS_Model |
| 14 | +from peps_ad.ctmrg import CTMRGNotConvergedError, CTMRGGradientNotConvergedError |
| 15 | + |
| 16 | +from .line_search import NoSuitableStepSizeError |
| 17 | +from .optimizer import optimize_peps_network, autosave_function |
| 18 | + |
| 19 | +from typing import List, Union, Tuple, cast, Sequence, Callable, Optional |
| 20 | + |
| 21 | + |
| 22 | +@dataclass |
| 23 | +class PEPS_AD_Basinhopping: |
| 24 | + """ |
| 25 | + Class to wrap the basinhopping algorithm for the variational update |
| 26 | + of PEPS or mapped structures. |
| 27 | +
|
| 28 | + The parameters of the class initialization are the same as for |
| 29 | + :obj:`~peps_ad.optimization.optimize_peps_network`. |
| 30 | +
|
| 31 | + Args: |
| 32 | + initial_guess (:obj:`~peps_ad.peps.PEPS_Unit_Cell` or :term:`sequence` of :obj:`jax.numpy.ndarray`): |
| 33 | + The PEPS unitcell to work on or the tensors which should be mapped by |
| 34 | + `convert_to_unitcell_func` to a PEPS unitcell. |
| 35 | + expectation_func (:obj:`~peps_ad.expectation.Expectation_Model`): |
| 36 | + Callable to calculate one expectation value which is used as loss |
| 37 | + loss function of the model. Likely the function to calculate the energy. |
| 38 | + convert_to_unitcell_func (:obj:`~peps_ad.mapping.Map_To_PEPS_Model`): |
| 39 | + Function to convert the `input_tensors` to a PEPS unitcell. If ommited, |
| 40 | + it is assumed that a PEPS unitcell is the first input parameter. |
| 41 | + autosave_filename (:obj:`os.PathLike`): |
| 42 | + Filename where intermediate results are automatically saved. |
| 43 | + autosave_func (:term:`callable`): |
| 44 | + Function which is called to autosave the intermediate results. |
| 45 | + The function has to accept the arguments `(filename, tensors, unitcell)`.data (:obj:`Unit_Cell_Data`): |
| 46 | + Instance of unit cell data class |
| 47 | + """ |
| 48 | + |
| 49 | + initial_guess: Union[PEPS_Unit_Cell, Sequence[jnp.ndarray]] |
| 50 | + expectation_func: Expectation_Model |
| 51 | + convert_to_unitcell_func: Optional[Map_To_PEPS_Model] = None |
| 52 | + autosave_filename: PathLike = "data/autosave.hdf5" |
| 53 | + autosave_func: Callable[ |
| 54 | + [PathLike, Sequence[jnp.ndarray], PEPS_Unit_Cell], None |
| 55 | + ] = autosave_function |
| 56 | + |
| 57 | + def __post_init__(self): |
| 58 | + if isinstance(self.initial_guess, PEPS_Unit_Cell): |
| 59 | + initial_guess_tensors = [ |
| 60 | + i.tensor for i in self.initial_guess.get_unique_tensors() |
| 61 | + ] |
| 62 | + else: |
| 63 | + initial_guess_tensors = list(self.initial_guess) |
| 64 | + |
| 65 | + initial_guess_flatten_tensors, self._map_pytree_func = ravel_pytree( |
| 66 | + initial_guess_tensors |
| 67 | + ) |
| 68 | + |
| 69 | + initial_guess_tensors_numpy = numpy.asarray(initial_guess_flatten_tensors) |
| 70 | + |
| 71 | + if numpy.iscomplexobj(initial_guess_tensors_numpy): |
| 72 | + self._initial_guess_tensors_numpy = numpy.concatenate( |
| 73 | + ( |
| 74 | + numpy.real(initial_guess_tensors_numpy), |
| 75 | + numpy.imag(initial_guess_tensors_numpy), |
| 76 | + ) |
| 77 | + ) |
| 78 | + self._iscomplex = True |
| 79 | + self._initial_guess_complex_length = initial_guess_flatten_tensors.size |
| 80 | + else: |
| 81 | + self._initial_guess_tensors_numpy = initial_guess_tensors_numpy |
| 82 | + self._iscomplex = False |
| 83 | + |
| 84 | + def _wrapper_own_optimizer( |
| 85 | + self, |
| 86 | + fun, |
| 87 | + x0, |
| 88 | + *args, |
| 89 | + **kwargs, |
| 90 | + ): |
| 91 | + if self._iscomplex: |
| 92 | + x0_jax = jnp.asarray( |
| 93 | + x0[: self._initial_guess_complex_length] |
| 94 | + + 1j * x0[self._initial_guess_complex_length :] |
| 95 | + ) |
| 96 | + else: |
| 97 | + x0_jax = jnp.asarray(x0) |
| 98 | + x0_jax = self._map_pytree_func(x0_jax) |
| 99 | + |
| 100 | + if isinstance(self.initial_guess, PEPS_Unit_Cell): |
| 101 | + input_obj = PEPS_Unit_Cell.from_tensor_list( |
| 102 | + x0_jax, self.initial_guess.data.structure |
| 103 | + ) |
| 104 | + else: |
| 105 | + input_obj = x0_jax |
| 106 | + |
| 107 | + opt_result = optimize_peps_network( |
| 108 | + input_obj, |
| 109 | + self.expectation_func, |
| 110 | + self.convert_to_unitcell_func, |
| 111 | + self.autosave_filename, |
| 112 | + self.autosave_func, |
| 113 | + ) |
| 114 | + |
| 115 | + result_tensors, _ = ravel_pytree(opt_result.x) |
| 116 | + result_tensors_numpy = numpy.asarray(result_tensors) |
| 117 | + if self._iscomplex: |
| 118 | + result_tensors_numpy = numpy.concatenate( |
| 119 | + (numpy.real(result_tensors_numpy), numpy.imag(result_tensors_numpy)) |
| 120 | + ) |
| 121 | + |
| 122 | + opt_result["x"] = result_tensors_numpy |
| 123 | + opt_result["fun"] = numpy.asarray(opt_result.fun) |
| 124 | + |
| 125 | + return opt_result |
| 126 | + |
| 127 | + @staticmethod |
| 128 | + def _dummy_func(x, *args, **kwargs): |
| 129 | + return x |
| 130 | + |
| 131 | + def run(self) -> spo.OptimizeResult: |
| 132 | + """ |
| 133 | + Run the basinhopping algorithm for the setup initialized in the class |
| 134 | + object. |
| 135 | +
|
| 136 | + For details see :obj:`scipy.optimize.basinhopping`. |
| 137 | +
|
| 138 | + Returns: |
| 139 | + :obj:`scipy.optimize.OptimizeResult`: |
| 140 | + Result from the basinhopping algorithm with additional fields |
| 141 | + ``unitcell`` and ``result_tensors`` for the result tensors and |
| 142 | + unitcell in the normal format of this library. |
| 143 | + """ |
| 144 | + result = spo.basinhopping( |
| 145 | + self._dummy_func, |
| 146 | + self._initial_guess_tensors_numpy, |
| 147 | + niter=peps_ad_config.basinhopping_niter, |
| 148 | + T=peps_ad_config.basinhopping_T, |
| 149 | + niter_success=peps_ad_config.basinhopping_niter_success, |
| 150 | + disp=True, |
| 151 | + minimizer_kwargs={"method": self._wrapper_own_optimizer}, |
| 152 | + ) |
| 153 | + |
| 154 | + result["unitcell"] = result.lowest_optimization_result.unitcell |
| 155 | + |
| 156 | + if self._iscomplex: |
| 157 | + x_jax = jnp.asarray( |
| 158 | + result.x[: self._initial_guess_complex_length] |
| 159 | + + 1j * result.x[self._initial_guess_complex_length :] |
| 160 | + ) |
| 161 | + else: |
| 162 | + x_jax = jnp.asarray(result.x) |
| 163 | + x_jax = self._map_pytree_func(x_jax) |
| 164 | + |
| 165 | + result["result_tensors"] = x_jax |
| 166 | + |
| 167 | + return result |
0 commit comments