diff --git a/docs_nnx/mnist_tutorial.md b/docs_nnx/mnist_tutorial.md index 3eb95545d..4d007101e 100644 --- a/docs_nnx/mnist_tutorial.md +++ b/docs_nnx/mnist_tutorial.md @@ -158,7 +158,7 @@ def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, b grad_fn = nnx.value_and_grad(loss_fn, has_aux=True) (loss, logits), grads = grad_fn(model, batch) metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates. - optimizer.update(grads) # In-place updates. + optimizer.update(model, grads) # In-place updates. @nnx.jit def eval_step(model: CNN, metrics: nnx.MultiMetric, batch): diff --git a/examples/gemma/input_pipeline.py b/examples/gemma/input_pipeline.py index da9ae4733..82b3441e1 100644 --- a/examples/gemma/input_pipeline.py +++ b/examples/gemma/input_pipeline.py @@ -15,15 +15,11 @@ """Input pipeline for a LM1B dataset.""" import os -import typing +from typing import Any +import tokenizer import tensorflow as tf import tensorflow_datasets as tfds -import tokenizer -from clu import deterministic_data - -if typing.TYPE_CHECKING: - from train import TrainConfig AUTOTUNE = tf.data.experimental.AUTOTUNE Features = dict[str, tf.Tensor] @@ -58,9 +54,9 @@ def get_raw_dataset(dataset_name: str, split: str) -> tf.data.Dataset: def pack_dataset( - dataset: tf.data.Dataset, - key2length: int | dict[str, int], - keys: list[str] | None = None, + dataset: tf.data.Dataset, + key2length: int | dict[str, int], + keys: list[str] | None = None, ) -> tf.data.Dataset: """Creates a 'packed' version of a dataset on-the-fly. @@ -107,8 +103,8 @@ def pack_dataset( for k in keys: if k not in shapes: raise ValueError( - 'Key %s not found in dataset. Available keys are %s' - % (k, shapes.keys()) + 'Key %s not found in dataset. Available keys are %s' + % (k, shapes.keys()) ) if not shapes[k].is_compatible_with(tf.TensorShape([None])): # type: ignore[wrong-arg-types] raise ValueError('Tensors to be packed must be one-dimensional.') @@ -122,14 +118,14 @@ def pack_dataset( # trim to length dataset = dataset.map( - lambda x: {k: x[k][: key2length[k]] for k in keys}, - num_parallel_calls=AUTOTUNE, + lambda x: {k: x[k][: key2length[k]] for k in keys}, + num_parallel_calls=AUTOTUNE, ) # Setting batch_size=length ensures that the concatenated sequences (if they # have length >=1) are sufficient to fill at least one packed example. batch_size = max(key2length.values()) dataset = dataset.padded_batch( - batch_size, padded_shapes={k: [-1] for k in keys} + batch_size, padded_shapes={k: [-1] for k in keys} ) dataset = _pack_with_tf_ops(dataset, keys, key2length) @@ -141,7 +137,7 @@ def my_fn(x): def _pack_with_tf_ops( - dataset: tf.data.Dataset, keys: list[str], key2length: dict[str, int] + dataset: tf.data.Dataset, keys: list[str], key2length: dict[str, int] ) -> tf.data.Dataset: """Helper-function for packing a dataset which has already been batched. @@ -166,8 +162,8 @@ def write_packed_example(partial, outputs): new_outputs = {} for k in keys_etc: new_outputs[k] = outputs[k].write( - outputs[k].size(), - tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]), + outputs[k].size(), + tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]), ) return new_partial, new_outputs @@ -188,10 +184,10 @@ def map_fn(x): outputs = {} for k in keys: outputs[k] = tf.TensorArray( - tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] + tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] ) outputs[k + '_position'] = tf.TensorArray( - tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] + tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] ) def body_fn(i, partial, outputs): @@ -213,10 +209,10 @@ def body_fn(i, partial, outputs): one_example[k] = val for k in keys: can_append = tf.logical_and( - can_append, - tf.less_equal( - tf.size(partial[k]) + tf.size(one_example[k]), key2length[k] - ), + can_append, + tf.less_equal( + tf.size(partial[k]) + tf.size(one_example[k]), key2length[k] + ), ) def false_fn(): @@ -232,28 +228,28 @@ def true_fn(): new_seq_len = tf.size(new_seq) new_partial[k] = tf.concat([partial[k], new_seq], 0) new_partial[k + '_position'] = tf.concat( - [partial[k + '_position'], tf.range(new_seq_len)], 0 + [partial[k + '_position'], tf.range(new_seq_len)], 0 ) partial = new_partial return i + 1, partial, outputs # For loop over all examples in the batch. - i, partial, outputs = tf.while_loop( - cond=lambda *_: True, - body=body_fn, - loop_vars=(i, partial, outputs), - shape_invariants=( - tf.TensorShape([]), - {k: tf.TensorShape([None]) for k in keys_etc}, # type: ignore[wrong-arg-types] - {k: tf.TensorShape(None) for k in keys_etc}, # type: ignore[wrong-arg-types] - ), - maximum_iterations=dynamic_batch_size, + _, partial, outputs = tf.while_loop( + cond=lambda *_: True, + body=body_fn, + loop_vars=(i, partial, outputs), + shape_invariants=( + tf.TensorShape([]), + {k: tf.TensorShape([None]) for k in keys_etc}, # type: ignore[wrong-arg-types] + {k: tf.TensorShape(None) for k in keys_etc}, # type: ignore[wrong-arg-types] + ), + maximum_iterations=dynamic_batch_size, ) _, outputs = write_packed_example(partial, outputs) packed = {k: outputs[k].stack() for k in keys_etc} for k in keys: packed[k + '_segmentation'] = tf.cumsum( - tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1 + tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1 ) * tf.cast(tf.not_equal(packed[k], 0), tf.int32) return packed @@ -263,8 +259,8 @@ def true_fn(): def shift_data_by_truncation(x): # https://github.com/AI-Hypercomputer/maxtext/blob/7fe1de75b3919c0fda00d23ad6cb29def9098362/MaxText/input_pipeline/_input_pipeline_utils.py#L53 - x["inputs"] = x["inputs"][:-1] - x["targets"] = x["targets"][1:] + x['inputs'] = x['inputs'][:-1] + x['targets'] = x['targets'][1:] return x @@ -272,16 +268,16 @@ def shift_data_by_truncation(x): # Main dataset prep routines. # ----------------------------------------------------------------------------- def preprocess_data( - dataset, - shuffle: bool, - num_epochs: int | None = 1, - pack_examples: bool = True, - shuffle_buffer_size: int = 1024, - max_length: int = 512, - batch_size: int = 256, - drop_remainder: bool = True, - prefetch_size: int = AUTOTUNE, - shift: bool = True, + dataset, + shuffle: bool, + num_epochs: int | None = 1, + pack_examples: bool = True, + shuffle_buffer_size: int = 1024, + max_length: int = 512, + batch_size: int = 256, + drop_remainder: bool = True, + prefetch_size: int = AUTOTUNE, + shift: bool = True, ): """Shuffle and batch/pack the given dataset.""" @@ -303,7 +299,9 @@ def filter_fn(x): # Shift inputs for teacher-forced training if shift: dataset = dataset.map( - shift_data_by_truncation, num_parallel_calls=AUTOTUNE, deterministic=True + shift_data_by_truncation, + num_parallel_calls=AUTOTUNE, + deterministic=True, ) if pack_examples: @@ -311,10 +309,10 @@ def filter_fn(x): dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) else: # simple (static-shape) padded batching dataset = dataset.padded_batch( - batch_size, - padded_shapes={'inputs': max_length, 'targets': max_length}, - padding_values={'inputs': 0, 'targets': 0}, - drop_remainder=drop_remainder, + batch_size, + padded_shapes={'inputs': max_length, 'targets': max_length}, + padding_values={'inputs': 0, 'targets': 0}, + drop_remainder=drop_remainder, ) if prefetch_size: @@ -324,10 +322,10 @@ def filter_fn(x): def get_datasets( - config: "TrainConfig", - *, - n_devices: int, - vocab_path: str | None = None, + config: Any, + *, + n_devices: int, + vocab_path: str | None = None, ): """Load and return dataset of batched examples for use during training.""" if vocab_path is None: @@ -343,16 +341,16 @@ def get_datasets( # Tokenize data. sp_processor = tokenizer.load_or_train_tokenizer( - train_data, - vocab_path=vocab_path, - vocab_size=config.vocab_size, - max_corpus_chars=config.max_corpus_chars, + train_data, + vocab_path=vocab_path, + vocab_size=config.vocab_size, + max_corpus_chars=config.max_corpus_chars, ) train_data = train_data.map( - tokenizer.TokenizeOp(sp_processor), num_parallel_calls=AUTOTUNE + tokenizer.TokenizeOp(sp_processor), num_parallel_calls=AUTOTUNE ) eval_data = eval_data.map( - tokenizer.TokenizeOp(sp_processor), num_parallel_calls=AUTOTUNE + tokenizer.TokenizeOp(sp_processor), num_parallel_calls=AUTOTUNE ) batch_size = config.per_device_batch_size * n_devices @@ -362,20 +360,20 @@ def get_datasets( eval_batch_size = batch_size train_ds = preprocess_data( - train_data, - shuffle=True, - num_epochs=None, - pack_examples=True, - batch_size=batch_size, - max_length=config.max_target_length, + train_data, + shuffle=True, + num_epochs=None, + pack_examples=True, + batch_size=batch_size, + max_length=config.max_target_length, ) eval_ds = preprocess_data( - eval_data, - shuffle=False, - pack_examples=False, - batch_size=eval_batch_size, - max_length=config.max_eval_target_length, + eval_data, + shuffle=False, + pack_examples=False, + batch_size=eval_batch_size, + max_length=config.max_eval_target_length, ) return train_ds, eval_ds, sp_processor diff --git a/examples/gemma/main.py b/examples/gemma/main.py index f4185e216..cd97f3f10 100644 --- a/examples/gemma/main.py +++ b/examples/gemma/main.py @@ -18,21 +18,24 @@ that can be easily tested and imported in Colab. """ -import jax -import tensorflow as tf -import train -from absl import app, flags, logging +from absl import app +from absl import flags +from absl import logging from clu import platform +import train +import jax from ml_collections import config_flags +import tensorflow as tf + FLAGS = flags.FLAGS flags.DEFINE_string('workdir', None, 'Directory to store model data.') config_flags.DEFINE_config_file( - 'config', - 'configs/default.py', - 'File path to the training hyperparameter configuration.', - lock_config=True, + 'config', + 'configs/default.py', + 'File path to the training hyperparameter configuration.', + lock_config=True, ) flags.mark_flags_as_required(['workdir']) @@ -51,11 +54,11 @@ def main(argv): # Add a note so that we can tell which task is which JAX host. # (Depending on the platform task 0 is not guaranteed to be host 0) platform.work_unit().set_task_status( - f'process_index: {jax.process_index()}, ' - f'process_count: {jax.process_count()}' + f'process_index: {jax.process_index()}, ' + f'process_count: {jax.process_count()}' ) platform.work_unit().create_artifact( - platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' + platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' ) train.train_and_evaluate(FLAGS.config, FLAGS.workdir) diff --git a/examples/gemma/tokenizer.py b/examples/gemma/tokenizer.py index 5dcaa437e..c9d33d2e4 100644 --- a/examples/gemma/tokenizer.py +++ b/examples/gemma/tokenizer.py @@ -14,28 +14,34 @@ """Provides op for tokenizing a dataset.""" +from collections.abc import Iterable import dataclasses import os import sys import tempfile import time from typing import Any -from collections.abc import Iterable +from absl import logging import jax import tensorflow as tf + +# pylint: disable=g-import-not-at-top if sys.version_info < (3, 13): import tensorflow_text as tftxt -from absl import logging -from sentencepiece import SentencePieceTrainer, SentencePieceProcessor + +from sentencepiece import SentencePieceProcessor # pylint: disable=g-importing-member +from sentencepiece import SentencePieceTrainer # pylint: disable=g-importing-member +# pylint: enable=g-import-not-at-top + Features = dict[str, tf.Tensor] def _dump_chars_to_textfile( - dataset: tf.data.Dataset, - maxchars: int = int(1e7), - data_keys=('inputs', 'targets'), + dataset: tf.data.Dataset, + maxchars: int = int(1e7), + data_keys=('inputs', 'targets'), ) -> tuple[str, int]: """Write part of a TFDS sentence dataset to lines in a text file. @@ -50,7 +56,7 @@ def _dump_chars_to_textfile( char_count = 0 ds_iter = dataset.as_numpy_iterator() with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/ds_chars' + delete=False, prefix='/tmp/ds_chars' ) as outfp: while char_count < maxchars: example = next(ds_iter) @@ -62,18 +68,18 @@ def _dump_chars_to_textfile( def _train_sentencepiece( - dataset: tf.data.Dataset, - *, - vocab_size: int, - maxchars: int = int(1e7), - model_path: str, - model_type: str = 'unigram', - character_coverage: float = 1.0, - data_keys=('inputs', 'targets'), - pad_id: int = 0, - eos_id: int = 1, - bos_id: int = 2, - unk_id: int = 3, + dataset: tf.data.Dataset, + *, + vocab_size: int, + maxchars: int = int(1e7), + model_path: str, + model_type: str = 'unigram', + character_coverage: float = 1.0, + data_keys=('inputs', 'targets'), + pad_id: int = 0, + eos_id: int = 1, + bos_id: int = 2, + unk_id: int = 3, ): """Train SentencePiece tokenizer from subset of tf dataset. @@ -100,14 +106,13 @@ def _train_sentencepiece( else: abs_model_path = os.path.abspath(os.path.expanduser(model_path)) fname, _ = _dump_chars_to_textfile( - dataset, maxchars=maxchars, data_keys=data_keys + dataset, maxchars=maxchars, data_keys=data_keys ) with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/sp_tmp' + delete=False, prefix='/tmp/sp_tmp' ) as model_fp: pass # we just want a prefix'd tmp-filename - argstr = ' '.join( - [ + argstr = ' '.join([ f'--input={fname}', f'--vocab_size={vocab_size}', f'--character_coverage={character_coverage}', @@ -124,8 +129,7 @@ def _train_sentencepiece( f'--bos_id={bos_id}', f'--eos_id={eos_id}', f'--unk_id={unk_id}', - ] - ) + ]) SentencePieceTrainer.Train(argstr) if jax.process_index() == 0: # Use an intermediate filename that is renamed to the target name to address @@ -142,27 +146,27 @@ def _train_sentencepiece( def _load_sentencepiece_tokenizer( - model_path: str, - add_bos: bool = False, - add_eos: bool = True, - reverse: bool = False, + model_path: str, + add_bos: bool = False, + add_eos: bool = True, + reverse: bool = False, ): """Load a tf-text SentencePiece tokenizer from given model filepath.""" with tf.io.gfile.GFile(model_path, 'rb') as model_fp: sp_model = model_fp.read() sp_tokenizer = tftxt.SentencepieceTokenizer( - model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse + model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse ) return sp_tokenizer def load_or_train_tokenizer( - dataset: tf.data.Dataset, - *, - vocab_path: str, - vocab_size: int, - max_corpus_chars: int, - data_keys: tuple[str, str] = ('inputs', 'targets'), + dataset: tf.data.Dataset, + *, + vocab_path: str, + vocab_size: int, + max_corpus_chars: int, + data_keys: tuple[str, str] = ('inputs', 'targets'), ): """Loads the tokenizer at `vocab_path` or trains a one from `dataset`.""" try: @@ -170,11 +174,11 @@ def load_or_train_tokenizer( except tf.errors.NotFoundError: logging.info('SentencePiece vocab not found, building one from data.') vocab_path = _train_sentencepiece( - dataset, - vocab_size=vocab_size, - maxchars=max_corpus_chars, - model_path=vocab_path, - data_keys=data_keys, + dataset, + vocab_size=vocab_size, + maxchars=max_corpus_chars, + model_path=vocab_path, + data_keys=data_keys, ) return _load_sentencepiece_tokenizer(vocab_path) @@ -192,5 +196,5 @@ def __call__(self, features: Features) -> Features: def load_sentencepiece_processor(vocab_path: str): spp = SentencePieceProcessor() - spp.load(vocab_path) + spp.Load(vocab_path) return spp diff --git a/examples/gemma/train.py b/examples/gemma/train.py index b5bc07745..b4da6e951 100644 --- a/examples/gemma/train.py +++ b/examples/gemma/train.py @@ -22,26 +22,25 @@ import dataclasses import os +from typing import Any +from absl import logging +from clu import metric_writers +from clu import periodic_actions +from flax import nnx import input_pipeline -import jax -import jax.numpy as jnp +import sampler as sampler_lib import tokenizer import transformer as transformer_lib +import utils +from flax.training import checkpoints +from flax.training import common_utils +import jax +from jax import random +import jax.numpy as jnp import numpy as np import optax -import sampler as sampler_lib import tensorflow as tf -import utils -from absl import logging -from clu import metric_writers, periodic_actions -from jax import random -from jax.sharding import Mesh, NamedSharding -from jax.sharding import PartitionSpec as P -from utils import TrainState - -from flax import nnx -from flax.training import checkpoints, common_utils @dataclasses.dataclass(unsafe_hash=True) @@ -53,13 +52,14 @@ class MeshRules: def __call__(self, *keys: str) -> tuple[str, ...]: return tuple( - getattr(self, key) if key is not None else None - for key in keys + getattr(self, key) if key is not None else None for key in keys ) @dataclasses.dataclass(unsafe_hash=True) class TrainConfig: + """Configuration for training a gemma model.""" + # Path to load or store sentencepiece vocab file. vocab_path: str | None # Vocabulary size if `vocab_path` is not given. @@ -107,10 +107,11 @@ class TrainConfig: # Gemma transformer name. # Possible values defined in transformer.TransformerConfig: - # (gemma_2b, gemma_7b, gemma2_2b, gemma2_9b, gemma2_27b, gemma3_1b, gemma3_4b, ...) + # (gemma_2b, gemma_7b, gemma2_2b, gemma2_9b, gemma2_27b, gemma3_1b, gemma3_4b, + # ...) transformer_name: str | None # or alternatively define the model using the dict of parameters - transformer_params: dict | None + transformer_params: dict[Any, Any] | None # Whether to save model checkpoints. save_checkpoints: bool @@ -157,8 +158,8 @@ def __post_init__(self): def rsqrt_schedule( - init_value: float, - shift: int = 0, + init_value: float, + shift: int = 0, ): """Applies a reverse square-root schedule. @@ -182,20 +183,20 @@ def schedule(count): def create_learning_rate_schedule(learning_rate: float, warmup_steps: int): """Creates a rsqrt schedule with linear warmup.""" return optax.join_schedules( - [ - optax.linear_schedule( - init_value=0, - end_value=learning_rate, - transition_steps=warmup_steps, - ), - rsqrt_schedule(init_value=learning_rate, shift=warmup_steps), - ], - boundaries=[warmup_steps], + [ + optax.linear_schedule( + init_value=0, + end_value=learning_rate, + transition_steps=warmup_steps, + ), + rsqrt_schedule(init_value=learning_rate, shift=warmup_steps), + ], + boundaries=[warmup_steps], ) def compute_weighted_cross_entropy( - logits, targets, weights=None, label_smoothing=0.0 + logits, targets, weights=None, label_smoothing=0.0 ): """Compute weighted cross entropy and entropy for log probs and targets. @@ -211,18 +212,18 @@ def compute_weighted_cross_entropy( """ if logits.ndim != targets.ndim + 1: raise ValueError( - 'Incorrect shapes. Got shape %s logits and %s targets' - % (str(logits.shape), str(targets.shape)) + 'Incorrect shapes. Got shape %s logits and %s targets' + % (str(logits.shape), str(targets.shape)) ) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( - confidence * jnp.log(confidence) - + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) + confidence * jnp.log(confidence) + + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) ) soft_targets = common_utils.onehot( - targets, vocab_size, on_value=confidence, off_value=low_confidence + targets, vocab_size, on_value=confidence, off_value=low_confidence ) loss = -jnp.sum(soft_targets * nnx.log_softmax(logits), axis=-1) @@ -249,8 +250,8 @@ def compute_weighted_accuracy(logits, targets, weights=None): """ if logits.ndim != targets.ndim + 1: raise ValueError( - 'Incorrect shapes. Got shape %s logits and %s targets' - % (str(logits.shape), str(targets.shape)) + 'Incorrect shapes. Got shape %s logits and %s targets' + % (str(logits.shape), str(targets.shape)) ) loss = jnp.equal(jnp.argmax(logits, axis=-1), targets) normalizing_factor = np.prod(logits.shape[:-1]) @@ -264,13 +265,13 @@ def compute_weighted_accuracy(logits, targets, weights=None): def compute_metrics(logits, labels, weights, label_smoothing=0.0): """Compute summary metrics.""" loss, weight_sum = compute_weighted_cross_entropy( - logits, labels, weights, label_smoothing + logits, labels, weights, label_smoothing ) acc, _ = compute_weighted_accuracy(logits, labels, weights) metrics = { - 'loss': loss, - 'accuracy': acc, - 'denominator': weight_sum, + 'loss': loss, + 'accuracy': acc, + 'denominator': weight_sum, } return metrics @@ -280,10 +281,10 @@ def compute_metrics(logits, labels, weights, label_smoothing=0.0): def train_step( - state: TrainState, - batch, - learning_rate_fn, - label_smoothing=0.0, + state: utils.TrainState, + batch, + learning_rate_fn, + label_smoothing=0.0, ): """Perform a single training step.""" # X_position and X_segmentation are needed only when using "packed examples" @@ -293,16 +294,20 @@ def train_step( # like a normal, unpacked sequence example. train_keys = ['inputs', 'inputs_position', 'inputs_segmentation', 'targets'] (inputs, inputs_positions, inputs_segmentation, targets) = ( - batch.get(k, None) for k in train_keys + batch.get(k, None) for k in train_keys ) # TODO: this should be defined globally pad_id = 0 weights = jnp.where(inputs > pad_id, 1, 0).astype(jnp.float32) input_mask = inputs > pad_id - attention_mask = transformer_lib.make_causal_attn_mask(input_mask) # (B, L, L) + attention_mask = transformer_lib.make_causal_attn_mask( + input_mask + ) # (B, L, L) # inputs_segmentation: (B, L) - mask = inputs_segmentation[:, :, None] == inputs_segmentation[:, None, :] # (B, L, L) + mask = ( + inputs_segmentation[:, :, None] == inputs_segmentation[:, None, :] + ) # (B, L, L) attention_mask = jnp.logical_and(mask, attention_mask) def loss_fn(params): @@ -310,14 +315,14 @@ def loss_fn(params): module = nnx.merge(state.graphdef, params) logits, _ = module( - inputs, - positions=inputs_positions, - attention_mask=attention_mask, - cache=None, + inputs, + positions=inputs_positions, + attention_mask=attention_mask, + cache=None, ) loss, weight_sum = compute_weighted_cross_entropy( - logits, targets, weights, label_smoothing + logits, targets, weights, label_smoothing ) mean_loss = loss / weight_sum return mean_loss, logits @@ -334,10 +339,10 @@ def loss_fn(params): def eval_step( - params: nnx.State, - batch, - graphdef: nnx.GraphDef[transformer_lib.Transformer], - label_smoothing=0.0, + params: nnx.State, + batch, + graphdef: nnx.GraphDef[transformer_lib.Transformer], + label_smoothing=0.0, ): """Calculate evaluation metrics on a batch.""" inputs, targets = batch['inputs'], batch['targets'] @@ -351,21 +356,21 @@ def eval_step( module = nnx.merge(graphdef, params) logits, _ = module( - inputs, - positions=inputs_positions, - attention_mask=attention_mask, - cache=None, + inputs, + positions=inputs_positions, + attention_mask=attention_mask, + cache=None, ) return compute_metrics(logits, targets, weights, label_smoothing) def evaluate( - *, - jit_eval_step, - state: TrainState, - eval_ds: tf.data.Dataset, - num_eval_steps: int, + *, + jit_eval_step, + state: utils.TrainState, + eval_ds: tf.data.Dataset, + num_eval_steps: int, ): """Evaluate the target an return a dictionary with the metrics.""" logging.info('Gathering evaluation metrics.') @@ -379,8 +384,8 @@ def evaluate( eval_metrics_sums = jax.tree.map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree.map( - lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop - eval_metrics_sums, + lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop + eval_metrics_sums, ) return eval_summary @@ -406,7 +411,7 @@ def train_and_evaluate(config: TrainConfig, workdir: str): # --------------------------------------------------------------------------- logging.info('Initializing dataset.') train_ds, eval_ds, encoder = input_pipeline.get_datasets( - n_devices=jax.local_device_count(), config=config, vocab_path=vocab_path + n_devices=jax.local_device_count(), config=config, vocab_path=vocab_path ) train_iter = iter(train_ds) @@ -417,48 +422,48 @@ def train_and_evaluate(config: TrainConfig, workdir: str): # --------------------------------------------------------------------------- if config.transformer_name is not None: model_config = transformer_lib.TransformerConfig.from_version_name( - config.transformer_name, - num_embed=vocab_size, - dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, - axis_rules=config.axis_rules, + config.transformer_name, + num_embed=vocab_size, + dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, + axis_rules=config.axis_rules, ) else: assert config.transformer_params is not None model_config = transformer_lib.TransformerConfig.from_dict( - **config.transformer_params, - num_embed=vocab_size, - dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, - axis_rules=config.axis_rules, + **config.transformer_params, + num_embed=vocab_size, + dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, + axis_rules=config.axis_rules, ) # Mesh definition devices_array = utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) + mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) start_step = 0 rng = jax.random.PRNGKey(config.seed) rng, init_rng = jax.random.split(rng) - rng, inference_rng = random.split(rng) + _, inference_rng = random.split(rng) def constructor(config: transformer_lib.TransformerConfig, key: jax.Array): return transformer_lib.Transformer(config, rngs=nnx.Rngs(params=key)) learning_rate_fn = create_learning_rate_schedule( - learning_rate=config.learning_rate, warmup_steps=config.warmup_steps + learning_rate=config.learning_rate, warmup_steps=config.warmup_steps ) optimizer = optax.adamw( - learning_rate_fn, - b1=0.9, - b2=0.98, - eps=1e-9, - weight_decay=config.weight_decay, + learning_rate_fn, + b1=0.9, + b2=0.98, + eps=1e-9, + weight_decay=config.weight_decay, ) state, state_sharding = utils.setup_initial_state( - constructor, optimizer, model_config, init_rng, mesh + constructor, optimizer, model_config, init_rng, mesh ) - data_sharding = NamedSharding(mesh, P(config.data_sharding)) + data_sharding = jax.NamedSharding(mesh, jax.P(config.data_sharding)) if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. @@ -467,38 +472,38 @@ def constructor(config: transformer_lib.TransformerConfig, key: jax.Array): start_step = int(state.step) writer = metric_writers.create_default_writer( - workdir, just_logging=jax.process_index() > 0 + workdir, just_logging=jax.process_index() > 0 ) if start_step == 0: writer.write_hparams(dataclasses.asdict(config)) # compile multidevice versions of train/eval/predict step fn. jit_train_step = jax.jit( - train_step, - in_shardings=( - state_sharding, - data_sharding, - ), # type: ignore - out_shardings=(state_sharding, None), # type: ignore - static_argnames=("learning_rate_fn", "label_smoothing"), - donate_argnums=0, + train_step, + in_shardings=( + state_sharding, + data_sharding, + ), # type: ignore + out_shardings=(state_sharding, None), # type: ignore + static_argnames=('learning_rate_fn', 'label_smoothing'), + donate_argnums=0, ) jit_eval_step = jax.jit( - eval_step, - in_shardings=( - state_sharding.params, - data_sharding, - ), # type: ignore - out_shardings=None, # type: ignore - static_argnames=("graphdef", "label_smoothing"), + eval_step, + in_shardings=( + state_sharding.params, + data_sharding, + ), # type: ignore + out_shardings=None, # type: ignore + static_argnames=('graphdef', 'label_smoothing'), ) vocab = tokenizer.load_sentencepiece_processor(vocab_path) - sampler = sampler_lib.Sampler( - transformer=nnx.merge(state.graphdef, state.params), - vocab=vocab, - cache_size=1024, + sampler = sampler_lib.Sampler( + transformer=nnx.merge(state.graphdef, state.params), + vocab=vocab, + cache_size=1024, ) # Main Train Loop @@ -509,12 +514,12 @@ def constructor(config: transformer_lib.TransformerConfig, key: jax.Array): logging.info('Starting training loop.') hooks = [] report_progress = periodic_actions.ReportProgress( - num_train_steps=config.num_train_steps, writer=writer + num_train_steps=config.num_train_steps, writer=writer ) if jax.process_index() == 0: hooks += [ - report_progress, - periodic_actions.Profile(logdir=workdir, num_profile_steps=5), + report_progress, + periodic_actions.Profile(logdir=workdir, num_profile_steps=5), ] train_metrics = [] with metric_writers.ensure_flushes(writer): @@ -525,12 +530,12 @@ def constructor(config: transformer_lib.TransformerConfig, key: jax.Array): with jax.profiler.StepTraceAnnotation('train', step_num=step): with report_progress.timed('data'): batch = next(train_iter) - batch = jax.tree.map(lambda x: jnp.asarray(x, device=data_sharding), batch) + batch = jax.tree.map( + lambda x: jnp.asarray(x, device=data_sharding), batch + ) with report_progress.timed('train_step'): - state, metrics = jit_train_step( - state, batch, learning_rate_fn, 0.0 - ) + state, metrics = jit_train_step(state, batch, learning_rate_fn, 0.0) train_metrics.append(metrics) # Quick indication that training is happening. @@ -541,14 +546,17 @@ def constructor(config: transformer_lib.TransformerConfig, key: jax.Array): # Write batch loss and lr every step to TB # without overwhelming the stdout: if jax.process_index() == 0: - tb_writer = writer._writers[-1] + tb_writer = writer._writers[-1] # pylint: disable=protected-access lr = train_metrics[-1]['learning_rate'] train_batch_loss = train_metrics[-1]['loss'] denominator = train_metrics[-1]['denominator'] - tb_writer.write_scalars(step, { - "train_learning_rate": lr, - "train_loss": train_batch_loss / denominator, - }) + tb_writer.write_scalars( + step, + { + 'train_learning_rate': lr, + 'train_loss': train_batch_loss / denominator, + }, + ) # Periodic metric handling. if (step > 0 and step % config.eval_every_steps == 0) or is_last_step: @@ -569,33 +577,33 @@ def constructor(config: transformer_lib.TransformerConfig, key: jax.Array): # update sampler's transformer state: sampler.transformer_state = state.params exemplars = sampler( - config.prompts, - total_generation_steps=config.num_predict_steps, - temperature=config.sampling_temperature, - top_p=config.sampling_top_p, - seed=inference_rng, - echo=True, + config.prompts, + total_generation_steps=config.num_predict_steps, + temperature=config.sampling_temperature, + top_p=config.sampling_top_p, + seed=inference_rng, + echo=True, ) - writer.write_texts(step, {'samples': exemplars.text}) + writer.write_texts(step, {'samples': exemplars.text[0]}) with report_progress.timed('eval'): eval_results = evaluate( - jit_eval_step=jit_eval_step, - state=state, - eval_ds=eval_ds, - num_eval_steps=config.num_eval_steps, + jit_eval_step=jit_eval_step, + state=state, + eval_ds=eval_ds, + num_eval_steps=config.num_eval_steps, ) # (clipped) perplexity after averaging log-perplexity eval_results['perplexity'] = jnp.clip( - jnp.exp(eval_results['loss']), max=1.0e4 + jnp.exp(eval_results['loss']), max=1.0e4 ) writer.write_scalars( - step, {'eval_' + k: v for k, v in eval_results.items()} + step, {'eval_' + k: v for k, v in eval_results.items()} ) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = ( - step % config.checkpoint_every_steps == 0 or is_last_step + step % config.checkpoint_every_steps == 0 or is_last_step ) if config.save_checkpoints and save_checkpoint: logging.info('Saving checkpoint step %d.', step) diff --git a/examples/gemma/utils.py b/examples/gemma/utils.py index 18f6909cc..5162c0618 100644 --- a/examples/gemma/utils.py +++ b/examples/gemma/utils.py @@ -12,39 +12,45 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copied over from MaxText (https://github.com/google/maxtext/blob/main/MaxText/max_utils.py). +# Copied over from MaxText +# (https://github.com/google/maxtext/blob/main/MaxText/max_utils.py). +"""Provides utilities for training the Flax gemma example.""" -import logging -from typing import Any, TYPE_CHECKING from collections.abc import Callable +import logging +from typing import Any +from flax import nnx +import transformer +from flax.training import train_state import jax +from jax.experimental import mesh_utils import jax.numpy as jnp import numpy as np -from jax.experimental import mesh_utils -from transformer import TransformerConfig, Transformer -from flax import nnx -from flax.training import train_state - -if TYPE_CHECKING: - from train import TrainConfig - Dtype = Any Shape = tuple[int, ...] class TrainState(train_state.TrainState): - graphdef: nnx.GraphDef[Transformer] + graphdef: nnx.GraphDef[transformer.Transformer] # Mesh utils. # ----------------------------------------------------------------------------- -def create_device_mesh(config: "TrainConfig"): - """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas.""" +def create_device_mesh(config: Any): + """Creates a device mesh with each slice in its own data parallel group. + + If there is only one slice, uses two replicas. + + Args: + config: The training configuration. + Returns: + The device mesh. + """ devices = jax.devices() num_devices = len(devices) try: @@ -52,58 +58,58 @@ def create_device_mesh(config: "TrainConfig"): except AttributeError: num_slices = 1 num_devices_per_slice = num_devices // num_slices - logging.info(f'Devices: {devices}') - logging.info(f'Number of devices: {num_devices}') + logging.info(f'Devices: {devices}') # pylint: disable=logging-fstring-interpolation + logging.info(f'Number of devices: {num_devices}') # pylint: disable=logging-fstring-interpolation multi_slice_env = hasattr(jax.devices()[0], 'slice_index') dcn_parallelism = [ - config.dcn_data_parallelism, - config.dcn_fsdp_parallelism, - config.dcn_tensor_parallelism, + config.dcn_data_parallelism, + config.dcn_fsdp_parallelism, + config.dcn_tensor_parallelism, ] ici_parallelism = [ - config.ici_data_parallelism, - config.ici_fsdp_parallelism, - config.ici_tensor_parallelism, + config.ici_data_parallelism, + config.ici_fsdp_parallelism, + config.ici_tensor_parallelism, ] # Find possible unspecified parallelisms dcn_parallelism = fill_unspecified_mesh_axes( - dcn_parallelism, num_slices, 'DCN' + dcn_parallelism, num_slices, 'DCN' ) ici_parallelism = fill_unspecified_mesh_axes( - ici_parallelism, num_devices_per_slice, 'ICI' + ici_parallelism, num_devices_per_slice, 'ICI' ) if multi_slice_env: mesh = mesh_utils.create_hybrid_device_mesh( - ici_parallelism, dcn_parallelism + ici_parallelism, dcn_parallelism ) else: mesh = mesh_utils.create_device_mesh(ici_parallelism) - logging.info(f'Decided on mesh: {mesh}') - logging.info(f'Mesh shape: {mesh.shape}') + logging.info(f'Decided on mesh: {mesh}') # pylint: disable=logging-fstring-interpolation + logging.info(f'Mesh shape: {mesh.shape}') # pylint: disable=logging-fstring-interpolation return mesh def fill_unspecified_mesh_axes( - parallelism_vals, target_product, parallelism_type + parallelism_vals, target_product, parallelism_type ): - """Evaluates unspecified DCN/ICI parallelism values""" + """Evaluates unspecified DCN/ICI parallelism values.""" if -1 in parallelism_vals: assert parallelism_vals.count(-1) == 1, ( - f'Found unspecified values (-1) for more than one {parallelism_type} ' - ' parallelism axis. At most one axis can be unspecified.' + f'Found unspecified values (-1) for more than one {parallelism_type} ' + ' parallelism axis. At most one axis can be unspecified.' ) determined_val = target_product / np.prod(parallelism_vals) * -1 assert determined_val >= 1 and determined_val.is_integer, ( - 'Unspecified value unable to be determined with the given ' - f' {parallelism_type} parallelism values' + 'Unspecified value unable to be determined with the given ' + f' {parallelism_type} parallelism values' ) parallelism_vals[parallelism_vals.index(-1)] = int(determined_val) @@ -111,8 +117,8 @@ def fill_unspecified_mesh_axes( target_type = 'slices' if parallelism_type == 'DCN' else 'devices per slice' assert np.prod(parallelism_vals) == target_product, ( - f'Number of {target_type} {target_product} does not match the product' - f' of the {parallelism_type} parallelism {np.prod(parallelism_vals)}' + f'Number of {target_type} {target_product} does not match the product' + f' of the {parallelism_type} parallelism {np.prod(parallelism_vals)}' ) return parallelism_vals @@ -129,14 +135,15 @@ def _to_array(x): def setup_initial_state( - constructor: Callable[[TransformerConfig, jax.Array], Transformer], - tx, - config: TransformerConfig, - rng: jax.Array, - mesh: jax.sharding.Mesh, + constructor: Callable[ + [transformer.TransformerConfig, jax.Array], transformer.Transformer + ], + tx, + config: transformer.TransformerConfig, + rng: jax.Array, + mesh: jax.sharding.Mesh, ) -> tuple[TrainState, TrainState]: - """We initialize the model and optimizer state, and optionally load from a - checkpoint as necessary. + """We initialize train state, optionally loading from checkpoint. Args: constructor: the model constructor @@ -155,10 +162,10 @@ def sharded_init(): model = constructor(config, rng) graphdef, params = nnx.split(model, nnx.Param) state = TrainState.create( - apply_fn=graphdef.apply, - params=params, - tx=tx, - graphdef=graphdef, + apply_fn=graphdef.apply, + params=params, + tx=tx, + graphdef=graphdef, ) state = jax.tree.map(_to_array, state) state_spec = nnx.get_partition_spec(state) diff --git a/examples/mnist/README.md b/examples/mnist/README.md index 2a8169271..386bc5fea 100644 --- a/examples/mnist/README.md +++ b/examples/mnist/README.md @@ -20,8 +20,7 @@ https://colab.research.google.com/github/google/flax/blob/main/examples/mnist/mn [gs://flax_public/examples/mnist/default]: https://console.cloud.google.com/storage/browser/flax_public/examples/mnist/default ``` -I0828 08:51:41.821526 139971964110656 train.py:130] train epoch: 10, loss: 0.0097, accuracy: 99.69 -I0828 08:51:42.248714 139971964110656 train.py:180] eval epoch: 10, loss: 0.0299, accuracy: 99.14 +I1009 17:56:42.674334 3280981 train.py:175] epoch: 10, train_loss: 0.0073, train_accuracy: 99.75, test_loss: 0.0294, test_accuracy: 99.25 ``` ### How to run diff --git a/examples/mnist/main.py b/examples/mnist/main.py index 887ecf71e..fe500cb96 100644 --- a/examples/mnist/main.py +++ b/examples/mnist/main.py @@ -26,7 +26,7 @@ from ml_collections import config_flags import tensorflow as tf -import train +import train # pylint: disable=g-bad-import-order FLAGS = flags.FLAGS diff --git a/examples/mnist/train.py b/examples/mnist/train.py index 0886a1963..34858e9a7 100644 --- a/examples/mnist/train.py +++ b/examples/mnist/train.py @@ -20,147 +20,174 @@ # See issue #620. # pytype: disable=wrong-keyword-args +import functools +from typing import Any from absl import logging -from flax import linen as nn +from flax import nnx from flax.metrics import tensorboard -from flax.training import train_state import jax -import jax.numpy as jnp import ml_collections -import numpy as np import optax +import tensorflow as tf import tensorflow_datasets as tfds +tf.random.set_seed(0) # Set the random seed for reproducibility. -class CNN(nn.Module): + +class CNN(nnx.Module): """A simple CNN model.""" - @nn.compact + def __init__(self, *, rngs: nnx.Rngs): + self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs) + self.batch_norm1 = nnx.BatchNorm(32, rngs=rngs) + self.dropout1 = nnx.Dropout(rate=0.025, rngs=rngs) + self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs) + self.batch_norm2 = nnx.BatchNorm(64, rngs=rngs) + self.avg_pool = functools.partial( + nnx.avg_pool, window_shape=(2, 2), strides=(2, 2) + ) + self.linear1 = nnx.Linear(3136, 256, rngs=rngs) + self.dropout2 = nnx.Dropout(rate=0.025, rngs=rngs) + self.linear2 = nnx.Linear(256, 10, rngs=rngs) + def __call__(self, x): - x = nn.Conv(features=32, kernel_size=(3, 3))(x) - x = nn.relu(x) - x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = nn.Conv(features=64, kernel_size=(3, 3))(x) - x = nn.relu(x) - x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = x.reshape((x.shape[0], -1)) # flatten - x = nn.Dense(features=256)(x) - x = nn.relu(x) - x = nn.Dense(features=10)(x) + x = self.avg_pool(nnx.relu(self.batch_norm1(self.dropout1(self.conv1(x))))) + x = self.avg_pool(nnx.relu(self.batch_norm2(self.conv2(x)))) + x = x.reshape(x.shape[0], -1) # flatten + x = nnx.relu(self.dropout2(self.linear1(x))) + x = self.linear2(x) return x -@jax.jit -def apply_model(state, images, labels): - """Computes gradients, loss and accuracy for a single batch.""" - - def loss_fn(params): - logits = state.apply_fn({'params': params}, images) - one_hot = jax.nn.one_hot(labels, 10) - loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) - return loss, logits - - grad_fn = jax.value_and_grad(loss_fn, has_aux=True) - (loss, logits), grads = grad_fn(state.params) - accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) - return grads, loss, accuracy +def loss_fn(model: CNN, batch) -> tuple[jax.Array, Any]: + logits = model(batch['image']) + loss = optax.softmax_cross_entropy_with_integer_labels( + logits=logits, labels=batch['label'] + ).mean() + return loss, logits -@jax.jit -def update_model(state, grads): - return state.apply_gradients(grads=grads) +@nnx.jit +def train_step( + model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch +) -> None: + """Train for a single step.""" + grad_fn = nnx.value_and_grad(loss_fn, has_aux=True) + (loss, logits), grads = grad_fn(model, batch) + metrics.update( + loss=loss, logits=logits, labels=batch['label'] + ) # In-place updates. + optimizer.update(model, grads) # In-place updates. -def train_epoch(state, train_ds, batch_size, rng): - """Train for a single epoch.""" - train_ds_size = len(train_ds['image']) - steps_per_epoch = train_ds_size // batch_size +@nnx.jit +def eval_step(model: CNN, metrics: nnx.MultiMetric, batch) -> None: + loss, logits = loss_fn(model, batch) + metrics.update( + loss=loss, logits=logits, labels=batch['label'] + ) # In-place updates. - perms = jax.random.permutation(rng, len(train_ds['image'])) - perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch - perms = perms.reshape((steps_per_epoch, batch_size)) - epoch_loss = [] - epoch_accuracy = [] - - for perm in perms: - batch_images = train_ds['image'][perm, ...] - batch_labels = train_ds['label'][perm, ...] - grads, loss, accuracy = apply_model(state, batch_images, batch_labels) - state = update_model(state, grads) - epoch_loss.append(loss) - epoch_accuracy.append(accuracy) - train_loss = np.mean(epoch_loss) - train_accuracy = np.mean(epoch_accuracy) - return state, train_loss, train_accuracy - - -def get_datasets(): +def get_datasets( + config: ml_collections.ConfigDict, +) -> tuple[tf.data.Dataset, tf.data.Dataset]: """Load MNIST train and test datasets into memory.""" - ds_builder = tfds.builder('mnist') - ds_builder.download_and_prepare() - train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1)) - test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1)) - train_ds['image'] = jnp.float32(train_ds['image']) / 255.0 - test_ds['image'] = jnp.float32(test_ds['image']) / 255.0 - return train_ds, test_ds + batch_size = config.batch_size + train_ds: tf.data.Dataset = tfds.load('mnist', split='train') + test_ds: tf.data.Dataset = tfds.load('mnist', split='test') + + train_ds = train_ds.map( + lambda sample: { + 'image': tf.cast(sample['image'], tf.float32) / 255, + 'label': sample['label'], + } + ) # normalize train set + test_ds = test_ds.map( + lambda sample: { + 'image': tf.cast(sample['image'], tf.float32) / 255, + 'label': sample['label'], + } + ) # normalize the test set. + + # Create a shuffled dataset by allocating a buffer size of 1024 to randomly + # draw elements from. + train_ds = train_ds.shuffle(1024) + # Group into batches of `batch_size` and skip incomplete batches, prefetch the + # next sample to improve latency. + train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1) + # Group into batches of `batch_size` and skip incomplete batches, prefetch the + # next sample to improve latency. + test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1) - -def create_train_state(rng, config): - """Creates initial `TrainState`.""" - cnn = CNN() - params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params'] - tx = optax.sgd(config.learning_rate, config.momentum) - return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx) + return train_ds, test_ds -def train_and_evaluate( - config: ml_collections.ConfigDict, workdir: str -) -> train_state.TrainState: +def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> None: """Execute model training and evaluation loop. Args: config: Hyperparameter configuration for training and evaluation. - workdir: Directory where the tensorboard summaries are written to. - - Returns: - The train state (which includes the `.params`). + workdir: Directory path to store metrics. """ - train_ds, test_ds = get_datasets() - rng = jax.random.key(0) + train_ds, test_ds = get_datasets(config) + + # Instantiate the model. + model = CNN(rngs=nnx.Rngs(0)) + + learning_rate = config.learning_rate + momentum = config.momentum summary_writer = tensorboard.SummaryWriter(workdir) summary_writer.hparams(dict(config)) - rng, init_rng = jax.random.split(rng) - state = create_train_state(init_rng, config) + optimizer = nnx.Optimizer( + model, optax.sgd(learning_rate, momentum), wrt=nnx.Param + ) + metrics = nnx.MultiMetric( + accuracy=nnx.metrics.Accuracy(), + loss=nnx.metrics.Average('loss'), + ) for epoch in range(1, config.num_epochs + 1): - rng, input_rng = jax.random.split(rng) - state, train_loss, train_accuracy = train_epoch( - state, train_ds, config.batch_size, input_rng - ) - _, test_loss, test_accuracy = apply_model( - state, test_ds['image'], test_ds['label'] - ) - - logging.info( + # Run the optimization for one step and make a stateful update to the + # following: + # - The train state's model parameters + # - The optimizer state + # - The training loss and accuracy batch metrics + model.train() # Switch to train mode + + for batch in train_ds.as_numpy_iterator(): + train_step(model, optimizer, metrics, batch) + # Compute the training metrics. + train_metrics = metrics.compute() + metrics.reset() # Reset the metrics for the test set. + + # Compute the metrics on the test set after each training epoch. + model.eval() # Switch to eval mode + for batch in test_ds.as_numpy_iterator(): + eval_step(model, metrics, batch) + + # Compute the eval metrics. + eval_metrics = metrics.compute() + metrics.reset() # Reset the metrics for the next training epoch. + + logging.info( # pylint: disable=logging-not-lazy 'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f,' ' test_accuracy: %.2f' % ( epoch, - train_loss, - train_accuracy * 100, - test_loss, - test_accuracy * 100, + train_metrics['loss'], + train_metrics['accuracy'] * 100, + eval_metrics['loss'], + eval_metrics['accuracy'] * 100, ) ) - summary_writer.scalar('train_loss', train_loss, epoch) - summary_writer.scalar('train_accuracy', train_accuracy, epoch) - summary_writer.scalar('test_loss', test_loss, epoch) - summary_writer.scalar('test_accuracy', test_accuracy, epoch) + # Write the metrics to TensorBoard. + summary_writer.scalar('train_loss', train_metrics['loss'], epoch) + summary_writer.scalar('train_accuracy', train_metrics['accuracy'], epoch) + summary_writer.scalar('test_loss', eval_metrics['loss'], epoch) + summary_writer.scalar('test_accuracy', eval_metrics['accuracy'], epoch) summary_writer.flush() - return state diff --git a/examples/vae/main.py b/examples/vae/main.py index 537ec08d6..20ed9dba1 100644 --- a/examples/vae/main.py +++ b/examples/vae/main.py @@ -31,12 +31,14 @@ FLAGS = flags.FLAGS +flags.DEFINE_string('workdir', None, 'Directory to store model data.') config_flags.DEFINE_config_file( 'config', None, 'File path to the training hyperparameter configuration.', lock_config=True, ) +flags.mark_flags_as_required(['config', 'workdir']) def main(argv): @@ -56,7 +58,7 @@ def main(argv): f'process_count: {jax.process_count()}' ) - train.train_and_evaluate(FLAGS.config) + train.train_and_evaluate(FLAGS.config, FLAGS.workdir) if __name__ == '__main__': diff --git a/examples/vae/train.py b/examples/vae/train.py index 84f1b582a..2d7056bac 100644 --- a/examples/vae/train.py +++ b/examples/vae/train.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Training and evaluation logic.""" +from typing import Any from absl import logging from flax import linen as nn @@ -24,6 +25,7 @@ import jax.numpy as jnp import ml_collections import optax +import tensorflow as tf import tensorflow_datasets as tfds @@ -47,6 +49,7 @@ def compute_metrics(recon_x, x, mean, logvar): def train_step(state, batch, z_rng, latents): + """Train step.""" def loss_fn(params): recon_x, mean, logvar = models.model(latents).apply( {'params': params}, batch, z_rng @@ -62,6 +65,7 @@ def loss_fn(params): def eval_f(params, images, z, z_rng, latents): + """Evaluation function.""" def eval_model(vae): recon_images, mean, logvar = vae(images, z_rng) comparison = jnp.concatenate([ @@ -77,8 +81,10 @@ def eval_model(vae): return nn.apply(eval_model, models.model(latents))({'params': params}) -def train_and_evaluate(config: ml_collections.ConfigDict): +def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Train and evaulate pipeline.""" + tf.io.gfile.makedirs(workdir) + rng = random.key(0) rng, key = random.split(rng) @@ -116,9 +122,11 @@ def train_and_evaluate(config: ml_collections.ConfigDict): state.params, test_ds, z, eval_rng, config.latents ) vae_utils.save_image( - comparison, f'results/reconstruction_{epoch}.png', nrow=8 + comparison, f'{workdir}/reconstruction_{epoch}.png', nrow=8 + ) + vae_utils.save_image( + sample, f'{workdir}/sample_{epoch}.png', nrow=8 ) - vae_utils.save_image(sample, f'results/sample_{epoch}.png', nrow=8) print( 'eval epoch: {}, loss: {:.4f}, BCE: {:.4f}, KLD: {:.4f}'.format( diff --git a/examples/wmt/models.py b/examples/wmt/models.py index 5da0f7065..725b1082f 100644 --- a/examples/wmt/models.py +++ b/examples/wmt/models.py @@ -20,8 +20,7 @@ # pytype: disable=wrong-keyword-args # pytype: disable=attribute-error -from typing import Any, Optional -from collections.abc import Callable +from typing import Any, Callable from flax import linen as nn from flax import struct @@ -549,7 +548,8 @@ def decode( # Make padding attention masks. if config.decode: - # for fast autoregressive decoding only a special encoder-decoder mask is used + # for fast autoregressive decoding only a special encoder-decoder mask is + # used decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( jnp.ones_like(targets) > 0, inputs > 0, dtype=config.dtype diff --git a/examples/wmt/train.py b/examples/wmt/train.py index 0f9f2bec5..e5ee1e9e1 100644 --- a/examples/wmt/train.py +++ b/examples/wmt/train.py @@ -250,7 +250,7 @@ def loss_fn(params): if state.dynamic_scale: # if is_fin == False the gradients contain Inf/NaNs and optimizer state and # params should be restored (= skip this step). - select_fn = functools.partial(jnp.where, is_fin) + select_fn = functools.partial(jnp.where, is_fin) # pylint: disable=undefined-variable new_state = new_state.replace( opt_state=jax.tree_util.tree_map( select_fn, new_state.opt_state, state.opt_state @@ -259,7 +259,7 @@ def loss_fn(params): select_fn, new_state.params, state.params ), ) - metrics["loss_scale"] = dynamic_scale.scale * metrics["denominator"] + metrics["loss_scale"] = dynamic_scale.scale * metrics["denominator"] # pylint: disable=undefined-variable return new_state, metrics @@ -649,8 +649,8 @@ def decode_tokens(toks): metrics_sums = jax.tree_util.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop("denominator") summary = jax.tree_util.tree_map( - lambda x: x / denominator, metrics_sums - ) # pylint: disable=cell-var-from-loop + lambda x: x / denominator, metrics_sums # pylint: disable=cell-var-from-loop + ) summary["learning_rate"] = lr summary = {"train_" + k: v for k, v in summary.items()} writer.write_scalars(step, summary)