Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs_nnx/mnist_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, b
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(model, batch)
metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates.
optimizer.update(grads) # In-place updates.
optimizer.update(model, grads) # In-place updates.

@nnx.jit
def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):
Expand Down
144 changes: 71 additions & 73 deletions examples/gemma/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,11 @@
"""Input pipeline for a LM1B dataset."""

import os
import typing
from typing import Any

import tokenizer
import tensorflow as tf
import tensorflow_datasets as tfds
import tokenizer
from clu import deterministic_data

if typing.TYPE_CHECKING:
from train import TrainConfig

AUTOTUNE = tf.data.experimental.AUTOTUNE
Features = dict[str, tf.Tensor]
Expand Down Expand Up @@ -58,9 +54,9 @@ def get_raw_dataset(dataset_name: str, split: str) -> tf.data.Dataset:


def pack_dataset(
dataset: tf.data.Dataset,
key2length: int | dict[str, int],
keys: list[str] | None = None,
dataset: tf.data.Dataset,
key2length: int | dict[str, int],
keys: list[str] | None = None,
) -> tf.data.Dataset:
"""Creates a 'packed' version of a dataset on-the-fly.

Expand Down Expand Up @@ -107,8 +103,8 @@ def pack_dataset(
for k in keys:
if k not in shapes:
raise ValueError(
'Key %s not found in dataset. Available keys are %s'
% (k, shapes.keys())
'Key %s not found in dataset. Available keys are %s'
% (k, shapes.keys())
)
if not shapes[k].is_compatible_with(tf.TensorShape([None])): # type: ignore[wrong-arg-types]
raise ValueError('Tensors to be packed must be one-dimensional.')
Expand All @@ -122,14 +118,14 @@ def pack_dataset(

# trim to length
dataset = dataset.map(
lambda x: {k: x[k][: key2length[k]] for k in keys},
num_parallel_calls=AUTOTUNE,
lambda x: {k: x[k][: key2length[k]] for k in keys},
num_parallel_calls=AUTOTUNE,
)
# Setting batch_size=length ensures that the concatenated sequences (if they
# have length >=1) are sufficient to fill at least one packed example.
batch_size = max(key2length.values())
dataset = dataset.padded_batch(
batch_size, padded_shapes={k: [-1] for k in keys}
batch_size, padded_shapes={k: [-1] for k in keys}
)
dataset = _pack_with_tf_ops(dataset, keys, key2length)

Expand All @@ -141,7 +137,7 @@ def my_fn(x):


def _pack_with_tf_ops(
dataset: tf.data.Dataset, keys: list[str], key2length: dict[str, int]
dataset: tf.data.Dataset, keys: list[str], key2length: dict[str, int]
) -> tf.data.Dataset:
"""Helper-function for packing a dataset which has already been batched.

Expand All @@ -166,8 +162,8 @@ def write_packed_example(partial, outputs):
new_outputs = {}
for k in keys_etc:
new_outputs[k] = outputs[k].write(
outputs[k].size(),
tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]),
outputs[k].size(),
tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]),
)
return new_partial, new_outputs

Expand All @@ -188,10 +184,10 @@ def map_fn(x):
outputs = {}
for k in keys:
outputs[k] = tf.TensorArray(
tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]
tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]
)
outputs[k + '_position'] = tf.TensorArray(
tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]
tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]
)

def body_fn(i, partial, outputs):
Expand All @@ -213,10 +209,10 @@ def body_fn(i, partial, outputs):
one_example[k] = val
for k in keys:
can_append = tf.logical_and(
can_append,
tf.less_equal(
tf.size(partial[k]) + tf.size(one_example[k]), key2length[k]
),
can_append,
tf.less_equal(
tf.size(partial[k]) + tf.size(one_example[k]), key2length[k]
),
)

def false_fn():
Expand All @@ -232,28 +228,28 @@ def true_fn():
new_seq_len = tf.size(new_seq)
new_partial[k] = tf.concat([partial[k], new_seq], 0)
new_partial[k + '_position'] = tf.concat(
[partial[k + '_position'], tf.range(new_seq_len)], 0
[partial[k + '_position'], tf.range(new_seq_len)], 0
)
partial = new_partial
return i + 1, partial, outputs

