diff --git a/estimator/lwe_primal.py b/estimator/lwe_primal.py index 4e8f999..7254ae3 100644 --- a/estimator/lwe_primal.py +++ b/estimator/lwe_primal.py @@ -10,7 +10,7 @@ from sage.all import oo, ceil, sqrt, log, RR, ZZ, binomial, cached_function from .reduction import delta as deltaf from .reduction import cost as costf -from .util import local_minimum +from .util import local_minimum, ternary_search from .cost import Cost from .lwe_parameters import LWEParameters from .simulator import normalize as simulator_normalize @@ -723,11 +723,19 @@ def __call__( ): zeta_max += 1 - with local_minimum(0, min(zeta_max, params.n), log_level=log_level) as it: - for zeta in it: - it.update(f(zeta=zeta, optimize_d=False, **kwds)) - # TODO: this should not be required - cost = min(it.y, f(0, optimize_d=False, **kwds)) + if params.n >= 2 ** 15: + with ternary_search(0, min(zeta_max, params.n), log_level=log_level) as it: + for zeta in it: + it.update(f(zeta=zeta, optimize_d=False, **kwds)) + # TODO: this should not be required + cost = min(it.y, f(0, optimize_d=False, **kwds)) + else: + with local_minimum(0, min(zeta_max, params.n), log_level=log_level) as it: + for zeta in it: + it.update(f(zeta=zeta, optimize_d=False, **kwds)) + # TODO: this should not be required + cost = min(it.y, f(0, optimize_d=False, **kwds)) + else: cost = f(zeta=zeta) diff --git a/estimator/util.py b/estimator/util.py index f726981..8889fa2 100644 --- a/estimator/util.py +++ b/estimator/util.py @@ -275,6 +275,140 @@ def neighborhood(self): return range(start, stop) +class ternary_search: + """ + an iterator context for finding a local minimum using ternary search. + + For an interval [a, b] we evaluate f(x) and the points + x1 = a + (b - a) / 3 + x2 = a + 2 * (b - a) / 3 + if f(x1) < f(x2), we keep [a, x2] + if f(x1) > f(x2), we keep [x1, b] + """ + + def __init__( + self, + start, + stop, + smallerf=lambda x, best: x <= best, + suppress_bounds_warning=False, + log_level=5, + ): + """ + Create a fresh local minimum ternary search context. + + :param start: starting point + :param stop: end point (exclusive) + :param smallerf: a function to decide if ``lhs`` is smaller than ``rhs`` + :param suppress_bounds_warning: do not warn if a boundary is picked as optimal + + """ + + if stop < start: + raise ValueError(f"Incorrect bounds {start} > {stop}.") + + self._suppress_bounds_warning = suppress_bounds_warning + self._log_level = log_level + self._start = start + self._stop = stop - 1 + self._x1 = start + (stop - start) // 3 + self._x2 = start + (2 * (stop - start)) // 3 + self._fx1 = None + self._fx2 = None + self._initial_bounds = Bounds(start, stop - 1) + self._smallerf = smallerf + self._last_x = None + self._best = Bounds(None, None) + self._vals = {} + + def __enter__(self): + """ """ + return self + + def __exit__(self, type, value, traceback): + """ """ + pass + + def __iter__(self): + """ """ + return self + + def __next__(self): + if self._x1 is not None and self._fx1 is None: + self._last_x = self._x1 + return self._last_x + if self._x2 is not None and self._fx2 is None: + self._last_x = self._x2 + return self._last_x + if self._best.low in self._initial_bounds and not self._suppress_bounds_warning: + # We warn the user if the optimal solution is at the edge and thus possibly not optimal. + msg = ( + f'warning: "optimal" solution {self._best.low} matches a bound ∈ {self._initial_bounds}.', + ) + Logging.log("bins", self._log_level, msg) + raise StopIteration + + @property + def x(self): + return self._best.low + + @property + def y(self): + return self._best.high + + def update(self, res): + Logging.log("bins", self._log_level, f"({self._last_x}, {repr(res)})") + + self._vals[self._last_x] = res + + # We got nothing yet + if self._best.low is None: + self._best = Bounds(self._last_x, res) + + # We found something better + if res is not False and self._smallerf(res, self._best.high): + # store it + self._best = Bounds(self._last_x, res) + + if self._last_x == self._x1: + self._fx1 = res + + if self._last_x == self._x2: + self._fx2 = res + + # we need to exit this loop either with something to do, or having calculated f for every point in [start, stop] + # if stop - start > 2, we are guaranteed to shrink + # to avoid getting stuck, we handle the cases stop - start <= 2 separately. + + while self._fx1 is not None and self._fx2 is not None and (self._stop - self._start) > 2: + # drop the right third + if self._smallerf(self._fx1, self._fx2): + self._start = self._start + self._stop = self._x2 + # drop the left third + else: + self._start = self._x1 + self._stop = self._stop + self._x1 = self._start + (self._stop - self._start) // 3 + self._x2 = self._start + (2 * (self._stop - self._start)) // 3 + + # if already seen, load the value: otherwise, mark None + self._fx1 = self._vals.get(self._x1, None) + self._fx2 = self._vals.get(self._x2, None) + + # at most three integers remain: exhaustively search over them + if self._stop - self._start <= 2: + # print(self._start, self._stop) + next = [x for x in range(self._start, self._stop + 1) if x not in self._vals] + if next: + # we assign remaining points arbitrarily to x1 and x2 + self._x1 = next[0] + self._fx1 = None + if len(next) > 1: + self._x2 = next[1] + self._fx2 = None + + class early_abort_range: """ An iterator context for finding a local minimum using linear search.