From 7baf60b2ab709842ccac45576a23da0a784620f2 Mon Sep 17 00:00:00 2001 From: John Aslanides Date: Sun, 12 Apr 2020 10:09:17 +0100 Subject: [PATCH 1/3] Add type annotations to optix. --- jax/experimental/optix.py | 141 ++++++++++++++++++++++++++------------ 1 file changed, 99 insertions(+), 42 deletions(-) diff --git a/jax/experimental/optix.py b/jax/experimental/optix.py index b7a7e0a4db4e..dbbdcf923f4a 100644 --- a/jax/experimental/optix.py +++ b/jax/experimental/optix.py @@ -40,6 +40,7 @@ import collections +from typing import Any, Callable, NamedTuple, Sequence, Tuple from jax import numpy as jnp from jax import random as jrandom @@ -50,14 +51,30 @@ from jax.tree_util import tree_unflatten -### Composable gradient transformations. ### +### +# Composable gradient transformations. +# TODO(jaslanides): Make these more specific. +OptState = NamedTuple # Optimizer state is a (possibly empty) namedtuple. +Params = Any # Parameters are nests of `jnp.ndarrays`. +Updates = Params # Gradient updates are of the same type as parameters. -InitUpdate = collections.namedtuple("InitUpdate", ("init", "update")) -ClipState = collections.namedtuple("ClipState", "") +InitFn = Callable[[Params], OptState] +UpdateFn = Callable[[Updates, OptState], Tuple[Updates, OptState]] -def clip(max_delta): + +class InitUpdate(NamedTuple): + """Optix optimizers consists of a pair of functions: (initialiser, update).""" + init: InitFn + update: UpdateFn + + +class ClipState(OptState): + """The `clip` transformation is stateless.""" + + +def clip(max_delta) -> InitUpdate: """Clip updates element-wise, to be between -max_delta and +max_delta. Args: @@ -78,14 +95,15 @@ def update_fn(updates, state): return InitUpdate(init_fn, update_fn) -ClipByGlobalNormState = collections.namedtuple("ClipByGlobalNormState", "") +def global_norm(updates: Updates) -> Updates: + return jnp.sqrt(jnp.sum([jnp.sum(x**2) for x in tree_leaves(updates)])) -def global_norm(items): - return jnp.sqrt(jnp.sum([jnp.sum(x**2) for x in tree_leaves(items)])) +class ClipByGlobalNormState(OptState): + """The `clip_by_global_norm` transformation is stateless.""" -def clip_by_global_norm(max_norm): +def clip_by_global_norm(max_norm) -> InitUpdate: """Clip updates using their global norm. References: @@ -111,15 +129,17 @@ def update_fn(updates, state): return InitUpdate(init_fn, update_fn) -TraceState = collections.namedtuple("TraceState", "trace") +class TraceState(OptState): + """Holds an aggregation of past updates.""" + trace: Params -def trace(decay, nesterov): +def trace(decay: float, nesterov: bool) -> InitUpdate: """Compute a trace of past updates. Args: decay: the decay rate for the tracing of past updates. - nesterov: whether to use nesterov momentum. + nesterov: whether to use Nesterov momentum. Returns: An (init_fn, update_fn) tuple. @@ -138,7 +158,9 @@ def update_fn(updates, state): return InitUpdate(init_fn, update_fn) -ScaleByRmsState = collections.namedtuple("ScaleByRmsState", "nu") +class ScaleByRmsState(OptState): + """State for exponential root mean-squared (RMS)-normalized updates.""" + nu: Updates def _update_moment(updates, moments, decay, order): @@ -146,7 +168,7 @@ def _update_moment(updates, moments, decay, order): lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments) -def scale_by_rms(decay=0.9, eps=1e-8): +def scale_by_rms(decay: float = 0.9, eps: float = 1e-8): """Rescale updates by the root of the exp. moving avg of the square. References: @@ -172,10 +194,13 @@ def update_fn(updates, state): return InitUpdate(init_fn, update_fn) -ScaleByRStdDevState = collections.namedtuple("ScaleByRStdDevState", "mu nu") +class ScaleByRStdDevState(OptState): + """State for centered exponential moving average of squares of updates.""" + mu: Updates + nu: Updates -def scale_by_stddev(decay=0.9, eps=1e-8): +def scale_by_stddev(decay: float = 0.9, eps: float = 1e-8) -> InitUpdate: """Rescale updates by the root of the centered exp. moving average of squares. References: @@ -204,10 +229,16 @@ def update_fn(updates, state): return InitUpdate(init_fn, update_fn) -ScaleByAdamState = collections.namedtuple("ScaleByAdamState", "count mu nu") +class ScaleByAdamState(OptState): + """State for the Adam algorithm.""" + count: jnp.ndarray # shape=(), dtype=jnp.int32. + mu: Updates + nu: Updates -def scale_by_adam(b1=0.9, b2=0.999, eps=1e-8): +def scale_by_adam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8) -> InitUpdate: """Rescale updates according to the Adam algorithm. References: @@ -239,10 +270,11 @@ def update_fn(updates, state): return InitUpdate(init_fn, update_fn) -ScaleState = collections.namedtuple("ScaleState", "") +class ScaleState(NamedTuple): + """The scale transformation is stateless.""" -def scale(step_size): +def scale(step_size: float) -> InitUpdate: """Scale updates by some fixed scalar `step_size`. Args: @@ -262,10 +294,12 @@ def update_fn(updates, state): return InitUpdate(init_fn, update_fn) -ScaleByScheduleState = collections.namedtuple("ScaleByScheduleState", "count") +class ScaleByScheduleState(OptState): + """Maintains count for scale scheduling.""" + count: jnp.ndarray # shape=(), dtype=jnp.int32 -def scale_by_schedule(step_size_fn): +def scale_by_schedule(step_size_fn: Callable[[jnp.ndarray], jnp.ndarray]): """Scale updates using a custom schedule for the `step_size`. Args: @@ -286,10 +320,13 @@ def update_fn(updates, state): return InitUpdate(init_fn, update_fn) -AddNoiseState = collections.namedtuple("AddNoiseState", "count rng_key") +class AddNoiseState(OptState): + """State for adding gradient noise. Contains a count for annealing.""" + count: jnp.ndarray + rng_key: jnp.ndarray -def add_noise(eta, gamma, seed): +def add_noise(eta: float, gamma: float, seed: int) -> InitUpdate: """Add gradient noise. References: @@ -323,10 +360,13 @@ def update_fn(updates, state): # pylint: disable=missing-docstring return InitUpdate(init_fn, update_fn) -ApplyEvery = collections.namedtuple("ApplyEvery", "count grad_acc") +class ApplyEvery(OptState): + """Contains a counter and a gradient accumulator.""" + count: jnp.ndarray + grad_acc: Updates -def apply_every(k=1): +def apply_every(k: int = 1) -> InitUpdate: """accumulate gradients and apply them every k steps. Args: @@ -353,10 +393,11 @@ def update_fn(updates, state): return InitUpdate(init_fn, update_fn) -### Utilities for building and using custom optimizers. ### +### +# Utilities for building and using custom optimizers. -def chain(*args): +def chain(*args: Sequence[InitUpdate]) -> InitUpdate: """Applies a list of chainable update transformations. Given a sequence of chainable transforms, `chain` returns an `init_fn` @@ -386,7 +427,7 @@ def update_fn(updates, state): return InitUpdate(init_fn, update_fn) -def apply_updates(params, updates): +def apply_updates(params: Params, updates: Updates) -> Params: """Applies an update to the corresponding parameters. This is an (optional) utility functions that applies an update, and returns @@ -404,34 +445,50 @@ def apply_updates(params, updates): return tree_multimap(lambda p, u: p + u, params, updates) -### Aliases for popular optimizers. ### +### +# Aliases for popular optimizers. -def sgd(learning_rate, momentum=0., nesterov=False): +def sgd(learning_rate: float, + momentum: float = 0., + nesterov: bool = False) -> InitUpdate: return chain( trace(decay=momentum, nesterov=nesterov), - scale(-learning_rate)) + scale(-learning_rate), + ) -def noisy_sgd(learning_rate, eta=0.01, gamma=0.55, seed=42): +def noisy_sgd(learning_rate: float, + eta: float = 0.01, + gamma: float = 0.55, + seed: int = 0) -> InitUpdate: return chain( trace(decay=0., nesterov=False), scale(-learning_rate), - add_noise(eta, gamma, seed)) + add_noise(eta, gamma, seed), + ) -def adam(learning_rate, b1=0.9, b2=0.999, eps=1e-8): +def adam(learning_rate: float, + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8) -> InitUpdate: return chain( scale_by_adam(b1=b1, b2=b2, eps=eps), - scale(-learning_rate)) + scale(-learning_rate), + ) -def rmsprop(learning_rate, decay=0.9, eps=1e-8, centered=False): - if not centered: - return chain( - scale_by_rms(decay=decay, eps=eps), - scale(-learning_rate)) - else: +def rmsprop(learning_rate: float, + decay: float = 0.9, + eps: float = 1e-8, + centered: bool = False) -> InitUpdate: + if centered: return chain( scale_by_stddev(decay=decay, eps=eps), - scale(-learning_rate)) + scale(-learning_rate), + ) + return chain( + scale_by_rms(decay=decay, eps=eps), + scale(-learning_rate), + ) From e05d37999ec5f47e18d7d7ea81f8c9e23aeec3b7 Mon Sep 17 00:00:00 2001 From: John Aslanides Date: Sun, 12 Apr 2020 12:06:05 +0100 Subject: [PATCH 2/3] Fix function signature for chain() and remove unused collections import. --- jax/experimental/optix.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/jax/experimental/optix.py b/jax/experimental/optix.py index dbbdcf923f4a..d7f349b5eef6 100644 --- a/jax/experimental/optix.py +++ b/jax/experimental/optix.py @@ -39,8 +39,7 @@ """ -import collections -from typing import Any, Callable, NamedTuple, Sequence, Tuple +from typing import Any, Callable, NamedTuple, Tuple from jax import numpy as jnp from jax import random as jrandom @@ -397,7 +396,7 @@ def update_fn(updates, state): # Utilities for building and using custom optimizers. -def chain(*args: Sequence[InitUpdate]) -> InitUpdate: +def chain(*args: InitUpdate) -> InitUpdate: """Applies a list of chainable update transformations. Given a sequence of chainable transforms, `chain` returns an `init_fn` From 57321cf391d84aa37db46e863035bf2e083e558d Mon Sep 17 00:00:00 2001 From: John Aslanides Date: Sun, 12 Apr 2020 19:33:42 +0100 Subject: [PATCH 3/3] Include Sequence[OptState] as possible output of Init. --- jax/experimental/optix.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/experimental/optix.py b/jax/experimental/optix.py index d7f349b5eef6..04df8ddbc918 100644 --- a/jax/experimental/optix.py +++ b/jax/experimental/optix.py @@ -39,7 +39,7 @@ """ -from typing import Any, Callable, NamedTuple, Tuple +from typing import Any, Callable, NamedTuple, Sequence, Tuple, Union from jax import numpy as jnp from jax import random as jrandom @@ -59,7 +59,7 @@ Updates = Params # Gradient updates are of the same type as parameters. -InitFn = Callable[[Params], OptState] +InitFn = Callable[[Params], Union[OptState, Sequence[OptState]]] UpdateFn = Callable[[Updates, OptState], Tuple[Updates, OptState]]