Skip to content

Commit 5289e75

Browse files
authored
Merge pull request #32 from AI-Hypercomputer/test_750675584
Sync to Github.
2 parents b58765e + cf60ef9 commit 5289e75

19 files changed

+2267
-58
lines changed

recml/core/__init__.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@
2727
from recml.core.training.core import Experiment
2828
from recml.core.training.core import run_experiment
2929
from recml.core.training.core import Trainer
30-
from recml.core.training.jax import JaxState
31-
from recml.core.training.jax import JaxTask
32-
from recml.core.training.jax import JaxTrainer
33-
from recml.core.training.jax import KerasState
30+
from recml.core.training.jax_trainer import JaxState
31+
from recml.core.training.jax_trainer import JaxTask
32+
from recml.core.training.jax_trainer import JaxTrainer
33+
from recml.core.training.jax_trainer import KerasState
34+
from recml.core.training.keras_trainer import KerasTask
35+
from recml.core.training.keras_trainer import KerasTrainer
3436
from recml.core.training.optax_factory import AdagradFactory
3537
from recml.core.training.optax_factory import AdamFactory
3638
from recml.core.training.optax_factory import OptimizerFactory

recml/core/data/iterator.py

+18-12
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,15 @@ def __next__(self) -> clu_data.Element:
5757
if self._prefetched_batch is not None:
5858
batch = self._prefetched_batch
5959
self._prefetched_batch = None
60-
return batch
61-
62-
batch = next(self._iterator)
63-
if self._postprocessor is not None:
64-
batch = self._postprocessor(batch)
60+
else:
61+
batch = next(self._iterator)
62+
if self._postprocessor is not None:
63+
batch = self._postprocessor(batch)
6564

