Skip to content

Commit

Permalink
Merge pull request #50 from patrick-kidger/v030
Browse files Browse the repository at this point in the history
Version 0.3.0 -- BatchNorm and stateful
  • Loading branch information
patrick-kidger authored Mar 27, 2022
2 parents d394f55 + 5120743 commit 2dba6cf
Show file tree
Hide file tree
Showing 22 changed files with 829 additions and 900 deletions.
2 changes: 1 addition & 1 deletion docs/api/utilities.md → docs/api/helpers.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Utilities
# Helpers for trees

::: equinox.apply_updates

Expand Down
8 changes: 8 additions & 0 deletions docs/api/nn/normalisation.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,11 @@
members:
- __init__
- __call__

---

::: equinox.experimental.BatchNorm
selection:
members:
- __init__
- __call__
45 changes: 45 additions & 0 deletions docs/api/stateful.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Stateful operations

These operations can be used to introduce save/load JAX arrays as a side-effect of JAX operations, even under JIT.

!!! warning

This is considered experimental.

Use cases:
- Something like [`equinox.experimental.BatchNorm`][], for which we would like to save the running statistics as a side-effect.
- Implicitly passing information between loop iterations -- i.e. rather than explicitly via the `carry` argument to `lax.scan`. Perhaps you're using a third-party library that handles the `lax.scan`, but you want to pass your own information between repeated invocations.

Example:
```python
import equinox as eqx
import jax
import jax.lax as lax
import jax.numpy as jnp

index = eqx.experimental.StateIndex()
eqx.experimental.set_state(index, jnp.array(0))

def scan_fun(_, __):
val = eqx.experimental.get_state(index)
val = val + 1
eqx.experimental.set_state(index, val)
return None, val

_, out = lax.scan(scan_fun, None, xs=None, length=5)
print(out) # [1 2 3 4 5]
```

---

::: equinox.experimental.StateIndex
selection:
members: false

---

::: equinox.experimental.get_state

---

::: equinox.experimental.set_state
38 changes: 38 additions & 0 deletions docs/faq.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# FAQ

## Optax is throwing an error.

Probably you're writing code that looks like
```python
optim = optax.adam(learning_rate)
optim.init(model)
```
and getting an error that looks like
```
TypeError: zeros_like requires ndarray or scalar arguments, got <class 'jax._src.custom_derivatives.custom_jvp'> at position 0.
```

This can be fixed by doing
```python
optim.init(eqx.filter(model, eqx.is_array))
```
which after a little thought should make sense: Optax can only optimise JAX arrays. It's not meaningful to ask Optax to optimsie whichever other arbitrary Python objects may be a part of your model. (e.g. the activation function of an `eqx.nn.MLP`).

## A module saved in two places has become two independent copies.

Probably you're doing something like
```python
class Module(eqx.Module):
linear1: eqx.nn.Linear
linear2: eqx.nn.Linear

def __init__(...):
shared_linear = eqx.nn.Linear(...)
self.linear1 = shared_linear
self.linear2 = shared_linear
```
in which the same object is saved multiple times in the model. After making some gradient updates you'll find that `self.linear1` and `self.linear2` are now different.

Recall that in Equinox, models are PyTrees. Meanwhile, JAX treats all PyTrees as *trees*: that is, the same object does not appear more in the tree than once. (If it did, then it would be a *directed acyclic graph* instead.) If JAX ever encounters the same object multiple times then it will unwittingly make independent copies of the object whenever it transforms the overall PyTree.

The resolution is simple: just don't store the same object in multiple places in the PyTree.
16 changes: 4 additions & 12 deletions equinox/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,18 @@
from . import nn
from . import experimental, nn
from .filters import (
combine,
filter,
is_array,
is_array_like,
is_inexact_array,
is_inexact_array_like,
merge,
partition,
split,
)
from .grad import (
filter_custom_vjp,
filter_grad,
filter_value_and_grad,
gradf,
value_and_grad_f,
)
from .jit import filter_jit, jitf
from .grad import filter_custom_vjp, filter_grad, filter_value_and_grad
from .jit import filter_jit
from .module import Module, static_field
from .tree import tree_at, tree_equal, tree_pformat
from .update import apply_updates


__version__ = "0.2.2"
__version__ = "0.3.0"
19 changes: 0 additions & 19 deletions equinox/deprecated.py

This file was deleted.

2 changes: 2 additions & 0 deletions equinox/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .normalisation import BatchNorm
from .stateful import get_state, set_state, StateIndex
174 changes: 174 additions & 0 deletions equinox/experimental/normalisation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from typing import Optional

import jax
import jax.lax as lax
import jax.numpy as jnp

from ..custom_types import Array
from ..module import Module, static_field
from .stateful import get_state, set_state, StateIndex


