Skip to content

Commit

Permalink
Tweaked example
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Mar 27, 2022
1 parent b65f303 commit 3af0cde
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions docs/api/stateful.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ import jax.lax as lax
import jax.numpy as jnp

index = eqx.experimental.StateIndex()
eqx.experimental.set_state(index, jnp.array(0))
init = jnp.array(0)
eqx.experimental.set_state(index, init)

@jax.jit
def scan_fun(_, __):
val = eqx.experimental.get_state(index)
val = eqx.experimental.get_state(index, like=init)
val = val + 1
eqx.experimental.set_state(index, val)
return None, val
Expand Down

0 comments on commit 3af0cde

Please sign in to comment.