Skip to content

Commit 01972e0

Browse files
tensorflow: Add legacy optimizers (#9997)
1 parent f7443a7 commit 01972e0

File tree

10 files changed

+288
-4
lines changed

10 files changed

+288
-4
lines changed

stubs/tensorflow/@tests/stubtest_allowlist.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@ tensorflow.Graph.__getattr__
1616
tensorflow.Operation.__getattr__
1717
tensorflow.Variable.__getattr__
1818
tensorflow.keras.layers.Layer.__getattr__
19+
tensorflow.GradientTape.__getattr__
1920
# Internal undocumented API
2021
tensorflow.RaggedTensor.__init__
2122
# Has an undocumented extra argument that tf.Variable which acts like subclass
2223
# (by dynamically patching tf.Tensor methods) does not preserve.
2324
tensorflow.Tensor.__getitem__
24-
# stub internal utility
25+
# stub internal utilities
2526
tensorflow._aliases
2627

2728
# Tensorflow imports are cursed.

stubs/tensorflow/tensorflow/__init__.pyi

+53-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from _typeshed import Incomplete, Unused
22
from abc import ABCMeta
33
from builtins import bool as _bool
4-
from collections.abc import Callable, Iterable, Iterator, Sequence
4+
from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Sequence
55
from contextlib import contextmanager
66
from enum import Enum
77
from types import TracebackType
@@ -10,6 +10,7 @@ from typing_extensions import ParamSpec, Self, TypeAlias
1010

1111
import numpy
1212
from tensorflow import initializers as initializers, keras as keras, math as math
13+
from tensorflow._aliases import ContainerGradients, ContainerTensors, ContainerTensorsLike, Gradients, TensorLike
1314

1415
# Explicit import of DType is covered by the wildcard, but
1516
# is necessary to avoid a crash in pytype.
@@ -53,7 +54,8 @@ from tensorflow.math import (
5354
subtract as subtract,
5455
tanh as tanh,
5556
)
56-
from tensorflow.sparse import SparseTensor
57+
from tensorflow.python.trackable.autotrackable import AutoTrackable
58+
from tensorflow.sparse import SparseTensor as SparseTensor
5759

5860
# Tensors ideally should be a generic type, but properly typing data type/shape
5961
# will be a lot of work. Until we have good non-generic tensorflow stubs,
@@ -263,7 +265,7 @@ class name_scope:
263265
_P = ParamSpec("_P")
264266
_R = TypeVar("_R")
265267

266-
class Module:
268+
class Module(AutoTrackable):
267269
def __init__(self, name: str | None = None) -> None: ...
268270
@property
269271
def name(self) -> str: ...
@@ -282,4 +284,52 @@ class Module:
282284
@classmethod
283285
def with_name_scope(cls, method: Callable[_P, _R]) -> Callable[_P, _R]: ...
284286

287+
class UnconnectedGradients(Enum):
288+
NONE = "none"
289+
ZERO = "zero"
290+
291+
class GradientTape:
292+
def __init__(self, persistent: _bool = False, watch_accessed_variables: _bool = True) -> None: ...
293+
def __enter__(self) -> Self: ...
294+
def __exit__(self, typ: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None) -> None: ...
295+
# Higher kinded types would be nice here and these overloads are a way to simulate some of them.
296+
@overload
297+
def gradient(
298+
self,
299+
target: ContainerTensors,
300+
sources: TensorLike,
301+
output_gradients: list[Tensor] | None = None,
302+
unconnected_gradients: UnconnectedGradients = ...,
303+
) -> Gradients: ...
304+
@overload
305+
def gradient(
306+
self,
307+
target: ContainerTensors,
308+
sources: Sequence[Tensor],
309+
output_gradients: list[Tensor] | None = None,
310+
unconnected_gradients: UnconnectedGradients = ...,
311+
) -> list[Gradients]: ...
312+
@overload
313+
def gradient(
314+
self,
315+
target: ContainerTensors,
316+
sources: Mapping[str, Tensor],
317+
output_gradients: list[Tensor] | None = None,
318+
unconnected_gradients: UnconnectedGradients = ...,
319+
) -> dict[str, Gradients]: ...
320+
@overload
321+
def gradient(
322+
self,
323+
target: ContainerTensors,
324+
sources: ContainerTensors,
325+
output_gradients: list[Tensor] | None = None,
326+
unconnected_gradients: UnconnectedGradients = ...,
327+
) -> ContainerGradients: ...
328+
@contextmanager
329+
def stop_recording(self) -> Generator[None, None, None]: ...
330+
def reset(self) -> None: ...
331+
def watch(self, tensor: ContainerTensorsLike) -> None: ...
332+
def watched_variables(self) -> tuple[Variable, ...]: ...
333+
def __getattr__(self, name: str) -> Incomplete: ...
334+
285335
def __getattr__(name: str) -> Incomplete: ...

stubs/tensorflow/tensorflow/_aliases.pyi

+8
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,16 @@ from typing import Any, TypeVar
77
from typing_extensions import TypeAlias
88

99
import numpy
10+
import tensorflow as tf
1011

1112
_T1 = TypeVar("_T1")
1213
ContainerGeneric: TypeAlias = Mapping[str, ContainerGeneric[_T1]] | Sequence[ContainerGeneric[_T1]] | _T1
1314

15+
TensorLike: TypeAlias = tf.Tensor | tf.RaggedTensor | tf.SparseTensor
16+
Gradients: TypeAlias = tf.Tensor | tf.IndexedSlices
17+
18+
ContainerTensorsLike: TypeAlias = ContainerGeneric[TensorLike]
19+
ContainerTensors: TypeAlias = ContainerGeneric[tf.Tensor]
20+
ContainerGradients: TypeAlias = ContainerGeneric[Gradients]
21+
1422
AnyArray: TypeAlias = numpy.ndarray[Any, Any]

stubs/tensorflow/tensorflow/keras/__init__.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ from tensorflow.keras import (
55
constraints as constraints,
66
initializers as initializers,
77
layers as layers,
8+
optimizers as optimizers,
89
regularizers as regularizers,
910
)
1011

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from _typeshed import Incomplete
2+
3+
from tensorflow.keras.optimizers import legacy as legacy, schedules as schedules
4+
5+
def __getattr__(name: str) -> Incomplete: ...
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from _typeshed import Incomplete
2+
from abc import abstractmethod
3+
from collections.abc import Callable, Iterable
4+
from typing import Any
5+
from typing_extensions import Self, TypeAlias
6+
7+
import tensorflow as tf
8+
from tensorflow._aliases import Gradients
9+
from tensorflow.keras.optimizers import schedules as schedules
10+
from tensorflow.python.trackable.base import Trackable
11+
12+
_Initializer: TypeAlias = str | Callable[[], tf.Tensor] | dict[str, Any]
13+
_Shape: TypeAlias = tf.TensorShape | Iterable[int | None]
14+
_Dtype: TypeAlias = tf.DType | str | None
15+
_LearningRate: TypeAlias = float | tf.Tensor | schedules.LearningRateSchedule | Callable[[], float | tf.Tensor]
16+
_GradientAggregator: TypeAlias = Callable[[list[tuple[Gradients, tf.Variable]]], list[tuple[Gradients, tf.Variable]]] | None
17+
_GradientTransformer: TypeAlias = (
18+
Iterable[Callable[[list[tuple[Gradients, tf.Variable]]], list[tuple[Gradients, tf.Variable]]]] | None
19+
)
20+
21+
# kwargs here and in other optimizers can be given better type after Unpack[TypedDict], PEP 692, is supported.
22+
class Optimizer(Trackable):
23+
_name: str
24+
_iterations: tf.Variable | None
25+
_weights: list[tf.Variable]
26+
gradient_aggregator: _GradientAggregator
27+
gradient_transformers: _GradientTransformer
28+
learning_rate: _LearningRate
29+
def __init__(
30+
self,
31+
name: str,
32+
gradient_aggregator: _GradientAggregator = None,
33+
gradient_transformers: _GradientTransformer = None,
34+
**kwargs: Any,
35+
) -> None: ...
36+
def _create_all_weights(self, var_list: Iterable[tf.Variable]) -> None: ...
37+
@property
38+
def iterations(self) -> tf.Variable: ...
39+
@iterations.setter
40+
def iterations(self, variable: tf.Variable) -> None: ...
41+
def add_slot(
42+
self, var: tf.Variable, slot_name: str, initializer: _Initializer = "zeros", shape: tf.TensorShape | None = None
43+
) -> tf.Variable: ...
44+
def add_weight(
45+
self,
46+
name: str,
47+
shape: _Shape,
48+
dtype: _Dtype = None,
49+
initializer: _Initializer = "zeros",
50+
trainable: None | bool = None,
51+
synchronization: tf.VariableSynchronization = ...,
52+
aggregation: tf.VariableAggregation = ...,
53+
) -> tf.Variable: ...
54+
def apply_gradients(
55+
self,
56+
grads_and_vars: Iterable[tuple[Gradients, tf.Variable]],
57+
name: str | None = None,
58+
experimental_aggregate_gradients: bool = True,
59+
) -> tf.Operation | None: ...
60+
@classmethod
61+
def from_config(cls, config: dict[str, Any], custom_objects: dict[str, type] | None = None) -> Self: ...
62+
# Missing ABC is intentional as class is not abstract at runtime.
63+
@abstractmethod
64+
def get_config(self) -> dict[str, Any]: ...
65+
def get_slot(self, var: tf.Variable, slot_name: str) -> tf.Variable: ...
66+
def get_slot_names(self) -> list[str]: ...
67+
def get_gradients(self, loss: tf.Tensor, params: list[tf.Variable]) -> list[Gradients]: ...
68+
def minimize(
69+
self,
70+
loss: tf.Tensor | Callable[[], tf.Tensor],
71+
var_list: list[tf.Variable] | tuple[tf.Variable, ...] | Callable[[], list[tf.Variable] | tuple[tf.Variable, ...]],
72+
grad_loss: tf.Tensor | None = None,
73+
name: str | None = None,
74+
tape: tf.GradientTape | None = None,
75+
) -> tf.Operation: ...
76+
def variables(self) -> list[tf.Variable]: ...
77+
@property
78+
def weights(self) -> list[tf.Variable]: ...
79+
80+
class Adam(Optimizer):
81+
def __init__(
82+
self,
83+
learning_rate: _LearningRate = 0.001,
84+
beta_1: float = 0.9,
85+
beta_2: float = 0.999,
86+
epsilon: float = 1e-07,
87+
amsgrad: bool = False,
88+
name: str = "Adam",
89+
**kwargs: Any,
90+
) -> None: ...
91+
def get_config(self) -> dict[str, Any]: ...
92+
93+
class Adagrad(Optimizer):
94+
_initial_accumulator_value: float
95+
96+
def __init__(
97+
self,
98+
learning_rate: _LearningRate = 0.001,
99+
initial_accumulator_value: float = 0.1,
100+
epsilon: float = 1e-7,
101+
name: str = "Adagrad",
102+
**kwargs: Any,
103+
) -> None: ...
104+
def get_config(self) -> dict[str, Any]: ...
105+
106+
class SGD(Optimizer):
107+
def __init__(
108+
self, learning_rate: _LearningRate = 0.01, momentum: float = 0.0, nesterov: bool = False, name: str = "SGD", **kwargs: Any
109+
) -> None: ...
110+
def get_config(self) -> dict[str, Any]: ...
111+
112+
def __getattr__(name: str) -> Incomplete: ...
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from abc import abstractmethod
2+
from collections.abc import Sequence
3+
from typing import Any
4+
from typing_extensions import Self
5+
6+
import tensorflow as tf
7+
8+
class LearningRateSchedule:
9+
# At runtime these methods are abstract even though class is not ABC.
10+
@abstractmethod
11+
def __call__(self, step: int | tf.Tensor) -> float | tf.Tensor: ...
12+
@abstractmethod
13+
def get_config(self) -> dict[str, Any]: ...
14+
@classmethod
15+
def from_config(cls, config: dict[str, Any]) -> Self: ...
16+
17+
class PiecewiseConstantDecay(LearningRateSchedule):
18+
def __init__(
19+
self,
20+
boundaries: Sequence[tf.Tensor] | Sequence[float],
21+
values: Sequence[float] | Sequence[tf.Tensor],
22+
name: str | None = None,
23+
) -> None: ...
24+
def __call__(self, step: int | tf.Tensor) -> float | tf.Tensor: ...
25+
def get_config(self) -> dict[str, Any]: ...
26+
@classmethod
27+
def from_config(cls, config: dict[str, Any]) -> Self: ...
28+
29+
class InverseTimeDecay(LearningRateSchedule):
30+
def __init__(
31+
self,
32+
initial_learning_rate: float | tf.Tensor,
33+
decay_steps: int,
34+
decay_rate: float,
35+
staircase: bool = False,
36+
name: str | None = None,
37+
) -> None: ...
38+
def __call__(self, step: int | tf.Tensor) -> float | tf.Tensor: ...
39+
def get_config(self) -> dict[str, Any]: ...
40+
@classmethod
41+
def from_config(cls, config: dict[str, Any]) -> Self: ...
42+
43+
class PolynomialDecay(LearningRateSchedule):
44+
def __init__(
45+
self,
46+
initial_learning_rate: float | tf.Tensor,
47+
decay_steps: int,
48+
end_learning_rate: float | tf.Tensor = 0.0001,
49+
power: float = 1.0,
50+
cycle: bool = False,
51+
name: str | None = None,
52+
) -> None: ...
53+
def __call__(self, step: int | tf.Tensor) -> float | tf.Tensor: ...
54+
def get_config(self) -> dict[str, Any]: ...
55+
@classmethod
56+
def from_config(cls, config: dict[str, Any]) -> Self: ...
57+
58+
class CosineDecay(LearningRateSchedule):
59+
def __init__(
60+
self, initial_learning_rate: float | tf.Tensor, decay_steps: int, alpha: float | tf.Tensor = 0.0, name: str | None = None
61+
) -> None: ...
62+
def __call__(self, step: int | tf.Tensor) -> float | tf.Tensor: ...
63+
def get_config(self) -> dict[str, Any]: ...
64+
@classmethod
65+
def from_config(cls, config: dict[str, Any]) -> Self: ...
66+
67+
class CosineDecayRestarts(LearningRateSchedule):
68+
def __init__(
69+
self,
70+
initial_learning_rate: float | tf.Tensor,
71+
first_decay_steps: int | tf.Tensor,
72+
t_mul: float | tf.Tensor = 2.0,
73+
m_mul: float | tf.Tensor = 1.0,
74+
alpha: float | tf.Tensor = 0.0,
75+
name: str | None = None,
76+
) -> None: ...
77+
def __call__(self, step: int | tf.Tensor) -> float | tf.Tensor: ...
78+
def get_config(self) -> dict[str, Any]: ...
79+
@classmethod
80+
def from_config(cls, config: dict[str, Any]) -> Self: ...
81+
82+
class ExponentialDecay(LearningRateSchedule):
83+
def __init__(
84+
self,
85+
initial_learning_rate: float | tf.Tensor,
86+
decay_steps: int | tf.Tensor,
87+
decay_rate: float | tf.Tensor,
88+
staircase: bool = False,
89+
name: str | None = None,
90+
) -> None: ...
91+
def __call__(self, step: int | tf.Tensor) -> float | tf.Tensor: ...
92+
def get_config(self) -> dict[str, Any]: ...
93+
@classmethod
94+
def from_config(cls, config: dict[str, Any]) -> Self: ...
95+
96+
def deserialize(
97+
config: dict[str, Any], custom_objects: dict[str, type] | None = None, use_legacy_format: bool = False
98+
) -> LearningRateSchedule: ...
99+
def serialize(learning_rate_schedule: LearningRateSchedule, use_legacy_format: bool = False) -> dict[str, Any]: ...

stubs/tensorflow/tensorflow/python/trackable/__init__.pyi

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from tensorflow.python.trackable.base import Trackable
2+
3+
class AutoTrackable(Trackable): ...
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Internal type that is commonly used as a base class
2+
# and some public apis the signature needs it. As type
3+
# is internal exact module it lives in is unstable across
4+
# versions.
5+
class Trackable: ...

0 commit comments

Comments
 (0)