Skip to content

Commit 17fc681

Browse files
committed
support Box
1 parent f6ed5ec commit 17fc681

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1648
-1492
lines changed

docs_nnx/guides/array_ref.ipynb

Lines changed: 0 additions & 602 deletions
This file was deleted.

docs_nnx/guides/hijax.ipynb

Lines changed: 572 additions & 0 deletions
Large diffs are not rendered by default.

docs_nnx/guides/array_ref.md renamed to docs_nnx/guides/hijax.md

Lines changed: 36 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -21,41 +21,31 @@ import optax
2121

2222
+++
2323

24-
### Array Refs 101
24+
### Variables Refs
2525

2626
```{code-cell} ipython3
27-
a_ref = jax.new_ref(jnp.array([1, 2, 3]))
27+
variable = nnx.Variable(jnp.array([1, 2, 3]), is_hijax=True)
28+
print(f"{variable.is_hijax = }\n")
2829
2930
@jax.jit
30-
def increment(a_ref: jax.Ref): # no return!
31-
array: jax.Array = a_ref[...] # access
32-
a_ref[...] = array + 1 # update
31+
def increment(variable: nnx.Variable[jax.Array]): # no return!
32+
new_value = variable + 1 # Array-like operations
33+
variable[...] = new_value # in-place updates
3334
34-
print("[1] =", a_ref); increment(a_ref); print("[2] =", a_ref)
35+
print("Before =", variable); increment(variable); print("After =", variable)
3536
```
3637

3738
```{code-cell} ipython3
38-
@jax.jit
39-
def inc(x):
40-
x[...] += 1
41-
42-
print(increment.lower(a_ref).as_text())
39+
# TODO: enable once as_text is fixed
40+
# print(increment.lower(variable).as_text())
4341
```
4442

45-
### Variables Refs
46-
4743
```{code-cell} ipython3
48-
variable = nnx.Variable(jnp.array([1, 2, 3]), use_ref=True)
49-
print(f"{variable.has_ref = }\n")
50-
51-
print("[1] =", variable); increment(variable); print("[2] =", variable)
52-
```
44+
nnx.use_hijax(True)
5345
54-
```{code-cell} ipython3
55-
with nnx.use_refs(True):
56-
variable = nnx.Variable(jnp.array([1, 2, 3]))
46+
variable = nnx.Variable(jnp.array([1, 2, 3]))
5747
58-
print(f"{variable.has_ref = }")
48+
print(f"{variable.is_hijax = }")
5949
```
6050

6151
Mention `nnx.use_refs` can be used as global flag
@@ -73,12 +63,14 @@ class Linear(nnx.Module):
7363
def __call__(self, x):
7464
return x @ self.kernel + self.bias[None]
7565
76-
model = Linear(1, 3, rngs=nnx.Rngs(0)) # without array refs
77-
refs_model = nnx.to_refs(model) # convert to array refs
78-
arrays_model = nnx.to_arrays(refs_model) # convert to regular arrays
66+
with nnx.use_hijax(False): # use lojax Variables
67+
model = Linear(1, 3, rngs=nnx.Rngs(0))
68+
69+
hijax_model = nnx.to_hijax(model) # convert hijax Variables
70+
arrays_model = nnx.to_lojax(hijax_model) # convert to lojax Variables
7971
80-
print("nnx.to_refs(model) =", refs_model)
81-
print("nnx.to_arrays(refs_model) =", arrays_model)
72+
print("nnx.to_hijax(model) =", hijax_model)
73+
print("nnx.to_lojax(refs_model) =", arrays_model)
8274
```
8375

8476
## Examples
@@ -99,9 +91,9 @@ class Block(nnx.Module):
9991
### Training Loop
10092

10193
```{code-cell} ipython3
102-
with nnx.use_refs(True):
103-
model = Block(2, 64, 3, rngs=nnx.Rngs(0))
104-
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
94+
# hijax Variables by default
95+
model = Block(2, 64, 3, rngs=nnx.Rngs(0))
96+
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
10597
10698
@jax.jit
10799
def train_step(model, optimizer, x, y):
@@ -110,7 +102,7 @@ def train_step(model, optimizer, x, y):
110102
model = nnx.merge(graphdef, params, nondiff)
111103
return ((model(x) - y) ** 2).mean()
112104
113-
loss, grads = jax.value_and_grad(loss_fn)(nnx.to_arrays(params)) # freeze ArrayRefs for jax.grad
105+
loss, grads = jax.value_and_grad(loss_fn)(nnx.to_lojax(params)) # lojax Variables for jax.grad
114106
optimizer.update(model, grads)
115107
116108
return loss
@@ -121,12 +113,11 @@ train_step(model, optimizer, x=jnp.ones((10, 2)), y=jnp.ones((10, 3)))
121113
### Scan Over Layers
122114

