Skip to content

Commit 764cd93

Browse files
authored
Make TableConfig and FeatureConfig unhashable. (#150)
`TableConfig` contains a keras `Optimizer` and a keras `Initializer`, which are not hashable and even mutable in the case of the optimizer. The "unsafe hash" was actually unsafe. Instead of using `TableConfig`s as keys directly, we use the id of the `TableConfig`, which is correct because we are detecting reused instances. Also: - renamed TensorFlow config conversion functions to be more consistent with the JAX ones - added unit tests for the TensorFlow config conversion functions - fix optimizer conversion with the Torch backend
1 parent 2fbdc2c commit 764cd93

File tree

7 files changed

+263
-41
lines changed

7 files changed

+263
-41
lines changed

keras_rs/src/layers/embedding/base_distributed_embedding.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -822,13 +822,13 @@ def _default_device_init(
822822
table_stacking: str | Sequence[Sequence[str]],
823823
) -> None:
824824
del table_stacking
825-
table_to_embedding_layer: dict[TableConfig, EmbedReduce] = {}
825+
table_config_id_to_embedding_layer: dict[int, EmbedReduce] = {}
826826
self._default_device_embedding_layers: dict[str, EmbedReduce] = {}
827827

828828
for path, feature_config in feature_configs.items():
829-
if feature_config.table in table_to_embedding_layer:
829+
if id(feature_config.table) in table_config_id_to_embedding_layer:
830830
self._default_device_embedding_layers[path] = (
831-
table_to_embedding_layer[feature_config.table]
831+
table_config_id_to_embedding_layer[id(feature_config.table)]
832832
)
833833
else:
834834
embedding_layer = EmbedReduce(
@@ -838,7 +838,9 @@ def _default_device_init(
838838
embeddings_initializer=feature_config.table.initializer,
839839
combiner=feature_config.table.combiner,
840840
)
841-
table_to_embedding_layer[feature_config.table] = embedding_layer
841+
table_config_id_to_embedding_layer[id(feature_config.table)] = (
842+
embedding_layer
843+
)
842844
self._default_device_embedding_layers[path] = embedding_layer
843845

844846
def _default_device_build(
@@ -1013,8 +1015,8 @@ def get_config(self) -> dict[str, Any]:
10131015

10141016
# The serialized `TableConfig` objects.
10151017
table_config_dicts: list[dict[str, Any]] = []
1016-
# Mapping from `TableConfig` to index in `table_config_dicts`.
1017-
table_config_indices: dict[TableConfig, int] = {}
1018+
# Mapping from `TableConfig` id to index in `table_config_dicts`.
1019+
table_config_id_to_index: dict[int, int] = {}
10181020

10191021
def serialize_feature_config(
10201022
feature_config: FeatureConfig,
@@ -1024,17 +1026,17 @@ def serialize_feature_config(
10241026
# key.
10251027
feature_config_dict = feature_config.get_config()
10261028

1027-
if feature_config.table not in table_config_indices:
1029+
if id(feature_config.table) not in table_config_id_to_index:
10281030
# Save the serialized `TableConfig` the first time we see it and
10291031
# remember its index.
1030-
table_config_indices[feature_config.table] = len(
1032+
table_config_id_to_index[id(feature_config.table)] = len(
10311033
table_config_dicts
10321034
)
10331035
table_config_dicts.append(feature_config_dict["table"])
10341036

10351037
# Replace the serialized `TableConfig` with its index.
1036-
feature_config_dict["table"] = table_config_indices[
1037-
feature_config.table
1038+
feature_config_dict["table"] = table_config_id_to_index[
1039+
id(feature_config.table)
10381040
]
10391041
return feature_config_dict
10401042

keras_rs/src/layers/embedding/distributed_embedding_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
@keras_rs_export("keras_rs.layers.TableConfig")
13-
@dataclasses.dataclass(eq=True, unsafe_hash=True, order=True)
13+
@dataclasses.dataclass(order=True)
1414
class TableConfig:
1515
"""Configuration for one embedding table.
1616
@@ -88,7 +88,7 @@ def from_config(cls, config: dict[str, Any]) -> "TableConfig":
8888

8989

9090
@keras_rs_export("keras_rs.layers.FeatureConfig")
91-
@dataclasses.dataclass(eq=True, unsafe_hash=True, order=True)
91+
@dataclasses.dataclass(order=True)
9292
class FeatureConfig:
9393
"""Configuration for one embedding feature.
9494

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -690,19 +690,22 @@ def _default_device_set_tables(
690690
raise ValueError("Layer must first be built before setting tables.")
691691

692692
if "default_device" in self._placement_to_path_to_feature_config:
693-
table_to_embedding_layer = {}
693+
table_name_to_embedding_layer = {}
694694
for (
695695
path,
696696
feature_config,
697697
) in self._placement_to_path_to_feature_config[
698698
"default_device"
699699
].items():
700-
table_to_embedding_layer[feature_config.table] = (
700+
table_name_to_embedding_layer[feature_config.table.name] = (
701701
self._default_device_embedding_layers[path]
702702
)
703703

704-
for table, embedding_layer in table_to_embedding_layer.items():
705-
table_values = tables.get(table.name, None)
704+
for (
705+
table_name,
706+
embedding_layer,
707+
) in table_name_to_embedding_layer.items():
708+
table_values = tables.get(table_name, None)
706709
if table_values is not None:
707710
if embedding_layer.lora_enabled:
708711
raise ValueError("Cannot set table if LoRA is enabled.")

keras_rs/src/layers/embedding/jax/distributed_embedding_test.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -561,12 +561,18 @@ def loss_fn(y_true, y_pred):
561561
# Setup a model with a zero initializer but otherwise the same
562562
# feature configs to test restore. Keep the same embedding layer name to
563563
# ensure the correct weights are restored.
564+
table_config_id_to_table_config_with_zero_init = {
565+
id(table_config): dataclasses.replace(
566+
table_config, initializer="zeros"
567+
)
568+
for table_config in table_configs
569+
}
564570
feature_configs_with_zero_init = {
565571
feature_config.name: dataclasses.replace(
566572
feature_config,
567-
table=dataclasses.replace(
568-
feature_config.table, initializer="zeros"
569-
),
573+
table=table_config_id_to_table_config_with_zero_init[
574+
id(feature_config.table)
575+
],
570576
)
571577
for feature_config in feature_configs
572578
}

keras_rs/src/layers/embedding/tensorflow/config_conversion.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
# KerasRS to TensorFlow
5454

5555

56-
def translate_keras_rs_configuration(
56+
def keras_to_tf_tpu_configuration(
5757
feature_configs: types.Nested[FeatureConfig],
5858
table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
5959
num_replicas_in_sync: int,
@@ -66,14 +66,15 @@ def translate_keras_rs_configuration(
6666
Args:
6767
feature_configs: The nested Keras RS feature configs.
6868
table_stacking: The Keras RS table stacking.
69+
num_replicas_in_sync: The number of replicas in sync from the strategy.
6970
7071
Returns:
7172
A tuple containing the TensorFlow TPU feature configs and the TensorFlow
7273
TPU sparse core embedding config.
7374
"""
74-
tables: dict[TableConfig, tf.tpu.experimental.embedding.TableConfig] = {}
75+
tables: dict[int, tf.tpu.experimental.embedding.TableConfig] = {}
7576
feature_configs = keras.tree.map_structure(
76-
lambda f: translate_keras_rs_feature_config(
77+
lambda f: keras_to_tf_tpu_feature_config(
7778
f, tables, num_replicas_in_sync
7879
),
7980
feature_configs,
@@ -108,9 +109,9 @@ def translate_keras_rs_configuration(
108109
return feature_configs, sparse_core_embedding_config
109110

110111

111-
def translate_keras_rs_feature_config(
112+
def keras_to_tf_tpu_feature_config(
112113
feature_config: FeatureConfig,
113-
tables: dict[TableConfig, tf.tpu.experimental.embedding.TableConfig],
114+
tables: dict[int, tf.tpu.experimental.embedding.TableConfig],
114115
num_replicas_in_sync: int,
115116
) -> tf.tpu.experimental.embedding.FeatureConfig:
116117
"""Translates a Keras RS feature config to a TensorFlow TPU feature config.
@@ -120,7 +121,8 @@ def translate_keras_rs_feature_config(
120121
121122
Args:
122123
feature_config: The Keras RS feature config to translate.
123-
tables: A mapping of KerasRS table configs to TF TPU table configs.
124+
tables: A mapping of KerasRS table config ids to TF TPU table configs.
125+
num_replicas_in_sync: The number of replicas in sync from the strategy.
124126
125127
Returns:
126128
The TensorFlow TPU feature config.
@@ -131,10 +133,10 @@ def translate_keras_rs_feature_config(
131133
f"but got {num_replicas_in_sync}."
132134
)
133135

134-
table = tables.get(feature_config.table, None)
136+
table = tables.get(id(feature_config.table), None)
135137
if table is None:
136-
table = translate_keras_rs_table_config(feature_config.table)
137-
tables[feature_config.table] = table
138+
table = keras_to_tf_tpu_table_config(feature_config.table)
139+
tables[id(feature_config.table)] = table
138140

139141
if len(feature_config.output_shape) < 2:
140142
raise ValueError(
@@ -168,7 +170,7 @@ def translate_keras_rs_feature_config(
168170
)
169171

170172

171-
def translate_keras_rs_table_config(
173+
def keras_to_tf_tpu_table_config(
172174
table_config: TableConfig,
173175
) -> tf.tpu.experimental.embedding.TableConfig:
174176
initializer = table_config.initializer
@@ -179,13 +181,13 @@ def translate_keras_rs_table_config(
179181
vocabulary_size=table_config.vocabulary_size,
180182
dim=table_config.embedding_dim,
181183
initializer=initializer,
182-
optimizer=translate_optimizer(table_config.optimizer),
184+
optimizer=to_tf_tpu_optimizer(table_config.optimizer),
183185
combiner=table_config.combiner,
184186
name=table_config.name,
185187
)
186188

187189

188-
def translate_keras_optimizer(
190+
def keras_to_tf_tpu_optimizer(
189191
optimizer: keras.optimizers.Optimizer,
190192
) -> TfTpuOptimizer:
191193
"""Translates a Keras optimizer to a TensorFlow TPU `_Optimizer`.
@@ -238,7 +240,12 @@ def translate_keras_optimizer(
238240
"Unsupported optimizer option `Optimizer.loss_scale_factor`."
239241
)
240242

241-
optimizer_mapping = OPTIMIZER_MAPPINGS.get(type(optimizer), None)
243+
optimizer_mapping = None
244+
for optimizer_class, mapping in OPTIMIZER_MAPPINGS.items():
245+
# Handle subclasses of the main optimizer class.
246+
if isinstance(optimizer, optimizer_class):
247+
optimizer_mapping = mapping
248+
break
242249
if optimizer_mapping is None:
243250
raise ValueError(
244251
f"Unsupported optimizer type {type(optimizer)}. Optimizer must be "
@@ -258,7 +265,7 @@ def translate_keras_optimizer(
258265
return optimizer_mapping.tpu_optimizer_class(**tpu_optimizer_kwargs)
259266

260267

261-
def translate_optimizer(
268+
def to_tf_tpu_optimizer(
262269
optimizer: str | keras.optimizers.Optimizer | TfTpuOptimizer | None,
263270
) -> TfTpuOptimizer:
264271
"""Translates a Keras optimizer into a TensorFlow TPU `_Optimizer`.
@@ -299,7 +306,7 @@ def translate_optimizer(
299306
"'sgd', 'adagrad', 'adam', or 'ftrl'"
300307
)
301308
elif isinstance(optimizer, keras.optimizers.Optimizer):
302-
return translate_keras_optimizer(optimizer)
309+
return keras_to_tf_tpu_optimizer(optimizer)
303310
else:
304311
raise ValueError(
305312
f"Unknown optimizer type {type(optimizer)}. Please pass an "
@@ -312,7 +319,7 @@ def translate_optimizer(
312319
# TensorFlow to TensorFlow
313320

314321

315-
def clone_tf_feature_configs(
322+
def clone_tf_tpu_feature_configs(
316323
feature_configs: types.Nested[tf.tpu.experimental.embedding.FeatureConfig],
317324
) -> types.Nested[tf.tpu.experimental.embedding.FeatureConfig]:
318325
"""Clones and resolves TensorFlow TPU feature configs.
@@ -327,7 +334,7 @@ def clone_tf_feature_configs(
327334
"""
328335
table_configs_dict = {}
329336

330-
def clone_and_resolve_tf_feature_config(
337+
def clone_and_resolve_tf_tpu_feature_config(
331338
fc: tf.tpu.experimental.embedding.FeatureConfig,
332339
) -> tf.tpu.experimental.embedding.FeatureConfig:
333340
if fc.table not in table_configs_dict:
@@ -336,7 +343,7 @@ def clone_and_resolve_tf_feature_config(
336343
vocabulary_size=fc.table.vocabulary_size,
337344
dim=fc.table.dim,
338345
initializer=fc.table.initializer,
339-
optimizer=translate_optimizer(fc.table.optimizer),
346+
optimizer=to_tf_tpu_optimizer(fc.table.optimizer),
340347
combiner=fc.table.combiner,
341348
name=fc.table.name,
342349
quantization_config=fc.table.quantization_config,
@@ -352,5 +359,5 @@ def clone_and_resolve_tf_feature_config(
352359
)
353360

354361
return keras.tree.map_structure(
355-
clone_and_resolve_tf_feature_config, feature_configs
362+
clone_and_resolve_tf_tpu_feature_config, feature_configs
356363
)

0 commit comments

Comments
 (0)