4
4
from typing import Any , Mapping , TypeAlias , Union
5
5
6
6
import jax
7
- import keras
8
- import numpy as np
9
7
from jax import numpy as jnp
8
+ from jax_tpu_embedding .sparsecore .lib .nn import embedding
10
9
from jax_tpu_embedding .sparsecore .lib .nn import embedding_spec
11
10
from jax_tpu_embedding .sparsecore .lib .nn .embedding_spec import FeatureSpec
12
11
from jax_tpu_embedding .sparsecore .lib .nn .embedding_spec import TableSpec
13
-
12
+ import keras
14
13
from keras_rs .src .layers .embedding .jax import embedding_utils
15
14
from keras_rs .src .layers .embedding .jax .embedding_utils import FeatureSamples
16
15
from keras_rs .src .types import Nested
16
+ import numpy as np
17
17
18
18
ArrayLike : TypeAlias = Union [jax .Array , np .ndarray [Any , Any ]]
19
19
Shape : TypeAlias = tuple [int , ...]
@@ -142,7 +142,7 @@ def create_tables(
142
142
def create_table_and_slot_variables (
143
143
table_specs : Nested [TableSpec ],
144
144
keys : Nested [ArrayLike ] | None = None ,
145
- ) -> Nested [ArrayLike ]:
145
+ ) -> Nested [embedding . EmbeddingVariables ]:
146
146
"""Creates and initializes embedding tables and slot variables.
147
147
148
148
Args:
@@ -164,7 +164,7 @@ def create_table_and_slot_variables(
164
164
def _create_table_and_slot_variables (
165
165
table_spec : TableSpec ,
166
166
key : ArrayLike ,
167
- ) -> tuple [ jax . Array , tuple [ jax . Array , ...]] :
167
+ ) -> embedding . EmbeddingVariables :
168
168
slot_initializers = table_spec .optimizer .slot_variables_initializers ()
169
169
num_slot_variables = len (keras .tree .flatten (slot_initializers ))
170
170
slot_keys = jnp .unstack (jax .random .split (key , num_slot_variables ))
@@ -178,10 +178,10 @@ def _create_table_and_slot_variables(
178
178
slot_initializers ,
179
179
slot_keys ,
180
180
)
181
- return (table , slot_variables )
181
+ return embedding . EmbeddingVariables (table , slot_variables )
182
182
183
183
# Initialize tables.
184
- output : Nested [ArrayLike ] = jax .tree .map (
184
+ output : Nested [embedding . EmbeddingVariables ] = jax .tree .map (
185
185
_create_table_and_slot_variables ,
186
186
table_specs ,
187
187
keys ,
@@ -311,14 +311,14 @@ def _create_samples(
311
311
312
312
def stack_shard_and_put_tables (
313
313
table_specs : Nested [TableSpec ],
314
- tables : Nested [jax . Array ],
314
+ tables : Nested [Any ],
315
315
num_shards : int ,
316
316
sharding : jax .sharding .Sharding ,
317
- ) -> dict [str , Nested [ jax . Array ] ]:
317
+ ) -> dict [str , Any ]:
318
318
sharded_tables = embedding_utils .stack_and_shard_tables (
319
319
table_specs , tables , num_shards
320
320
)
321
- output : dict [str , Nested [ jax . Array ] ] = jax .device_put (
321
+ output : dict [str , Any ] = jax .device_put (
322
322
jax .tree .map (
323
323
# Flatten shard dimension to allow auto-sharding to split the array.
324
324
lambda table : table .reshape ((- 1 , table .shape [- 1 ])),
@@ -469,27 +469,24 @@ def compute_expected_lookup_grad(
469
469
def _update_table_and_slot_variables (
470
470
table_spec : TableSpec ,
471
471
grad : jax .Array ,
472
- table_and_slot_variables : tuple [jax .Array , tuple [jax .Array , ...]],
473
- ) -> tuple [
474
- jax .Array ,
475
- embedding_spec .SGDSlotVariables | embedding_spec .AdagradSlotVariables ,
476
- ]:
472
+ table_and_slot_variables : embedding .EmbeddingVariables ,
473
+ ) -> embedding .EmbeddingVariables :
477
474
"""Updates a table and its slot variables based on the gradient."""
478
- table = table_and_slot_variables [ 0 ]
475
+ table = table_and_slot_variables . table
479
476
optimizer = table_spec .optimizer
480
477
481
478
# Adagrad, update and apply gradient accumulator.
482
479
if isinstance (optimizer , embedding_spec .AdagradOptimizerSpec ):
483
- accumulator = table_and_slot_variables [ 1 ][ 0 ]
480
+ accumulator = table_and_slot_variables . slot . accumulator
484
481
accumulator = accumulator + grad * grad
485
482
learning_rate = optimizer .get_learning_rate (0 ) / jnp .sqrt (accumulator )
486
- return (
483
+ return embedding . EmbeddingVariables (
487
484
table - learning_rate * grad ,
488
485
embedding_spec .AdagradSlotVariables (accumulator = accumulator ),
489
486
)
490
487
491
488
# SGD
492
- return (
489
+ return embedding . EmbeddingVariables (
493
490
table - optimizer .get_learning_rate (0 ) * grad ,
494
491
embedding_spec .SGDSlotVariables (),
495
492
)
@@ -500,8 +497,8 @@ def compute_expected_updates(
500
497
feature_samples : Nested [FeatureSamples ],
501
498
activation_gradients : Nested [jax .Array ],
502
499
table_specs : Nested [TableSpec ],
503
- table_and_slot_variables : Nested [jax . Array ],
504
- ) -> Nested [jax . Array ]:
500
+ table_and_slot_variables : Nested [embedding . EmbeddingVariables ],
501
+ ) -> Nested [embedding . EmbeddingVariables ]:
505
502
"""Computes the expected updates for a given embedding lookup.
506
503
507
504
Args:
@@ -522,7 +519,7 @@ def compute_expected_updates(
522
519
)
523
520
524
521
# Apply updates per table.
525
- output : Nested [jax . Array ] = jax .tree .map (
522
+ output : Nested [embedding . EmbeddingVariables ] = jax .tree .map (
526
523
_update_table_and_slot_variables ,
527
524
table_specs ,
528
525
table_grads ,
0 commit comments