39
39
"""
40
40
41
41
42
- import collections
42
+ from typing import Any , Callable , NamedTuple , Sequence , Tuple , Union
43
43
44
44
from jax import numpy as jnp
45
45
from jax import random as jrandom
50
50
from jax .tree_util import tree_unflatten
51
51
52
52
53
- ### Composable gradient transformations. ###
53
+ ###
54
+ # Composable gradient transformations.
54
55
56
+ # TODO(jaslanides): Make these more specific.
57
+ OptState = NamedTuple # Optimizer state is a (possibly empty) namedtuple.
58
+ Params = Any # Parameters are nests of `jnp.ndarrays`.
59
+ Updates = Params # Gradient updates are of the same type as parameters.
55
60
56
- InitUpdate = collections .namedtuple ("InitUpdate" , ("init" , "update" ))
57
- ClipState = collections .namedtuple ("ClipState" , "" )
58
61
62
+ InitFn = Callable [[Params ], Union [OptState , Sequence [OptState ]]]
63
+ UpdateFn = Callable [[Updates , OptState ], Tuple [Updates , OptState ]]
59
64
60
- def clip (max_delta ):
65
+
66
+ class InitUpdate (NamedTuple ):
67
+ """Optix optimizers consists of a pair of functions: (initialiser, update)."""
68
+ init : InitFn
69
+ update : UpdateFn
70
+
71
+
72
+ class ClipState (OptState ):
73
+ """The `clip` transformation is stateless."""
74
+
75
+
76
+ def clip (max_delta ) -> InitUpdate :
61
77
"""Clip updates element-wise, to be between -max_delta and +max_delta.
62
78
63
79
Args:
@@ -78,14 +94,15 @@ def update_fn(updates, state):
78
94
return InitUpdate (init_fn , update_fn )
79
95
80
96
81
- ClipByGlobalNormState = collections .namedtuple ("ClipByGlobalNormState" , "" )
97
+ def global_norm (updates : Updates ) -> Updates :
98
+ return jnp .sqrt (jnp .sum ([jnp .sum (x ** 2 ) for x in tree_leaves (updates )]))
82
99
83
100
84
- def global_norm ( items ):
85
- return jnp . sqrt ( jnp . sum ([ jnp . sum ( x ** 2 ) for x in tree_leaves ( items )]))
101
+ class ClipByGlobalNormState ( OptState ):
102
+ """The `clip_by_global_norm` transformation is stateless."""
86
103
87
104
88
- def clip_by_global_norm (max_norm ):
105
+ def clip_by_global_norm (max_norm ) -> InitUpdate :
89
106
"""Clip updates using their global norm.
90
107
91
108
References:
@@ -111,15 +128,17 @@ def update_fn(updates, state):
111
128
return InitUpdate (init_fn , update_fn )
112
129
113
130
114
- TraceState = collections .namedtuple ("TraceState" , "trace" )
131
+ class TraceState (OptState ):
132
+ """Holds an aggregation of past updates."""
133
+ trace : Params
115
134
116
135
117
- def trace (decay , nesterov ) :
136
+ def trace (decay : float , nesterov : bool ) -> InitUpdate :
118
137
"""Compute a trace of past updates.
119
138
120
139
Args:
121
140
decay: the decay rate for the tracing of past updates.
122
- nesterov: whether to use nesterov momentum.
141
+ nesterov: whether to use Nesterov momentum.
123
142
124
143
Returns:
125
144
An (init_fn, update_fn) tuple.
@@ -138,15 +157,17 @@ def update_fn(updates, state):
138
157
return InitUpdate (init_fn , update_fn )
139
158
140
159
141
- ScaleByRmsState = collections .namedtuple ("ScaleByRmsState" , "nu" )
160
+ class ScaleByRmsState (OptState ):
161
+ """State for exponential root mean-squared (RMS)-normalized updates."""
162
+ nu : Updates
142
163
143
164
144
165
def _update_moment (updates , moments , decay , order ):
145
166
return tree_multimap (
146
167
lambda g , t : (1 - decay ) * (g ** order ) + decay * t , updates , moments )
147
168
148
169
149
- def scale_by_rms (decay = 0.9 , eps = 1e-8 ):
170
+ def scale_by_rms (decay : float = 0.9 , eps : float = 1e-8 ):
150
171
"""Rescale updates by the root of the exp. moving avg of the square.
151
172
152
173
References:
@@ -172,10 +193,13 @@ def update_fn(updates, state):
172
193
return InitUpdate (init_fn , update_fn )
173
194
174
195
175
- ScaleByRStdDevState = collections .namedtuple ("ScaleByRStdDevState" , "mu nu" )
196
+ class ScaleByRStdDevState (OptState ):
197
+ """State for centered exponential moving average of squares of updates."""
198
+ mu : Updates
199
+ nu : Updates
176
200
177
201
178
- def scale_by_stddev (decay = 0.9 , eps = 1e-8 ):
202
+ def scale_by_stddev (decay : float = 0.9 , eps : float = 1e-8 ) -> InitUpdate :
179
203
"""Rescale updates by the root of the centered exp. moving average of squares.
180
204
181
205
References:
@@ -204,10 +228,16 @@ def update_fn(updates, state):
204
228
return InitUpdate (init_fn , update_fn )
205
229
206
230
207
- ScaleByAdamState = collections .namedtuple ("ScaleByAdamState" , "count mu nu" )
231
+ class ScaleByAdamState (OptState ):
232
+ """State for the Adam algorithm."""
233
+ count : jnp .ndarray # shape=(), dtype=jnp.int32.
234
+ mu : Updates
235
+ nu : Updates
208
236
209
237
210
- def scale_by_adam (b1 = 0.9 , b2 = 0.999 , eps = 1e-8 ):
238
+ def scale_by_adam (b1 : float = 0.9 ,
239
+ b2 : float = 0.999 ,
240
+ eps : float = 1e-8 ) -> InitUpdate :
211
241
"""Rescale updates according to the Adam algorithm.
212
242
213
243
References:
@@ -239,10 +269,11 @@ def update_fn(updates, state):
239
269
return InitUpdate (init_fn , update_fn )
240
270
241
271
242
- ScaleState = collections .namedtuple ("ScaleState" , "" )
272
+ class ScaleState (NamedTuple ):
273
+ """The scale transformation is stateless."""
243
274
244
275
245
- def scale (step_size ) :
276
+ def scale (step_size : float ) -> InitUpdate :
246
277
"""Scale updates by some fixed scalar `step_size`.
247
278
248
279
Args:
@@ -262,10 +293,12 @@ def update_fn(updates, state):
262
293
return InitUpdate (init_fn , update_fn )
263
294
264
295
265
- ScaleByScheduleState = collections .namedtuple ("ScaleByScheduleState" , "count" )
296
+ class ScaleByScheduleState (OptState ):
297
+ """Maintains count for scale scheduling."""
298
+ count : jnp .ndarray # shape=(), dtype=jnp.int32
266
299
267
300
268
- def scale_by_schedule (step_size_fn ):
301
+ def scale_by_schedule (step_size_fn : Callable [[ jnp . ndarray ], jnp . ndarray ] ):
269
302
"""Scale updates using a custom schedule for the `step_size`.
270
303
271
304
Args:
@@ -286,10 +319,13 @@ def update_fn(updates, state):
286
319
return InitUpdate (init_fn , update_fn )
287
320
288
321
289
- AddNoiseState = collections .namedtuple ("AddNoiseState" , "count rng_key" )
322
+ class AddNoiseState (OptState ):
323
+ """State for adding gradient noise. Contains a count for annealing."""
324
+ count : jnp .ndarray
325
+ rng_key : jnp .ndarray
290
326
291
327
292
- def add_noise (eta , gamma , seed ) :
328
+ def add_noise (eta : float , gamma : float , seed : int ) -> InitUpdate :
293
329
"""Add gradient noise.
294
330
295
331
References:
@@ -323,10 +359,13 @@ def update_fn(updates, state): # pylint: disable=missing-docstring
323
359
return InitUpdate (init_fn , update_fn )
324
360
325
361
326
- ApplyEvery = collections .namedtuple ("ApplyEvery" , "count grad_acc" )
362
+ class ApplyEvery (OptState ):
363
+ """Contains a counter and a gradient accumulator."""
364
+ count : jnp .ndarray
365
+ grad_acc : Updates
327
366
328
367
329
- def apply_every (k = 1 ) :
368
+ def apply_every (k : int = 1 ) -> InitUpdate :
330
369
"""accumulate gradients and apply them every k steps.
331
370
332
371
Args:
@@ -353,10 +392,11 @@ def update_fn(updates, state):
353
392
return InitUpdate (init_fn , update_fn )
354
393
355
394
356
- ### Utilities for building and using custom optimizers. ###
395
+ ###
396
+ # Utilities for building and using custom optimizers.
357
397
358
398
359
- def chain (* args ) :
399
+ def chain (* args : InitUpdate ) -> InitUpdate :
360
400
"""Applies a list of chainable update transformations.
361
401
362
402
Given a sequence of chainable transforms, `chain` returns an `init_fn`
@@ -386,7 +426,7 @@ def update_fn(updates, state):
386
426
return InitUpdate (init_fn , update_fn )
387
427
388
428
389
- def apply_updates (params , updates ) :
429
+ def apply_updates (params : Params , updates : Updates ) -> Params :
390
430
"""Applies an update to the corresponding parameters.
391
431
392
432
This is an (optional) utility functions that applies an update, and returns
@@ -404,34 +444,50 @@ def apply_updates(params, updates):
404
444
return tree_multimap (lambda p , u : p + u , params , updates )
405
445
406
446
407
- ### Aliases for popular optimizers. ###
447
+ ###
448
+ # Aliases for popular optimizers.
408
449
409
450
410
- def sgd (learning_rate , momentum = 0. , nesterov = False ):
451
+ def sgd (learning_rate : float ,
452
+ momentum : float = 0. ,
453
+ nesterov : bool = False ) -> InitUpdate :
411
454
return chain (
412
455
trace (decay = momentum , nesterov = nesterov ),
413
- scale (- learning_rate ))
456
+ scale (- learning_rate ),
457
+ )
414
458
415
459
416
- def noisy_sgd (learning_rate , eta = 0.01 , gamma = 0.55 , seed = 42 ):
460
+ def noisy_sgd (learning_rate : float ,
461
+ eta : float = 0.01 ,
462
+ gamma : float = 0.55 ,
463
+ seed : int = 0 ) -> InitUpdate :
417
464
return chain (
418
465
trace (decay = 0. , nesterov = False ),
419
466
scale (- learning_rate ),
420
- add_noise (eta , gamma , seed ))
467
+ add_noise (eta , gamma , seed ),
468
+ )
421
469
422
470
423
- def adam (learning_rate , b1 = 0.9 , b2 = 0.999 , eps = 1e-8 ):
471
+ def adam (learning_rate : float ,
472
+ b1 : float = 0.9 ,
473
+ b2 : float = 0.999 ,
474
+ eps : float = 1e-8 ) -> InitUpdate :
424
475
return chain (
425
476
scale_by_adam (b1 = b1 , b2 = b2 , eps = eps ),
426
- scale (- learning_rate ))
477
+ scale (- learning_rate ),
478
+ )
427
479
428
480
429
- def rmsprop (learning_rate , decay = 0.9 , eps = 1e-8 , centered = False ):
430
- if not centered :
431
- return chain (
432
- scale_by_rms (decay = decay , eps = eps ),
433
- scale (- learning_rate ))
434
- else :
481
+ def rmsprop (learning_rate : float ,
482
+ decay : float = 0.9 ,
483
+ eps : float = 1e-8 ,
484
+ centered : bool = False ) -> InitUpdate :
485
+ if centered :
435
486
return chain (
436
487
scale_by_stddev (decay = decay , eps = eps ),
437
- scale (- learning_rate ))
488
+ scale (- learning_rate ),
489
+ )
490
+ return chain (
491
+ scale_by_rms (decay = decay , eps = eps ),
492
+ scale (- learning_rate ),
493
+ )
0 commit comments