Skip to content

Commit 21d9214

Browse files
Refactor embedding variable handling in JAX test utilities
1 parent e002b4b commit 21d9214

File tree

2 files changed

+23
-25
lines changed

2 files changed

+23
-25
lines changed

keras_rs/src/layers/embedding/jax/embedding_lookup_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def test_forward_pass(self, ragged: bool, stacked: bool):
193193

194194
# Add pseudo gradients to the inputs.
195195
embedding_variables = jax.tree.map(
196-
lambda table: (table, None),
196+
lambda table: embedding.EmbeddingVariables(table=table, slot=()),
197197
sharded_tables,
198198
)
199199

@@ -288,7 +288,7 @@ def test_model_sharding(
288288

289289
# Add pseudo gradients to the inputs.
290290
embedding_variables = jax.tree.map(
291-
lambda table: (table, None),
291+
lambda table: embedding.EmbeddingVariables(table=table, slot=()),
292292
sharded_tables,
293293
)
294294

@@ -479,7 +479,8 @@ def test_autograd(
479479
)
480480
)
481481
sharded_table_and_slot_variables = typing.cast(
482-
dict[str, tuple[jax.Array, ...]], sharded_table_and_slot_variables
482+
dict[str, embedding.EmbeddingVariables],
483+
sharded_table_and_slot_variables,
483484
)
484485

485486
# Shard samples for lookup query.

keras_rs/src/layers/embedding/jax/test_utils.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@
44
from typing import Any, Mapping, TypeAlias, Union
55

66
import jax
7-
import keras
8-
import numpy as np
97
from jax import numpy as jnp
8+
from jax_tpu_embedding.sparsecore.lib.nn import embedding
109
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
1110
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import FeatureSpec
1211
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import TableSpec
13-
12+
import keras
1413
from keras_rs.src.layers.embedding.jax import embedding_utils
1514
from keras_rs.src.layers.embedding.jax.embedding_utils import FeatureSamples
1615
from keras_rs.src.types import Nested
16+
import numpy as np
1717

1818
ArrayLike: TypeAlias = Union[jax.Array, np.ndarray[Any, Any]]
1919
Shape: TypeAlias = tuple[int, ...]
@@ -142,7 +142,7 @@ def create_tables(
142142
def create_table_and_slot_variables(
143143
table_specs: Nested[TableSpec],
144144
keys: Nested[ArrayLike] | None = None,
145-
) -> Nested[ArrayLike]:
145+
) -> Nested[embedding.EmbeddingVariables]:
146146
"""Creates and initializes embedding tables and slot variables.
147147
148148
Args:
@@ -164,7 +164,7 @@ def create_table_and_slot_variables(
164164
def _create_table_and_slot_variables(
165165
table_spec: TableSpec,
166166
key: ArrayLike,
167-
) -> tuple[jax.Array, tuple[jax.Array, ...]]:
167+
) -> embedding.EmbeddingVariables:
168168
slot_initializers = table_spec.optimizer.slot_variables_initializers()
169169
num_slot_variables = len(keras.tree.flatten(slot_initializers))
170170
slot_keys = jnp.unstack(jax.random.split(key, num_slot_variables))
@@ -178,10 +178,10 @@ def _create_table_and_slot_variables(
178178
slot_initializers,
179179
slot_keys,
180180
)
181-
return (table, slot_variables)
181+
return embedding.EmbeddingVariables(table, slot_variables)
182182

183183
# Initialize tables.
184-
output: Nested[ArrayLike] = jax.tree.map(
184+
output: Nested[embedding.EmbeddingVariables] = jax.tree.map(
185185
_create_table_and_slot_variables,
186186
table_specs,
187187
keys,
@@ -311,14 +311,14 @@ def _create_samples(
311311

312312
def stack_shard_and_put_tables(
313313
table_specs: Nested[TableSpec],
314-
tables: Nested[jax.Array],
314+
tables: Nested[Any],
315315
num_shards: int,
316316
sharding: jax.sharding.Sharding,
317-
) -> dict[str, Nested[jax.Array]]:
317+
) -> dict[str, Any]:
318318
sharded_tables = embedding_utils.stack_and_shard_tables(
319319
table_specs, tables, num_shards
320320
)
321-
output: dict[str, Nested[jax.Array]] = jax.device_put(
321+
output: dict[str, Any] = jax.device_put(
322322
jax.tree.map(
323323
# Flatten shard dimension to allow auto-sharding to split the array.
324324
lambda table: table.reshape((-1, table.shape[-1])),
@@ -469,27 +469,24 @@ def compute_expected_lookup_grad(
469469
def _update_table_and_slot_variables(
470470
table_spec: TableSpec,
471471
grad: jax.Array,
472-
table_and_slot_variables: tuple[jax.Array, tuple[jax.Array, ...]],
473-
) -> tuple[
474-
jax.Array,
475-
embedding_spec.SGDSlotVariables | embedding_spec.AdagradSlotVariables,
476-
]:
472+
table_and_slot_variables: embedding.EmbeddingVariables,
473+
) -> embedding.EmbeddingVariables:
477474
"""Updates a table and its slot variables based on the gradient."""
478-
table = table_and_slot_variables[0]
475+
table = table_and_slot_variables.table
479476
optimizer = table_spec.optimizer
480477

481478
# Adagrad, update and apply gradient accumulator.
482479
if isinstance(optimizer, embedding_spec.AdagradOptimizerSpec):
483-
accumulator = table_and_slot_variables[1][0]
480+
accumulator = table_and_slot_variables.slot.accumulator
484481
accumulator = accumulator + grad * grad
485482
learning_rate = optimizer.get_learning_rate(0) / jnp.sqrt(accumulator)
486-
return (
483+
return embedding.EmbeddingVariables(
487484
table - learning_rate * grad,
488485
embedding_spec.AdagradSlotVariables(accumulator=accumulator),
489486
)
490487

491488
# SGD
492-
return (
489+
return embedding.EmbeddingVariables(
493490
table - optimizer.get_learning_rate(0) * grad,
494491
embedding_spec.SGDSlotVariables(),
495492
)
@@ -500,8 +497,8 @@ def compute_expected_updates(
500497
feature_samples: Nested[FeatureSamples],
501498
activation_gradients: Nested[jax.Array],
502499
table_specs: Nested[TableSpec],
503-
table_and_slot_variables: Nested[jax.Array],
504-
) -> Nested[jax.Array]:
500+
table_and_slot_variables: Nested[embedding.EmbeddingVariables],
501+
) -> Nested[embedding.EmbeddingVariables]:
505502
"""Computes the expected updates for a given embedding lookup.
506503
507504
Args:
@@ -522,7 +519,7 @@ def compute_expected_updates(
522519
)
523520

524521
# Apply updates per table.
525-
output: Nested[jax.Array] = jax.tree.map(
522+
output: Nested[embedding.EmbeddingVariables] = jax.tree.map(
526523
_update_table_and_slot_variables,
527524
table_specs,
528525
table_grads,

0 commit comments

Comments
 (0)