Skip to content

Better cache#145

Open
ASKabalan wants to merge 5 commits intoDifferentiableUniverseInitiative:masterfrom
ASKabalan:better-cache
Open

Better cache#145
ASKabalan wants to merge 5 commits intoDifferentiableUniverseInitiative:masterfrom
ASKabalan:better-cache

Conversation

@ASKabalan
Copy link
Member

@ASKabalan ASKabalan commented Feb 17, 2026

This pull request introduces a host-side caching system for JAX functions to avoid issues with JAX tracers and improve efficiency, and adds support for second-order growth factor and growth rate calculations in the cosmological background module. The changes also refactor how distance and growth tables are computed and accessed, and add comprehensive tests for the new second-order growth functionality.

Caching and Performance Improvements:

  • Added a new @caching decorator in jax_cosmo/cache.py that enables host-side LRU caching for functions, preventing UnexpectedTracerError and improving efficiency in nested JAX transformations. The decorator is now used for distance and growth table computations. [1] [2] [3]

Background Cosmology API Enhancements:

  • Introduced new functions growth_factor_second and growth_rate_second in background.py to compute the second-order growth factor and its rate, including their ODE solvers and interpolation logic. These are exposed in the module's public API. [1] [2] [3]
  • Refactored table computation for distances and growth factors to use the new caching system, replacing the previous _workspace-based caching. Functions such as radial_comoving_distance, a_of_chi, growth_factor, and growth_rate now retrieve cached tables via the new mechanism. [1] [2] [3] [4] [5]

Numerical/Algorithmic Improvements:

  • Updated the ODE integrator (odeint) to return both the final and intermediate values, supporting the new caching and table-building logic.

Testing:

  • Added new tests for growth_factor_second and growth_rate_second to ensure correctness and normalization, including numerical consistency checks.

@ASKabalan
Copy link
Member Author

ASKabalan commented Feb 17, 2026

This is a followup for #143

I have tried a lot to overcome the caching issues in the Cosmology object, however it seems impossible to cover all cases.

Reproducing the error is not easy but it happens very often .. mainly when cosmo object enters a new jit scope and loses it's cache .. in that case (inside any tracing transform like lax scan or vmap) a leak is produced.

All of this can be fixed with a workaround where we clear the workspace everytime (very bad for perf) but in my case when I was taking gradients of a full PM+lightcone where cosmo object is used in many scopes .. I had gradient issues

mainly this error

  File "/home/wassim/micromamba/envs/ffi11/lib/python3.11/site-packages/jax/_src/interpreters/ad.py", line 453, in accum
    ct_check(self, x)
  File "/home/wassim/micromamba/envs/ffi11/lib/python3.11/site-packages/jax/_src/interpreters/ad.py", line 467, in ct_check
    raise ValueError(
ValueError: Input primal JAX type to VJP function is float64[]. Hence the expected cotangent type is float64[] but got float64[128]

in here to be exact

  File "/home/wassim/Projects/NBody/fwd_model_tools/DEBUGGING/debug_grad.py", line 51, in simulation
    dx, p = ffi.lpt(cosmo, initial_field, ts=0.1, order=1, painting=ffi.PaintingOptions(target="particles"))
  File "/home/wassim/Projects/NBody/fwd_model_tools/src/fwd_model_tools/pm/lpt.py", line 128, in lpt
    comoving_centers = jc.background.radial_comoving_distance(cosmo, a.flatten()).reshape(a.shape)
  File "/home/wassim/Projects/NBody/jax_cosmo/jax_cosmo/background.py", line 302, in radial_comoving_distance
    _ensure_background_workspace(cosmo)
  File "/home/wassim/Projects/NBody/jax_cosmo/jax_cosmo/background.py", line 217, in _ensure_background_workspace
    chi_final, chitab = odeint(dchioverdlna, 0.0, np.log(atab))
  File "/home/wassim/Projects/NBody/jax_cosmo/jax_cosmo/scipy/ode.py", line 21, in odeint
    (yf, _), y = jax.lax.scan(rk4, (y0, np.array(t[0])), t)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: Input primal JAX type to VJP function is float64[]. Hence the expected cotangent type is float64[] but got float64[128]

This is unavoidable and it is basically caused by the checkpointing scheme trying to reuse (the now unused) k steps that where used to compute the tables to be interpolated.

The scheme I implemented is much more idiomatic see jax-ml/jax#26274

And it supports gradients and vmap (although sequential for now) but it works pretty well and it is very stable

To test you can do this

import jax_cosmo as jc
import jax

cosmo = jc.Planck15()

print(f"First call to radial_comoving_distance:")
%time r = jc.background.radial_comoving_distance(cosmo, 0.5).block_until_ready()
print(f"Second call to radial_comoving_distance (should be faster due to caching):")
%time r = jc.background.radial_comoving_distance(cosmo, 0.5).block_until_ready()
print(f"Third call with int a to cause a rejit")
%time r = jc.background.radial_comoving_distance(cosmo, 0).block_until_ready()
print(f"Call again with int a to see if it is cached")
%time r = jc.background.radial_comoving_distance(cosmo, 0).block_until_ready()

You will get

First call to radial_comoving_distance:
CPU times: user 467 ms, sys: 22.7 ms, total: 490 ms
Wall time: 468 ms
Second call to radial_comoving_distance (should be faster due to caching):
CPU times: user 24.7 ms, sys: 0 ns, total: 24.7 ms
Wall time: 24.4 ms
Third call with int a to cause a rejit
CPU times: user 88.1 ms, sys: 901 μs, total: 89 ms
Wall time: 86.6 ms
Call again with int a to see if it is cached
CPU times: user 24.7 ms, sys: 0 ns, total: 24.7 ms
Wall time: 24.6 ms

So you can see that the rejit did not recompute the table and used the existing cache

@ASKabalan
Copy link
Member Author

Please merge after #146

@ASKabalan ASKabalan mentioned this pull request Mar 2, 2026
8 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant