-
Notifications
You must be signed in to change notification settings - Fork 2
implement vmap for jax solver #19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Very nice! I think having this performance boost makes the usage of previous solutions obsolete - at least partially, because using the previous solutions also gives us some sort of stability over the curves (execution-dimension). Can we make it an (default) option and keep the other version? |
we may keep the old code, but in that case i suggest adding a docstring (or even a small section in the docs) to highlight that our current version (sequentially solving for functions) is way slower than using the vectorized version. def __init__(self, f, w0, max_fit_attemps, vectorized=True):in |
Fully agree! Sounds good! |
| self.max_fit_attempts = max_fit_attemps | ||
|
|
||
| def _latents_to_array(self, latents): | ||
| l_x0_mat = jnp.array([latent.x0 for latent in latents]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
something like would be a bit more code efficient (while harder to understand though)
return (
jnp.array([attr(latent, "dim") for latent in latents]) for dim in ["x0", "x1",....]
)
Done in this PR
Implemented a faster implementation of curve sampling in
JaxCurveGenerationSolverusingjax.vmap.This is much faster according to my benchmarks.
Tested on 10k curves, meaning 10k latent informations which need to be solved.
With the old code, the computation succeded in 268 seconds, the vectorized code succeeds in 8 seconds.
But this goes with two downsides:
w0.callbacksince the whole matrix is computed at once. This might be a less of an issue, since we still get all results, but much faster and not computed sequentially. Maybe we should provide a way to get the losses for each instance?@windisch What do you think?