Equinox v0.11.5
JAX compatibility
Recent versions of JAX (0.4.28+) have made some changes to:
- Hashing of tracers;
- Tree-map'ing over Nones;
- Callbacks;
- Pretty-printing.
With this update, we should now be compatible with both old and new versions of JAX: this fixes both some new crashes, and some new warnings. (#719, #724, #753, #758, thanks @jakevdp, @hawkinsp!)
Better errors
-
The error messages from
eqx.error_if
are now substantially more informative: they include traceback information including the stack, and mention the availability of theEQX_ON_ERROR
variable. We also do a much better job hiding the large unhelpful printouts that XLA gives by default. (#785, #803) -
The default value of
EQX_ON_ERROR_BREAKPOINT_FRAMES
is now1
. (#777) The impact of this is that usingeqx.error_if
alongsideEQX_ON_ERROR=breakpoint
will now:- reliably always open a debugger, rather than sometimes crashing at trace-time due to upstream JAX bug #16732.
- however, by default the debugger will no longer include any additional stack frames above it (accessed via
u
). - much of the above is now explained in a printed-out informative message prior to the debugger opening.
Bugfixes
-
eqx.filter_{jacfwd, jacrev}
now only apply filtering to their inputs but not their outputs. Previously this was problematic as there was no way to represent static-input-by-static-output in the returned Jacobian, so pieces were silently dropped. (#734, thanks @lockwo!) -
eqx.tree_at
can now be used to replace empty tuples. (#715, #717, #722, thanks @lockwo!) -
eqx.filter_custom_jvp
no longer raises a trace-time crash in some scenarios in which its**kwargs
were erroneously counted as having tangents. (#745 (comment), #749) -
No longer getting a trace-time crash when doing a particular combination of vmap + autodiff + checkpointed while loops. This occurred when using
optimistix.BFGS
arounddiffrax.diffeqsolve
. (#777) -
Fixed a trace-time crash when:
- using a checkpointed while loop...
- ...with a body function that has a closed-over tracer...
- ...and that closed-over tracer is differentiated...
- ...and there are no other closed-over tracers that are differentiated...
- ...and the dependency on that tracer is only linear.
- (patrick-kidger/diffrax#387 (comment), #752, thanks @dkweiss31!)
-
Fixed a trace-time crash when composing the grad of vmap of
lineax.linear_solve
. (patrick-kidger/lineax#101, #795, thanks @rhacking!) -
eqx.nn.RMSNorm
now uses at least 32-bit precision for numerical stability (#723, thanks @AakashKumarNain!)
New features
-
eqx.nn.{Linear,Conv,GRUCell,LSTMCell}
now support complex dtypes (#765, thanks @ChenAo-Phys!) -
Added
eqx.nn.RotaryEmbedding(..., theta=...)
. (#735, thanks @Artur-Galstyan!)
Other changes
-
Several doc fixes. (#708, #731, #733, #747, #750, #757 + several other PRs, thanks @Artur-Galstyan, @matteoguarrera, @lockwo, @nasyxx!)
-
Several internal test fixes as downstream libraries have changed slightly. (#740, #742 + several other PRs, big thanks to @GaetanLepage for reporting many of these!)
-
There is now a Mistral 7B implementation using JAX+Equinox available over in AakashKumarNain/mistral_jax!
New Contributors
- @nasyxx made their first contribution in #708
- @jakevdp made their first contribution in #724
- @matteoguarrera made their first contribution in #739
Full Changelog: v0.11.4...v0.11.5