Replies: 1 comment 2 replies
-
I can reproduce the issue with jax 0.6.0 (current main) on a Linux box. As you have already mentioned, the numerics of non-jitted and jitted JAX code is slightly different and the learning rate is too high. In general, too high learning rate means that the SGD algorithm may diverge and the final optimization result will be very sensitive to perturbation in numerics in any program. In this particular case, with the given learning rate 2.0 the optimization result (e.g. consider the loss and the gradient) are very high (approx 40 and 300000, respectively) meaning that the optimization process did not converge and the results from non-jitted and jitted code are not comparable by definition as both are divergent. When using a smaller learning rate, say 2e-2, the optimization process will converge (the final loss and gradient are in the order of 0.01) and the results from non-jitted and jitted programs become comparable and are close indeed. In sum, the cause of the reported huge discrepancy between non-jitted and jitted programs is a divergent optimization process that result is very sensitive to differences in numerics of non-jitted and jitted programs. As a resolution, one should use meta parameters (such as learning rate) that leads to a convergent optimization process which will suppresse the possible discrepancies in numerics of non-jitted and jitted programs. |
Beta Was this translation helpful? Give feedback.
-
I’ve been working on differentiable simulations where gradients are propagated through physical solvers. Recently, I encountered a puzzling discrepancy: running the code with JIT enabled gives significantly different results compared to the non-JIT version, even though both use the same inputs and computation logic. I know that JIT can change the numerics but had always felt that the margin was not of particular concern (probably in the decimals). But the difference here is difficult to ignore.
The issue surfaced (seemingly randomly) for certain combinations of hyperparameters. After isolating the components of my simulation, I built a minimal working example (MWE) that doesn’t involve physics or FEA, but still demonstrates the problem (I know that the learning rate is too high and that I am using a sinusoidal activation but that is by construction to illustrate the point).
Has anyone else faced this issue?
Jax version=0.5.3 (CPU version)
MAC OS M3
Python=3.11
optax=0.2.4
equinox=0.12.1
Beta Was this translation helpful? Give feedback.
All reactions