Releases: patrick-kidger/equinox
Equinox v0.10.0
Highlights
-
A dramatically simplified API for
equinox.{filter_jit, filter_grad, filter_value_and_grad, filter_vmap, filter_pmap}
. This is a backward-incompatible change. -
equinox.internal.while_loop
, which is a reverse-mode autodifferentiable while loop, using recursive checkpointing.
Full change list
New features
Some new relatively minor new features available in this release.
- Added support for donating buffers when using
eqx.{filter_jit, filter_pmap}
. (Thanks @uuirs in #235!) - Added
eqx.nn.PRelu
. (Thanks @enver1323 in #249!) - Added
eqx.tree_pprint
. - Added
eqx.module_update_wrapper
. eqx.filter_custom_jvp
now supports keyword arguments (which are always treated as nondifferentiable).
New internal
features
Introducing a slew of new features for the advanced JAX user.
These are all available in the equinox.internal
namespace. Note that these comes without stability guarantees, as they often depend on functionality that JAX doesn't make fully public.
eqxi.abstractattribute
, for marking abstract instance attributes of abstract Equinox modules.eqxi.tree_pp
, for producing a pretty-print doc of an object. (This is what is then formatted to a particular width in e.g.eqx.tree_pformat
.) In addition classes can now have custom pretty behaviour when used witheqx.{tree_pp, tree_pformat, tree_pprint}
, by setting a__tree_pp__
method.eqxi.if_mapped
, as an alternative to the usualeqx.if_array
passed toeqx.{filter_vmap, filter_pmap}(out_axes=...)
.eqxi.{finalise_jaxpr, finalise_fn}
for tracing through custom primitivesimpl
rules (so that the custom primitive no longer appears in the jaxpr). This is useful for replacing such custom primitives prior to offloading a jaxpr to some other IR, e.g. viajax2tf
.eqxi.{nonbatchable, nondifferentiable, nondifferentiable_backward, nontraceable}
for asserting that an operation is never batched, differentiated, or subject to any transform at all.eqxi.to_onnx
for exporting to ONNX.eqxi.while_loop
for reverse-mode autodifferentiable while loops; in particular making use of recursive checkpointing. (A la treeverse.)
Backward-incompatible changes
- The API for
equinox.{filter_jit, filter_grad, filter_value_and_grad, filter_vmap, filter_pmap}
has been dramatically simplified. If you were using the extra arguments to these functions (i.e. not just calling@eqx.filter_jit
etc. directly) then this is a backward-incompatible change; see the discussion below for more details. - Removed
equinox.nn.{AvgPool1D, AvgPool2D, AvgPool3D, MaxPool1D, MaxPool2D, MaxPool3D}
. UseAvgPool1d
etc. (lower-case "d") instead. (These were backward-compatiblity stubs that have now been removed.) - Removed
equinox.Module.{tree_flatten, tree_unflatten}
. These were never technically public API; usejax.tree_util.{tree_flatten, tree_unflatten}
instead. equinox.filter_closure_convert
now asserts that you call it with argments compatible with those it was closure-converted with.- Dropped support for Python 3.7.
Other
- The Python overhead when crossing a
filter_jit
orfilter_pmap
boundary should now be much reduced. eqx.tree_inference
now runs faster. (Thanks @uuirs in #233!)- Lots of documentation improvements; in particular a new "Tricks" section forsome advanced notes. (Thanks @carlosgmartin in #239!)
Filtered transformation API changes (AKA: "my code isn't working any more?")
These APIs have been simplified and made much easier to understand. No functionality has been lost, things might just need tweaking.
filter_jit
This previously took default
, args
, kwargs
, out
, fn
arguments, for controlling what should be traced and what should be held static.
In practice all JAX arrays and NumPy arrays always had to be traced, and everything that wasn't a JAXable type (JAX array, NumPy array, bool
, int
, float
, complex
) had to be held static. So these arguments just weren't that useful: pretty much the only thing you could do with them was to specify that you'd like to trace a bool
/int
/float
/complex
.
This minor use-case wasn't worth complicating such an important API for, which is why these arguments have been removed.
If after this change you still want to trace with respect to bool
/int
/float
/complex
, then do so simply by wrapping them into JAX arrays or NumPy arrays first: np.asarray(x)
.
filter_grad
and filter_value_and_grad
These previously took an arg
argument, for controlling what parts of the first argument should be differentiated.
This was useful occasionally -- e.g. when freezing parts of a layer -- but in practice it still wasn't used that often. As such it this argument has been removed for the sake of simplicity.
If after this change you want to replicate the previous behaviour, then it is simple to do so using partition
and combine
:
# Before
@eqx.filter_grad(arg=foo)
def loss(first_arg, ...):
...
loss(bar, ...)
# After
@eqx.filter_grad
def loss(diff_first_arg, static_first_arg, ...):
first_arg = eqx.combine(diff_first_arg, static_first_arg)
...
diff_bar, static_bar = eqx.partition(bar, foo)
loss(diff_bar, static_bar, ...)
See also the updated frozen layer example for a demonstration.
filter_vmap
This previously took default
, args
, kwargs
, out
, fn
arguments, for controlling what axes should be vectorised over.
In practice this API was just a bit more complicated than it really needed to be. The only useful feature relative to jax.vmap
was kwargs
, for easily specifying just a few named arguments that should behave differently.
The new API instead accepts in_axes
and out_axes
arguments, just like jax.vmap
. To replace kwargs
, one extra feature is supported: in_axes
may be a dictionary of named argments, e.g.
@eqx.filter_vmap(in_axes=dict(bar=None))
def fn(foo, bar):
...
All arguments not named in kwargs
will have the default value of eqx.if_array(0) -> 0 if is_array(x) else None
applied to them.
On which note, a new eqx.if_array(i)
now exists, to make it easier to specify values for in_axes
and out_axes
.
If you were using the old fn
argument, then this can be replicated by instead decorating a function that accepts the callable:
# Before
@eqx.filter_vmap(foo, fn=bar)(x, y)
# After
@eqx.filter_vmap(in_axes=dict(fn=bar))
def accepts_foo(fn, x, y):
return fn(x, y)
accepts_foo(foo, x, y)
filter_pmap
.
This previously took default
, args
, kwargs
, out
, fn
arguments, for controlling what axes should be parallelised over, and which arguments should be traced vs static.
This was a fiendishly complicated API merging together both the filter_jit
and filter_vmap
APIs.
The JIT part of it is now handled automatically, as with filter_jit
: all arrays are traced, everything else is static.
The vmap part of it is now handled in the same way as filter_vmap
, using in_axes
and out_axes
arguments.
New Contributors
- @carlosgmartin made their first contribution in #239
- @enver1323 made their first contribution in #249
Full Changelog: v0.9.2...v0.10.0
Equinox v0.9.2
Autogenerated release notes as follows:
What's Changed
- Minor doc fixes by @patrick-kidger in #228
- Allow passing file-like objects to eqx.serialise/deserialise by @jatentaki in #229
- Fixed broken
filter_closure_convert
(and new JAX breaking Equinox's experimental stateful operations) by @patrick-kidger in #232
Full Changelog: v0.9.1...v0.9.2
Equinox v0.9.1
New features
These are all pretty self-explanatory!
equinox.filter_make_jaxpr
equinox.filter_vjp
equinox.filter_closure_convert
equinox.filter_pure_callback
Also:
equinox.internal.debug_backward_nan(x)
will print out the primal and cotangent forx
, and if the cotangent has a NaN then the computation is halted.
Bugfixes
equinox.{is_array, is_array_like, is_inexact_array, is_inexact_array_like}
all now recognise NumPy scalars as being array types.equinox.internal.{error_if, branched_error_if}
are now compatible withjax.ensure_compile_time_eval
.equinox.internal.noinline
will now no longer throw an assert error during tracing under certain edge-case conditions. (In particular, when part of the branched of avmap
'dlax.cond
with batched predicate.)equinox.tree_pformat
now prints outjax.tree_util.Partial
s, and dataclass types (not instances) correctly.
Tweaks:
equinox.internal.noinline
is now compatible withjax.jit
, i.e. anoinline
-wrapped function can be passed across a jit API boundary. (Previouslyequinox.filter_jit
was required.)equinox.internal.announce_jaxpr
has been renamed toequinox.internal.announce_transform
.equinox.internal.{nondifferentiable, nondifferentiable_backward}
now take amsg
argument for overriding their error messages.
Full Changelog: v0.9.0...v0.9.1
Equinox v0.9.0
This is a big update. The highlight here is the new equinox.internal
namespace, which contains a slew of advanced features.
These are only "semi public". These are deliberately not in the main documentation, and exist primarily for the benefit of downstream libraries like Diffrax. But you may still have fun playing with them.
Features
equinox.internal.
- Autodiff:
nondifferentiable
: will raise an error at trace-time if you attempt to differentiate it.nondifferentiable_backward
: will raise an error at trace-time if you attempt to reverse-mode differentiate it.
- Debug tooling:
announce_jaxpr
: will call a custom callback whenever it is traced/transformed in a jaxpr.print(<transform stack>)
is the default callback.
- Runtime errors:
error_if
: can raise runtime errors. (Works on CPU; doesn't work on TPU. GPU support may be flaky.)branched_error_if
: can raise one of multiple errors, depending on a traced value.
- Floating point manipulation:
nextafter
: returns the next floating point number. Unlikejnp.nextafter
, it is differentiable.prevbefore
: returns the previous floating point number. Is differentiable.
- MLIR sub-graphs:
noinline
: used to mark that a subcomputation should be placed in a separate computation graph, e.g. to avoid compiling the same thing multiple times if it is called repeatedly. Can also be used to iteratively recompile just parts of a computation graph, if the sub/super-graph is the only thing that changes.
- Omega:
ω
: nice syntax for tree arithmetic. For example(x**ω + y**ω).ω == tree_map(operator.add, x, y)
. Like tree-math but with nicer syntax.
- Custom primitives:
filter_primitive_{def, jvp, transpose, batching, bind}
: Define rules for custom primitive that accept arbitrary PyTrees; not just JAX arrays.create_vprim
: Autodefines batching rules for higher-order primitives, according totransform(vmap(prim)) == vmap(transform(prim))
.
- String handling:
str2jax
: turns a string into a JAX'able object.
- Unvmap'ing:
unvmap_{any, all, max}
: apply reductions whilst ignoring the batch dimension.
- Autodiff:
- New filtered transformations:
eqx.{filter_jvp,filter_custom_jvp}
Bugfixes / backward incompatibilities
eqx.nn.GRUCell
will now use its bias term. (Previously it was never adding this.)eqx.filter_eval_shape
will no longer promote array-likes to arrays, in either its input or its output.eqx.tree_equal
now treats JAX arrays and NumPy arrays as equal.
Misc
- Improved compilation speed of
eqx.filter_vmap
.
New Contributors
- @jondeaton made their first contribution in #204
- @IsaacBreen made their first contribution in #215
Full Changelog: v0.8.0...v0.9.0
Equinox v0.8.0
The ongoing march of small tweaks progresses.
Main changes this release:
eqx.{is_array,is_inexact_array}
now returnTrue
fornp.ndarray
s rather thanFalse
. This is technically a breaking change, hence the new minor version bump. Rationale in #202.- We now use
jaxtyping
. Hurrah!
Other changes:
- make sequential module immutable by @jenkspt in #195
- Add support for asymmetric padding in Conv and ConvTransposed. by @Gurvan in #197
New Contributors
Full Changelog: v0.7.1...v0.8.0
Equinox v0.7.1
Autogenerated release notes as follows:
What's Changed
- Fixed
NotImplementedError
when computing gradients ofstateful
models by @patrick-kidger in #191 - fix attention with mask and add tests by @uuirs in #190
New Contributors
Full Changelog: v0.7.0...v0.7.1
Equinox v0.7.0
- Multiple bugfixes for differentiating through, and serialising,
eqx.experimental.BatchNorm
.- This is the reason for the version bump: if you are using
eqx.experimental.{BatchNorm,SpectralNorm,StateIndex}
then the serialisation format has changed.
- This is the reason for the version bump: if you are using
- Feature:
use_ceil
added to all pooling layers.
Autogenerated release notes as follows:
What's Changed
- Add len and iter methods to nn.Sequential by @jenkspt in #174
- Add attention functions and tests by @jenkspt in #181
- Fixed BatchNorm not de/serialising correctly by @patrick-kidger in #172
- Ordered tree map by @paganpasta in #170
- added use_ceil to pooling by @paganpasta in #176
- Dev by @patrick-kidger in #184
Full Changelog: v0.6.0...v0.7.0
Equinox v0.6.0
- Refactor: the serialisation format for
eqx.experimental.{BatchNorm,SpectralNorm,StateIndex}
undereqx.tree_{de,}serialise_leaves
has been tweaked slightly to avoid an edge-case crash. [This is the reason for the minor version bump to 0.6.0, as this is technically a (very minor) compatibility break.] - Refactor: changed from
jax.tree_map
tojax.tree_util.tree_map
to remove all the deprecation warnings JAX has started giving. - Feature: added
eqx.nn.Lambda
(for use witheqx.nn.Sequential
) - Feature: added
eqx.default_{de,}serialise_filter_spec
(for use `eqx.tree_{de,}serialise_leaves). - Bugfix: fixed
BatchNorm
crashing underjax.grad
. - Documentation: lots of tidy-ups and improvements.
Autogenerated release notes as follows:
What's Changed
- Doc tweak by @patrick-kidger in #141
- Fix GroupNorm channels argument and docstring by @jenkspt in #148
- make
Sequential
indexable and add tests by @jenkspt in #153 - replace tree_* with tree_util.tree_* to avoid jax warning messages by @amir-saadat in #156
- Extend deserial filter by @paganpasta in #145
- added lambda_layer to composites by @paganpasta in #158
- Tweaked docs for Lambda by @patrick-kidger in #159
- Tweaked intro docs to improve readability of filtering by @patrick-kidger in #160
- Batch norm grad crash fix by @patrick-kidger in #162
- added #outputs to the StateIndex example by @paganpasta in #164
- Fixed crash when serialising StateIndices without saved state by @patrick-kidger in #167
- v0.6.0 by @patrick-kidger in #169
New Contributors
- @jenkspt made their first contribution in #148
- @amir-saadat made their first contribution in #156
Full Changelog: v0.5.6...v0.6.0
Equinox v0.5.6
Autogenerated release notes as follows:
What's Changed
- Adaptive avg pool 1d by @paganpasta in #129
{Avg,Max}Pool{1,2,3}D
->{Avg,Max}Pool{1,2,3}d
. Removed wrong stride default. by @patrick-kidger in #135- Tweaked AdaptivePool by @patrick-kidger in #139
- Adds adaptive pooling by @patrick-kidger in #140
New Contributors
- @paganpasta made their first contribution in #129
Full Changelog: v0.5.5...v0.5.6
Equinox v0.5.5
Autogenerated release notes as follows:
What's Changed
- Fix doc typo by @patrick-kidger in #130
- Updated pooling docs with init and call by @patrick-kidger in #131
- Doc fix by @patrick-kidger in #132
- Tidied helper into a relative import by @patrick-kidger in #133
- minor bug fix by @patrick-kidger in #134
Full Changelog: v0.5.4...v0.5.5