Better cache#145
Conversation
|
This is a followup for #143 I have tried a lot to overcome the caching issues in the Cosmology object, however it seems impossible to cover all cases. Reproducing the error is not easy but it happens very often .. mainly when cosmo object enters a new jit scope and loses it's cache .. in that case (inside any tracing transform like lax scan or vmap) a leak is produced. All of this can be fixed with a workaround where we clear the workspace everytime (very bad for perf) but in my case when I was taking gradients of a full PM+lightcone where cosmo object is used in many scopes .. I had gradient issues mainly this error in here to be exact This is unavoidable and it is basically caused by the checkpointing scheme trying to reuse (the now unused) k steps that where used to compute the tables to be interpolated. The scheme I implemented is much more idiomatic see jax-ml/jax#26274 And it supports gradients and vmap (although sequential for now) but it works pretty well and it is very stable To test you can do this import jax_cosmo as jc
import jax
cosmo = jc.Planck15()
print(f"First call to radial_comoving_distance:")
%time r = jc.background.radial_comoving_distance(cosmo, 0.5).block_until_ready()
print(f"Second call to radial_comoving_distance (should be faster due to caching):")
%time r = jc.background.radial_comoving_distance(cosmo, 0.5).block_until_ready()
print(f"Third call with int a to cause a rejit")
%time r = jc.background.radial_comoving_distance(cosmo, 0).block_until_ready()
print(f"Call again with int a to see if it is cached")
%time r = jc.background.radial_comoving_distance(cosmo, 0).block_until_ready()You will get So you can see that the rejit did not recompute the table and used the existing cache |
|
Please merge after #146 |
This pull request introduces a host-side caching system for JAX functions to avoid issues with JAX tracers and improve efficiency, and adds support for second-order growth factor and growth rate calculations in the cosmological background module. The changes also refactor how distance and growth tables are computed and accessed, and add comprehensive tests for the new second-order growth functionality.
Caching and Performance Improvements:
@cachingdecorator injax_cosmo/cache.pythat enables host-side LRU caching for functions, preventingUnexpectedTracerErrorand improving efficiency in nested JAX transformations. The decorator is now used for distance and growth table computations. [1] [2] [3]Background Cosmology API Enhancements:
growth_factor_secondandgrowth_rate_secondinbackground.pyto compute the second-order growth factor and its rate, including their ODE solvers and interpolation logic. These are exposed in the module's public API. [1] [2] [3]_workspace-based caching. Functions such asradial_comoving_distance,a_of_chi,growth_factor, andgrowth_ratenow retrieve cached tables via the new mechanism. [1] [2] [3] [4] [5]Numerical/Algorithmic Improvements:
odeint) to return both the final and intermediate values, supporting the new caching and table-building logic.Testing:
growth_factor_secondandgrowth_rate_secondto ensure correctness and normalization, including numerical consistency checks.