Skip to content

Add type annotations to optix. #2687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 14, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 99 additions & 43 deletions jax/experimental/optix.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"""


import collections
from typing import Any, Callable, NamedTuple, Sequence, Tuple, Union

from jax import numpy as jnp
from jax import random as jrandom
Expand All @@ -50,14 +50,30 @@
from jax.tree_util import tree_unflatten


### Composable gradient transformations. ###
###
# Composable gradient transformations.

# TODO(jaslanides): Make these more specific.
OptState = NamedTuple # Optimizer state is a (possibly empty) namedtuple.
Params = Any # Parameters are nests of `jnp.ndarrays`.
Updates = Params # Gradient updates are of the same type as parameters.

InitUpdate = collections.namedtuple("InitUpdate", ("init", "update"))
ClipState = collections.namedtuple("ClipState", "")

InitFn = Callable[[Params], Union[OptState, Sequence[OptState]]]
UpdateFn = Callable[[Updates, OptState], Tuple[Updates, OptState]]

def clip(max_delta):

class InitUpdate(NamedTuple):
"""Optix optimizers consists of a pair of functions: (initialiser, update)."""
init: InitFn
update: UpdateFn


class ClipState(OptState):
"""The `clip` transformation is stateless."""


def clip(max_delta) -> InitUpdate:
"""Clip updates element-wise, to be between -max_delta and +max_delta.

Args:
Expand All @@ -78,14 +94,15 @@ def update_fn(updates, state):
return InitUpdate(init_fn, update_fn)


ClipByGlobalNormState = collections.namedtuple("ClipByGlobalNormState", "")
def global_norm(updates: Updates) -> Updates:
return jnp.sqrt(jnp.sum([jnp.sum(x**2) for x in tree_leaves(updates)]))


def global_norm(items):
return jnp.sqrt(jnp.sum([jnp.sum(x**2) for x in tree_leaves(items)]))
class ClipByGlobalNormState(OptState):
"""The `clip_by_global_norm` transformation is stateless."""


def clip_by_global_norm(max_norm):
def clip_by_global_norm(max_norm) -> InitUpdate:
"""Clip updates using their global norm.

References:
Expand All @@ -111,15 +128,17 @@ def update_fn(updates, state):
return InitUpdate(init_fn, update_fn)


TraceState = collections.namedtuple("TraceState", "trace")
class TraceState(OptState):
"""Holds an aggregation of past updates."""
trace: Params


def trace(decay, nesterov):
def trace(decay: float, nesterov: bool) -> InitUpdate:
"""Compute a trace of past updates.

Args:
decay: the decay rate for the tracing of past updates.
nesterov: whether to use nesterov momentum.
nesterov: whether to use Nesterov momentum.

Returns:
An (init_fn, update_fn) tuple.
Expand All @@ -138,15 +157,17 @@ def update_fn(updates, state):
return InitUpdate(init_fn, update_fn)


ScaleByRmsState = collections.namedtuple("ScaleByRmsState", "nu")
class ScaleByRmsState(OptState):
"""State for exponential root mean-squared (RMS)-normalized updates."""
nu: Updates


def _update_moment(updates, moments, decay, order):
return tree_multimap(
lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments)


def scale_by_rms(decay=0.9, eps=1e-8):
def scale_by_rms(decay: float = 0.9, eps: float = 1e-8):
"""Rescale updates by the root of the exp. moving avg of the square.

References:
Expand All @@ -172,10 +193,13 @@ def update_fn(updates, state):
return InitUpdate(init_fn, update_fn)


ScaleByRStdDevState = collections.namedtuple("ScaleByRStdDevState", "mu nu")
class ScaleByRStdDevState(OptState):
"""State for centered exponential moving average of squares of updates."""
mu: Updates
nu: Updates


def scale_by_stddev(decay=0.9, eps=1e-8):
def scale_by_stddev(decay: float = 0.9, eps: float = 1e-8) -> InitUpdate:
"""Rescale updates by the root of the centered exp. moving average of squares.

References:
Expand Down Expand Up @@ -204,10 +228,16 @@ def update_fn(updates, state):
return InitUpdate(init_fn, update_fn)


ScaleByAdamState = collections.namedtuple("ScaleByAdamState", "count mu nu")
class ScaleByAdamState(OptState):
"""State for the Adam algorithm."""
count: jnp.ndarray # shape=(), dtype=jnp.int32.
mu: Updates
nu: Updates


def scale_by_adam(b1=0.9, b2=0.999, eps=1e-8):
def scale_by_adam(b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8) -> InitUpdate:
"""Rescale updates according to the Adam algorithm.

References:
Expand Down Expand Up @@ -239,10 +269,11 @@ def update_fn(updates, state):
return InitUpdate(init_fn, update_fn)


ScaleState = collections.namedtuple("ScaleState", "")
class ScaleState(NamedTuple):
"""The scale transformation is stateless."""


