Skip to content

Commit 14ffa77

Browse files
Mohamed Hammadrecml authors
Mohamed Hammad
authored and
recml authors
committed
Internal change
PiperOrigin-RevId: 750714985
1 parent b58765e commit 14ffa77

19 files changed

+2264
-58
lines changed

recml/core/__init__.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@
2727
from recml.core.training.core import Experiment
2828
from recml.core.training.core import run_experiment
2929
from recml.core.training.core import Trainer
30-
from recml.core.training.jax import JaxState
31-
from recml.core.training.jax import JaxTask
32-
from recml.core.training.jax import JaxTrainer
33-
from recml.core.training.jax import KerasState
30+
from recml.core.training.jax_trainer import JaxState
31+
from recml.core.training.jax_trainer import JaxTask
32+
from recml.core.training.jax_trainer import JaxTrainer
33+
from recml.core.training.jax_trainer import KerasState
34+
from recml.core.training.keras_trainer import KerasTask
35+
from recml.core.training.keras_trainer import KerasTrainer
3436
from recml.core.training.optax_factory import AdagradFactory
3537
from recml.core.training.optax_factory import AdamFactory
3638
from recml.core.training.optax_factory import OptimizerFactory

recml/core/data/iterator.py

+18-12
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,15 @@ def __next__(self) -> clu_data.Element:
5757
if self._prefetched_batch is not None:
5858
batch = self._prefetched_batch
5959
self._prefetched_batch = None
60-
return batch
61-
62-
batch = next(self._iterator)
63-
if self._postprocessor is not None:
64-
batch = self._postprocessor(batch)
60+
else:
61+
batch = next(self._iterator)
62+
if self._postprocessor is not None:
63+
batch = self._postprocessor(batch)
6564