# This is marked experimental because it uses the experimental stateful functionality.
class BatchNorm(Module):
r"""Computes a mean and standard deviation over the batch and spatial
dimensions of an array, and uses these to normalise the whole array. Optionally
applies a channelwise affine transformation afterwards.
Given an input array $x = [x_1, ... x_C]$ with $C$ channels, this layer computes
$$\frac{x_i - \mathbb{E}[x_i]}{\sqrt{\text{Var}[x_i] + \varepsilon}} * \gamma_i + \beta_i$$
for all $i$. Here $*$ denotes elementwise multiplication and $\gamma$, $\beta$ have
shape $(C,)$ if `channelwise_affine=True` and $\gamma = 1$, $\beta = 0$ if
`channelwise_affine=False`. Expectations are computed over all spatial dimensions
*and* over the batch dimension, and updated batch-by-batch according to `momentum`.
!!! warning
This layer must be used inside of a `vmap` or `pmap` with a matching
`axis_name`. (Not doing so will raise a `NameError`.)
!!! warning
[`equinox.experimental.BatchNorm`][] saves the running statistics as a side
effect of its forward pass. Side effects are quite unusual in JAX; as such
`BatchNorm` is considered experimental. Let us know how you find it!
!!! example
```python
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
key = jr.PRNGKey(0)
mkey, dkey = jr.split(key)
model = eqx.nn.Sequential([
eqx.nn.Linear(in_features=3, out_features=4, key=mkey),
eqx.experimental.BatchNorm(input_size=4, axis_name="batch"),
])
x = jr.normal(dkey, (10, 3))
jax.vmap(model, axis_name="batch")(x)
```
""" # noqa: E501
weight: Optional[Array["input_size"]]
bias: Optional[Array["input_size"]]
first_time_index: StateIndex
state_index: StateIndex
axis_name: str
inference: bool
input_size: int = static_field()
eps: float = static_field()
channelwise_affine: bool = static_field()
momentum: float = static_field()

def __init__(
self,
input_size: int,
axis_name: str,
eps: float = 1e-5,
channelwise_affine: bool = True,
momentum: float = 0.99,
inference: bool = False,
**kwargs,
):
"""**Arguments:**
- `input_size`: The number of channels in the input array.
- `axis_name`: The name of the batch axis to compute statistics over, as passed
to `axis_name` in `jax.vmap` or `jax.pmap`.
- `eps`: Value added to the denominator for numerical stability.
- `channelwise_affine`: Whether the module has learnable channel-wise affine
parameters.
- `momentum`: The rate at which to update the running statistics. Should be a
value between 0 and 1 exclusive.
- `inference`: If `False` then the batch means and variances will be calculated
and used to update the running statistics. If `True` then the running
statistics are directly used for normalisation.
"""

super().__init__(**kwargs)

if channelwise_affine:
self.weight = jnp.ones((input_size,))
self.bias = jnp.zeros((input_size,))
else:
self.weight = None
self.bias = None
self.first_time_index = StateIndex()
self.state_index = StateIndex()
self.inference = inference
self.axis_name = axis_name
self.input_size = input_size
self.eps = eps
self.channelwise_affine = channelwise_affine
self.momentum = momentum

set_state(self.first_time_index, jnp.array(True))

def __call__(
self,
x: Array,
*,
key: Optional["jax.random.PRNGKey"] = None,
inference: Optional[bool] = None,
) -> Array:
"""**Arguments:**
- `x`: A JAX array of shape `(input_size, dim_1, ..., dim_N)`.
- `key`: Ignored; provided for compatibility with the rest of the Equinox API.
(Keyword only argument.)
- `inference`: As per [`equinox.experimental.BatchNorm.__init__`][]. If
`True` or `False` then it will take priority over `self.update_stats`. If
`None` then the value from `self.update_stats` will be used.
**Returns:**
A JAX array of shape `(input_size, dim_1, ..., dim_N)`.
**Raises:**
A `NameError` if no `vmap`s are placed around this operation, or if this vmap
does not have a matching `axis_name`.
"""

def _stats(y):
mean = jnp.mean(y)
mean = lax.pmean(mean, self.axis_name)
var = jnp.mean((y - mean) ** 2)
var = lax.pmean(var, self.axis_name)
return mean, var

batch_state = jax.vmap(_stats)(x)

if inference is None:
inference = self.inference
if inference:
running_mean, running_var = get_state(self.state_index, like=batch_state)
else:
first_time = get_state(self.first_time_index, like=jnp.array(False))
running_state = lax.cond(
first_time,
lambda: batch_state,
lambda: get_state(self.state_index, like=batch_state),
)
set_state(self.first_time_index, jnp.array(False))
running_mean, running_var = running_state

batch_mean, batch_var = batch_state
running_mean = (
1 - self.momentum
) * batch_mean + self.momentum * running_mean
running_var = (1 - self.momentum) * batch_var + self.momentum * running_var
set_state(self.state_index, (running_mean, running_var))

def _norm(y, m, v, w, b):
out = (y - m) / jnp.sqrt(v + self.eps)
if self.channelwise_affine:
out = out * w + b
return out

return jax.vmap(_norm)(x, running_mean, running_var, self.weight, self.bias)
Loading

0 comments on commit 2dba6cf

Please sign in to comment.