56
56
def translate_keras_rs_configuration (
57
57
feature_configs : types .Nested [FeatureConfig ],
58
58
table_stacking : str | Sequence [str ] | Sequence [Sequence [str ]],
59
+ num_replicas_in_sync : int ,
59
60
) -> tuple [
60
61
types .Nested [tf .tpu .experimental .embedding .FeatureConfig ],
61
62
tf .tpu .experimental .embedding .SparseCoreEmbeddingConfig ,
@@ -72,7 +73,10 @@ def translate_keras_rs_configuration(
72
73
"""
73
74
tables : dict [TableConfig , tf .tpu .experimental .embedding .TableConfig ] = {}
74
75
feature_configs = keras .tree .map_structure (
75
- lambda f : translate_keras_rs_feature_config (f , tables ), feature_configs
76
+ lambda f : translate_keras_rs_feature_config (
77
+ f , tables , num_replicas_in_sync
78
+ ),
79
+ feature_configs ,
76
80
)
77
81
78
82
# max_ids_per_chip_per_sample
@@ -107,6 +111,7 @@ def translate_keras_rs_configuration(
107
111
def translate_keras_rs_feature_config (
108
112
feature_config : FeatureConfig ,
109
113
tables : dict [TableConfig , tf .tpu .experimental .embedding .TableConfig ],
114
+ num_replicas_in_sync : int ,
110
115
) -> tf .tpu .experimental .embedding .FeatureConfig :
111
116
"""Translates a Keras RS feature config to a TensorFlow TPU feature config.
112
117
@@ -120,18 +125,46 @@ def translate_keras_rs_feature_config(
120
125
Returns:
121
126
The TensorFlow TPU feature config.
122
127
"""
128
+ if num_replicas_in_sync <= 0 :
129
+ raise ValueError (
130
+ "`num_replicas_in_sync` must be positive, "
131
+ f"but got { num_replicas_in_sync } ."
132
+ )
133
+
123
134
table = tables .get (feature_config .table , None )
124
135
if table is None :
125
136
table = translate_keras_rs_table_config (feature_config .table )
126
137
tables [feature_config .table ] = table
127
138
139
+ if len (feature_config .output_shape ) < 2 :
140
+ raise ValueError (
141
+ f"Invalid `output_shape` { feature_config .output_shape } in "
142
+ f"`FeatureConfig` { feature_config } . It must have at least 2 "
143
+ "dimensions: a batch dimension and an embedding dimension."
144
+ )
145
+
146
+ # Exclude last dimension, TensorFlow's TPUEmbedding doesn't want it.
147
+ output_shape = list (feature_config .output_shape [0 :- 1 ])
148
+
149
+ batch_size = output_shape [0 ]
150
+ per_replica_batch_size : int | None = None
151
+ if batch_size is not None :
152
+ if batch_size % num_replicas_in_sync != 0 :
153
+ raise ValueError (
154
+ f"Invalid `output_shape` { feature_config .output_shape } in "
155
+ f"`FeatureConfig` { feature_config } . Batch size { batch_size } is "
156
+ f"not a multiple of the number of TPUs { num_replicas_in_sync } ."
157
+ )
158
+ per_replica_batch_size = batch_size // num_replicas_in_sync
159
+
160
+ # TensorFlow's TPUEmbedding wants the per replica batch size.
161
+ output_shape = [per_replica_batch_size ] + output_shape [1 :]
162
+
128
163
# max_sequence_length
129
164
return tf .tpu .experimental .embedding .FeatureConfig (
130
165
name = feature_config .name ,
131
166
table = table ,
132
- output_shape = feature_config .output_shape [
133
- 0 :- 1
134
- ], # exclude last dimension
167
+ output_shape = output_shape ,
135
168
)
136
169
137
170
0 commit comments