6665
def _maybe_to_numpy(
67-
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor,
66+
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | np.ndarray,
6867
) -> np.ndarray | tf.SparseTensor | tf.RaggedTensor:
69-
if isinstance(x, (tf.SparseTensor, tf.RaggedTensor)):
68+
if isinstance(x, (tf.SparseTensor, tf.RaggedTensor, np.ndarray)):
7069
return x
7170
if hasattr(x, "_numpy"):
7271
numpy = x._numpy() # pylint: disable=protected-access
@@ -83,13 +82,16 @@ def _maybe_to_numpy(
8382
@property
8483
def element_spec(self) -> clu_data.ElementSpec:
8584
if self._element_spec is not None:
86-
batch = self._element_spec
87-
else:
88-
batch = self.__next__()
89-
self._prefetched_batch = batch
85+
return self._element_spec
86+
87+
batch = next(self._iterator)
88+
if self._postprocessor is not None:
89+
batch = self._postprocessor(batch)
90+
91+
self._prefetched_batch = batch
9092

9193
def _to_element_spec(
92-
x: np.ndarray | tf.SparseTensor | tf.RaggedTensor,
94+
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | np.ndarray,
9395
) -> clu_data.ArraySpec:
9496
if isinstance(x, tf.SparseTensor):
9597
return clu_data.ArraySpec(
@@ -101,6 +103,10 @@ def _to_element_spec(
101103
dtype=x.dtype.as_numpy_dtype, # pylint: disable=attribute-error
102104
shape=tuple(x.shape.as_list()), # pylint: disable=attribute-error
103105
)
106+
if isinstance(x, tf.Tensor):
107+
return clu_data.ArraySpec(
108+
dtype=x.dtype.as_numpy_dtype, shape=tuple(x.shape.as_list())
109+
)
104110
return clu_data.ArraySpec(dtype=x.dtype, shape=tuple(x.shape))
105111

106112
element_spec = tf.nest.map_structure(_to_element_spec, batch)

recml/core/ops/embedding_ops.py

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright 2024 RecML authors <[email protected]>.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Embedding lookup ops."""
15+
16+
from collections.abc import Mapping, Sequence
17+
import dataclasses
18+
import functools
19+
20+
from etils import epy
21+
import jax
22+
from jax.experimental import shard_map
23+
24+
with epy.lazy_imports():
25+
# pylint: disable=g-import-not-at-top
26+
from jax_tpu_embedding.sparsecore.lib.nn import embedding
27+
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
28+
# pylint: enable=g-import-not-at-top
29+
30+
31+
@dataclasses.dataclass
32+
class SparsecoreParams:
33+
"""Embedding parameters."""
34+
35+
feature_specs: embedding.Nested[embedding_spec.FeatureSpec]
36+
abstract_mesh: jax.sharding.AbstractMesh
37+
data_axes: Sequence[str | None]
38+
embedding_axes: Sequence[str | None]
39+
sharding_strategy: str
40+
41+
42+
@functools.partial(jax.custom_vjp, nondiff_argnums=(0,))
43+
def sparsecore_lookup(
44+
sparsecore_params: SparsecoreParams,
45+
tables: Mapping[str, tuple[jax.Array, ...]],
46+
csr_inputs: tuple[jax.Array, ...],
47+
):
48+
return shard_map.shard_map(
49+
functools.partial(
50+
embedding.tpu_sparse_dense_matmul,
51+
global_device_count=sparsecore_params.abstract_mesh.size,
52+
feature_specs=sparsecore_params.feature_specs,
53+
sharding_strategy=sparsecore_params.sharding_strategy,
54+
),
55+
mesh=sparsecore_params.abstract_mesh,
56+
in_specs=(
57+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
58+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
59+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
60+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
61+
jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes),
62+
),
63+
out_specs=jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
64+
check_rep=False,
65+
)(*csr_inputs, tables)
66+
67+
68+
def _emb_lookup_fwd(
69+
sparsecore_params: SparsecoreParams,
70+
tables: Mapping[str, tuple[jax.Array, ...]],
71+
csr_inputs: tuple[jax.Array, ...],
72+
):
73+
out = sparsecore_lookup(sparsecore_params, tables, csr_inputs)
74+
return out, (tables, csr_inputs)
75+
76+
77+
def _emb_lookup_bwd(
78+
sparsecore_params: SparsecoreParams,
79+
res: tuple[Mapping[str, tuple[jax.Array, ...]], tuple[jax.Array, ...]],
80+
gradients: embedding.Nested[jax.Array],
81+
) -> tuple[embedding.Nested[jax.Array], None]:
82+
"""Backward pass for embedding lookup."""
83+
(tables, csr_inputs) = res
84+
85+
emb_table_grads = shard_map.shard_map(
86+
functools.partial(
87+
embedding.tpu_sparse_dense_matmul_grad,
88+
feature_specs=sparsecore_params.feature_specs,
89+
sharding_strategy=sparsecore_params.sharding_strategy,
90+
),
91+
mesh=sparsecore_params.abstract_mesh,
92+
in_specs=(
93+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
94+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
95+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
96+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
97+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
98+
jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes),
99+
),
100+
out_specs=jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
101+
check_rep=False,
102+
)(gradients, *csr_inputs, tables)
103+
104+
# `tpu_sparse_dense_matmul_grad` returns a general mapping (usually a dict).
105+
# It may not be the same type as the embedding table (e.g. FrozenDict).
106+
# Here we use flatten / unflatten to ensure the types are the same.
107+
emb_table_grads = jax.tree.unflatten(
108+
jax.tree.structure(tables), jax.tree.leaves(emb_table_grads)
109+
)
110+
111+
return emb_table_grads, None
112+
113+
114+
sparsecore_lookup.defvjp(_emb_lookup_fwd, _emb_lookup_bwd)

recml/core/training/core.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""Core training library for Jax."""
1515

1616
import abc
17-
from collections.abc import Mapping
17+
from collections.abc import Mapping, Sequence
1818
import dataclasses
1919
import enum
2020
from typing import Any, Generic, TypeVar
@@ -33,6 +33,8 @@
3333
TRAINING_COMPLETE_MARKER_FILE = "marker.txt"
3434
TRAIN_LOG_DIRNAME = "train"
3535
EVAL_LOG_DIRNAME = "val"
36+
KERAS_MODEL_SAVEFILE = "model.keras"
37+
ORBAX_CHECKPOINT_DEFAULT_KEY = "default"
3638

3739
DEFAULT_RNG_SEED = 0
3840
IN_TRAINER_CONTEXT = False # Set to true when run from the main trainer.
@@ -171,6 +173,15 @@ def get_iterators(
171173
return train_dataset, eval_datasets # pytype: disable=bad-return-type
172174

173175

176+
def get_shape(
177+
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor,
178+
) -> Sequence[int | None]:
179+
"""Gets the shape of a dense / sparse / ragged tensor."""
180+
if isinstance(x, tf.SparseTensor):
181+
return [x.shape[0]] + [None for _ in x.shape[1:]]
182+
return x.shape.as_list()
183+
184+
174185
def in_tracing_context() -> bool:
175186
"""Returns whether the current context is a tracing context."""
176187
return isinstance(jnp.ones(()), jax.core.Tracer)

recml/core/training/jax.py renamed to recml/core/training/jax_trainer.py

+50-7
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from clu import periodic_actions
2727
import clu.metrics as clu_metrics
2828
from flax import struct
29+
import flax.linen as nn
2930
import jax
3031
import jax.numpy as jnp
3132
import keras
@@ -40,7 +41,7 @@
4041
import tensorflow as tf
4142

4243

43-
# pylint: disable=logging-fstring-interpolation
44+
# pylint: disable=logging-fstring-interpolation, bad-whitespace
4445

4546
StateT = TypeVar("StateT")
4647
MetricsT = TypeVar("MetricsT", bound=Mapping[str, clu_metrics.Metric])
@@ -67,43 +68,85 @@ class JaxState(struct.PyTreeNode, Generic[MetaT]):
6768
step: A counter of the current step of the job. It starts at zero and it is
6869
incremented by 1 on a call to `state.update(...)`. This should be a Jax
6970
array and not a Python integer.
70-
apply: A function that can be used to apply the forward pass of the model.
71-
For Flax models this is usually set to `model.apply`.
7271
params: A pytree of trainable variables that will be updated by `tx` and
7372
used in `apply`.
7473
tx: An optax gradient transformation that will be used to update the
7574
parameters contained in `params` on a call to `state.update(...)`.
7675
opt_state: The optimizer state for `tx`. This is usually created by calling
7776
`tx.init(params)`.
77+
_apply: An optional function that can be used to apply the forward pass of
78+
the model. For Flax models this is usually set to `model.apply` while for
79+
Haiku models this is usually set to `transform.apply`.
80+
_model: An optional reference to a stateless Flax model for convenience.
7881
mutable: A pytree of mutable variables that are used by `apply`.
7982
meta: Arbitrary metadata that is recorded on the state. This can be useful
8083
for tracking additional references in the state.
8184
"""
8285

8386
step: jax.Array
84-
apply: Callable[..., Any] = struct.field(pytree_node=False)
8587
params: PyTree = struct.field(pytree_node=True)
8688
tx: optax.GradientTransformation = struct.field(pytree_node=False)
8789
opt_state: optax.OptState = struct.field(pytree_node=True)
8890
mutable: PyTree = struct.field(pytree_node=True, default_factory=dict)
8991
meta: MetaT = struct.field(pytree_node=False, default_factory=dict)
92+
_apply: Callable[..., Any] | None = struct.field(
93+
pytree_node=False, default_factory=None
94+
)
95+
_model: nn.Module | None = struct.field(pytree_node=False, default=None)
96+
97+
@property
98+
def model(self) -> nn.Module:
99+
"""Returns a reference to the model used to create the state."""
100+
if self._model is None:
101+
raise ValueError("No Flax `model` is set on the state.")
102+
return self._model
103+
104+
def apply(self, *args, **kwargs) -> Any:
105+
"""Applies the forward pass of the model."""
106+
if self._apply is None:
107+
raise ValueError("No `apply` function is set on the state.")
108+
return self._apply(*args, **kwargs)
90109

91110
@classmethod
92111
def create(
93112
cls,
94113
*,
95-
apply: Callable[..., Any],
114+
apply: Callable[..., Any] | None = None,
115+
model: nn.Module | None = None,
96116
params: PyTree,
97117
tx: optax.GradientTransformation,
98118
**kwargs,
99119
) -> Self:
100-
"""Creates a new instance from a Jax apply function and Optax optimizer."""
120+
"""Creates a new instance from a Jax model / apply fn and Optax optimizer.
121+
122+
Args:
123+
apply: A function that can be used to apply the forward pass of the model.
124+
For Flax models this is usually set to `model.apply`. This cannot be set
125+
along with `model`.
126+
model: A reference to a stateless Flax model. This cannot be set along
127+
with `apply`. When set the `apply` attribute of the state will be set to
128+
`model.apply`.
129+
params: A pytree of trainable variables that will be updated by `tx` and
130+
used in `apply`.
131+
tx: An optax gradient transformation that will be used to update the
132+
parameters contained in `params` on a call to `state.update(...)`.
133+
**kwargs: Other updates to set on the new state.
134+
135+
Returns:
136+
An new instance of the state.
137+
"""
138+
if apply is not None and model is not None:
139+
raise ValueError("Only one of `apply` or `model` can be provided.")
140+
elif model is not None:
141+
apply = model.apply
142+
101143
return cls(
102144
step=jnp.zeros([], dtype=jnp.int32),
103-
apply=apply,
104145
params=params,
105146
tx=tx,
106147
opt_state=tx.init(params),
148+
_apply=apply,
149+
_model=model,
107150
**kwargs,
108151
)
109152

recml/core/training/jax_quality_test.py renamed to recml/core/training/jax_trainer_quality_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@
2525
import jax.numpy as jnp
2626
import jaxtyping as jt
2727
import optax
28-
from recml.core.training import jax as jax_lib
28+
from recml.core.training import jax_trainer
2929
from recml.core.training import partitioning
3030
import tensorflow as tf
3131
import tensorflow_datasets as tfds
3232

3333

34-
class _MNISTTask(jax_lib.JaxTask):
34+
class _MNISTTask(jax_trainer.JaxTask):
3535
"""Task for fitting a CNN on MNIST."""
3636

3737
def create_datasets(self) -> tuple[tf.data.Dataset, tf.data.Dataset]:
@@ -126,7 +126,7 @@ def setUp(self):
126126
def test_mnist_e2e(self):
127127
model_dir = self.create_tempdir().full_path
128128
task = _MNISTTask()
129-
trainer = jax_lib.JaxTrainer(
129+
trainer = jax_trainer.JaxTrainer(
130130
partitioner=partitioning.DataParallelPartitioner(),
131131
train_steps=1000,
132132
steps_per_eval=50,

0 commit comments

Comments
 (0)