Skip to content

Sync to Github. #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions recml/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
from recml.core.training.core import Experiment
from recml.core.training.core import run_experiment
from recml.core.training.core import Trainer
from recml.core.training.jax import JaxState
from recml.core.training.jax import JaxTask
from recml.core.training.jax import JaxTrainer
from recml.core.training.jax import KerasState
from recml.core.training.jax_trainer import JaxState
from recml.core.training.jax_trainer import JaxTask
from recml.core.training.jax_trainer import JaxTrainer
from recml.core.training.jax_trainer import KerasState
from recml.core.training.keras_trainer import KerasTask
from recml.core.training.keras_trainer import KerasTrainer
from recml.core.training.optax_factory import AdagradFactory
from recml.core.training.optax_factory import AdamFactory
from recml.core.training.optax_factory import OptimizerFactory
Expand Down
30 changes: 18 additions & 12 deletions recml/core/data/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,15 @@ def __next__(self) -> clu_data.Element:
if self._prefetched_batch is not None:
batch = self._prefetched_batch
self._prefetched_batch = None
return batch

batch = next(self._iterator)
if self._postprocessor is not None:
batch = self._postprocessor(batch)
else:
batch = next(self._iterator)
if self._postprocessor is not None:
batch = self._postprocessor(batch)

