Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs_nnx/mnist_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, b
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(model, batch)
metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates.
optimizer.update(grads) # In-place updates.
optimizer.update(model, grads) # In-place updates.

@nnx.jit
def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):
Expand Down
3 changes: 1 addition & 2 deletions examples/mnist/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ https://colab.research.google.com/github/google/flax/blob/main/examples/mnist/mn
[gs://flax_public/examples/mnist/default]: https://console.cloud.google.com/storage/browser/flax_public/examples/mnist/default

```
I0828 08:51:41.821526 139971964110656 train.py:130] train epoch: 10, loss: 0.0097, accuracy: 99.69
I0828 08:51:42.248714 139971964110656 train.py:180] eval epoch: 10, loss: 0.0299, accuracy: 99.14
I1009 17:56:42.674334 3280981 train.py:175] epoch: 10, train_loss: 0.0073, train_accuracy: 99.75, test_loss: 0.0294, test_accuracy: 99.25
```

### How to run
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ml_collections import config_flags
import tensorflow as tf

import train
import train # pylint: disable=g-bad-import-order


FLAGS = flags.FLAGS
Expand Down
225 changes: 126 additions & 99 deletions examples/mnist/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,147 +20,174 @@

# See issue #620.
# pytype: disable=wrong-keyword-args
import functools
from typing import Any

from absl import logging
from flax import linen as nn
from flax import nnx
from flax.metrics import tensorboard
from flax.training import train_state
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
import tensorflow as tf
import tensorflow_datasets as tfds

tf.random.set_seed(0) # Set the random seed for reproducibility.

class CNN(nn.Module):

class CNN(nnx.Module):
"""A simple CNN model."""

@nn.compact
def __init__(self, *, rngs: nnx.Rngs):
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
self.batch_norm1 = nnx.BatchNorm(32, rngs=rngs)
self.dropout1 = nnx.Dropout(rate=0.025, rngs=rngs)
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
self.batch_norm2 = nnx.BatchNorm(64, rngs=rngs)
self.avg_pool = functools.partial(
nnx.avg_pool, window_shape=(2, 2), strides=(2, 2)
)
self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
self.dropout2 = nnx.Dropout(rate=0.025, rngs=rngs)
self.linear2 = nnx.Linear(256, 10, rngs=rngs)

def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = self.avg_pool(nnx.relu(self.batch_norm1(self.dropout1(self.conv1(x)))))
x = self.avg_pool(nnx.relu(self.batch_norm2(self.conv2(x))))
x = x.reshape(x.shape[0], -1) # flatten
x = nnx.relu(self.dropout2(self.linear1(x)))
x = self.linear2(x)
return x


@jax.jit
def apply_model(state, images, labels):
"""Computes gradients, loss and accuracy for a single batch."""

def loss_fn(params):
logits = state.apply_fn({'params': params}, images)
one_hot = jax.nn.one_hot(labels, 10)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits

grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
return grads, loss, accuracy
def loss_fn(model: CNN, batch) -> tuple[jax.Array, Any]:
logits = model(batch['image'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label']
).mean()
return loss, logits


@jax.jit
def update_model(state, grads):
return state.apply_gradients(grads=grads)
@nnx.jit
def train_step(
model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch
) -> None:
"""Train for a single step."""
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(model, batch)
metrics.update(
loss=loss, logits=logits, labels=batch['label']
) # In-place updates.
optimizer.update(model, grads) # In-place updates.


def train_epoch(state, train_ds, batch_size, rng):
"""Train for a single epoch."""
train_ds_size = len(train_ds['image'])
steps_per_epoch = train_ds_size // batch_size
@nnx.jit
def eval_step(model: CNN, metrics: nnx.MultiMetric, batch) -> None:
loss, logits = loss_fn(model, batch)
metrics.update(
loss=loss, logits=logits, labels=batch['label']
) # In-place updates.

perms = jax.random.permutation(rng, len(train_ds['image']))
perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))

epoch_loss = []
epoch_accuracy = []

for perm in perms:
batch_images = train_ds['image'][perm, ...]
batch_labels = train_ds['label'][perm, ...]
grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
state = update_model(state, grads)
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
train_loss = np.mean(epoch_loss)
train_accuracy = np.mean(epoch_accuracy)
return state, train_loss, train_accuracy


def get_datasets():
def get_datasets(
config: ml_collections.ConfigDict,
) -> tuple[tf.data.Dataset, tf.data.Dataset]:
"""Load MNIST train and test datasets into memory."""
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
train_ds['image'] = jnp.float32(train_ds['image']) / 255.0
test_ds['image'] = jnp.float32(test_ds['image']) / 255.0
return train_ds, test_ds
batch_size = config.batch_size
train_ds: tf.data.Dataset = tfds.load('mnist', split='train')
test_ds: tf.data.Dataset = tfds.load('mnist', split='test')

train_ds = train_ds.map(
lambda sample: {
'image': tf.cast(sample['image'], tf.float32) / 255,
'label': sample['label'],
}
) # normalize train set
test_ds = test_ds.map(
lambda sample: {
'image': tf.cast(sample['image'], tf.float32) / 255,
'label': sample['label'],
}
) # normalize the test set.

# Create a shuffled dataset by allocating a buffer size of 1024 to randomly
# draw elements from.
train_ds = train_ds.shuffle(1024)
# Group into batches of `batch_size` and skip incomplete batches, prefetch the
# next sample to improve latency.
train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1)
# Group into batches of `batch_size` and skip incomplete batches, prefetch the
# next sample to improve latency.
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)


