Skip to content

Releases: patrick-kidger/equinox

Equinox v0.10.0

21 Feb 00:15
76a2baa
Compare
Choose a tag to compare

Highlights

  1. A dramatically simplified API for equinox.{filter_jit, filter_grad, filter_value_and_grad, filter_vmap, filter_pmap} . This is a backward-incompatible change.

  2. equinox.internal.while_loop, which is a reverse-mode autodifferentiable while loop, using recursive checkpointing.

Full change list

New features

Some new relatively minor new features available in this release.

  • Added support for donating buffers when using eqx.{filter_jit, filter_pmap}. (Thanks @uuirs in #235!)
  • Added eqx.nn.PRelu. (Thanks @enver1323 in #249!)
  • Added eqx.tree_pprint.
  • Added eqx.module_update_wrapper.
  • eqx.filter_custom_jvp now supports keyword arguments (which are always treated as nondifferentiable).

New internal features

Introducing a slew of new features for the advanced JAX user.

These are all available in the equinox.internal namespace. Note that these comes without stability guarantees, as they often depend on functionality that JAX doesn't make fully public.

  • eqxi.abstractattribute, for marking abstract instance attributes of abstract Equinox modules.
  • eqxi.tree_pp, for producing a pretty-print doc of an object. (This is what is then formatted to a particular width in e.g. eqx.tree_pformat.) In addition classes can now have custom pretty behaviour when used with eqx.{tree_pp, tree_pformat, tree_pprint}, by setting a __tree_pp__ method.
  • eqxi.if_mapped, as an alternative to the usual eqx.if_array passed to eqx.{filter_vmap, filter_pmap}(out_axes=...).
  • eqxi.{finalise_jaxpr, finalise_fn} for tracing through custom primitives impl rules (so that the custom primitive no longer appears in the jaxpr). This is useful for replacing such custom primitives prior to offloading a jaxpr to some other IR, e.g. via jax2tf.
  • eqxi.{nonbatchable, nondifferentiable, nondifferentiable_backward, nontraceable} for asserting that an operation is never batched, differentiated, or subject to any transform at all.
  • eqxi.to_onnx for exporting to ONNX.
  • eqxi.while_loop for reverse-mode autodifferentiable while loops; in particular making use of recursive checkpointing. (A la treeverse.)

Backward-incompatible changes

  • The API for equinox.{filter_jit, filter_grad, filter_value_and_grad, filter_vmap, filter_pmap} has been dramatically simplified. If you were using the extra arguments to these functions (i.e. not just calling @eqx.filter_jit etc. directly) then this is a backward-incompatible change; see the discussion below for more details.
  • Removed equinox.nn.{AvgPool1D, AvgPool2D, AvgPool3D, MaxPool1D, MaxPool2D, MaxPool3D}. Use AvgPool1d etc. (lower-case "d") instead. (These were backward-compatiblity stubs that have now been removed.)
  • Removed equinox.Module.{tree_flatten, tree_unflatten}. These were never technically public API; use jax.tree_util.{tree_flatten, tree_unflatten} instead.
  • equinox.filter_closure_convert now asserts that you call it with argments compatible with those it was closure-converted with.
  • Dropped support for Python 3.7.

Other

  • The Python overhead when crossing a filter_jit or filter_pmap boundary should now be much reduced.
  • eqx.tree_inference now runs faster. (Thanks @uuirs in #233!)
  • Lots of documentation improvements; in particular a new "Tricks" section forsome advanced notes. (Thanks @carlosgmartin in #239!)

Filtered transformation API changes (AKA: "my code isn't working any more?")

These APIs have been simplified and made much easier to understand. No functionality has been lost, things might just need tweaking.

filter_jit

This previously took default, args, kwargs, out, fn arguments, for controlling what should be traced and what should be held static.

In practice all JAX arrays and NumPy arrays always had to be traced, and everything that wasn't a JAXable type (JAX array, NumPy array, bool, int, float, complex) had to be held static. So these arguments just weren't that useful: pretty much the only thing you could do with them was to specify that you'd like to trace a bool/int/float/complex.

This minor use-case wasn't worth complicating such an important API for, which is why these arguments have been removed.

If after this change you still want to trace with respect to bool/int/float/complex, then do so simply by wrapping them into JAX arrays or NumPy arrays first: np.asarray(x).

filter_grad and filter_value_and_grad

These previously took an arg argument, for controlling what parts of the first argument should be differentiated.

This was useful occasionally -- e.g. when freezing parts of a layer -- but in practice it still wasn't used that often. As such it this argument has been removed for the sake of simplicity.

If after this change you want to replicate the previous behaviour, then it is simple to do so using partition and combine:

# Before
@eqx.filter_grad(arg=foo)
def loss(first_arg, ...):
    ...

loss(bar, ...)

# After
@eqx.filter_grad
def loss(diff_first_arg, static_first_arg, ...):
    first_arg = eqx.combine(diff_first_arg, static_first_arg)
    ...

diff_bar, static_bar = eqx.partition(bar, foo)
loss(diff_bar, static_bar, ...)

See also the updated frozen layer example for a demonstration.

filter_vmap

This previously took default, args, kwargs, out, fn arguments, for controlling what axes should be vectorised over.

In practice this API was just a bit more complicated than it really needed to be. The only useful feature relative to jax.vmap was kwargs, for easily specifying just a few named arguments that should behave differently.

The new API instead accepts in_axes and out_axes arguments, just like jax.vmap. To replace kwargs, one extra feature is supported: in_axes may be a dictionary of named argments, e.g.

@eqx.filter_vmap(in_axes=dict(bar=None))
def fn(foo, bar):
    ...

All arguments not named in kwargs will have the default value of eqx.if_array(0) -> 0 if is_array(x) else None applied to them.

On which note, a new eqx.if_array(i) now exists, to make it easier to specify values for in_axes and out_axes.

If you were using the old fn argument, then this can be replicated by instead decorating a function that accepts the callable:

# Before
@eqx.filter_vmap(foo, fn=bar)(x, y)

# After
@eqx.filter_vmap(in_axes=dict(fn=bar))
def accepts_foo(fn, x, y):
    return fn(x, y)

accepts_foo(foo, x, y)

filter_pmap.

This previously took default, args, kwargs, out, fn arguments, for controlling what axes should be parallelised over, and which arguments should be traced vs static.

This was a fiendishly complicated API merging together both the filter_jit and filter_vmap APIs.

The JIT part of it is now handled automatically, as with filter_jit: all arrays are traced, everything else is static.

The vmap part of it is now handled in the same way as filter_vmap, using in_axes and out_axes arguments.

New Contributors

Full Changelog: v0.9.2...v0.10.0

Equinox v0.9.2

17 Nov 05:43
253522c
Compare
Choose a tag to compare

Autogenerated release notes as follows:

What's Changed

Full Changelog: v0.9.1...v0.9.2

Equinox v0.9.1

15 Nov 05:31
e560936
Compare
Choose a tag to compare

New features

These are all pretty self-explanatory!

  • equinox.filter_make_jaxpr
  • equinox.filter_vjp
  • equinox.filter_closure_convert
  • equinox.filter_pure_callback

Also:

  • equinox.internal.debug_backward_nan(x) will print out the primal and cotangent for x, and if the cotangent has a NaN then the computation is halted.

Bugfixes

  • equinox.{is_array, is_array_like, is_inexact_array, is_inexact_array_like} all now recognise NumPy scalars as being array types.
  • equinox.internal.{error_if, branched_error_if} are now compatible with jax.ensure_compile_time_eval.
  • equinox.internal.noinline will now no longer throw an assert error during tracing under certain edge-case conditions. (In particular, when part of the branched of a vmap'd lax.cond with batched predicate.)
  • equinox.tree_pformat now prints out jax.tree_util.Partials, and dataclass types (not instances) correctly.

Tweaks:

  • equinox.internal.noinline is now compatible with jax.jit, i.e. a noinline-wrapped function can be passed across a jit API boundary. (Previously equinox.filter_jit was required.)
  • equinox.internal.announce_jaxpr has been renamed to equinox.internal.announce_transform.
  • equinox.internal.{nondifferentiable, nondifferentiable_backward} now take a msg argument for overriding their error messages.

Full Changelog: v0.9.0...v0.9.1

Equinox v0.9.0

02 Nov 23:15
1f5373f
Compare
Choose a tag to compare

This is a big update. The highlight here is the new equinox.internal namespace, which contains a slew of advanced features.

These are only "semi public". These are deliberately not in the main documentation, and exist primarily for the benefit of downstream libraries like Diffrax. But you may still have fun playing with them.

Features

  • equinox.internal.
    • Autodiff:
      • nondifferentiable: will raise an error at trace-time if you attempt to differentiate it.
      • nondifferentiable_backward: will raise an error at trace-time if you attempt to reverse-mode differentiate it.
    • Debug tooling:
      • announce_jaxpr: will call a custom callback whenever it is traced/transformed in a jaxpr. print(<transform stack>) is the default callback.
    • Runtime errors:
      • error_if: can raise runtime errors. (Works on CPU; doesn't work on TPU. GPU support may be flaky.)
      • branched_error_if: can raise one of multiple errors, depending on a traced value.
    • Floating point manipulation:
      • nextafter: returns the next floating point number. Unlike jnp.nextafter, it is differentiable.
      • prevbefore: returns the previous floating point number. Is differentiable.
    • MLIR sub-graphs:
      • noinline: used to mark that a subcomputation should be placed in a separate computation graph, e.g. to avoid compiling the same thing multiple times if it is called repeatedly. Can also be used to iteratively recompile just parts of a computation graph, if the sub/super-graph is the only thing that changes.
    • Omega:
      • ω: nice syntax for tree arithmetic. For example (x**ω + y**ω).ω == tree_map(operator.add, x, y). Like tree-math but with nicer syntax.
    • Custom primitives:
      • filter_primitive_{def, jvp, transpose, batching, bind}: Define rules for custom primitive that accept arbitrary PyTrees; not just JAX arrays.
      • create_vprim: Autodefines batching rules for higher-order primitives, according to transform(vmap(prim)) == vmap(transform(prim)).
    • String handling:
      • str2jax: turns a string into a JAX'able object.
    • Unvmap'ing:
      • unvmap_{any, all, max}: apply reductions whilst ignoring the batch dimension.
  • New filtered transformations: eqx.{filter_jvp,filter_custom_jvp}

Bugfixes / backward incompatibilities

  • eqx.nn.GRUCell will now use its bias term. (Previously it was never adding this.)
  • eqx.filter_eval_shape will no longer promote array-likes to arrays, in either its input or its output.
  • eqx.tree_equal now treats JAX arrays and NumPy arrays as equal.

Misc

  • Improved compilation speed of eqx.filter_vmap.

New Contributors

Full Changelog: v0.8.0...v0.9.0

Equinox v0.8.0

22 Sep 23:17
c057deb
Compare
Choose a tag to compare

The ongoing march of small tweaks progresses.

Main changes this release:

  • eqx.{is_array,is_inexact_array} now return True for np.ndarrays rather than False. This is technically a breaking change, hence the new minor version bump. Rationale in #202.
  • We now use jaxtyping. Hurrah!

Other changes:

  • make sequential module immutable by @jenkspt in #195
  • Add support for asymmetric padding in Conv and ConvTransposed. by @Gurvan in #197

New Contributors

Full Changelog: v0.7.1...v0.8.0

Equinox v0.7.1

06 Sep 01:52
5f2f038
Compare
Choose a tag to compare

Autogenerated release notes as follows:

What's Changed

  • Fixed NotImplementedError when computing gradients of stateful models by @patrick-kidger in #191
  • fix attention with mask and add tests by @uuirs in #190

New Contributors

Full Changelog: v0.7.0...v0.7.1

Equinox v0.7.0

30 Aug 21:13
18d260d
Compare
Choose a tag to compare
  • Multiple bugfixes for differentiating through, and serialising, eqx.experimental.BatchNorm.
    • This is the reason for the version bump: if you are using eqx.experimental.{BatchNorm,SpectralNorm,StateIndex} then the serialisation format has changed.
  • Feature: use_ceil added to all pooling layers.

Autogenerated release notes as follows:

What's Changed

Full Changelog: v0.6.0...v0.7.0

Equinox v0.6.0

03 Aug 21:35
6058475
Compare
Choose a tag to compare
  • Refactor: the serialisation format for eqx.experimental.{BatchNorm,SpectralNorm,StateIndex} under eqx.tree_{de,}serialise_leaves has been tweaked slightly to avoid an edge-case crash. [This is the reason for the minor version bump to 0.6.0, as this is technically a (very minor) compatibility break.]
  • Refactor: changed from jax.tree_map to jax.tree_util.tree_map to remove all the deprecation warnings JAX has started giving.
  • Feature: added eqx.nn.Lambda (for use with eqx.nn.Sequential)
  • Feature: added eqx.default_{de,}serialise_filter_spec (for use `eqx.tree_{de,}serialise_leaves).
  • Bugfix: fixed BatchNorm crashing under jax.grad.
  • Documentation: lots of tidy-ups and improvements.

Autogenerated release notes as follows:

What's Changed

New Contributors

Full Changelog: v0.5.6...v0.6.0

Equinox v0.5.6

20 Jul 17:37
483c7a4
Compare
Choose a tag to compare

Autogenerated release notes as follows:

What's Changed

New Contributors

Full Changelog: v0.5.5...v0.5.6

Equinox v0.5.5

20 Jul 17:28
Compare
Choose a tag to compare

Autogenerated release notes as follows:

What's Changed

Full Changelog: v0.5.4...v0.5.5