Equinox v0.9.0
This is a big update. The highlight here is the new equinox.internal
namespace, which contains a slew of advanced features.
These are only "semi public". These are deliberately not in the main documentation, and exist primarily for the benefit of downstream libraries like Diffrax. But you may still have fun playing with them.
Features
equinox.internal.
- Autodiff:
nondifferentiable
: will raise an error at trace-time if you attempt to differentiate it.nondifferentiable_backward
: will raise an error at trace-time if you attempt to reverse-mode differentiate it.
- Debug tooling:
announce_jaxpr
: will call a custom callback whenever it is traced/transformed in a jaxpr.print(<transform stack>)
is the default callback.
- Runtime errors:
error_if
: can raise runtime errors. (Works on CPU; doesn't work on TPU. GPU support may be flaky.)branched_error_if
: can raise one of multiple errors, depending on a traced value.
- Floating point manipulation:
nextafter
: returns the next floating point number. Unlikejnp.nextafter
, it is differentiable.prevbefore
: returns the previous floating point number. Is differentiable.
- MLIR sub-graphs:
noinline
: used to mark that a subcomputation should be placed in a separate computation graph, e.g. to avoid compiling the same thing multiple times if it is called repeatedly. Can also be used to iteratively recompile just parts of a computation graph, if the sub/super-graph is the only thing that changes.
- Omega:
ω
: nice syntax for tree arithmetic. For example(x**ω + y**ω).ω == tree_map(operator.add, x, y)
. Like tree-math but with nicer syntax.
- Custom primitives:
filter_primitive_{def, jvp, transpose, batching, bind}
: Define rules for custom primitive that accept arbitrary PyTrees; not just JAX arrays.create_vprim
: Autodefines batching rules for higher-order primitives, according totransform(vmap(prim)) == vmap(transform(prim))
.
- String handling:
str2jax
: turns a string into a JAX'able object.
- Unvmap'ing:
unvmap_{any, all, max}
: apply reductions whilst ignoring the batch dimension.
- Autodiff:
- New filtered transformations:
eqx.{filter_jvp,filter_custom_jvp}
Bugfixes / backward incompatibilities
eqx.nn.GRUCell
will now use its bias term. (Previously it was never adding this.)eqx.filter_eval_shape
will no longer promote array-likes to arrays, in either its input or its output.eqx.tree_equal
now treats JAX arrays and NumPy arrays as equal.
Misc
- Improved compilation speed of
eqx.filter_vmap
.
New Contributors
- @jondeaton made their first contribution in #204
- @IsaacBreen made their first contribution in #215
Full Changelog: v0.8.0...v0.9.0