def create_train_state(rng, config):
"""Creates initial `TrainState`."""
cnn = CNN()
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
tx = optax.sgd(config.learning_rate, config.momentum)
return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)
return train_ds, test_ds


def train_and_evaluate(
config: ml_collections.ConfigDict, workdir: str
) -> train_state.TrainState:
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> None:
"""Execute model training and evaluation loop.

Args:
config: Hyperparameter configuration for training and evaluation.
workdir: Directory where the tensorboard summaries are written to.

Returns:
The train state (which includes the `.params`).
workdir: Directory path to store metrics.
"""
train_ds, test_ds = get_datasets()
rng = jax.random.key(0)
train_ds, test_ds = get_datasets(config)

# Instantiate the model.
model = CNN(rngs=nnx.Rngs(0))

learning_rate = config.learning_rate
momentum = config.momentum

summary_writer = tensorboard.SummaryWriter(workdir)
summary_writer.hparams(dict(config))

rng, init_rng = jax.random.split(rng)
state = create_train_state(init_rng, config)
optimizer = nnx.Optimizer(
model, optax.sgd(learning_rate, momentum), wrt=nnx.Param
)
metrics = nnx.MultiMetric(
accuracy=nnx.metrics.Accuracy(),
loss=nnx.metrics.Average('loss'),
)

for epoch in range(1, config.num_epochs + 1):
rng, input_rng = jax.random.split(rng)
state, train_loss, train_accuracy = train_epoch(
state, train_ds, config.batch_size, input_rng
)
_, test_loss, test_accuracy = apply_model(
state, test_ds['image'], test_ds['label']
)

logging.info(
# Run the optimization for one step and make a stateful update to the
# following:
# - The train state's model parameters
# - The optimizer state
# - The training loss and accuracy batch metrics
model.train() # Switch to train mode

for batch in train_ds.as_numpy_iterator():
train_step(model, optimizer, metrics, batch)
# Compute the training metrics.
train_metrics = metrics.compute()
metrics.reset() # Reset the metrics for the test set.

# Compute the metrics on the test set after each training epoch.
model.eval() # Switch to eval mode
for batch in test_ds.as_numpy_iterator():
eval_step(model, metrics, batch)

# Compute the eval metrics.
eval_metrics = metrics.compute()
metrics.reset() # Reset the metrics for the next training epoch.

logging.info( # pylint: disable=logging-not-lazy
'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f,'
' test_accuracy: %.2f'
% (
epoch,
train_loss,
train_accuracy * 100,
test_loss,
test_accuracy * 100,
train_metrics['loss'],
train_metrics['accuracy'] * 100,
eval_metrics['loss'],
eval_metrics['accuracy'] * 100,
)
)

summary_writer.scalar('train_loss', train_loss, epoch)
summary_writer.scalar('train_accuracy', train_accuracy, epoch)
summary_writer.scalar('test_loss', test_loss, epoch)
summary_writer.scalar('test_accuracy', test_accuracy, epoch)
# Write the metrics to TensorBoard.
summary_writer.scalar('train_loss', train_metrics['loss'], epoch)
summary_writer.scalar('train_accuracy', train_metrics['accuracy'], epoch)
summary_writer.scalar('test_loss', eval_metrics['loss'], epoch)
summary_writer.scalar('test_accuracy', eval_metrics['accuracy'], epoch)

summary_writer.flush()
return state
6 changes: 3 additions & 3 deletions examples/wmt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
# pytype: disable=wrong-keyword-args
# pytype: disable=attribute-error

from typing import Any, Optional
from collections.abc import Callable
from typing import Any, Callable

from flax import linen as nn
from flax import struct
Expand Down Expand Up @@ -549,7 +548,8 @@ def decode(

# Make padding attention masks.
if config.decode:
# for fast autoregressive decoding only a special encoder-decoder mask is used
# for fast autoregressive decoding only a special encoder-decoder mask is
# used
decoder_mask = None
encoder_decoder_mask = nn.make_attention_mask(
jnp.ones_like(targets) > 0, inputs > 0, dtype=config.dtype
Expand Down
8 changes: 4 additions & 4 deletions examples/wmt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def loss_fn(params):
if state.dynamic_scale:
# if is_fin == False the gradients contain Inf/NaNs and optimizer state and
# params should be restored (= skip this step).
select_fn = functools.partial(jnp.where, is_fin)
select_fn = functools.partial(jnp.where, is_fin) # pylint: disable=undefined-variable
new_state = new_state.replace(
opt_state=jax.tree_util.tree_map(
select_fn, new_state.opt_state, state.opt_state
Expand All @@ -259,7 +259,7 @@ def loss_fn(params):
select_fn, new_state.params, state.params
),
)
metrics["loss_scale"] = dynamic_scale.scale * metrics["denominator"]
metrics["loss_scale"] = dynamic_scale.scale * metrics["denominator"] # pylint: disable=undefined-variable

return new_state, metrics

Expand Down Expand Up @@ -649,8 +649,8 @@ def decode_tokens(toks):
metrics_sums = jax.tree_util.tree_map(jnp.sum, train_metrics)
denominator = metrics_sums.pop("denominator")
summary = jax.tree_util.tree_map(
lambda x: x / denominator, metrics_sums
) # pylint: disable=cell-var-from-loop
lambda x: x / denominator, metrics_sums # pylint: disable=cell-var-from-loop
)
summary["learning_rate"] = lr
summary = {"train_" + k: v for k, v in summary.items()}
writer.write_scalars(step, summary)
Expand Down
Loading