|
| 1 | +# Copyright 2024 The Flax Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +# %% |
| 16 | +import jax |
| 17 | +import jax.numpy as jnp |
| 18 | +import matplotlib.pyplot as plt |
| 19 | +import numpy as np |
| 20 | + |
| 21 | +from flax import nnx |
| 22 | + |
| 23 | +# ## Data |
| 24 | +# We create a simple dataset of points sampled from a parabola with some noise. |
| 25 | +X = np.linspace(-jnp.pi, jnp.pi, 100)[:, None] |
| 26 | +Y = 0.8 * jnp.sin(X) + 0.1 + np.random.normal(0, 0.1, size=X.shape) |
| 27 | + |
| 28 | + |
| 29 | +def dataset(batch_size): |
| 30 | + while True: |
| 31 | + idx = np.random.choice(len(X), size=batch_size) |
| 32 | + yield X[idx], Y[idx] |
| 33 | + |
| 34 | + |
| 35 | +# ## Model |
| 36 | +# Here we define a MLP made of a stack of blocks. Each block contains a linear layer, |
| 37 | +# batch normalization, and a dropout layer. |
| 38 | +# |
| 39 | +# In this version we want the Modules to be pytrees so they can be used with JAX transforms |
| 40 | +# so we use a new Pytree type as the base. The main difference with current NNX is that |
| 41 | +# attributes that contain arrays or other pytrees now need to be explicitly marked as |
| 42 | +# using `nnx.data` to be included in the pytree. |
| 43 | +class Linear(nnx.Module): |
| 44 | + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): |
| 45 | + self.din, self.dout = din, dout |
| 46 | + initializer = jax.nn.initializers.lecun_normal() |
| 47 | + # nnx.data is used mark attributes as pytree data |
| 48 | + # Param, BatchState, and Cache are built-in Variable subtypes |
| 49 | + self.w = nnx.Param(initializer(rngs.params(), (din, dout))) |
| 50 | + self.b = nnx.Param(jnp.zeros((dout,))) |
| 51 | + |
| 52 | + def __call__(self, x: jax.Array): |
| 53 | + return x @ self.w + self.b[None] |
| 54 | + |
| 55 | + |
| 56 | +# Block implements linear, batch norm, and dropout. Its behavior |
| 57 | +# is controlled by the 'use_stats' and 'deterministic' flags. |
| 58 | +class Block(nnx.Module): |
| 59 | + def __init__( |
| 60 | + self, |
| 61 | + din: int, |
| 62 | + dout: int, |
| 63 | + *, |
| 64 | + dropout_rate: float = 0.05, |
| 65 | + moumentum: float = 0.95, |
| 66 | + use_stats: bool = False, |
| 67 | + deterministic: bool = False, |
| 68 | + rngs: nnx.Rngs, |
| 69 | + ): |
| 70 | + # ----------- linear ------------------- |
| 71 | + self.din, self.dout = din, dout |
| 72 | + initializer = jax.nn.initializers.lecun_normal() |
| 73 | + self.w = nnx.Param(initializer(rngs.params(), (din, dout))) |
| 74 | + self.b = nnx.Param(jnp.zeros((dout,))) |
| 75 | + # ----------- batch norm --------------- |
| 76 | + self.mu = moumentum # momentum |
| 77 | + self.use_stats = use_stats |
| 78 | + self.mean = nnx.BatchStat(jnp.zeros((dout,))) |
| 79 | + self.var = nnx.BatchStat(jnp.ones((dout,))) |
| 80 | + self.scale = nnx.Param(jnp.ones((dout,))) |
| 81 | + self.bias = nnx.Param(jnp.zeros((dout,))) |
| 82 | + # ----------- dropout ------------------ |
| 83 | + self.dropout_rate = dropout_rate |
| 84 | + self.deterministic = deterministic |
| 85 | + |
| 86 | + def __call__( |
| 87 | + self, x: jax.Array, *, rngs: nnx.Rngs | None = None |
| 88 | + ) -> jax.Array: |
| 89 | + # ----------- linear -------------------- |
| 90 | + x = x @ self.w + self.b[None] |
| 91 | + # ----------- batch norm ---------------- |
| 92 | + if self.use_stats: |
| 93 | + mean = self.mean |
| 94 | + var = self.var |
| 95 | + else: |
| 96 | + mean = jnp.mean(x, axis=0) |
| 97 | + var = jnp.var(x, axis=0) |
| 98 | + # ema updates |
| 99 | + # stop gradient is used until a Hijax supports updates from grad tracers |
| 100 | + sg = jax.lax.stop_gradient |
| 101 | + self.mean.value = sg(self.mu * self.mean + (1 - self.mu) * mean) |
| 102 | + self.var.value = sg(self.mu * self.var + (1 - self.mu) * var) |
| 103 | + x = (x - mean[None]) / jnp.sqrt(var[None] + 1e-5) |
| 104 | + x = x * self.scale + self.bias |
| 105 | + # ----------- dropout ------------------- |
| 106 | + if not self.deterministic and self.dropout_rate > 0.0: |
| 107 | + assert rngs is not None |
| 108 | + keep_prob = 1.0 - self.dropout_rate |
| 109 | + mask = jax.random.bernoulli(rngs.dropout(), keep_prob, x.shape) |
| 110 | + x = jnp.where(mask, x / keep_prob, jnp.zeros_like(x)) |
| 111 | + # ----------- activation --------------- |
| 112 | + x = jax.nn.gelu(x) |
| 113 | + return x |
| 114 | + |
| 115 | + |
| 116 | +class Model(nnx.Module): |
| 117 | + def __init__( |
| 118 | + self, |
| 119 | + num_blocks: int, |
| 120 | + din: int, |
| 121 | + dhidden: int, |
| 122 | + dout: int, |
| 123 | + *, |
| 124 | + use_scan: bool = True, |
| 125 | + rngs: nnx.Rngs, |
| 126 | + ): |
| 127 | + self.count = nnx.Variable(jnp.array(0)) |
| 128 | + self.block_in = Block(din, dhidden, rngs=rngs) |
| 129 | + self.linear_out = Linear(dhidden, dout, rngs=rngs) |
| 130 | + |
| 131 | + # 'blocks' is either a list of blocks or single block |
| 132 | + # whose parameters contain an additional 'layer' dimension, |
| 133 | + # here created using jax.vmap |
| 134 | + if use_scan: |
| 135 | + |
| 136 | + @jax.vmap |
| 137 | + def create_block(rngs, /): |
| 138 | + # return nnx.stateless(Block(dhidden, dhidden, rngs=rngs)) |
| 139 | + return Block(dhidden, dhidden, rngs=rngs) |
| 140 | + |
| 141 | + # self.blocks = nnx.stateful(create_block(rngs.fork(split=num_blocks))) |
| 142 | + self.blocks = create_block(rngs.fork(split=num_blocks)) |
| 143 | + else: |
| 144 | + self.blocks = nnx.List( |
| 145 | + [Block(dhidden, dhidden, rngs=rngs) for i in range(num_blocks)] |
| 146 | + ) |
| 147 | + |
| 148 | + def __call__(self, x: jax.Array, *, rngs: nnx.Rngs | None = None): |
| 149 | + self.count.value += 1 |
| 150 | + x = self.block_in(x, rngs=rngs) |
| 151 | + |
| 152 | + # on the forward pass we either iterate over the block |
| 153 | + # list or use jax.lax.scan to apply the blocks, if we |
| 154 | + # had shared state we would use split and merge to |
| 155 | + # pass the shared state as a capture |
| 156 | + if isinstance(self.blocks, nnx.List): |
| 157 | + for block in self.blocks: |
| 158 | + x = block(x, rngs=rngs) |
| 159 | + else: |
| 160 | + |
| 161 | + def block_fw(x, block: Block): |
| 162 | + x = block(x, rngs=rngs) |
| 163 | + return x, None |
| 164 | + |
| 165 | + x, _ = jax.lax.scan(block_fw, x, self.blocks) |
| 166 | + x = self.linear_out(x) |
| 167 | + return x |
| 168 | + |
| 169 | + |
| 170 | +# ## Optimizer |
| 171 | +class OptState(nnx.Variable): ... |
| 172 | + |
| 173 | + |
| 174 | +# Optimizer are an interesting case as they are inherently stateful and |
| 175 | +# pose a good use case for MutableHijax. Here we implement SGD with |
| 176 | +# momentum. The optimizer receives the params as constructor arguments but doesn't |
| 177 | +# hold a reference to them, it only uses the params to initialize its state |
| 178 | +# by creating new OptState Variables that reuse the param's metadata. |
| 179 | +class SGD(nnx.Pytree): |
| 180 | + def __init__(self, params, lr: float, decay: float = 0.9): |
| 181 | + self.lr = lr |
| 182 | + self.decay = decay |
| 183 | + |
| 184 | + def make_opt_state(x): |
| 185 | + if isinstance(x, nnx.Variable): |
| 186 | + return OptState(jnp.zeros_like(x.value), **x.get_metadata()) |
| 187 | + else: |
| 188 | + return OptState(jnp.zeros_like(x)) |
| 189 | + |
| 190 | + self.momentum = nnx.data(jax.tree.map(make_opt_state, params)) |
| 191 | + |
| 192 | + # during the update we simply map over (params, momentum, grads), |
| 193 | + # for each triplet we implement the SGD update rule which updates |
| 194 | + # both the optimizer's state (momentum) and the params in place. |
| 195 | + def update(self, params, grads): |
| 196 | + def update_fn( |
| 197 | + param: nnx.Variable[jax.Array], |
| 198 | + momentum: nnx.Variable[jax.Array], |
| 199 | + grad: nnx.Variable[jax.Array], |
| 200 | + ): |
| 201 | + momentum.value = self.decay * momentum + (1 - self.decay) * grad |
| 202 | + param.value -= self.lr * momentum |
| 203 | + |
| 204 | + # is_leaf might not be necesarry as MutableHijaxVariable are not pytreees |
| 205 | + jax.tree.map(update_fn, params, self.momentum, grads) |
| 206 | + |
| 207 | + |
| 208 | +# ## Training |
| 209 | + |
| 210 | +nnx.use_hijax('mutable') |
| 211 | +rngs = nnx.Rngs(params=0, dropout=1) |
| 212 | +model = Model( |
| 213 | + num_blocks=3, din=1, dhidden=256, dout=1, use_scan=False, rngs=rngs |
| 214 | +) |
| 215 | +optimizer = SGD(params=nnx.state(model, nnx.Param), lr=3e-3, decay=0.99) |
| 216 | + |
| 217 | +# Create a copy of the model structure and set its attributes to eval model. |
| 218 | +# This works because they share the underlying ArrayRefs so both models |
| 219 | +# will always be in sync. |
| 220 | +eval_model = nnx.merge(*nnx.split(model)) |
| 221 | +eval_model.set_attributes(use_stats=True, deterministic=True) |
| 222 | + |
| 223 | + |
| 224 | +# The training step uses 'jax.jit' and receives the model and optimizer as arguments, |
| 225 | +# this is supported as they are now pytrees. The first thing we do is group the model |
| 226 | +# state into the params and the non-differentiable state using 'split'. We differentiate |
| 227 | +# the loss function using 'jax.grad' with respect to the params-only. Inside the loss |
| 228 | +# function we merge the params and non-diff state back into a single model and then |
| 229 | +# compute the loss by calling the model with the inputs. |
| 230 | +@jax.jit |
| 231 | +def train_step(model: Model, optimizer: SGD, rngs: nnx.Rngs, x, y): |
| 232 | + graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) |
| 233 | + |
| 234 | + def loss_fn(params): |
| 235 | + model = nnx.merge(graphdef, params, nondiff) |
| 236 | + loss = jnp.mean((model(x, rngs=rngs) - y) ** 2) |
| 237 | + return loss |
| 238 | + |
| 239 | + # For the time being we have to use 'freeze' make the Variables immutable |
| 240 | + # as 'jax.grad' doesn't support Hijax types yet. |
| 241 | + grads = jax.grad(loss_fn)(nnx.stateless(params)) |
| 242 | + # 'update' mutates the optimizer's state and the params in place |
| 243 | + # so we don't need to return anything 🚀 |
| 244 | + optimizer.update(params, grads) |
| 245 | + |
| 246 | + |
| 247 | +# simple test step that computes the loss |
| 248 | +@jax.jit |
| 249 | +def test_step(model: Model, x, y): |
| 250 | + return {'loss': jnp.mean((model(x) - y) ** 2)} |
| 251 | + |
| 252 | + |
| 253 | +# minimalistic training loop |
| 254 | +total_steps = 400 |
| 255 | +for step, (x, y) in enumerate(dataset(32)): |
| 256 | + train_step(model, optimizer, rngs, x, y) |
| 257 | + |
| 258 | + if step % 10 == 0: |
| 259 | + logs = test_step(eval_model, X, Y) |
| 260 | + print(f'step: {step}, loss: {logs["loss"]}') |
| 261 | + |
| 262 | + if step >= total_steps - 1: |
| 263 | + break |
| 264 | + |
| 265 | +# ## Sample |
| 266 | +# Sampling is trivial, just use 'model_eval' |
| 267 | +print('times called:', eval_model.count.value) |
| 268 | + |
| 269 | +y_pred = eval_model(X) |
| 270 | + |
| 271 | +plt.scatter(X, Y, color='blue') |
| 272 | +plt.plot(X, y_pred, color='black') |
| 273 | +plt.show() |
0 commit comments