Skip to content

Commit e002b4b

Browse files
authored
Fix for batch size of FeatureConfig.output_shape on TensorFlow. (#139)
The TensorFlow's `FeatureConfig.output_shape` expects the batch size to be per replica. However, `keras_rs` use the global batch size in `input_shape` and `output_shape`. This modifies the conversion code to take this into account.
1 parent fc91f77 commit e002b4b

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

keras_rs/src/layers/embedding/tensorflow/config_conversion.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
def translate_keras_rs_configuration(
5757
feature_configs: types.Nested[FeatureConfig],
5858
table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
59+
num_replicas_in_sync: int,
5960
) -> tuple[
6061
types.Nested[tf.tpu.experimental.embedding.FeatureConfig],
6162
tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig,
@@ -72,7 +73,10 @@ def translate_keras_rs_configuration(
7273
"""
7374
tables: dict[TableConfig, tf.tpu.experimental.embedding.TableConfig] = {}
7475
feature_configs = keras.tree.map_structure(
75-
lambda f: translate_keras_rs_feature_config(f, tables), feature_configs
76+
lambda f: translate_keras_rs_feature_config(
77+
f, tables, num_replicas_in_sync
78+
),
79+
feature_configs,
7680
)
7781

7882
# max_ids_per_chip_per_sample
@@ -107,6 +111,7 @@ def translate_keras_rs_configuration(
107111
def translate_keras_rs_feature_config(
108112
feature_config: FeatureConfig,
109113
tables: dict[TableConfig, tf.tpu.experimental.embedding.TableConfig],
114+
num_replicas_in_sync: int,
110115
) -> tf.tpu.experimental.embedding.FeatureConfig:
111116
"""Translates a Keras RS feature config to a TensorFlow TPU feature config.
112117
@@ -120,18 +125,46 @@ def translate_keras_rs_feature_config(
120125
Returns:
121126
The TensorFlow TPU feature config.
122127
"""
128+
if num_replicas_in_sync <= 0:
129+
raise ValueError(
130+
"`num_replicas_in_sync` must be positive, "
131+
f"but got {num_replicas_in_sync}."
132+
)
133+
123134
table = tables.get(feature_config.table, None)
124135
if table is None:
125136
table = translate_keras_rs_table_config(feature_config.table)
126137
tables[feature_config.table] = table
127138

139+
if len(feature_config.output_shape) < 2:
140+
raise ValueError(
141+
f"Invalid `output_shape` {feature_config.output_shape} in "
142+
f"`FeatureConfig` {feature_config}. It must have at least 2 "
143+
"dimensions: a batch dimension and an embedding dimension."
144+
)
145+
146+
# Exclude last dimension, TensorFlow's TPUEmbedding doesn't want it.
147+
output_shape = list(feature_config.output_shape[0:-1])
148+
149+
batch_size = output_shape[0]
150+
per_replica_batch_size: int | None = None
151+
if batch_size is not None:
152+
if batch_size % num_replicas_in_sync != 0:
153+
raise ValueError(
154+
f"Invalid `output_shape` {feature_config.output_shape} in "
155+
f"`FeatureConfig` {feature_config}. Batch size {batch_size} is "
156+
f"not a multiple of the number of TPUs {num_replicas_in_sync}."
157+
)
158+
per_replica_batch_size = batch_size // num_replicas_in_sync
159+
160+
# TensorFlow's TPUEmbedding wants the per replica batch size.
161+
output_shape = [per_replica_batch_size] + output_shape[1:]
162+
128163
# max_sequence_length
129164
return tf.tpu.experimental.embedding.FeatureConfig(
130165
name=feature_config.name,
131166
table=table,
132-
output_shape=feature_config.output_shape[
133-
0:-1
134-
], # exclude last dimension
167+
output_shape=output_shape,
135168
)
136169

137170

keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def _sparsecore_init(
107107
)
108108
self._tpu_feature_configs, self._sparse_core_embedding_config = (
109109
config_conversion.translate_keras_rs_configuration(
110-
feature_configs, table_stacking
110+
feature_configs,
111+
table_stacking,
112+
strategy.num_replicas_in_sync,
111113
)
112114
)
113115
if tpu_embedding_feature == EMBEDDING_FEATURE_V1:

0 commit comments

Comments
 (0)