53
53
# KerasRS to TensorFlow
54
54
55
55
56
- def translate_keras_rs_configuration (
56
+ def keras_to_tf_tpu_configuration (
57
57
feature_configs : types .Nested [FeatureConfig ],
58
58
table_stacking : str | Sequence [str ] | Sequence [Sequence [str ]],
59
59
num_replicas_in_sync : int ,
@@ -66,14 +66,15 @@ def translate_keras_rs_configuration(
66
66
Args:
67
67
feature_configs: The nested Keras RS feature configs.
68
68
table_stacking: The Keras RS table stacking.
69
+ num_replicas_in_sync: The number of replicas in sync from the strategy.
69
70
70
71
Returns:
71
72
A tuple containing the TensorFlow TPU feature configs and the TensorFlow
72
73
TPU sparse core embedding config.
73
74
"""
74
- tables : dict [TableConfig , tf .tpu .experimental .embedding .TableConfig ] = {}
75
+ tables : dict [int , tf .tpu .experimental .embedding .TableConfig ] = {}
75
76
feature_configs = keras .tree .map_structure (
76
- lambda f : translate_keras_rs_feature_config (
77
+ lambda f : keras_to_tf_tpu_feature_config (
77
78
f , tables , num_replicas_in_sync
78
79
),
79
80
feature_configs ,
@@ -108,9 +109,9 @@ def translate_keras_rs_configuration(
108
109
return feature_configs , sparse_core_embedding_config
109
110
110
111
111
- def translate_keras_rs_feature_config (
112
+ def keras_to_tf_tpu_feature_config (
112
113
feature_config : FeatureConfig ,
113
- tables : dict [TableConfig , tf .tpu .experimental .embedding .TableConfig ],
114
+ tables : dict [int , tf .tpu .experimental .embedding .TableConfig ],
114
115
num_replicas_in_sync : int ,
115
116
) -> tf .tpu .experimental .embedding .FeatureConfig :
116
117
"""Translates a Keras RS feature config to a TensorFlow TPU feature config.
@@ -120,7 +121,8 @@ def translate_keras_rs_feature_config(
120
121
121
122
Args:
122
123
feature_config: The Keras RS feature config to translate.
123
- tables: A mapping of KerasRS table configs to TF TPU table configs.
124
+ tables: A mapping of KerasRS table config ids to TF TPU table configs.
125
+ num_replicas_in_sync: The number of replicas in sync from the strategy.
124
126
125
127
Returns:
126
128
The TensorFlow TPU feature config.
@@ -131,10 +133,10 @@ def translate_keras_rs_feature_config(
131
133
f"but got { num_replicas_in_sync } ."
132
134
)
133
135
134
- table = tables .get (feature_config .table , None )
136
+ table = tables .get (id ( feature_config .table ) , None )
135
137
if table is None :
136
- table = translate_keras_rs_table_config (feature_config .table )
137
- tables [feature_config .table ] = table
138
+ table = keras_to_tf_tpu_table_config (feature_config .table )
139
+ tables [id ( feature_config .table ) ] = table
138
140
139
141
if len (feature_config .output_shape ) < 2 :
140
142
raise ValueError (
@@ -168,7 +170,7 @@ def translate_keras_rs_feature_config(
168
170
)
169
171
170
172
171
- def translate_keras_rs_table_config (
173
+ def keras_to_tf_tpu_table_config (
172
174
table_config : TableConfig ,
173
175
) -> tf .tpu .experimental .embedding .TableConfig :
174
176
initializer = table_config .initializer
@@ -179,13 +181,13 @@ def translate_keras_rs_table_config(
179
181
vocabulary_size = table_config .vocabulary_size ,
180
182
dim = table_config .embedding_dim ,
181
183
initializer = initializer ,
182
- optimizer = translate_optimizer (table_config .optimizer ),
184
+ optimizer = to_tf_tpu_optimizer (table_config .optimizer ),
183
185
combiner = table_config .combiner ,
184
186
name = table_config .name ,
185
187
)
186
188
187
189
188
- def translate_keras_optimizer (
190
+ def keras_to_tf_tpu_optimizer (
189
191
optimizer : keras .optimizers .Optimizer ,
190
192
) -> TfTpuOptimizer :
191
193
"""Translates a Keras optimizer to a TensorFlow TPU `_Optimizer`.
@@ -238,7 +240,12 @@ def translate_keras_optimizer(
238
240
"Unsupported optimizer option `Optimizer.loss_scale_factor`."
239
241
)
240
242
241
- optimizer_mapping = OPTIMIZER_MAPPINGS .get (type (optimizer ), None )
243
+ optimizer_mapping = None
244
+ for optimizer_class , mapping in OPTIMIZER_MAPPINGS .items ():
245
+ # Handle subclasses of the main optimizer class.
246
+ if isinstance (optimizer , optimizer_class ):
247
+ optimizer_mapping = mapping
248
+ break
242
249
if optimizer_mapping is None :
243
250
raise ValueError (
244
251
f"Unsupported optimizer type { type (optimizer )} . Optimizer must be "
@@ -258,7 +265,7 @@ def translate_keras_optimizer(
258
265
return optimizer_mapping .tpu_optimizer_class (** tpu_optimizer_kwargs )
259
266
260
267
261
- def translate_optimizer (
268
+ def to_tf_tpu_optimizer (
262
269
optimizer : str | keras .optimizers .Optimizer | TfTpuOptimizer | None ,
263
270
) -> TfTpuOptimizer :
264
271
"""Translates a Keras optimizer into a TensorFlow TPU `_Optimizer`.
@@ -299,7 +306,7 @@ def translate_optimizer(
299
306
"'sgd', 'adagrad', 'adam', or 'ftrl'"
300
307
)
301
308
elif isinstance (optimizer , keras .optimizers .Optimizer ):
302
- return translate_keras_optimizer (optimizer )
309
+ return keras_to_tf_tpu_optimizer (optimizer )
303
310
else :
304
311
raise ValueError (
305
312
f"Unknown optimizer type { type (optimizer )} . Please pass an "
@@ -312,7 +319,7 @@ def translate_optimizer(
312
319
# TensorFlow to TensorFlow
313
320
314
321
315
- def clone_tf_feature_configs (
322
+ def clone_tf_tpu_feature_configs (
316
323
feature_configs : types .Nested [tf .tpu .experimental .embedding .FeatureConfig ],
317
324
) -> types .Nested [tf .tpu .experimental .embedding .FeatureConfig ]:
318
325
"""Clones and resolves TensorFlow TPU feature configs.
@@ -327,7 +334,7 @@ def clone_tf_feature_configs(
327
334
"""
328
335
table_configs_dict = {}
329
336
330
- def clone_and_resolve_tf_feature_config (
337
+ def clone_and_resolve_tf_tpu_feature_config (
331
338
fc : tf .tpu .experimental .embedding .FeatureConfig ,
332
339
) -> tf .tpu .experimental .embedding .FeatureConfig :
333
340
if fc .table not in table_configs_dict :
@@ -336,7 +343,7 @@ def clone_and_resolve_tf_feature_config(
336
343
vocabulary_size = fc .table .vocabulary_size ,
337
344
dim = fc .table .dim ,
338
345
initializer = fc .table .initializer ,
339
- optimizer = translate_optimizer (fc .table .optimizer ),
346
+ optimizer = to_tf_tpu_optimizer (fc .table .optimizer ),
340
347
combiner = fc .table .combiner ,
341
348
name = fc .table .name ,
342
349
quantization_config = fc .table .quantization_config ,
@@ -352,5 +359,5 @@ def clone_and_resolve_tf_feature_config(
352
359
)
353
360
354
361
return keras .tree .map_structure (
355
- clone_and_resolve_tf_feature_config , feature_configs
362
+ clone_and_resolve_tf_tpu_feature_config , feature_configs
356
363
)
0 commit comments