-
Notifications
You must be signed in to change notification settings - Fork 52
Description
Hello,
i'd really like to try out asal. However, most of my Alife models I have are implemented in PyTorch. I thought that I could implement my substrate by simply wrapping my torch model, and doing the jax conversion in and out of the methods that are being exposed to asal.
However, this proved much more challenging than anticipated. I don't know JAX, but I keep receiving errors such as :
File "/home/frotaur/asal/main_opt.py", line 71, in calc_loss rollout_data = rollout_fn(rng, params) File "/home/frotaur/asal/rollout.py", line 84, in rollout_simulation _, state_vid = jax.lax.scan(step_fn, s0, split(rng, rollout_steps)) ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/frotaur/asal/rollout.py", line 82, in step_fn next_state = substrate.step_state(_rng, state, params) File "/home/frotaur/asal/substrates/__init__.py", line 80, in step_state return self.substrate.step_state(rng, state, params) ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^ File "/home/frotaur/asal/substrates/difflenia_sub.py", line 56, in step_state if(torch.allclose(torch.tensor(params['mu'])-self.lenia.params['mu'])): ~~~~~~~~~~~~^^^^^^^^^^^^^^ File "/home/frotaur/.pvenv/lib/python3.13/site-packages/torch/utils/dlpack.py", line 101, in from_dlpack device = ext_tensor.__dlpack_device__() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ AttributeError: DynamicJaxprTracer has no attribute __dlpack_device__
which I assume are due to my torch operation not being JIT-compilable.
So I was wondering : do you think it is possible to implement the actual Alife models in something else than Jax, and use the substrate class to wrap the model, to be able to use asal?
Or are we stuck having to re-implement everything in Jax to use this codebase?
I can provide a minimal implementation of GoL in torch, as well as my attempt to create a substrate wrapper for it, if necessary.
Thanks for the help !