6665
def _maybe_to_numpy(
67-
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor,
66+
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | np.ndarray,
6867
) -> np.ndarray | tf.SparseTensor | tf.RaggedTensor:
69-
if isinstance(x, (tf.SparseTensor, tf.RaggedTensor)):
68+
if isinstance(x, (tf.SparseTensor, tf.RaggedTensor, np.ndarray)):
7069
return x
7170
if hasattr(x, "_numpy"):
7271
numpy = x._numpy() # pylint: disable=protected-access
@@ -83,13 +82,16 @@ def _maybe_to_numpy(
8382
@property
8483
def element_spec(self) -> clu_data.ElementSpec:
8584
if self._element_spec is not None:
86-
batch = self._element_spec
87-
else:
88-
batch = self.__next__()
89-
self._prefetched_batch = batch
85+
return self._element_spec
86+
87+
batch = next(self._iterator)
88+
if self._postprocessor is not None:
89+
batch = self._postprocessor(batch)
90+
91+
self._prefetched_batch = batch
9092

9193
def _to_element_spec(
92-
x: np.ndarray | tf.SparseTensor | tf.RaggedTensor,
94+
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor | np.ndarray,
9395
) -> clu_data.ArraySpec:
9496
if isinstance(x, tf.SparseTensor):
9597
return clu_data.ArraySpec(
@@ -101,6 +103,10 @@ def _to_element_spec(
101103
dtype=x.dtype.as_numpy_dtype, # pylint: disable=attribute-error
102104
shape=tuple(x.shape.as_list()), # pylint: disable=attribute-error
103105
)
106+
if isinstance(x, tf.Tensor):
107+
return clu_data.ArraySpec(
108+
dtype=x.dtype.as_numpy_dtype, shape=tuple(x.shape.as_list())
109+
)
104110
return clu_data.ArraySpec(dtype=x.dtype, shape=tuple(x.shape))
105111

106112
element_spec = tf.nest.map_structure(_to_element_spec, batch)

recml/core/ops/embedding_ops.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright 2024 RecML authors <[email protected]>.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Embedding lookup ops."""
15+
16+
from collections.abc import Mapping, Sequence
17+
import dataclasses
18+
import functools
19+
from typing import Any, TypeVar
20+
21+
from etils import epy
22+
import jax
23+
from jax.experimental import shard_map
24+
25+
with epy.lazy_imports():
26+
# pylint: disable=g-import-not-at-top
27+
from jax_tpu_embedding.sparsecore.lib.nn import embedding
28+
# pylint: enable=g-import-not-at-top
29+
30+
31+
T = TypeVar("T")
32+
Nested = T | Sequence[T] | Mapping[str, T]
33+
FeatureSpec = Any
34+
35+
36+
@dataclasses.dataclass
37+
class SparsecoreParams:
38+
"""Embedding parameters."""
39+
40+
feature_specs: Nested[FeatureSpec]
41+
abstract_mesh: jax.sharding.AbstractMesh
42+
data_axes: Sequence[str | None]
43+
embedding_axes: Sequence[str | None]
44+
sharding_strategy: str
45+
46+
47+
@functools.partial(jax.custom_vjp, nondiff_argnums=(0,))
48+
def sparsecore_lookup(
49+
sparsecore_params: SparsecoreParams,
50+
tables: Mapping[str, tuple[jax.Array, ...]],
51+
csr_inputs: tuple[jax.Array, ...],
52+
):
53+
return shard_map.shard_map(
54+
functools.partial(
55+
embedding.tpu_sparse_dense_matmul,
56+
global_device_count=sparsecore_params.abstract_mesh.size,
57+
feature_specs=sparsecore_params.feature_specs,
58+
sharding_strategy=sparsecore_params.sharding_strategy,
59+
),
60+
mesh=sparsecore_params.abstract_mesh,
61+
in_specs=(
62+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
63+
jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes),
64+
),
65+
out_specs=jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
66+
check_rep=False,
67+
)(csr_inputs, tables)
68+
69+
70+
def _emb_lookup_fwd(
71+
sparsecore_params: SparsecoreParams,
72+
tables: Mapping[str, tuple[jax.Array, ...]],
73+
csr_inputs: tuple[jax.Array, ...],
74+
):
75+
out = sparsecore_lookup(sparsecore_params, tables, csr_inputs)
76+
return out, (tables, csr_inputs)
77+
78+
79+
def _emb_lookup_bwd(
80+
sparsecore_params: SparsecoreParams,
81+
res: tuple[Mapping[str, tuple[jax.Array, ...]], tuple[jax.Array, ...]],
82+
gradients: Nested[jax.Array],
83+
) -> tuple[Nested[jax.Array], None]:
84+
"""Backward pass for embedding lookup."""
85+
(tables, csr_inputs) = res
86+
87+
emb_table_grads = shard_map.shard_map(
88+
functools.partial(
89+
embedding.tpu_sparse_dense_matmul_grad,
90+
feature_specs=sparsecore_params.feature_specs,
91+
sharding_strategy=sparsecore_params.sharding_strategy,
92+
),
93+
mesh=sparsecore_params.abstract_mesh,
94+
in_specs=(
95+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
96+
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
97+
jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes),
98+
),
99+
out_specs=jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
100+
check_rep=False,
101+
)(gradients, csr_inputs, tables)
102+
103+
# `tpu_sparse_dense_matmul_grad` returns a general mapping (usually a dict).
104+
# It may not be the same type as the embedding table (e.g. FrozenDict).
105+
# Here we use flatten / unflatten to ensure the types are the same.
106+
emb_table_grads = jax.tree.unflatten(
107+
jax.tree.structure(tables), jax.tree.leaves(emb_table_grads)
108+
)
109+
110+
return emb_table_grads, None
111+
112+
113+
sparsecore_lookup.defvjp(_emb_lookup_fwd, _emb_lookup_bwd)

recml/core/training/core.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""Core training library for Jax."""
1515

1616
import abc
17-
from collections.abc import Mapping
17+
from collections.abc import Mapping, Sequence
1818
import dataclasses
1919
import enum
2020
from typing import Any, Generic, TypeVar
@@ -33,6 +33,8 @@
3333
TRAINING_COMPLETE_MARKER_FILE = "marker.txt"
3434
TRAIN_LOG_DIRNAME = "train"
3535
EVAL_LOG_DIRNAME = "val"
36+
KERAS_MODEL_SAVEFILE = "model.keras"
37+
ORBAX_CHECKPOINT_DEFAULT_KEY = "default"
3638

3739
DEFAULT_RNG_SEED = 0
3840
IN_TRAINER_CONTEXT = False # Set to true when run from the main trainer.
@@ -171,6 +173,15 @@ def get_iterators(
171173
return train_dataset, eval_datasets # pytype: disable=bad-return-type
172174

173175

176+
def get_shape(
177+
x: tf.Tensor | tf.SparseTensor | tf.RaggedTensor,
178+
) -> Sequence[int | None]:
179+
"""Gets the shape of a dense / sparse / ragged tensor."""
180+
if isinstance(x, tf.SparseTensor):
181+
return [x.shape[0]] + [None for _ in x.shape[1:]]
182+
return x.shape.as_list()
183+
184+
174185
def in_tracing_context() -> bool:
175186
"""Returns whether the current context is a tracing context."""
176187
return isinstance(jnp.ones(()), jax.core.Tracer)

recml/core/training/jax.py renamed to recml/core/training/jax_trainer.py

+50-7
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from clu import periodic_actions
2727
import clu.metrics as clu_metrics
2828
from flax import struct
29+
import flax.linen as nn
2930
import jax
3031
import jax.numpy as jnp
3132
import keras
@@ -40,7 +41,7 @@
4041
import tensorflow as tf
4142

4243

43-
# pylint: disable=logging-fstring-interpolation
44+
# pylint: disable=logging-fstring-interpolation, bad-whitespace
4445

4546
StateT = TypeVar("StateT")
4647
MetricsT = TypeVar("MetricsT", bound=Mapping[str, clu_metrics.Metric])
@@ -67,43 +68,85 @@ class JaxState(struct.PyTreeNode, Generic[MetaT]):
6768
step: A counter of the current step of the job. It starts at zero and it is
6869
incremented by 1 on a call to `state.update(...)`. This should be a Jax
6970
array and not a Python integer.
70-
apply: A function that can be used to apply the forward pass of the model.
71-
For Flax models this is usually set to `model.apply`.
7271
params: A pytree of trainable variables that will be updated by `tx` and
7372
used in `apply`.
7473
tx: An optax gradient transformation that will be used to update the
7574
parameters contained in `params` on a call to `state.update(...)`.
7675
opt_state: The optimizer state for `tx`. This is usually created by calling
7776
`tx.init(params)`.
77+
_apply: An optional function that can be used to apply the forward pass of
78+
the model. For Flax models this is usually set to `model.apply` while for
79+
Haiku models this is usually set to `transform.apply`.
80+
_model: An optional reference to a stateless Flax model for convenience.
7881
mutable: A pytree of mutable variables that are used by `apply`.
7982
meta: Arbitrary metadata that is recorded on the state. This can be useful
8083
for tracking additional references in the state.
8184
"""
8285

8386
step: jax.Array
84-
apply: Callable[..., Any] = struct.field(pytree_node=False)
8587
params: PyTree = struct.field(pytree_node=True)
8688
tx: optax.GradientTransformation = struct.field(pytree_node=False)
8789
opt_state: optax.OptState = struct.field(pytree_node=True)
8890
mutable: PyTree = struct.field(pytree_node=True, default_factory=dict)
8991
meta: MetaT = struct.field(pytree_node=False, default_factory=dict)
92+
_apply: Callable[..., Any] | None = struct.field(
93+
pytree_node=False, default_factory=None
94+
)
95+
_model: nn.Module | None = struct.field(pytree_node=False, default=None)
96+
97+
@property
98+
def model(self) -> nn.Module:
99+
"""Returns a reference to the model used to create the state."""
100+
if self._model is None:
101+
raise ValueError("No Flax `model` is set on the state.")
102+
return self._model
103+
104+
def apply(self, *args, **kwargs) -> Any:
105+
"""Applies the forward pass of the model."""
106+
if self._apply is None:
107+
raise ValueError("No `apply` function is set on the state.")
108+
return self._apply(*args, **kwargs)
90109

91110
@classmethod
92111
def create(
93112
cls,
94113
*,
95-
apply: Callable[..., Any],
114+
apply: Callable[..., Any] | None = None,
115+
model: nn.Module | None = None,
96116
params: PyTree,
97117
tx: optax.GradientTransformation,
98118
**kwargs,
99119
) -> Self:
100-
"""Creates a new instance from a Jax apply function and Optax optimizer."""
120+
"""Creates a new instance from a Jax model / apply fn and Optax optimizer.
121+
122+
Args:
123+
apply: A function that can be used to apply the forward pass of the model.
124+
For Flax models this is usually set to `model.apply`. This cannot be set
125+
along with `model`.
126+
model: A reference to a stateless Flax model. This cannot be set along
127+
with `apply`. When set the `apply` attribute of the state will be set to
128+
`model.apply`.
129+
params: A pytree of trainable variables that will be updated by `tx` and
130+
used in `apply`.
131+
tx: An optax gradient transformation that will be used to update the
132+
parameters contained in `params` on a call to `state.update(...)`.
133+
**kwargs: Other updates to set on the new state.
134+
135+
Returns:
136+
An new instance of the state.
137+
"""
138+
if apply is not None and model is not None:
139+
raise ValueError("Only one of `apply` or `model` can be provided.")
140+
elif model is not None:
141+
apply = model.apply
142+
101143
return cls(
102144
step=jnp.zeros([], dtype=jnp.int32),
103-
apply=apply,
104145
params=params,
105146
tx=tx,
106147
opt_state=tx.init(params),
148+
_apply=apply,
149+
_model=model,
107150
**kwargs,
108151
)
109152

recml/core/training/jax_quality_test.py renamed to recml/core/training/jax_trainer_quality_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@
2525
import jax.numpy as jnp
2626
import jaxtyping as jt
2727
import optax
28-
from recml.core.training import jax as jax_lib
28+
from recml.core.training import jax_trainer
2929
from recml.core.training import partitioning
3030
import tensorflow as tf
3131
import tensorflow_datasets as tfds
3232

3333

34-
class _MNISTTask(jax_lib.JaxTask):
34+
class _MNISTTask(jax_trainer.JaxTask):
3535
"""Task for fitting a CNN on MNIST."""
3636

3737
def create_datasets(self) -> tuple[tf.data.Dataset, tf.data.Dataset]:
@@ -126,7 +126,7 @@ def setUp(self):
126126
def test_mnist_e2e(self):
127127
model_dir = self.create_tempdir().full_path
128128
task = _MNISTTask()
129-
trainer = jax_lib.JaxTrainer(
129+
trainer = jax_trainer.JaxTrainer(
130130
partitioner=partitioning.DataParallelPartitioner(),
131131
train_steps=1000,
132132
steps_per_eval=50,

0 commit comments

Comments
 (0)