Skip to content

Add type-hints to BaseLearner #374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 26 additions & 16 deletions adaptive/learner/base_learner.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

import abc
from contextlib import suppress
from typing import Any, Callable

import cloudpickle

from adaptive.utils import _RequireAttrsABCMeta, load, save


def uses_nth_neighbors(n: int):
def uses_nth_neighbors(n: int) -> Callable[[int], Callable[[BaseLearner], float]]:
"""Decorator to specify how many neighboring intervals the loss function uses.

Wraps loss functions to indicate that they expect intervals together
Expand Down Expand Up @@ -53,7 +56,9 @@ def uses_nth_neighbors(n: int):
... return loss
"""

def _wrapped(loss_per_interval):
def _wrapped(
loss_per_interval: Callable[[BaseLearner], float]
) -> Callable[[BaseLearner], float]:
loss_per_interval.nth_neighbors = n
return loss_per_interval

Expand Down Expand Up @@ -82,10 +87,15 @@ class BaseLearner(metaclass=_RequireAttrsABCMeta):
"""

data: dict
npoints: int
pending_points: set
function: Callable

@property
@abc.abstractmethod
def npoints(self) -> int:
"""Number of learned points."""

def tell(self, x, y):
def tell(self, x: Any, y: Any) -> None:
"""Tell the learner about a single value.

Parameters
Expand All @@ -95,7 +105,7 @@ def tell(self, x, y):
"""
self.tell_many([x], [y])

def tell_many(self, xs, ys):
def tell_many(self, xs: Any, ys: Any) -> None:
"""Tell the learner about some values.

Parameters
Expand All @@ -107,16 +117,16 @@ def tell_many(self, xs, ys):
self.tell(x, y)

@abc.abstractmethod
def tell_pending(self, x):
def tell_pending(self, x: Any) -> None:
"""Tell the learner that 'x' has been requested such
that it's not suggested again."""

@abc.abstractmethod
def remove_unfinished(self):
def remove_unfinished(self) -> None:
"""Remove uncomputed data from the learner."""

@abc.abstractmethod
def loss(self, real=True):
def loss(self, real: bool = True) -> float:
"""Return the loss for the current state of the learner.

Parameters
Expand All @@ -128,7 +138,7 @@ def loss(self, real=True):
"""

@abc.abstractmethod
def ask(self, n, tell_pending=True):
def ask(self, n: int, tell_pending: bool = True) -> tuple[list[Any], list[float]]:
"""Choose the next 'n' points to evaluate.

Parameters
Expand All @@ -142,19 +152,19 @@ def ask(self, n, tell_pending=True):
"""

@abc.abstractmethod
def _get_data(self):
def _get_data(self) -> Any:
pass

@abc.abstractmethod
def _set_data(self):
def _set_data(self, data: Any):
pass

@abc.abstractmethod
def new(self):
"""Return a new learner with the same function and parameters."""
pass

def copy_from(self, other):
def copy_from(self, other: BaseLearner) -> None:
"""Copy over the data from another learner.

Parameters
Expand All @@ -164,7 +174,7 @@ def copy_from(self, other):
"""
self._set_data(other._get_data())

def save(self, fname, compress=True):
def save(self, fname: str, compress: bool = True) -> None:
"""Save the data of the learner into a pickle file.

Parameters
Expand All @@ -178,7 +188,7 @@ def save(self, fname, compress=True):
data = self._get_data()
save(fname, data, compress)

def load(self, fname, compress=True):
def load(self, fname: str, compress: bool = True) -> None:
"""Load the data of a learner from a pickle file.

Parameters
Expand All @@ -193,8 +203,8 @@ def load(self, fname, compress=True):
data = load(fname, compress)
self._set_data(data)

def __getstate__(self):
def __getstate__(self) -> bytes:
return cloudpickle.dumps(self.__dict__)

def __setstate__(self, state):
def __setstate__(self, state: bytes) -> None:
self.__dict__ = cloudpickle.loads(state)
Loading