diff --git a/warp/_src/jax_experimental/ffi.py b/warp/_src/jax_experimental/ffi.py index e9fe408f47..ad2214627d 100644 --- a/warp/_src/jax_experimental/ffi.py +++ b/warp/_src/jax_experimental/ffi.py @@ -1203,6 +1203,7 @@ def jax_kernel( hashable_launch_dims = launch_dims if not enable_backward: + hashable_in_out = tuple(in_out_argnames) if in_out_argnames is not None else None key = ( kernel.func, kernel.sig, @@ -1210,6 +1211,7 @@ def jax_kernel( vmap_method, hashable_launch_dims, hashable_output_dims, + hashable_in_out, module_preload_mode, has_side_effect, ) @@ -1548,12 +1550,18 @@ def jax_callable( hashable_output_dims = output_dims # Note: we don't include graph_cache_max in the key, it is applied below. + hashable_in_out = tuple(in_out_argnames) if in_out_argnames is not None else None + hashable_stage_in = tuple(stage_in_argnames) if stage_in_argnames is not None else None + hashable_stage_out = tuple(stage_out_argnames) if stage_out_argnames is not None else None key = ( func, num_outputs, graph_mode, vmap_method, hashable_output_dims, + hashable_in_out, + hashable_stage_in, + hashable_stage_out, module_preload_mode, has_side_effect, )