Skip to content

Equinox v0.9.0

Compare
Choose a tag to compare
@github-actions github-actions released this 02 Nov 23:15
· 546 commits to main since this release
1f5373f

This is a big update. The highlight here is the new equinox.internal namespace, which contains a slew of advanced features.

These are only "semi public". These are deliberately not in the main documentation, and exist primarily for the benefit of downstream libraries like Diffrax. But you may still have fun playing with them.

Features

  • equinox.internal.
    • Autodiff:
      • nondifferentiable: will raise an error at trace-time if you attempt to differentiate it.
      • nondifferentiable_backward: will raise an error at trace-time if you attempt to reverse-mode differentiate it.
    • Debug tooling:
      • announce_jaxpr: will call a custom callback whenever it is traced/transformed in a jaxpr. print(<transform stack>) is the default callback.
    • Runtime errors:
      • error_if: can raise runtime errors. (Works on CPU; doesn't work on TPU. GPU support may be flaky.)
      • branched_error_if: can raise one of multiple errors, depending on a traced value.
    • Floating point manipulation:
      • nextafter: returns the next floating point number. Unlike jnp.nextafter, it is differentiable.
      • prevbefore: returns the previous floating point number. Is differentiable.
    • MLIR sub-graphs:
      • noinline: used to mark that a subcomputation should be placed in a separate computation graph, e.g. to avoid compiling the same thing multiple times if it is called repeatedly. Can also be used to iteratively recompile just parts of a computation graph, if the sub/super-graph is the only thing that changes.
    • Omega:
      • ω: nice syntax for tree arithmetic. For example (x**ω + y**ω).ω == tree_map(operator.add, x, y). Like tree-math but with nicer syntax.
    • Custom primitives:
      • filter_primitive_{def, jvp, transpose, batching, bind}: Define rules for custom primitive that accept arbitrary PyTrees; not just JAX arrays.
      • create_vprim: Autodefines batching rules for higher-order primitives, according to transform(vmap(prim)) == vmap(transform(prim)).
    • String handling:
      • str2jax: turns a string into a JAX'able object.
    • Unvmap'ing:
      • unvmap_{any, all, max}: apply reductions whilst ignoring the batch dimension.
  • New filtered transformations: eqx.{filter_jvp,filter_custom_jvp}

Bugfixes / backward incompatibilities

  • eqx.nn.GRUCell will now use its bias term. (Previously it was never adding this.)
  • eqx.filter_eval_shape will no longer promote array-likes to arrays, in either its input or its output.
  • eqx.tree_equal now treats JAX arrays and NumPy arrays as equal.

Misc

  • Improved compilation speed of eqx.filter_vmap.

New Contributors

Full Changelog: v0.8.0...v0.9.0