123115
```{code-cell} ipython3
124-
@nnx.vmap
116+
@jax.vmap
125117
def create_stack(rngs):
126-
return Block(2, 64, 2, rngs=rngs)
118+
return nnx.to_lojax(Block(2, 64, 2, rngs=rngs))
127119
128-
with nnx.use_refs(True):
129-
block_stack = create_stack(nnx.Rngs(0).fork(split=8))
120+
block_stack = nnx.to_hijax(create_stack(nnx.Rngs(0).fork(split=8)))
130121
131122
def scan_fn(x, block):
132123
x = block(x)
@@ -150,42 +141,42 @@ def create_model(rngs):
150141
return Block(2, 64, 3, rngs=rngs)
151142
152143
try:
153-
with nnx.use_refs(True):
154-
model = create_model(nnx.Rngs(0))
144+
model = create_model(nnx.Rngs(0))
155145
except Exception as e:
156146
print(f"Error:", e)
157147
```
158148

159149
```{code-cell} ipython3
160-
with nnx.use_refs(False): # <-- disable array refs
150+
with nnx.use_hijax(False): # <-- disable hijax Variables
161151
model = create_model(nnx.Rngs(0))
162152
163-
model = nnx.to_refs(model) # convert to mutable after creation
153+
model = nnx.to_hijax(model) # convert to mutable after creation
164154
165155
print("model.linear =", model.linear)
166156
```
167157

168158
```{code-cell} ipython3
159+
# TODO: why does this work?
169160
@nnx.jit
170161
def create_model(rngs):
171162
return Block(2, 64, 3, rngs=rngs)
172163
173-
with nnx.use_refs(True):
174-
model = create_model(nnx.Rngs(0))
164+
model = create_model(nnx.Rngs(0))
175165
176166
print("model.linear =", model.linear)
177167
```
178168

179169
### Reference Sharing (aliasing)
180170

181171
```{code-cell} ipython3
172+
# TODO: why does this not fail?
182173
def get_error(f, *args):
183174
try:
184175
return f(*args)
185176
except Exception as e:
186177
return f"{type(e).__name__}: {e}"
187-
188-
x = jax.new_ref(jnp.array(0))
178+
179+
x = nnx.Variable(jnp.array(0))
189180
190181
@jax.jit
191182
def f(a, b):
@@ -211,9 +202,8 @@ class SharedModules(nnx.Pytree):
211202
def g(pytree):
212203
...
213204
214-
with nnx.use_refs(True):
215-
shared_variables = SharedVariables()
216-
shared_modules = SharedModules()
205+
shared_variables = SharedVariables()
206+
shared_modules = SharedModules()
217207
218208
print("SharedVariables", get_error(g, shared_variables))
219209
print("SharedModules", get_error(g, shared_modules))

docs_nnx/index.rst

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,11 @@ Basic usage
107107
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
108108
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
109109

110-
@nnx.jit # automatic state management for JAX transforms
110+
@nnx.jit # automatic state propagation
111111
def train_step(model, optimizer, x, y):
112-
def loss_fn(model):
113-
y_pred = model(x) # call methods directly
114-
return ((y_pred - y) ** 2).mean()
115-
112+
loss_fn = lambda model: ((model(x) - y) ** 2).mean()
116113
loss, grads = nnx.value_and_grad(loss_fn)(model)
117114
optimizer.update(model, grads) # in-place updates
118-
119115
return loss
120116

121117

docs_nnx/nnx_basics.ipynb

Lines changed: 27 additions & 48 deletions
Large diffs are not rendered by default.

docs_nnx/nnx_basics.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class Linear(nnx.Module):
4040
self.din, self.dout = din, dout
4141
4242
def __call__(self, x: jax.Array):
43-
return x @ self.w + self.b
43+
return x @ self.w + self.b[None]
4444
```
4545

