diff --git a/docs/faq.md b/docs/faq.md index dab442c0..5c87cba9 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -140,6 +140,53 @@ class Model(eqx.Module): return self.param * x + jax.lax.stop_gradient(self.buffer) ``` +## How to use (non-array) modules as inputs to scan/cond/while etc.? + +If you have a non jax array in a module and pass it to `scan`/`cond`/etc. you will see an error of the form `TypeError: Value <[non-jax object] at 0x1321b09d0> with type is not a valid JAX type`. The way to solve this is with filtering, specifically, you can filter out the static (i.e. non-jax arrays) and capture them via closure. For example, + +```python +mlp = eqx.nn.MLP(...) + +def rollout(mlp, xs): + def step(carry, x): + mlp = carry + val = mlp(x) + carry = mlp + return carry, [val] + + _, scan_out = jax.lax.scan( + step, + [mlp], + xs + ) + + return scan_out + +key, subkey = jax.random.split(key) +vals = rollout(mlp, jax.random.normal(key=subkey, shape=(200, 3))) +``` + +will error. To fix this, you can explicitly capture the static elements via + +```python +def rollout(mlp, xs): + arr, static = eqx.partition(mlp, eqx.is_array) + def step(carry, x): + mlp = eqx.combine(carry, static) + val = mlp(x) + carry, _ = eqx.partition(mlp, eqx.is_array) + return carry, [val] + + _, scan_out = jax.lax.scan( + step, + arr, + xs + ) + return scan_out +``` + +What about if you want a module function to be the function being `scan`-ed over? If you just try to `jax.lax.scan(module, ...)` you will see a `TypeError: unhashable type: 'jaxlib.xla_extension.ArrayImpl'`. This is a [bug in jax](https://github.com/google/jax/issues/13554) that can be avoided by simply wrapping the module function in a lambda, e.g. `jax.lax.scan(lambda x, y: module(x, y), ...)`. + ## I think my function is being recompiled each time it is run. Use [`equinox.debug.assert_max_traces`][], for example