-
-
Notifications
You must be signed in to change notification settings - Fork 150
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #50 from patrick-kidger/v030
Version 0.3.0 -- BatchNorm and stateful
- Loading branch information
Showing
22 changed files
with
829 additions
and
900 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
# Utilities | ||
# Helpers for trees | ||
|
||
::: equinox.apply_updates | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,11 @@ | |
members: | ||
- __init__ | ||
- __call__ | ||
|
||
--- | ||
|
||
::: equinox.experimental.BatchNorm | ||
selection: | ||
members: | ||
- __init__ | ||
- __call__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Stateful operations | ||
|
||
These operations can be used to introduce save/load JAX arrays as a side-effect of JAX operations, even under JIT. | ||
|
||
!!! warning | ||
|
||
This is considered experimental. | ||
|
||
Use cases: | ||
- Something like [`equinox.experimental.BatchNorm`][], for which we would like to save the running statistics as a side-effect. | ||
- Implicitly passing information between loop iterations -- i.e. rather than explicitly via the `carry` argument to `lax.scan`. Perhaps you're using a third-party library that handles the `lax.scan`, but you want to pass your own information between repeated invocations. | ||
|
||
Example: | ||
```python | ||
import equinox as eqx | ||
import jax | ||
import jax.lax as lax | ||
import jax.numpy as jnp | ||
|
||
index = eqx.experimental.StateIndex() | ||
eqx.experimental.set_state(index, jnp.array(0)) | ||
|
||
def scan_fun(_, __): | ||
val = eqx.experimental.get_state(index) | ||
val = val + 1 | ||
eqx.experimental.set_state(index, val) | ||
return None, val | ||
|
||
_, out = lax.scan(scan_fun, None, xs=None, length=5) | ||
print(out) # [1 2 3 4 5] | ||
``` | ||
|
||
--- | ||
|
||
::: equinox.experimental.StateIndex | ||
selection: | ||
members: false | ||
|
||
--- | ||
|
||
::: equinox.experimental.get_state | ||
|
||
--- | ||
|
||
::: equinox.experimental.set_state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# FAQ | ||
|
||
## Optax is throwing an error. | ||
|
||
Probably you're writing code that looks like | ||
```python | ||
optim = optax.adam(learning_rate) | ||
optim.init(model) | ||
``` | ||
and getting an error that looks like | ||
``` | ||
TypeError: zeros_like requires ndarray or scalar arguments, got <class 'jax._src.custom_derivatives.custom_jvp'> at position 0. | ||
``` | ||
|
||
This can be fixed by doing | ||
```python | ||
optim.init(eqx.filter(model, eqx.is_array)) | ||
``` | ||
which after a little thought should make sense: Optax can only optimise JAX arrays. It's not meaningful to ask Optax to optimsie whichever other arbitrary Python objects may be a part of your model. (e.g. the activation function of an `eqx.nn.MLP`). | ||
|
||
## A module saved in two places has become two independent copies. | ||
|
||
Probably you're doing something like | ||
```python | ||
class Module(eqx.Module): | ||
linear1: eqx.nn.Linear | ||
linear2: eqx.nn.Linear | ||
|
||
def __init__(...): | ||
shared_linear = eqx.nn.Linear(...) | ||
self.linear1 = shared_linear | ||
self.linear2 = shared_linear | ||
``` | ||
in which the same object is saved multiple times in the model. After making some gradient updates you'll find that `self.linear1` and `self.linear2` are now different. | ||
|
||
Recall that in Equinox, models are PyTrees. Meanwhile, JAX treats all PyTrees as *trees*: that is, the same object does not appear more in the tree than once. (If it did, then it would be a *directed acyclic graph* instead.) If JAX ever encounters the same object multiple times then it will unwittingly make independent copies of the object whenever it transforms the overall PyTree. | ||
|
||
The resolution is simple: just don't store the same object in multiple places in the PyTree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,18 @@ | ||
from . import nn | ||
from . import experimental, nn | ||
from .filters import ( | ||
combine, | ||
filter, | ||
is_array, | ||
is_array_like, | ||
is_inexact_array, | ||
is_inexact_array_like, | ||
merge, | ||
partition, | ||
split, | ||
) | ||
from .grad import ( | ||
filter_custom_vjp, | ||
filter_grad, | ||
filter_value_and_grad, | ||
gradf, | ||
value_and_grad_f, | ||
) | ||
from .jit import filter_jit, jitf | ||
from .grad import filter_custom_vjp, filter_grad, filter_value_and_grad | ||
from .jit import filter_jit | ||
from .module import Module, static_field | ||
from .tree import tree_at, tree_equal, tree_pformat | ||
from .update import apply_updates | ||
|
||
|
||
__version__ = "0.2.2" | ||
__version__ = "0.3.0" |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .normalisation import BatchNorm | ||
from .stateful import get_state, set_state, StateIndex |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
from typing import Optional | ||
|
||
import jax | ||
import jax.lax as lax | ||
import jax.numpy as jnp | ||
|
||
from ..custom_types import Array | ||
from ..module import Module, static_field | ||
from .stateful import get_state, set_state, StateIndex | ||
|
||
|
||
# This is marked experimental because it uses the experimental stateful functionality. | ||
class BatchNorm(Module): | ||
r"""Computes a mean and standard deviation over the batch and spatial | ||
dimensions of an array, and uses these to normalise the whole array. Optionally | ||
applies a channelwise affine transformation afterwards. | ||
Given an input array $x = [x_1, ... x_C]$ with $C$ channels, this layer computes | ||
$$\frac{x_i - \mathbb{E}[x_i]}{\sqrt{\text{Var}[x_i] + \varepsilon}} * \gamma_i + \beta_i$$ | ||
for all $i$. Here $*$ denotes elementwise multiplication and $\gamma$, $\beta$ have | ||
shape $(C,)$ if `channelwise_affine=True` and $\gamma = 1$, $\beta = 0$ if | ||
`channelwise_affine=False`. Expectations are computed over all spatial dimensions | ||
*and* over the batch dimension, and updated batch-by-batch according to `momentum`. | ||
!!! warning | ||
This layer must be used inside of a `vmap` or `pmap` with a matching | ||
`axis_name`. (Not doing so will raise a `NameError`.) | ||
!!! warning | ||
[`equinox.experimental.BatchNorm`][] saves the running statistics as a side | ||
effect of its forward pass. Side effects are quite unusual in JAX; as such | ||
`BatchNorm` is considered experimental. Let us know how you find it! | ||
!!! example | ||
```python | ||
import equinox as eqx | ||
import jax | ||
import jax.numpy as jnp | ||
import jax.random as jr | ||
key = jr.PRNGKey(0) | ||
mkey, dkey = jr.split(key) | ||
model = eqx.nn.Sequential([ | ||
eqx.nn.Linear(in_features=3, out_features=4, key=mkey), | ||
eqx.experimental.BatchNorm(input_size=4, axis_name="batch"), | ||
]) | ||
x = jr.normal(dkey, (10, 3)) | ||
jax.vmap(model, axis_name="batch")(x) | ||
``` | ||
""" # noqa: E501 | ||
weight: Optional[Array["input_size"]] | ||
bias: Optional[Array["input_size"]] | ||
first_time_index: StateIndex | ||
state_index: StateIndex | ||
axis_name: str | ||
inference: bool | ||
input_size: int = static_field() | ||
eps: float = static_field() | ||
channelwise_affine: bool = static_field() | ||
momentum: float = static_field() | ||
|
||
def __init__( | ||
self, | ||
input_size: int, | ||
axis_name: str, | ||
eps: float = 1e-5, | ||
channelwise_affine: bool = True, | ||
momentum: float = 0.99, | ||
inference: bool = False, | ||
**kwargs, | ||
): | ||
"""**Arguments:** | ||
- `input_size`: The number of channels in the input array. | ||
- `axis_name`: The name of the batch axis to compute statistics over, as passed | ||
to `axis_name` in `jax.vmap` or `jax.pmap`. | ||
- `eps`: Value added to the denominator for numerical stability. | ||
- `channelwise_affine`: Whether the module has learnable channel-wise affine | ||
parameters. | ||
- `momentum`: The rate at which to update the running statistics. Should be a | ||
value between 0 and 1 exclusive. | ||
- `inference`: If `False` then the batch means and variances will be calculated | ||
and used to update the running statistics. If `True` then the running | ||
statistics are directly used for normalisation. | ||
""" | ||
|
||
super().__init__(**kwargs) | ||
|
||
if channelwise_affine: | ||
self.weight = jnp.ones((input_size,)) | ||
self.bias = jnp.zeros((input_size,)) | ||
else: | ||
self.weight = None | ||
self.bias = None | ||
self.first_time_index = StateIndex() | ||
self.state_index = StateIndex() | ||
self.inference = inference | ||
self.axis_name = axis_name | ||
self.input_size = input_size | ||
self.eps = eps | ||
self.channelwise_affine = channelwise_affine | ||
self.momentum = momentum | ||
|
||
set_state(self.first_time_index, jnp.array(True)) | ||
|
||
def __call__( | ||
self, | ||
x: Array, | ||
*, | ||
key: Optional["jax.random.PRNGKey"] = None, | ||
inference: Optional[bool] = None, | ||
) -> Array: | ||
"""**Arguments:** | ||
- `x`: A JAX array of shape `(input_size, dim_1, ..., dim_N)`. | ||
- `key`: Ignored; provided for compatibility with the rest of the Equinox API. | ||
(Keyword only argument.) | ||
- `inference`: As per [`equinox.experimental.BatchNorm.__init__`][]. If | ||
`True` or `False` then it will take priority over `self.update_stats`. If | ||
`None` then the value from `self.update_stats` will be used. | ||
**Returns:** | ||
A JAX array of shape `(input_size, dim_1, ..., dim_N)`. | ||
**Raises:** | ||
A `NameError` if no `vmap`s are placed around this operation, or if this vmap | ||
does not have a matching `axis_name`. | ||
""" | ||
|
||
def _stats(y): | ||
mean = jnp.mean(y) | ||
mean = lax.pmean(mean, self.axis_name) | ||
var = jnp.mean((y - mean) ** 2) | ||
var = lax.pmean(var, self.axis_name) | ||
return mean, var | ||
|
||
batch_state = jax.vmap(_stats)(x) | ||
|
||
if inference is None: | ||
inference = self.inference | ||
if inference: | ||
running_mean, running_var = get_state(self.state_index, like=batch_state) | ||
else: | ||
first_time = get_state(self.first_time_index, like=jnp.array(False)) | ||
running_state = lax.cond( | ||
first_time, | ||
lambda: batch_state, | ||
lambda: get_state(self.state_index, like=batch_state), | ||
) | ||
set_state(self.first_time_index, jnp.array(False)) | ||
running_mean, running_var = running_state | ||
|
||
batch_mean, batch_var = batch_state | ||
running_mean = ( | ||
1 - self.momentum | ||
) * batch_mean + self.momentum * running_mean | ||
running_var = (1 - self.momentum) * batch_var + self.momentum * running_var | ||
set_state(self.state_index, (running_mean, running_var)) | ||
|
||
def _norm(y, m, v, w, b): | ||
out = (y - m) / jnp.sqrt(v + self.eps) | ||
if self.channelwise_affine: | ||
out = out * w + b | ||
return out | ||
|
||
return jax.vmap(_norm)(x, running_mean, running_var, self.weight, self.bias) |
Oops, something went wrong.