From 3af0cdefd70e7a36850260a34cbc58d52c6a4b59 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sun, 27 Mar 2022 22:03:15 +0100 Subject: [PATCH] Tweaked example --- docs/api/stateful.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/api/stateful.md b/docs/api/stateful.md index 69c08f7a..0b36ac1e 100644 --- a/docs/api/stateful.md +++ b/docs/api/stateful.md @@ -18,10 +18,12 @@ import jax.lax as lax import jax.numpy as jnp index = eqx.experimental.StateIndex() -eqx.experimental.set_state(index, jnp.array(0)) +init = jnp.array(0) +eqx.experimental.set_state(index, init) +@jax.jit def scan_fun(_, __): - val = eqx.experimental.get_state(index) + val = eqx.experimental.get_state(index, like=init) val = val + 1 eqx.experimental.set_state(index, val) return None, val