Skip to content

Commit

Permalink
scan faq
Browse files Browse the repository at this point in the history
  • Loading branch information
lockwo authored and patrick-kidger committed Jun 7, 2024
1 parent 111d258 commit 0cbbdbf
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions docs/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,53 @@ class Model(eqx.Module):
return self.param * x + jax.lax.stop_gradient(self.buffer)
```

## How to use (non-array) modules as inputs to scan/cond/while etc.?

If you have a non jax array in a module and pass it to `scan`/`cond`/etc. you will see an error of the form `TypeError: Value <[non-jax object] at 0x1321b09d0> with type <class [not jax array]> is not a valid JAX type`. The way to solve this is with filtering, specifically, you can filter out the static (i.e. non-jax arrays) and capture them via closure. For example,

```python
mlp = eqx.nn.MLP(...)

def rollout(mlp, xs):
def step(carry, x):
mlp = carry
val = mlp(x)
carry = mlp
return carry, [val]

_, scan_out = jax.lax.scan(
step,
[mlp],
xs
)

return scan_out

key, subkey = jax.random.split(key)
vals = rollout(mlp, jax.random.normal(key=subkey, shape=(200, 3)))
```

will error. To fix this, you can explicitly capture the static elements via

```python
def rollout(mlp, xs):
arr, static = eqx.partition(mlp, eqx.is_array)
def step(carry, x):
mlp = eqx.combine(carry, static)
val = mlp(x)
carry, _ = eqx.partition(mlp, eqx.is_array)
return carry, [val]

_, scan_out = jax.lax.scan(
step,
arr,
xs
)
return scan_out
```

What about if you want a module function to be the function being `scan`-ed over? If you just try to `jax.lax.scan(module, ...)` you will see a `TypeError: unhashable type: 'jaxlib.xla_extension.ArrayImpl'`. This is a [bug in jax](https://github.com/google/jax/issues/13554) that can be avoided by simply wrapping the module function in a lambda, e.g. `jax.lax.scan(lambda x, y: module(x, y), ...)`.

## I think my function is being recompiled each time it is run.

Use [`equinox.debug.assert_max_traces`][], for example
Expand Down

0 comments on commit 0cbbdbf

Please sign in to comment.