4646
Also note that the inner values of `Variable`s can be accessed using the `value` property, but for convenience they implement all numeric operators and can be used directly in arithmetic expressions (as shown in the code above).
@@ -73,12 +73,12 @@ class Counter(nnx.Module):
7373
self.count = Count(jnp.array(0))
7474
7575
def __call__(self):
76-
self.count.value += 1
76+
self.count[...] += 1
7777
7878
counter = Counter()
79-
print(f'{counter.count.value = }')
79+
print(f'{counter.count[...] = }')
8080
counter()
81-
print(f'{counter.count.value = }')
81+
print(f'{counter.count[...] = }')
8282
```
8383

8484
Mutable references are usually avoided in JAX. But Flax NNX provides sound mechanisms

examples/gemma/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def assign_val_fn(
7474
mapped_path: tuple[str | int, ...],
7575
val: Any,
7676
) -> dict[tuple[str, ...], Any]:
77-
state[mapped_path].value = val
77+
state[mapped_path].set_value(val)
7878
return state
7979

8080
mdl: M = nnx.eval_shape(module_factory)

examples/gemma/helpers_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,11 @@ def _map_key_fn(key: tuple[str, ...]) -> tuple[str | int, ...]:
137137
np.testing.assert_array_equal(output, linen_output)
138138
for i in range(len(num_features)):
139139
np.testing.assert_array_equal(
140-
mdl.layers[i].layers[0].mean.value,
140+
mdl.layers[i].layers[0].mean[...],
141141
linen_vars['batch_stats'][f'layers_{i}']['layers_0']['mean'],
142142
)
143143
np.testing.assert_array_equal(
144-
mdl.layers[i].layers[0].var.value,
144+
mdl.layers[i].layers[0].var[...],
145145
linen_vars['batch_stats'][f'layers_{i}']['layers_0']['var'],
146146
)
147147

examples/gemma/layers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ def __init__(
4444
self.w = nnx.Param(kernel_init(rngs.params(), shape, dtype))
4545

4646
def __call__(self, x: ArrayLike) -> Array:
47-
return jnp.einsum(self.einsum_str, x, self.w.value)
47+
return jnp.einsum(self.einsum_str, x, self.w[...])
4848

4949
@property
5050
def shape(self) -> Shape:
51-
return self.w.value.shape
51+
return self.w.shape
5252

5353

5454
class RMSNorm(nnx.Module):
@@ -65,12 +65,12 @@ def __init__(
6565
self.scale = nnx.Param(scale_init(rngs.params(), dim, dtype))
6666

6767
def __call__(self, x: Array) -> Array:
68-
dtype = self.scale.value.dtype
68+
dtype = self.scale.dtype
6969
var = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
7070
normed_inputs = jnp.asarray(x * jax.lax.rsqrt(var + 1e-06), dtype=dtype)
7171
# normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale is
7272
# a rank-1 tensor. To avoid implicit rank-promotion, reshape scale to
7373
# a (1, ..., 1, D) tensor, so the rank of scale matches normed_inputs.
74-
scale = jnp.expand_dims(self.scale.value, axis=range(len(x.shape) - 1))
74+
scale = jnp.expand_dims(self.scale, axis=range(len(x.shape) - 1))
7575
normed_inputs = normed_inputs * (1 + scale)
7676
return normed_inputs

examples/gemma/modules.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,15 @@ def encode(self, x: ArrayLike) -> Array:
6363
return x
6464

6565
def decode(self, x: ArrayLike) -> Array:
66-
return jnp.dot(x, self.input_embedding.value.T)
66+
return jnp.dot(x, self.input_embedding.T)
6767

6868
@property
6969
def embed_dim(self):
70-
return self.input_embedding.value.shape[1]
70+
return self.input_embedding.shape[1]
7171

7272
@property
7373
def num_embed(self):
74-
return self.input_embedding.value.shape[0]
74+
return self.input_embedding.shape[0]
7575

7676

7777
class Attention(nnx.Module):

0 commit comments

Comments
 (0)