# For loop over all examples in the batch.
i, partial, outputs = tf.while_loop(
cond=lambda *_: True,
body=body_fn,
loop_vars=(i, partial, outputs),
shape_invariants=(
tf.TensorShape([]),
{k: tf.TensorShape([None]) for k in keys_etc}, # type: ignore[wrong-arg-types]
{k: tf.TensorShape(None) for k in keys_etc}, # type: ignore[wrong-arg-types]
),
maximum_iterations=dynamic_batch_size,
_, partial, outputs = tf.while_loop(
cond=lambda *_: True,
body=body_fn,
loop_vars=(i, partial, outputs),
shape_invariants=(
tf.TensorShape([]),
{k: tf.TensorShape([None]) for k in keys_etc}, # type: ignore[wrong-arg-types]
{k: tf.TensorShape(None) for k in keys_etc}, # type: ignore[wrong-arg-types]
),
maximum_iterations=dynamic_batch_size,
)
_, outputs = write_packed_example(partial, outputs)
packed = {k: outputs[k].stack() for k in keys_etc}
for k in keys:
packed[k + '_segmentation'] = tf.cumsum(
tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1
tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1
) * tf.cast(tf.not_equal(packed[k], 0), tf.int32)
return packed

Expand All @@ -263,25 +259,25 @@ def true_fn():

def shift_data_by_truncation(x):
# https://github.com/AI-Hypercomputer/maxtext/blob/7fe1de75b3919c0fda00d23ad6cb29def9098362/MaxText/input_pipeline/_input_pipeline_utils.py#L53
x["inputs"] = x["inputs"][:-1]
x["targets"] = x["targets"][1:]
x['inputs'] = x['inputs'][:-1]
x['targets'] = x['targets'][1:]
return x


# -----------------------------------------------------------------------------
# Main dataset prep routines.
# -----------------------------------------------------------------------------
def preprocess_data(
dataset,
shuffle: bool,
num_epochs: int | None = 1,
pack_examples: bool = True,
shuffle_buffer_size: int = 1024,
max_length: int = 512,
batch_size: int = 256,
drop_remainder: bool = True,
prefetch_size: int = AUTOTUNE,
shift: bool = True,
dataset,
shuffle: bool,
num_epochs: int | None = 1,
pack_examples: bool = True,
shuffle_buffer_size: int = 1024,
max_length: int = 512,
batch_size: int = 256,
drop_remainder: bool = True,
prefetch_size: int = AUTOTUNE,
shift: bool = True,
):
"""Shuffle and batch/pack the given dataset."""

Expand All @@ -303,18 +299,20 @@ def filter_fn(x):
# Shift inputs for teacher-forced training
if shift:
dataset = dataset.map(
shift_data_by_truncation, num_parallel_calls=AUTOTUNE, deterministic=True
shift_data_by_truncation,
num_parallel_calls=AUTOTUNE,
deterministic=True,
)

if pack_examples:
dataset = pack_dataset(dataset, max_length)
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
else: # simple (static-shape) padded batching
dataset = dataset.padded_batch(
batch_size,
padded_shapes={'inputs': max_length, 'targets': max_length},
padding_values={'inputs': 0, 'targets': 0},
drop_remainder=drop_remainder,
batch_size,
padded_shapes={'inputs': max_length, 'targets': max_length},
padding_values={'inputs': 0, 'targets': 0},
drop_remainder=drop_remainder,
)

if prefetch_size:
Expand All @@ -324,10 +322,10 @@ def filter_fn(x):


def get_datasets(
config: "TrainConfig",
*,
n_devices: int,
vocab_path: str | None = None,
config: Any,
*,
n_devices: int,
vocab_path: str | None = None,
):
"""Load and return dataset of batched examples for use during training."""
if vocab_path is None:
Expand All @@ -343,16 +341,16 @@ def get_datasets(

# Tokenize data.
sp_processor = tokenizer.load_or_train_tokenizer(
train_data,
vocab_path=vocab_path,
vocab_size=config.vocab_size,
max_corpus_chars=config.max_corpus_chars,
train_data,
vocab_path=vocab_path,
vocab_size=config.vocab_size,
max_corpus_chars=config.max_corpus_chars,
)
train_data = train_data.map(
tokenizer.TokenizeOp(sp_processor), num_parallel_calls=AUTOTUNE
tokenizer.TokenizeOp(sp_processor), num_parallel_calls=AUTOTUNE
)
eval_data = eval_data.map(
tokenizer.TokenizeOp(sp_processor), num_parallel_calls=AUTOTUNE
tokenizer.TokenizeOp(sp_processor), num_parallel_calls=AUTOTUNE
)

batch_size = config.per_device_batch_size * n_devices
Expand All @@ -362,20 +360,20 @@ def get_datasets(
eval_batch_size = batch_size

train_ds = preprocess_data(
train_data,
shuffle=True,
num_epochs=None,
pack_examples=True,
batch_size=batch_size,
max_length=config.max_target_length,
train_data,
shuffle=True,
num_epochs=None,
pack_examples=True,
batch_size=batch_size,
max_length=config.max_target_length,
)

eval_ds = preprocess_data(
eval_data,
shuffle=False,
pack_examples=False,
batch_size=eval_batch_size,
max_length=config.max_eval_target_length,
eval_data,
shuffle=False,
pack_examples=False,
batch_size=eval_batch_size,
max_length=config.max_eval_target_length,
)

return train_ds, eval_ds, sp_processor
25 changes: 14 additions & 11 deletions examples/gemma/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,24 @@
that can be easily tested and imported in Colab.
"""

import jax
import tensorflow as tf
import train
from absl import app, flags, logging
from absl import app
from absl import flags
from absl import logging
from clu import platform
import train
import jax
from ml_collections import config_flags
import tensorflow as tf


FLAGS = flags.FLAGS

flags.DEFINE_string('workdir', None, 'Directory to store model data.')
config_flags.DEFINE_config_file(
'config',
'configs/default.py',
'File path to the training hyperparameter configuration.',
lock_config=True,
'config',
'configs/default.py',
'File path to the training hyperparameter configuration.',
lock_config=True,
)
flags.mark_flags_as_required(['workdir'])

Expand All @@ -51,11 +54,11 @@ def main(argv):
# Add a note so that we can tell which task is which JAX host.
# (Depending on the platform task 0 is not guaranteed to be host 0)
platform.work_unit().set_task_status(
f'process_index: {jax.process_index()}, '
f'process_count: {jax.process_count()}'
f'process_index: {jax.process_index()}, '
f'process_count: {jax.process_count()}'
)
platform.work_unit().create_artifact(
platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir'
platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir'
)

train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
Expand Down
Loading
Loading