def scale(step_size):
def scale(step_size: float) -> InitUpdate:
"""Scale updates by some fixed scalar `step_size`.

Args:
Expand All @@ -262,10 +293,12 @@ def update_fn(updates, state):
return InitUpdate(init_fn, update_fn)


ScaleByScheduleState = collections.namedtuple("ScaleByScheduleState", "count")
class ScaleByScheduleState(OptState):
"""Maintains count for scale scheduling."""
count: jnp.ndarray # shape=(), dtype=jnp.int32


def scale_by_schedule(step_size_fn):
def scale_by_schedule(step_size_fn: Callable[[jnp.ndarray], jnp.ndarray]):
"""Scale updates using a custom schedule for the `step_size`.

Args:
Expand All @@ -286,10 +319,13 @@ def update_fn(updates, state):
return InitUpdate(init_fn, update_fn)


AddNoiseState = collections.namedtuple("AddNoiseState", "count rng_key")
class AddNoiseState(OptState):
"""State for adding gradient noise. Contains a count for annealing."""
count: jnp.ndarray
rng_key: jnp.ndarray


def add_noise(eta, gamma, seed):
def add_noise(eta: float, gamma: float, seed: int) -> InitUpdate:
"""Add gradient noise.

References:
Expand Down Expand Up @@ -323,10 +359,13 @@ def update_fn(updates, state): # pylint: disable=missing-docstring
return InitUpdate(init_fn, update_fn)


ApplyEvery = collections.namedtuple("ApplyEvery", "count grad_acc")
class ApplyEvery(OptState):
"""Contains a counter and a gradient accumulator."""
count: jnp.ndarray
grad_acc: Updates


def apply_every(k=1):
def apply_every(k: int = 1) -> InitUpdate:
"""accumulate gradients and apply them every k steps.

Args:
Expand All @@ -353,10 +392,11 @@ def update_fn(updates, state):
return InitUpdate(init_fn, update_fn)


### Utilities for building and using custom optimizers. ###
###
# Utilities for building and using custom optimizers.


def chain(*args):
def chain(*args: InitUpdate) -> InitUpdate:
"""Applies a list of chainable update transformations.

Given a sequence of chainable transforms, `chain` returns an `init_fn`
Expand Down Expand Up @@ -386,7 +426,7 @@ def update_fn(updates, state):
return InitUpdate(init_fn, update_fn)


def apply_updates(params, updates):
def apply_updates(params: Params, updates: Updates) -> Params:
"""Applies an update to the corresponding parameters.

This is an (optional) utility functions that applies an update, and returns
Expand All @@ -404,34 +444,50 @@ def apply_updates(params, updates):
return tree_multimap(lambda p, u: p + u, params, updates)


### Aliases for popular optimizers. ###
###
# Aliases for popular optimizers.


def sgd(learning_rate, momentum=0., nesterov=False):
def sgd(learning_rate: float,
momentum: float = 0.,
nesterov: bool = False) -> InitUpdate:
return chain(
trace(decay=momentum, nesterov=nesterov),
scale(-learning_rate))
scale(-learning_rate),
)


def noisy_sgd(learning_rate, eta=0.01, gamma=0.55, seed=42):
def noisy_sgd(learning_rate: float,
eta: float = 0.01,
gamma: float = 0.55,
seed: int = 0) -> InitUpdate:
return chain(
trace(decay=0., nesterov=False),
scale(-learning_rate),
add_noise(eta, gamma, seed))
add_noise(eta, gamma, seed),
)


def adam(learning_rate, b1=0.9, b2=0.999, eps=1e-8):
def adam(learning_rate: float,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have AdamState (etc) for the common cases (e.g. AdamState = Sequence[ScaleByAdamState, ScaleState, ScaleState]).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave that for a follow-up. I think we should improve the other types before we do that -- e.g. by making InitUpdate parametric/generic w.r.t. the state, so that we can have InitUpdate[Adam] rather than shoving unions of things into InitUpdate. WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I think that will require quite a bit of refactoring though (e.g. I'm pretty sure I've tried to subclass NamedTuple and Generic before and it failed because of a metaclass conflict)? Up to you if you think its worth it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack -- will return to this at some point if I have time.

b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8) -> InitUpdate:
return chain(
scale_by_adam(b1=b1, b2=b2, eps=eps),
scale(-learning_rate))
scale(-learning_rate),
)


def rmsprop(learning_rate, decay=0.9, eps=1e-8, centered=False):
if not centered:
return chain(
scale_by_rms(decay=decay, eps=eps),
scale(-learning_rate))
else:
def rmsprop(learning_rate: float,
decay: float = 0.9,
eps: float = 1e-8,
centered: bool = False) -> InitUpdate:
if centered:
return chain(
scale_by_stddev(decay=decay, eps=eps),
scale(-learning_rate))
scale(-learning_rate),
)
return chain(
scale_by_rms(decay=decay, eps=eps),
scale(-learning_rate),
)