Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions estimator/lwe_primal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sage.all import oo, ceil, sqrt, log, RR, ZZ, binomial, cached_function
from .reduction import delta as deltaf
from .reduction import cost as costf
from .util import local_minimum
from .util import local_minimum, ternary_search
from .cost import Cost
from .lwe_parameters import LWEParameters
from .simulator import normalize as simulator_normalize
Expand Down Expand Up @@ -723,11 +723,19 @@ def __call__(
):
zeta_max += 1

with local_minimum(0, min(zeta_max, params.n), log_level=log_level) as it:
for zeta in it:
it.update(f(zeta=zeta, optimize_d=False, **kwds))
# TODO: this should not be required
cost = min(it.y, f(0, optimize_d=False, **kwds))
if params.n >= 2 ** 15:
with ternary_search(0, min(zeta_max, params.n), log_level=log_level) as it:
for zeta in it:
it.update(f(zeta=zeta, optimize_d=False, **kwds))
# TODO: this should not be required
cost = min(it.y, f(0, optimize_d=False, **kwds))
else:
with local_minimum(0, min(zeta_max, params.n), log_level=log_level) as it:
for zeta in it:
it.update(f(zeta=zeta, optimize_d=False, **kwds))
# TODO: this should not be required
cost = min(it.y, f(0, optimize_d=False, **kwds))

else:
cost = f(zeta=zeta)

Expand Down
134 changes: 134 additions & 0 deletions estimator/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,140 @@ def neighborhood(self):
return range(start, stop)


class ternary_search:
"""
an iterator context for finding a local minimum using ternary search.

For an interval [a, b] we evaluate f(x) and the points
x1 = a + (b - a) / 3
x2 = a + 2 * (b - a) / 3
if f(x1) < f(x2), we keep [a, x2]
if f(x1) > f(x2), we keep [x1, b]
"""

def __init__(
self,
start,
stop,
smallerf=lambda x, best: x <= best,
suppress_bounds_warning=False,
log_level=5,
):
"""
Create a fresh local minimum ternary search context.

:param start: starting point
:param stop: end point (exclusive)
:param smallerf: a function to decide if ``lhs`` is smaller than ``rhs``
:param suppress_bounds_warning: do not warn if a boundary is picked as optimal

"""

if stop < start:
raise ValueError(f"Incorrect bounds {start} > {stop}.")

self._suppress_bounds_warning = suppress_bounds_warning
self._log_level = log_level
self._start = start
self._stop = stop - 1
self._x1 = start + (stop - start) // 3
self._x2 = start + (2 * (stop - start)) // 3
self._fx1 = None
self._fx2 = None
self._initial_bounds = Bounds(start, stop - 1)
self._smallerf = smallerf
self._last_x = None
self._best = Bounds(None, None)
self._vals = {}

def __enter__(self):
""" """
return self

def __exit__(self, type, value, traceback):
""" """
pass

def __iter__(self):
""" """
return self

def __next__(self):
if self._x1 is not None and self._fx1 is None:
self._last_x = self._x1
return self._last_x
if self._x2 is not None and self._fx2 is None:
self._last_x = self._x2
return self._last_x
if self._best.low in self._initial_bounds and not self._suppress_bounds_warning:
# We warn the user if the optimal solution is at the edge and thus possibly not optimal.
msg = (
f'warning: "optimal" solution {self._best.low} matches a bound ∈ {self._initial_bounds}.',
)
Logging.log("bins", self._log_level, msg)
raise StopIteration

@property
def x(self):
return self._best.low

@property
def y(self):
return self._best.high

def update(self, res):
Logging.log("bins", self._log_level, f"({self._last_x}, {repr(res)})")

self._vals[self._last_x] = res

# We got nothing yet
if self._best.low is None:
self._best = Bounds(self._last_x, res)

# We found something better
if res is not False and self._smallerf(res, self._best.high):
# store it
self._best = Bounds(self._last_x, res)

if self._last_x == self._x1:
self._fx1 = res

if self._last_x == self._x2:
self._fx2 = res

# we need to exit this loop either with something to do, or having calculated f for every point in [start, stop]
# if stop - start > 2, we are guaranteed to shrink
# to avoid getting stuck, we handle the cases stop - start <= 2 separately.

while self._fx1 is not None and self._fx2 is not None and (self._stop - self._start) > 2:
# drop the right third
if self._smallerf(self._fx1, self._fx2):
self._start = self._start
self._stop = self._x2
# drop the left third
else:
self._start = self._x1
self._stop = self._stop
self._x1 = self._start + (self._stop - self._start) // 3
self._x2 = self._start + (2 * (self._stop - self._start)) // 3

# if already seen, load the value: otherwise, mark None
self._fx1 = self._vals.get(self._x1, None)
self._fx2 = self._vals.get(self._x2, None)

# at most three integers remain: exhaustively search over them
if self._stop - self._start <= 2:
# print(self._start, self._stop)
next = [x for x in range(self._start, self._stop + 1) if x not in self._vals]
if next:
# we assign remaining points arbitrarily to x1 and x2
self._x1 = next[0]
self._fx1 = None
if len(next) > 1:
self._x2 = next[1]
self._fx2 = None


class early_abort_range:
"""
An iterator context for finding a local minimum using linear search.
Expand Down