def _maybe_to_numpy(
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor,
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | np.ndarray,
) -> np.ndarray | tf.SparseTensor | tf.RaggedTensor:
if isinstance(x, (tf.SparseTensor, tf.RaggedTensor)):
if isinstance(x, (tf.SparseTensor, tf.RaggedTensor, np.ndarray)):
return x
if hasattr(x, "_numpy"):
numpy = x._numpy() # pylint: disable=protected-access
Expand All @@ -83,13 +82,16 @@ def _maybe_to_numpy(
@property
def element_spec(self) -> clu_data.ElementSpec:
if self._element_spec is not None:
batch = self._element_spec
else:
batch = self.__next__()
self._prefetched_batch = batch
return self._element_spec

batch = next(self._iterator)
if self._postprocessor is not None:
batch = self._postprocessor(batch)

self._prefetched_batch = batch

def _to_element_spec(
x: np.ndarray | tf.SparseTensor | tf.RaggedTensor,
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | np.ndarray,
) -> clu_data.ArraySpec:
if isinstance(x, tf.SparseTensor):
return clu_data.ArraySpec(
Expand All @@ -101,6 +103,10 @@ def _to_element_spec(
dtype=x.dtype.as_numpy_dtype, # pylint: disable=attribute-error
shape=tuple(x.shape.as_list()), # pylint: disable=attribute-error
)
if isinstance(x, tf.Tensor):
return clu_data.ArraySpec(
dtype=x.dtype.as_numpy_dtype, shape=tuple(x.shape.as_list())
)
return clu_data.ArraySpec(dtype=x.dtype, shape=tuple(x.shape))

element_spec = tf.nest.map_structure(_to_element_spec, batch)
Expand Down
113 changes: 113 additions & 0 deletions recml/core/ops/embedding_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2024 RecML authors <[email protected]>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Embedding lookup ops."""

from collections.abc import Mapping, Sequence
import dataclasses
import functools
from typing import Any, TypeVar

from etils import epy
import jax
from jax.experimental import shard_map

with epy.lazy_imports():
# pylint: disable=g-import-not-at-top
from jax_tpu_embedding.sparsecore.lib.nn import embedding
# pylint: enable=g-import-not-at-top


T = TypeVar("T")
Nested = T | Sequence[T] | Mapping[str, T]
FeatureSpec = Any


@dataclasses.dataclass
class SparsecoreParams:
"""Embedding parameters."""

feature_specs: Nested[FeatureSpec]
abstract_mesh: jax.sharding.AbstractMesh
data_axes: Sequence[str | None]
embedding_axes: Sequence[str | None]
sharding_strategy: str


@functools.partial(jax.custom_vjp, nondiff_argnums=(0,))
def sparsecore_lookup(
sparsecore_params: SparsecoreParams,
tables: Mapping[str, tuple[jax.Array, ...]],
csr_inputs: tuple[jax.Array, ...],
):
return shard_map.shard_map(
functools.partial(
embedding.tpu_sparse_dense_matmul,
global_device_count=sparsecore_params.abstract_mesh.size,
feature_specs=sparsecore_params.feature_specs,
sharding_strategy=sparsecore_params.sharding_strategy,
),
mesh=sparsecore_params.abstract_mesh,
in_specs=(
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes),
),
out_specs=jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
check_rep=False,
)(csr_inputs, tables)


def _emb_lookup_fwd(
sparsecore_params: SparsecoreParams,
tables: Mapping[str, tuple[jax.Array, ...]],
csr_inputs: tuple[jax.Array, ...],
):
out = sparsecore_lookup(sparsecore_params, tables, csr_inputs)
return out, (tables, csr_inputs)


def _emb_lookup_bwd(
sparsecore_params: SparsecoreParams,
res: tuple[Mapping[str, tuple[jax.Array, ...]], tuple[jax.Array, ...]],
gradients: Nested[jax.Array],
) -> tuple[Nested[jax.Array], None]:
"""Backward pass for embedding lookup."""
(tables, csr_inputs) = res

emb_table_grads = shard_map.shard_map(
functools.partial(
embedding.tpu_sparse_dense_matmul_grad,
feature_specs=sparsecore_params.feature_specs,
sharding_strategy=sparsecore_params.sharding_strategy,
),
mesh=sparsecore_params.abstract_mesh,
in_specs=(
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes),
),
out_specs=jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
check_rep=False,
)(gradients, csr_inputs, tables)

# `tpu_sparse_dense_matmul_grad` returns a general mapping (usually a dict).
# It may not be the same type as the embedding table (e.g. FrozenDict).
# Here we use flatten / unflatten to ensure the types are the same.
emb_table_grads = jax.tree.unflatten(
jax.tree.structure(tables), jax.tree.leaves(emb_table_grads)
)

return emb_table_grads, None


sparsecore_lookup.defvjp(_emb_lookup_fwd, _emb_lookup_bwd)
13 changes: 12 additions & 1 deletion recml/core/training/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Core training library for Jax."""

import abc
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
import dataclasses
import enum
from typing import Any, Generic, TypeVar
Expand All @@ -33,6 +33,8 @@
TRAINING_COMPLETE_MARKER_FILE = "marker.txt"
TRAIN_LOG_DIRNAME = "train"
EVAL_LOG_DIRNAME = "val"
KERAS_MODEL_SAVEFILE = "model.keras"
ORBAX_CHECKPOINT_DEFAULT_KEY = "default"

DEFAULT_RNG_SEED = 0
IN_TRAINER_CONTEXT = False # Set to true when run from the main trainer.
Expand Down Expand Up @@ -171,6 +173,15 @@ def get_iterators(
return train_dataset, eval_datasets # pytype: disable=bad-return-type


def get_shape(
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor,
) -> Sequence[int | None]:
"""Gets the shape of a dense / sparse / ragged tensor."""
if isinstance(x, tf.SparseTensor):
return [x.shape[0]] + [None for _ in x.shape[1:]]
return x.shape.as_list()


def in_tracing_context() -> bool:
"""Returns whether the current context is a tracing context."""
return isinstance(jnp.ones(()), jax.core.Tracer)
57 changes: 50 additions & 7 deletions recml/core/training/jax.py → recml/core/training/jax_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from clu import periodic_actions
import clu.metrics as clu_metrics
from flax import struct
import flax.linen as nn
import jax
import jax.numpy as jnp
import keras
Expand All @@ -40,7 +41,7 @@
import tensorflow as tf


# pylint: disable=logging-fstring-interpolation
# pylint: disable=logging-fstring-interpolation, bad-whitespace

StateT = TypeVar("StateT")
MetricsT = TypeVar("MetricsT", bound=Mapping[str, clu_metrics.Metric])
Expand All @@ -67,43 +68,85 @@ class JaxState(struct.PyTreeNode, Generic[MetaT]):
step: A counter of the current step of the job. It starts at zero and it is
incremented by 1 on a call to `state.update(...)`. This should be a Jax
array and not a Python integer.
apply: A function that can be used to apply the forward pass of the model.
For Flax models this is usually set to `model.apply`.
params: A pytree of trainable variables that will be updated by `tx` and
used in `apply`.
tx: An optax gradient transformation that will be used to update the
parameters contained in `params` on a call to `state.update(...)`.
opt_state: The optimizer state for `tx`. This is usually created by calling
`tx.init(params)`.
_apply: An optional function that can be used to apply the forward pass of
the model. For Flax models this is usually set to `model.apply` while for
Haiku models this is usually set to `transform.apply`.
_model: An optional reference to a stateless Flax model for convenience.
mutable: A pytree of mutable variables that are used by `apply`.
meta: Arbitrary metadata that is recorded on the state. This can be useful
for tracking additional references in the state.
"""

step: jax.Array
apply: Callable[..., Any] = struct.field(pytree_node=False)
params: PyTree = struct.field(pytree_node=True)
tx: optax.GradientTransformation = struct.field(pytree_node=False)
opt_state: optax.OptState = struct.field(pytree_node=True)
mutable: PyTree = struct.field(pytree_node=True, default_factory=dict)
meta: MetaT = struct.field(pytree_node=False, default_factory=dict)
_apply: Callable[..., Any] | None = struct.field(
pytree_node=False, default_factory=None
)
_model: nn.Module | None = struct.field(pytree_node=False, default=None)

@property
def model(self) -> nn.Module:
"""Returns a reference to the model used to create the state."""
if self._model is None:
raise ValueError("No Flax `model` is set on the state.")
return self._model

def apply(self, *args, **kwargs) -> Any:
"""Applies the forward pass of the model."""
if self._apply is None:
raise ValueError("No `apply` function is set on the state.")
return self._apply(*args, **kwargs)

@classmethod
def create(
cls,
*,
apply: Callable[..., Any],
apply: Callable[..., Any] | None = None,
model: nn.Module | None = None,
params: PyTree,
tx: optax.GradientTransformation,
**kwargs,
) -> Self:
"""Creates a new instance from a Jax apply function and Optax optimizer."""
"""Creates a new instance from a Jax model / apply fn and Optax optimizer.

Args:
apply: A function that can be used to apply the forward pass of the model.
For Flax models this is usually set to `model.apply`. This cannot be set
along with `model`.
model: A reference to a stateless Flax model. This cannot be set along
with `apply`. When set the `apply` attribute of the state will be set to
`model.apply`.
params: A pytree of trainable variables that will be updated by `tx` and
used in `apply`.
tx: An optax gradient transformation that will be used to update the
parameters contained in `params` on a call to `state.update(...)`.
**kwargs: Other updates to set on the new state.

Returns:
An new instance of the state.
"""
if apply is not None and model is not None:
raise ValueError("Only one of `apply` or `model` can be provided.")
elif model is not None:
apply = model.apply

return cls(
step=jnp.zeros([], dtype=jnp.int32),
apply=apply,
params=params,
tx=tx,
opt_state=tx.init(params),
_apply=apply,
_model=model,
**kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
import jax.numpy as jnp
import jaxtyping as jt
import optax
from recml.core.training import jax as jax_lib
from recml.core.training import jax_trainer
from recml.core.training import partitioning
import tensorflow as tf
import tensorflow_datasets as tfds


class _MNISTTask(jax_lib.JaxTask):
class _MNISTTask(jax_trainer.JaxTask):
"""Task for fitting a CNN on MNIST."""

def create_datasets(self) -> tuple[tf.data.Dataset, tf.data.Dataset]:
Expand Down Expand Up @@ -126,7 +126,7 @@ def setUp(self):
def test_mnist_e2e(self):
model_dir = self.create_tempdir().full_path
task = _MNISTTask()
trainer = jax_lib.JaxTrainer(
trainer = jax_trainer.JaxTrainer(
partitioner=partitioning.DataParallelPartitioner(),
train_steps=1000,
steps_per_eval=50,
Expand Down
Loading