Skip to content

Commit 5de40a7

Browse files
Refactor: Use embedding.EmbeddingVariables in Keras-RS JAX embedding tests. (#156)
This change updates the Keras-RS JAX embedding tests to use the `embedding.EmbeddingVariables` dataclass from `jax_tpu_embedding` for representing embedding tables and slot variables, instead of a custom tuple structure. This involves updating type hints, variable access, and build dependencies.
1 parent cee2286 commit 5de40a7

File tree

2 files changed

+21
-22
lines changed

2 files changed

+21
-22
lines changed

keras_rs/src/layers/embedding/jax/embedding_lookup_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def test_forward_pass(self, ragged: bool, stacked: bool):
193193

194194
# Add pseudo gradients to the inputs.
195195
embedding_variables = jax.tree.map(
196-
lambda table: (table, None),
196+
lambda table: embedding.EmbeddingVariables(table=table, slot=()),
197197
sharded_tables,
198198
)
199199

@@ -288,7 +288,7 @@ def test_model_sharding(
288288

289289
# Add pseudo gradients to the inputs.
290290
embedding_variables = jax.tree.map(
291-
lambda table: (table, None),
291+
lambda table: embedding.EmbeddingVariables(table=table, slot=()),
292292
sharded_tables,
293293
)
294294

@@ -479,7 +479,8 @@ def test_autograd(
479479
)
480480
)
481481
sharded_table_and_slot_variables = typing.cast(
482-
dict[str, tuple[jax.Array, ...]], sharded_table_and_slot_variables
482+
dict[str, embedding.EmbeddingVariables],
483+
sharded_table_and_slot_variables,
483484
)
484485

485486
# Shard samples for lookup query.

keras_rs/src/layers/embedding/jax/test_utils.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import keras
88
import numpy as np
99
from jax import numpy as jnp
10+
from jax_tpu_embedding.sparsecore.lib.nn import embedding
1011
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
1112
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import FeatureSpec
1213
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import TableSpec
@@ -142,7 +143,7 @@ def create_tables(
142143
def create_table_and_slot_variables(
143144
table_specs: Nested[TableSpec],
144145
keys: Nested[ArrayLike] | None = None,
145-
) -> Nested[ArrayLike]:
146+
) -> Nested[embedding.EmbeddingVariables]:
146147
"""Creates and initializes embedding tables and slot variables.
147148
148149
Args:
@@ -164,7 +165,7 @@ def create_table_and_slot_variables(
164165
def _create_table_and_slot_variables(
165166
table_spec: TableSpec,
166167
key: ArrayLike,
167-
) -> tuple[jax.Array, tuple[jax.Array, ...]]:
168+
) -> embedding.EmbeddingVariables:
168169
slot_initializers = table_spec.optimizer.slot_variables_initializers()
169170
num_slot_variables = len(keras.tree.flatten(slot_initializers))
170171
slot_keys = jnp.unstack(jax.random.split(key, num_slot_variables))
@@ -178,10 +179,10 @@ def _create_table_and_slot_variables(
178179
slot_initializers,
179180
slot_keys,
180181
)
181-
return (table, slot_variables)
182+
return embedding.EmbeddingVariables(table, slot_variables)
182183

183184
# Initialize tables.
184-
output: Nested[ArrayLike] = jax.tree.map(
185+
output: Nested[embedding.EmbeddingVariables] = jax.tree.map(
185186
_create_table_and_slot_variables,
186187
table_specs,
187188
keys,
@@ -311,14 +312,14 @@ def _create_samples(
311312

312313
def stack_shard_and_put_tables(
313314
table_specs: Nested[TableSpec],
314-
tables: Nested[jax.Array],
315+
tables: Nested[embedding.EmbeddingVariables],
315316
num_shards: int,
316317
sharding: jax.sharding.Sharding,
317-
) -> dict[str, Nested[jax.Array]]:
318+
) -> dict[str, embedding.EmbeddingVariables]:
318319
sharded_tables = embedding_utils.stack_and_shard_tables(
319320
table_specs, tables, num_shards
320321
)
321-
output: dict[str, Nested[jax.Array]] = jax.device_put(
322+
output: dict[str, embedding.EmbeddingVariables] = jax.device_put(
322323
jax.tree.map(
323324
# Flatten shard dimension to allow auto-sharding to split the array.
324325
lambda table: table.reshape((-1, table.shape[-1])),
@@ -469,27 +470,24 @@ def compute_expected_lookup_grad(
469470
def _update_table_and_slot_variables(
470471
table_spec: TableSpec,
471472
grad: jax.Array,
472-
table_and_slot_variables: tuple[jax.Array, tuple[jax.Array, ...]],
473-
) -> tuple[
474-
jax.Array,
475-
embedding_spec.SGDSlotVariables | embedding_spec.AdagradSlotVariables,
476-
]:
473+
table_and_slot_variables: embedding.EmbeddingVariables,
474+
) -> embedding.EmbeddingVariables:
477475
"""Updates a table and its slot variables based on the gradient."""
478-
table = table_and_slot_variables[0]
476+
table = table_and_slot_variables.table
479477
optimizer = table_spec.optimizer
480478

481479
# Adagrad, update and apply gradient accumulator.
482480
if isinstance(optimizer, embedding_spec.AdagradOptimizerSpec):
483-
accumulator = table_and_slot_variables[1][0]
481+
accumulator = table_and_slot_variables.slot.accumulator
484482
accumulator = accumulator + grad * grad
485483
learning_rate = optimizer.get_learning_rate(0) / jnp.sqrt(accumulator)
486-
return (
484+
return embedding.EmbeddingVariables(
487485
table - learning_rate * grad,
488486
embedding_spec.AdagradSlotVariables(accumulator=accumulator),
489487
)
490488

491489
# SGD
492-
return (
490+
return embedding.EmbeddingVariables(
493491
table - optimizer.get_learning_rate(0) * grad,
494492
embedding_spec.SGDSlotVariables(),
495493
)
@@ -500,8 +498,8 @@ def compute_expected_updates(
500498
feature_samples: Nested[FeatureSamples],
501499
activation_gradients: Nested[jax.Array],
502500
table_specs: Nested[TableSpec],
503-
table_and_slot_variables: Nested[jax.Array],
504-
) -> Nested[jax.Array]:
501+
table_and_slot_variables: Nested[embedding.EmbeddingVariables],
502+
) -> Nested[embedding.EmbeddingVariables]:
505503
"""Computes the expected updates for a given embedding lookup.
506504
507505
Args:
@@ -522,7 +520,7 @@ def compute_expected_updates(
522520
)
523521

524522
# Apply updates per table.
525-
output: Nested[jax.Array] = jax.tree.map(
523+
output: Nested[embedding.EmbeddingVariables] = jax.tree.map(
526524
_update_table_and_slot_variables,
527525
table_specs,
528526
table_grads,

0 commit comments

Comments
 (0)