|
| 1 | +from _typeshed import Incomplete |
| 2 | +from abc import abstractmethod |
| 3 | +from collections.abc import Callable, Iterable |
| 4 | +from typing import Any |
| 5 | +from typing_extensions import Self, TypeAlias |
| 6 | + |
| 7 | +import tensorflow as tf |
| 8 | +from tensorflow._aliases import Gradients |
| 9 | +from tensorflow.keras.optimizers import schedules as schedules |
| 10 | +from tensorflow.python.trackable.base import Trackable |
| 11 | + |
| 12 | +_Initializer: TypeAlias = str | Callable[[], tf.Tensor] | dict[str, Any] |
| 13 | +_Shape: TypeAlias = tf.TensorShape | Iterable[int | None] |
| 14 | +_Dtype: TypeAlias = tf.DType | str | None |
| 15 | +_LearningRate: TypeAlias = float | tf.Tensor | schedules.LearningRateSchedule | Callable[[], float | tf.Tensor] |
| 16 | +_GradientAggregator: TypeAlias = Callable[[list[tuple[Gradients, tf.Variable]]], list[tuple[Gradients, tf.Variable]]] | None |
| 17 | +_GradientTransformer: TypeAlias = ( |
| 18 | + Iterable[Callable[[list[tuple[Gradients, tf.Variable]]], list[tuple[Gradients, tf.Variable]]]] | None |
| 19 | +) |
| 20 | + |
| 21 | +# kwargs here and in other optimizers can be given better type after Unpack[TypedDict], PEP 692, is supported. |
| 22 | +class Optimizer(Trackable): |
| 23 | + _name: str |
| 24 | + _iterations: tf.Variable | None |
| 25 | + _weights: list[tf.Variable] |
| 26 | + gradient_aggregator: _GradientAggregator |
| 27 | + gradient_transformers: _GradientTransformer |
| 28 | + learning_rate: _LearningRate |
| 29 | + def __init__( |
| 30 | + self, |
| 31 | + name: str, |
| 32 | + gradient_aggregator: _GradientAggregator = None, |
| 33 | + gradient_transformers: _GradientTransformer = None, |
| 34 | + **kwargs: Any, |
| 35 | + ) -> None: ... |
| 36 | + def _create_all_weights(self, var_list: Iterable[tf.Variable]) -> None: ... |
| 37 | + @property |
| 38 | + def iterations(self) -> tf.Variable: ... |
| 39 | + @iterations.setter |
| 40 | + def iterations(self, variable: tf.Variable) -> None: ... |
| 41 | + def add_slot( |
| 42 | + self, var: tf.Variable, slot_name: str, initializer: _Initializer = "zeros", shape: tf.TensorShape | None = None |
| 43 | + ) -> tf.Variable: ... |
| 44 | + def add_weight( |
| 45 | + self, |
| 46 | + name: str, |
| 47 | + shape: _Shape, |
| 48 | + dtype: _Dtype = None, |
| 49 | + initializer: _Initializer = "zeros", |
| 50 | + trainable: None | bool = None, |
| 51 | + synchronization: tf.VariableSynchronization = ..., |
| 52 | + aggregation: tf.VariableAggregation = ..., |
| 53 | + ) -> tf.Variable: ... |
| 54 | + def apply_gradients( |
| 55 | + self, |
| 56 | + grads_and_vars: Iterable[tuple[Gradients, tf.Variable]], |
| 57 | + name: str | None = None, |
| 58 | + experimental_aggregate_gradients: bool = True, |
| 59 | + ) -> tf.Operation | None: ... |
| 60 | + @classmethod |
| 61 | + def from_config(cls, config: dict[str, Any], custom_objects: dict[str, type] | None = None) -> Self: ... |
| 62 | + # Missing ABC is intentional as class is not abstract at runtime. |
| 63 | + @abstractmethod |
| 64 | + def get_config(self) -> dict[str, Any]: ... |
| 65 | + def get_slot(self, var: tf.Variable, slot_name: str) -> tf.Variable: ... |
| 66 | + def get_slot_names(self) -> list[str]: ... |
| 67 | + def get_gradients(self, loss: tf.Tensor, params: list[tf.Variable]) -> list[Gradients]: ... |
| 68 | + def minimize( |
| 69 | + self, |
| 70 | + loss: tf.Tensor | Callable[[], tf.Tensor], |
| 71 | + var_list: list[tf.Variable] | tuple[tf.Variable, ...] | Callable[[], list[tf.Variable] | tuple[tf.Variable, ...]], |
| 72 | + grad_loss: tf.Tensor | None = None, |
| 73 | + name: str | None = None, |
| 74 | + tape: tf.GradientTape | None = None, |
| 75 | + ) -> tf.Operation: ... |
| 76 | + def variables(self) -> list[tf.Variable]: ... |
| 77 | + @property |
| 78 | + def weights(self) -> list[tf.Variable]: ... |
| 79 | + |
| 80 | +class Adam(Optimizer): |
| 81 | + def __init__( |
| 82 | + self, |
| 83 | + learning_rate: _LearningRate = 0.001, |
| 84 | + beta_1: float = 0.9, |
| 85 | + beta_2: float = 0.999, |
| 86 | + epsilon: float = 1e-07, |
| 87 | + amsgrad: bool = False, |
| 88 | + name: str = "Adam", |
| 89 | + **kwargs: Any, |
| 90 | + ) -> None: ... |
| 91 | + def get_config(self) -> dict[str, Any]: ... |
| 92 | + |
| 93 | +class Adagrad(Optimizer): |
| 94 | + _initial_accumulator_value: float |
| 95 | + |
| 96 | + def __init__( |
| 97 | + self, |
| 98 | + learning_rate: _LearningRate = 0.001, |
| 99 | + initial_accumulator_value: float = 0.1, |
| 100 | + epsilon: float = 1e-7, |
| 101 | + name: str = "Adagrad", |
| 102 | + **kwargs: Any, |
| 103 | + ) -> None: ... |
| 104 | + def get_config(self) -> dict[str, Any]: ... |
| 105 | + |
| 106 | +class SGD(Optimizer): |
| 107 | + def __init__( |
| 108 | + self, learning_rate: _LearningRate = 0.01, momentum: float = 0.0, nesterov: bool = False, name: str = "SGD", **kwargs: Any |
| 109 | + ) -> None: ... |
| 110 | + def get_config(self) -> dict[str, Any]: ... |
| 111 | + |
| 112 | +def __getattr__(name: str) -> Incomplete: ... |
0 commit comments