Skip to content

Commit dff0ff6

Browse files
committed
hijax Variable
1 parent e30c934 commit dff0ff6

File tree

11 files changed

+1391
-90
lines changed

11 files changed

+1391
-90
lines changed

docs_nnx/guides/pytree.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@
295295
"source": [
296296
"The only change we had to do here is use `nnx.List` to signal that `layers` contains `data`, the status of the rest of the attributes can be correctly inferred. The rules that determine if a value is data or not are the following:\n",
297297
"\n",
298-
"* `Array`s, `Variable`s, `ArrayRef`s, and `Pytree`s are data.\n",
298+
"* `Array`s, `Variable`s, `ArrayRef`s, and `nnx.Pytree`s are data.\n",
299299
"* Types registered using `nnx.register_data_type` are data.\n",
300300
"* All other types are static.\n",
301301
"\n",

docs_nnx/guides/pytree.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ pytree_structure(pytree)
155155

156156
The only change we had to do here is use `nnx.List` to signal that `layers` contains `data`, the status of the rest of the attributes can be correctly inferred. The rules that determine if a value is data or not are the following:
157157

158-
* `Array`s, `Variable`s, `ArrayRef`s, and `Pytree`s are data.
158+
* `Array`s, `Variable`s, `ArrayRef`s, and `nnx.Pytree`s are data.
159159
* Types registered using `nnx.register_data_type` are data.
160160
* All other types are static.
161161

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
# Copyright 2024 The Flax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# %%
16+
import jax
17+
import jax.numpy as jnp
18+
import matplotlib.pyplot as plt
19+
import numpy as np
20+
21+
from flax import nnx
22+
23+
# ## Data
24+
# We create a simple dataset of points sampled from a parabola with some noise.
25+
X = np.linspace(-jnp.pi, jnp.pi, 100)[:, None]
26+
Y = 0.8 * jnp.sin(X) + 0.1 + np.random.normal(0, 0.1, size=X.shape)
27+
28+
29+
def dataset(batch_size):
30+
while True:
31+
idx = np.random.choice(len(X), size=batch_size)
32+
yield X[idx], Y[idx]
33+
34+
35+
# ## Model
36+
# Here we define a MLP made of a stack of blocks. Each block contains a linear layer,
37+
# batch normalization, and a dropout layer.
38+
#
39+
# In this version we want the Modules to be pytrees so they can be used with JAX transforms
40+
# so we use a new Pytree type as the base. The main difference with current NNX is that
41+
# attributes that contain arrays or other pytrees now need to be explicitly marked as
42+
# using `nnx.data` to be included in the pytree.
43+
class Linear(nnx.Module):
44+
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
45+
self.din, self.dout = din, dout
46+
initializer = jax.nn.initializers.lecun_normal()
47+
# nnx.data is used mark attributes as pytree data
48+
# Param, BatchState, and Cache are built-in Variable subtypes
49+
self.w = nnx.Param(initializer(rngs.params(), (din, dout)))
50+
self.b = nnx.Param(jnp.zeros((dout,)))
51+
52+
def __call__(self, x: jax.Array):
53+
return x @ self.w + self.b[None]
54+
55+
56+
# Block implements linear, batch norm, and dropout. Its behavior
57+
# is controlled by the 'use_stats' and 'deterministic' flags.
58+
class Block(nnx.Module):
59+
def __init__(
60+
self,
61+
din: int,
62+
dout: int,
63+
*,
64+
dropout_rate: float = 0.05,
65+
moumentum: float = 0.95,
66+
use_stats: bool = False,
67+
deterministic: bool = False,
68+
rngs: nnx.Rngs,
69+
):
70+
# ----------- linear -------------------
71+
self.din, self.dout = din, dout
72+
initializer = jax.nn.initializers.lecun_normal()
73+
self.w = nnx.Param(initializer(rngs.params(), (din, dout)))
74+
self.b = nnx.Param(jnp.zeros((dout,)))
75+
# ----------- batch norm ---------------
76+
self.mu = moumentum # momentum
77+
self.use_stats = use_stats
78+
self.mean = nnx.BatchStat(jnp.zeros((dout,)))
79+
self.var = nnx.BatchStat(jnp.ones((dout,)))
80+
self.scale = nnx.Param(jnp.ones((dout,)))
81+
self.bias = nnx.Param(jnp.zeros((dout,)))
82+
# ----------- dropout ------------------
83+
self.dropout_rate = dropout_rate
84+
self.deterministic = deterministic
85+
86+
def __call__(
87+
self, x: jax.Array, *, rngs: nnx.Rngs | None = None
88+
) -> jax.Array:
89+
# ----------- linear --------------------
90+
x = x @ self.w + self.b[None]
91+
# ----------- batch norm ----------------
92+
if self.use_stats:
93+
mean = self.mean
94+
var = self.var
95+
else:
96+
mean = jnp.mean(x, axis=0)
97+
var = jnp.var(x, axis=0)
98+
# ema updates
99+
# stop gradient is used until a Hijax supports updates from grad tracers
100+
sg = jax.lax.stop_gradient
101+
self.mean.value = sg(self.mu * self.mean + (1 - self.mu) * mean)
102+
self.var.value = sg(self.mu * self.var + (1 - self.mu) * var)
103+
x = (x - mean[None]) / jnp.sqrt(var[None] + 1e-5)
104+
x = x * self.scale + self.bias
105+
# ----------- dropout -------------------
106+
if not self.deterministic and self.dropout_rate > 0.0:
107+
assert rngs is not None
108+
keep_prob = 1.0 - self.dropout_rate
109+
mask = jax.random.bernoulli(rngs.dropout(), keep_prob, x.shape)
110+
x = jnp.where(mask, x / keep_prob, jnp.zeros_like(x))
111+
# ----------- activation ---------------
112+
x = jax.nn.gelu(x)
113+
return x
114+
115+
116+
class Model(nnx.Module):
117+
def __init__(
118+
self,
119+
num_blocks: int,
120+
din: int,
121+
dhidden: int,
122+
dout: int,
123+
*,
124+
use_scan: bool = True,
125+
rngs: nnx.Rngs,
126+
):
127+
self.count = nnx.Variable(jnp.array(0))
128+
self.block_in = Block(din, dhidden, rngs=rngs)
129+
self.linear_out = Linear(dhidden, dout, rngs=rngs)
130+
131+
# 'blocks' is either a list of blocks or single block
132+
# whose parameters contain an additional 'layer' dimension,
133+
# here created using jax.vmap
134+
if use_scan:
135+
136+
@jax.vmap
137+
def create_block(rngs, /):
138+
# return nnx.stateless(Block(dhidden, dhidden, rngs=rngs))
139+
return Block(dhidden, dhidden, rngs=rngs)
140+
141+
# self.blocks = nnx.stateful(create_block(rngs.fork(split=num_blocks)))
142+
self.blocks = create_block(rngs.fork(split=num_blocks))
143+
else:
144+
self.blocks = nnx.List(
145+
[Block(dhidden, dhidden, rngs=rngs) for i in range(num_blocks)]
146+
)
147+
148+
def __call__(self, x: jax.Array, *, rngs: nnx.Rngs | None = None):
149+
self.count.value += 1
150+
x = self.block_in(x, rngs=rngs)
151+
152+
# on the forward pass we either iterate over the block
153+
# list or use jax.lax.scan to apply the blocks, if we
154+
# had shared state we would use split and merge to
155+
# pass the shared state as a capture
156+
if isinstance(self.blocks, nnx.List):
157+
for block in self.blocks:
158+
x = block(x, rngs=rngs)
159+
else:
160+
161+
def block_fw(x, block: Block):
162+
x = block(x, rngs=rngs)
163+
return x, None
164+
165+
x, _ = jax.lax.scan(block_fw, x, self.blocks)
166+
x = self.linear_out(x)
167+
return x
168+
169+
170+
# ## Optimizer
171+
class OptState(nnx.Variable): ...
172+
173+
174+
# Optimizer are an interesting case as they are inherently stateful and
175+
# pose a good use case for MutableHijax. Here we implement SGD with
176+
# momentum. The optimizer receives the params as constructor arguments but doesn't
177+
# hold a reference to them, it only uses the params to initialize its state
178+
# by creating new OptState Variables that reuse the param's metadata.
179+
class SGD(nnx.Pytree):
180+
def __init__(self, params, lr: float, decay: float = 0.9):
181+
self.lr = lr
182+
self.decay = decay
183+
184+
def make_opt_state(x):
185+
if isinstance(x, nnx.Variable):
186+
return OptState(jnp.zeros_like(x.value), **x.get_metadata())
187+
else:
188+
return OptState(jnp.zeros_like(x))
189+
190+
self.momentum = nnx.data(jax.tree.map(make_opt_state, params))
191+
192+
# during the update we simply map over (params, momentum, grads),
193+
# for each triplet we implement the SGD update rule which updates
194+
# both the optimizer's state (momentum) and the params in place.
195+
def update(self, params, grads):
196+
def update_fn(
197+
param: nnx.Variable[jax.Array],
198+
momentum: nnx.Variable[jax.Array],
199+
grad: nnx.Variable[jax.Array],
200+
):
201+
momentum.value = self.decay * momentum + (1 - self.decay) * grad
202+
param.value -= self.lr * momentum
203+
204+
# is_leaf might not be necesarry as MutableHijaxVariable are not pytreees
205+
jax.tree.map(update_fn, params, self.momentum, grads)
206+
207+
208+
# ## Training
209+
210+
nnx.use_hijax('mutable')
211+
rngs = nnx.Rngs(params=0, dropout=1)
212+
model = Model(
213+
num_blocks=3, din=1, dhidden=256, dout=1, use_scan=False, rngs=rngs
214+
)
215+
optimizer = SGD(params=nnx.state(model, nnx.Param), lr=3e-3, decay=0.99)
216+
217+
# Create a copy of the model structure and set its attributes to eval model.
218+
# This works because they share the underlying ArrayRefs so both models
219+
# will always be in sync.
220+
eval_model = nnx.merge(*nnx.split(model))
221+
eval_model.set_attributes(use_stats=True, deterministic=True)
222+
223+
224+
# The training step uses 'jax.jit' and receives the model and optimizer as arguments,
225+
# this is supported as they are now pytrees. The first thing we do is group the model
226+
# state into the params and the non-differentiable state using 'split'. We differentiate
227+
# the loss function using 'jax.grad' with respect to the params-only. Inside the loss
228+
# function we merge the params and non-diff state back into a single model and then
229+
# compute the loss by calling the model with the inputs.
230+
@jax.jit
231+
def train_step(model: Model, optimizer: SGD, rngs: nnx.Rngs, x, y):
232+
graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)
233+
234+
def loss_fn(params):
235+
model = nnx.merge(graphdef, params, nondiff)
236+
loss = jnp.mean((model(x, rngs=rngs) - y) ** 2)
237+
return loss
238+
239+
# For the time being we have to use 'freeze' make the Variables immutable
240+
# as 'jax.grad' doesn't support Hijax types yet.
241+
grads = jax.grad(loss_fn)(nnx.stateless(params))
242+
# 'update' mutates the optimizer's state and the params in place
243+
# so we don't need to return anything 🚀
244+
optimizer.update(params, grads)
245+
246+
247+
# simple test step that computes the loss
248+
@jax.jit
249+
def test_step(model: Model, x, y):
250+
return {'loss': jnp.mean((model(x) - y) ** 2)}
251+
252+
253+
# minimalistic training loop
254+
total_steps = 400
255+
for step, (x, y) in enumerate(dataset(32)):
256+
train_step(model, optimizer, rngs, x, y)
257+
258+
if step % 10 == 0:
259+
logs = test_step(eval_model, X, Y)
260+
print(f'step: {step}, loss: {logs["loss"]}')
261+
262+
if step >= total_steps - 1:
263+
break
264+
265+
# ## Sample
266+
# Sampling is trivial, just use 'model_eval'
267+
print('times called:', eval_model.count.value)
268+
269+
y_pred = eval_model(X)
270+
271+
plt.scatter(X, Y, color='blue')
272+
plt.plot(X, y_pred, color='black')
273+
plt.show()

