7
7
import keras
8
8
import numpy as np
9
9
from jax import numpy as jnp
10
+ from jax_tpu_embedding .sparsecore .lib .nn import embedding
10
11
from jax_tpu_embedding .sparsecore .lib .nn import embedding_spec
11
12
from jax_tpu_embedding .sparsecore .lib .nn .embedding_spec import FeatureSpec
12
13
from jax_tpu_embedding .sparsecore .lib .nn .embedding_spec import TableSpec
@@ -142,7 +143,7 @@ def create_tables(
142
143
def create_table_and_slot_variables (
143
144
table_specs : Nested [TableSpec ],
144
145
keys : Nested [ArrayLike ] | None = None ,
145
- ) -> Nested [ArrayLike ]:
146
+ ) -> Nested [embedding . EmbeddingVariables ]:
146
147
"""Creates and initializes embedding tables and slot variables.
147
148
148
149
Args:
@@ -164,7 +165,7 @@ def create_table_and_slot_variables(
164
165
def _create_table_and_slot_variables (
165
166
table_spec : TableSpec ,
166
167
key : ArrayLike ,
167
- ) -> tuple [ jax . Array , tuple [ jax . Array , ...]] :
168
+ ) -> embedding . EmbeddingVariables :
168
169
slot_initializers = table_spec .optimizer .slot_variables_initializers ()
169
170
num_slot_variables = len (keras .tree .flatten (slot_initializers ))
170
171
slot_keys = jnp .unstack (jax .random .split (key , num_slot_variables ))
@@ -178,10 +179,10 @@ def _create_table_and_slot_variables(
178
179
slot_initializers ,
179
180
slot_keys ,
180
181
)
181
- return (table , slot_variables )
182
+ return embedding . EmbeddingVariables (table , slot_variables )
182
183
183
184
# Initialize tables.
184
- output : Nested [ArrayLike ] = jax .tree .map (
185
+ output : Nested [embedding . EmbeddingVariables ] = jax .tree .map (
185
186
_create_table_and_slot_variables ,
186
187
table_specs ,
187
188
keys ,
@@ -311,14 +312,14 @@ def _create_samples(
311
312
312
313
def stack_shard_and_put_tables (
313
314
table_specs : Nested [TableSpec ],
314
- tables : Nested [jax . Array ],
315
+ tables : Nested [embedding . EmbeddingVariables ],
315
316
num_shards : int ,
316
317
sharding : jax .sharding .Sharding ,
317
- ) -> dict [str , Nested [ jax . Array ] ]:
318
+ ) -> dict [str , embedding . EmbeddingVariables ]:
318
319
sharded_tables = embedding_utils .stack_and_shard_tables (
319
320
table_specs , tables , num_shards
320
321
)
321
- output : dict [str , Nested [ jax . Array ] ] = jax .device_put (
322
+ output : dict [str , embedding . EmbeddingVariables ] = jax .device_put (
322
323
jax .tree .map (
323
324
# Flatten shard dimension to allow auto-sharding to split the array.
324
325
lambda table : table .reshape ((- 1 , table .shape [- 1 ])),
@@ -469,27 +470,24 @@ def compute_expected_lookup_grad(
469
470
def _update_table_and_slot_variables (
470
471
table_spec : TableSpec ,
471
472
grad : jax .Array ,
472
- table_and_slot_variables : tuple [jax .Array , tuple [jax .Array , ...]],
473
- ) -> tuple [
474
- jax .Array ,
475
- embedding_spec .SGDSlotVariables | embedding_spec .AdagradSlotVariables ,
476
- ]:
473
+ table_and_slot_variables : embedding .EmbeddingVariables ,
474
+ ) -> embedding .EmbeddingVariables :
477
475
"""Updates a table and its slot variables based on the gradient."""
478
- table = table_and_slot_variables [ 0 ]
476
+ table = table_and_slot_variables . table
479
477
optimizer = table_spec .optimizer
480
478
481
479
# Adagrad, update and apply gradient accumulator.
482
480
if isinstance (optimizer , embedding_spec .AdagradOptimizerSpec ):
483
- accumulator = table_and_slot_variables [ 1 ][ 0 ]
481
+ accumulator = table_and_slot_variables . slot . accumulator
484
482
accumulator = accumulator + grad * grad
485
483
learning_rate = optimizer .get_learning_rate (0 ) / jnp .sqrt (accumulator )
486
- return (
484
+ return embedding . EmbeddingVariables (
487
485
table - learning_rate * grad ,
488
486
embedding_spec .AdagradSlotVariables (accumulator = accumulator ),
489
487
)
490
488
491
489
# SGD
492
- return (
490
+ return embedding . EmbeddingVariables (
493
491
table - optimizer .get_learning_rate (0 ) * grad ,
494
492
embedding_spec .SGDSlotVariables (),
495
493
)
@@ -500,8 +498,8 @@ def compute_expected_updates(
500
498
feature_samples : Nested [FeatureSamples ],
501
499
activation_gradients : Nested [jax .Array ],
502
500
table_specs : Nested [TableSpec ],
503
- table_and_slot_variables : Nested [jax . Array ],
504
- ) -> Nested [jax . Array ]:
501
+ table_and_slot_variables : Nested [embedding . EmbeddingVariables ],
502
+ ) -> Nested [embedding . EmbeddingVariables ]:
505
503
"""Computes the expected updates for a given embedding lookup.
506
504
507
505
Args:
@@ -522,7 +520,7 @@ def compute_expected_updates(
522
520
)
523
521
524
522
# Apply updates per table.
525
- output : Nested [jax . Array ] = jax .tree .map (
523
+ output : Nested [embedding . EmbeddingVariables ] = jax .tree .map (
526
524
_update_table_and_slot_variables ,
527
525
table_specs ,
528
526
table_grads ,
0 commit comments