Skip to content

Conversation

@edgarWolf
Copy link
Owner

Done in this PR

Implemented a faster implementation of curve sampling in JaxCurveGenerationSolver using jax.vmap.
This is much faster according to my benchmarks.
Tested on 10k curves, meaning 10k latent informations which need to be solved.
With the old code, the computation succeded in 268 seconds, the vectorized code succeeds in 8 seconds.

But this goes with two downsides:

  • We can't use the previous solution for the next problem, because in order to be able to use vmap, all problems need to be solvable independent of eachother. So every problem has the same inital guess w0.
  • We lose the callback since the whole matrix is computed at once. This might be a less of an issue, since we still get all results, but much faster and not computed sequentially. Maybe we should provide a way to get the losses for each instance?

@windisch What do you think?

@edgarWolf edgarWolf requested a review from windisch November 15, 2025 15:20
@edgarWolf edgarWolf linked an issue Nov 15, 2025 that may be closed by this pull request
@windisch
Copy link
Collaborator

Very nice! I think having this performance boost makes the usage of previous solutions obsolete - at least partially, because using the previous solutions also gives us some sort of stability over the curves (execution-dimension). Can we make it an (default) option and keep the other version?

@edgarWolf
Copy link
Owner Author

Very nice! I think having this performance boost makes the usage of previous solutions obsolete - at least partially, because using the previous solutions also gives us some sort of stability over the curves (execution-dimension). Can we make it an (default) option and keep the other version?

we may keep the old code, but in that case i suggest adding a docstring (or even a small section in the docs) to highlight that our current version (sequentially solving for functions) is way slower than using the vectorized version.
and i would definitely make the new version the default, maybe via a flag like

def __init__(self, f, w0, max_fit_attemps, vectorized=True):

in JaxCurveGenerationSolver

@windisch
Copy link
Collaborator

windisch commented Dec 7, 2025

Very nice! I think having this performance boost makes the usage of previous solutions obsolete - at least partially, because using the previous solutions also gives us some sort of stability over the curves (execution-dimension). Can we make it an (default) option and keep the other version?

we may keep the old code, but in that case i suggest adding a docstring (or even a small section in the docs) to highlight that our current version (sequentially solving for functions) is way slower than using the vectorized version. and i would definitely make the new version the default, maybe via a flag like

def __init__(self, f, w0, max_fit_attemps, vectorized=True):

in JaxCurveGenerationSolver

Fully agree! Sounds good!

self.max_fit_attempts = max_fit_attemps

def _latents_to_array(self, latents):
l_x0_mat = jnp.array([latent.x0 for latent in latents])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something like would be a bit more code efficient (while harder to understand though)

return (
    jnp.array([attr(latent, "dim") for latent in latents]) for dim in ["x0", "x1",....]
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Speedup jax with vmap

3 participants