Skip to content

Commit 1ac80d7

Browse files
authored
Add type annotations to optix. (#2687)
* Add type annotations to optix. * Fix function signature for chain() and remove unused collections import. * Include Sequence[OptState] as possible output of Init.
1 parent bbf7a43 commit 1ac80d7

File tree

1 file changed

+99
-43
lines changed

1 file changed

+99
-43
lines changed

jax/experimental/optix.py

+99-43
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
"""
4040

4141

42-
import collections
42+
from typing import Any, Callable, NamedTuple, Sequence, Tuple, Union
4343

4444
from jax import numpy as jnp
4545
from jax import random as jrandom
@@ -50,14 +50,30 @@
5050
from jax.tree_util import tree_unflatten
5151

5252

53-
### Composable gradient transformations. ###
53+
###
54+
# Composable gradient transformations.
5455

56+
# TODO(jaslanides): Make these more specific.
57+
OptState = NamedTuple # Optimizer state is a (possibly empty) namedtuple.
58+
Params = Any # Parameters are nests of `jnp.ndarrays`.
59+
Updates = Params # Gradient updates are of the same type as parameters.
5560

56-
InitUpdate = collections.namedtuple("InitUpdate", ("init", "update"))
57-
ClipState = collections.namedtuple("ClipState", "")
5861

62+
InitFn = Callable[[Params], Union[OptState, Sequence[OptState]]]
63+
UpdateFn = Callable[[Updates, OptState], Tuple[Updates, OptState]]
5964

60-
def clip(max_delta):
65+
66+
class InitUpdate(NamedTuple):
67+
"""Optix optimizers consists of a pair of functions: (initialiser, update)."""
68+
init: InitFn
69+
update: UpdateFn
70+
71+
72+
class ClipState(OptState):
73+
"""The `clip` transformation is stateless."""
74+
75+
76+
def clip(max_delta) -> InitUpdate:
6177
"""Clip updates element-wise, to be between -max_delta and +max_delta.
6278
6379
Args:
@@ -78,14 +94,15 @@ def update_fn(updates, state):
7894
return InitUpdate(init_fn, update_fn)
7995

8096

81-
ClipByGlobalNormState = collections.namedtuple("ClipByGlobalNormState", "")
97+
def global_norm(updates: Updates) -> Updates:
98+
return jnp.sqrt(jnp.sum([jnp.sum(x**2) for x in tree_leaves(updates)]))
8299

83100

84-
def global_norm(items):
85-
return jnp.sqrt(jnp.sum([jnp.sum(x**2) for x in tree_leaves(items)]))
101+
class ClipByGlobalNormState(OptState):
102+
"""The `clip_by_global_norm` transformation is stateless."""
86103

87104

88-
def clip_by_global_norm(max_norm):
105+
def clip_by_global_norm(max_norm) -> InitUpdate:
89106
"""Clip updates using their global norm.
90107
91108
References:
@@ -111,15 +128,17 @@ def update_fn(updates, state):
111128
return InitUpdate(init_fn, update_fn)
112129

113130

114-
TraceState = collections.namedtuple("TraceState", "trace")
131+
class TraceState(OptState):
132+
"""Holds an aggregation of past updates."""
133+
trace: Params
115134

116135

117-
def trace(decay, nesterov):
136+
def trace(decay: float, nesterov: bool) -> InitUpdate:
118137
"""Compute a trace of past updates.
119138
120139
Args:
121140
decay: the decay rate for the tracing of past updates.
122-
nesterov: whether to use nesterov momentum.
141+
nesterov: whether to use Nesterov momentum.
123142
124143
Returns:
125144
An (init_fn, update_fn) tuple.
@@ -138,15 +157,17 @@ def update_fn(updates, state):
138157
return InitUpdate(init_fn, update_fn)
139158

140159

141-
ScaleByRmsState = collections.namedtuple("ScaleByRmsState", "nu")
160+
class ScaleByRmsState(OptState):
161+
"""State for exponential root mean-squared (RMS)-normalized updates."""
162+
nu: Updates
142163

143164

144165
def _update_moment(updates, moments, decay, order):
145166
return tree_multimap(
146167
lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments)
147168

148169

149-
def scale_by_rms(decay=0.9, eps=1e-8):
170+
def scale_by_rms(decay: float = 0.9, eps: float = 1e-8):
150171
"""Rescale updates by the root of the exp. moving avg of the square.
151172
152173
References:
@@ -172,10 +193,13 @@ def update_fn(updates, state):
172193
return InitUpdate(init_fn, update_fn)
173194

174195

175-
ScaleByRStdDevState = collections.namedtuple("ScaleByRStdDevState", "mu nu")
196+
class ScaleByRStdDevState(OptState):
197+
"""State for centered exponential moving average of squares of updates."""
198+
mu: Updates
199+
nu: Updates
176200

177201

178-
def scale_by_stddev(decay=0.9, eps=1e-8):
202+
def scale_by_stddev(decay: float = 0.9, eps: float = 1e-8) -> InitUpdate:
179203
"""Rescale updates by the root of the centered exp. moving average of squares.
180204
181205
References:
@@ -204,10 +228,16 @@ def update_fn(updates, state):
204228
return InitUpdate(init_fn, update_fn)
205229

206230

207-
ScaleByAdamState = collections.namedtuple("ScaleByAdamState", "count mu nu")
231+
class ScaleByAdamState(OptState):
232+
"""State for the Adam algorithm."""
233+
count: jnp.ndarray # shape=(), dtype=jnp.int32.
234+
mu: Updates
235+
nu: Updates
208236

209237

210-
def scale_by_adam(b1=0.9, b2=0.999, eps=1e-8):
238+
def scale_by_adam(b1: float = 0.9,
239+
b2: float = 0.999,
240+
eps: float = 1e-8) -> InitUpdate:
211241
"""Rescale updates according to the Adam algorithm.
212242
213243
References:
@@ -239,10 +269,11 @@ def update_fn(updates, state):
239269
return InitUpdate(init_fn, update_fn)
240270

241271

242-
ScaleState = collections.namedtuple("ScaleState", "")
272+
class ScaleState(NamedTuple):
273+
"""The scale transformation is stateless."""
243274

244275

245-
def scale(step_size):
276+
def scale(step_size: float) -> InitUpdate:
246277
"""Scale updates by some fixed scalar `step_size`.
247278
248279
Args:
@@ -262,10 +293,12 @@ def update_fn(updates, state):
262293
return InitUpdate(init_fn, update_fn)
263294

264295

265-
ScaleByScheduleState = collections.namedtuple("ScaleByScheduleState", "count")
296+
class ScaleByScheduleState(OptState):
297+
"""Maintains count for scale scheduling."""
298+
count: jnp.ndarray # shape=(), dtype=jnp.int32
266299

267300

268-
def scale_by_schedule(step_size_fn):
301+
def scale_by_schedule(step_size_fn: Callable[[jnp.ndarray], jnp.ndarray]):
269302
"""Scale updates using a custom schedule for the `step_size`.
270303
271304
Args:
@@ -286,10 +319,13 @@ def update_fn(updates, state):
286319
return InitUpdate(init_fn, update_fn)
287320

288321

289-
AddNoiseState = collections.namedtuple("AddNoiseState", "count rng_key")
322+
class AddNoiseState(OptState):
323+
"""State for adding gradient noise. Contains a count for annealing."""
324+
count: jnp.ndarray
325+
rng_key: jnp.ndarray
290326

291327

292-
def add_noise(eta, gamma, seed):
328+
def add_noise(eta: float, gamma: float, seed: int) -> InitUpdate:
293329
"""Add gradient noise.
294330
295331
References:
@@ -323,10 +359,13 @@ def update_fn(updates, state): # pylint: disable=missing-docstring
323359
return InitUpdate(init_fn, update_fn)
324360

325361

326-
ApplyEvery = collections.namedtuple("ApplyEvery", "count grad_acc")
362+
class ApplyEvery(OptState):
363+
"""Contains a counter and a gradient accumulator."""
364+
count: jnp.ndarray
365+
grad_acc: Updates
327366

328367

329-
def apply_every(k=1):
368+
def apply_every(k: int = 1) -> InitUpdate:
330369
"""accumulate gradients and apply them every k steps.
331370
332371
Args:
@@ -353,10 +392,11 @@ def update_fn(updates, state):
353392
return InitUpdate(init_fn, update_fn)
354393

355394

356-
### Utilities for building and using custom optimizers. ###
395+
###
396+
# Utilities for building and using custom optimizers.
357397

358398

359-
def chain(*args):
399+
def chain(*args: InitUpdate) -> InitUpdate:
360400
"""Applies a list of chainable update transformations.
361401
362402
Given a sequence of chainable transforms, `chain` returns an `init_fn`
@@ -386,7 +426,7 @@ def update_fn(updates, state):
386426
return InitUpdate(init_fn, update_fn)
387427

388428

389-
def apply_updates(params, updates):
429+
def apply_updates(params: Params, updates: Updates) -> Params:
390430
"""Applies an update to the corresponding parameters.
391431
392432
This is an (optional) utility functions that applies an update, and returns
@@ -404,34 +444,50 @@ def apply_updates(params, updates):
404444
return tree_multimap(lambda p, u: p + u, params, updates)
405445

406446

407-
### Aliases for popular optimizers. ###
447+
###
448+
# Aliases for popular optimizers.
408449

409450

410-
def sgd(learning_rate, momentum=0., nesterov=False):
451+
def sgd(learning_rate: float,
452+
momentum: float = 0.,
453+
nesterov: bool = False) -> InitUpdate:
411454
return chain(
412455
trace(decay=momentum, nesterov=nesterov),
413-
scale(-learning_rate))
456+
scale(-learning_rate),
457+
)
414458

415459

416-
def noisy_sgd(learning_rate, eta=0.01, gamma=0.55, seed=42):
460+
def noisy_sgd(learning_rate: float,
461+
eta: float = 0.01,
462+
gamma: float = 0.55,
463+
seed: int = 0) -> InitUpdate:
417464
return chain(
418465
trace(decay=0., nesterov=False),
419466
scale(-learning_rate),
420-
add_noise(eta, gamma, seed))
467+
add_noise(eta, gamma, seed),
468+
)
421469

422470

423-
def adam(learning_rate, b1=0.9, b2=0.999, eps=1e-8):
471+
def adam(learning_rate: float,
472+
b1: float = 0.9,
473+
b2: float = 0.999,
474+
eps: float = 1e-8) -> InitUpdate:
424475
return chain(
425476
scale_by_adam(b1=b1, b2=b2, eps=eps),
426-
scale(-learning_rate))
477+
scale(-learning_rate),
478+
)
427479

428480

429-
def rmsprop(learning_rate, decay=0.9, eps=1e-8, centered=False):
430-
if not centered:
431-
return chain(
432-
scale_by_rms(decay=decay, eps=eps),
433-
scale(-learning_rate))
434-
else:
481+
def rmsprop(learning_rate: float,
482+
decay: float = 0.9,
483+
eps: float = 1e-8,
484+
centered: bool = False) -> InitUpdate:
485+
if centered:
435486
return chain(
436487
scale_by_stddev(decay=decay, eps=eps),
437-
scale(-learning_rate))
488+
scale(-learning_rate),
489+
)
490+
return chain(
491+
scale_by_rms(decay=decay, eps=eps),
492+
scale(-learning_rate),
493+
)

0 commit comments

Comments
 (0)