flax/nnx/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767
from .graph import variables as variables
6868
from .graph import to_arrays as to_arrays
6969
from .graph import to_refs as to_refs
70+
from .graph import to_hijax as to_hijax
71+
from .graph import to_lojax as to_lojax
7072
from .graph import pure as pure
7173
from .graph import cached_partial as cached_partial
7274
from .graph import flatten as flatten
@@ -193,6 +195,8 @@
193195
from .variablelib import register_variable_name as register_variable_name
194196
from .variablelib import use_refs as use_refs
195197
from .variablelib import using_refs as using_refs
198+
from .variablelib import use_hijax as use_hijax
199+
from .variablelib import using_hijax as using_hijax
196200
from .visualization import display as display
197201
from .extract import to_tree as to_tree
198202
from .extract import from_tree as from_tree

flax/nnx/bridge/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def _get_variables(self) -> tp.Mapping:
390390

391391
if (
392392
isinstance(variable, variablelib.Variable)
393-
and not variable._var_metadata
393+
and not variable.get_metadata()
394394
):
395395
leaf = variable.value
396396
else:

flax/nnx/extract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def check_consistent_aliasing(
6363
lambda: f'Trying to extract graph node from different trace level, got {value!r}'
6464
)
6565
if isinstance(value, graph.Variable):
66-
if not value._trace_state.is_valid():
66+
if not value.trace_state.is_valid():
6767
raise ValueError(
6868
f'Cannot extract graph node from different trace level, got {value!r}'
6969
)

0 commit comments

Comments
 (0)