|
| 1 | +# Copyright 2024 RecML authors <[email protected]>. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""Embedding lookup ops.""" |
| 15 | + |
| 16 | +from collections.abc import Mapping, Sequence |
| 17 | +import dataclasses |
| 18 | +import functools |
| 19 | +from typing import Any, TypeVar |
| 20 | + |
| 21 | +from etils import epy |
| 22 | +import jax |
| 23 | +from jax.experimental import shard_map |
| 24 | + |
| 25 | +with epy.lazy_imports(): |
| 26 | + # pylint: disable=g-import-not-at-top |
| 27 | + from jax_tpu_embedding.sparsecore.lib.nn import embedding |
| 28 | + # pylint: enable=g-import-not-at-top |
| 29 | + |
| 30 | + |
| 31 | +T = TypeVar("T") |
| 32 | +Nested = T | Sequence[T] | Mapping[str, T] |
| 33 | +FeatureSpec = Any |
| 34 | + |
| 35 | + |
| 36 | +@dataclasses.dataclass |
| 37 | +class SparsecoreParams: |
| 38 | + """Embedding parameters.""" |
| 39 | + |
| 40 | + feature_specs: Nested[FeatureSpec] |
| 41 | + abstract_mesh: jax.sharding.AbstractMesh |
| 42 | + data_axes: Sequence[str | None] |
| 43 | + embedding_axes: Sequence[str | None] |
| 44 | + sharding_strategy: str |
| 45 | + |
| 46 | + |
| 47 | +@functools.partial(jax.custom_vjp, nondiff_argnums=(0,)) |
| 48 | +def sparsecore_lookup( |
| 49 | + sparsecore_params: SparsecoreParams, |
| 50 | + tables: Mapping[str, tuple[jax.Array, ...]], |
| 51 | + csr_inputs: tuple[jax.Array, ...], |
| 52 | +): |
| 53 | + return shard_map.shard_map( |
| 54 | + functools.partial( |
| 55 | + embedding.tpu_sparse_dense_matmul, |
| 56 | + global_device_count=sparsecore_params.abstract_mesh.size, |
| 57 | + feature_specs=sparsecore_params.feature_specs, |
| 58 | + sharding_strategy=sparsecore_params.sharding_strategy, |
| 59 | + ), |
| 60 | + mesh=sparsecore_params.abstract_mesh, |
| 61 | + in_specs=( |
| 62 | + jax.sharding.PartitionSpec(*sparsecore_params.data_axes), |
| 63 | + jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes), |
| 64 | + ), |
| 65 | + out_specs=jax.sharding.PartitionSpec(*sparsecore_params.data_axes), |
| 66 | + check_rep=False, |
| 67 | + )(csr_inputs, tables) |
| 68 | + |
| 69 | + |
| 70 | +def _emb_lookup_fwd( |
| 71 | + sparsecore_params: SparsecoreParams, |
| 72 | + tables: Mapping[str, tuple[jax.Array, ...]], |
| 73 | + csr_inputs: tuple[jax.Array, ...], |
| 74 | +): |
| 75 | + out = sparsecore_lookup(sparsecore_params, tables, csr_inputs) |
| 76 | + return out, (tables, csr_inputs) |
| 77 | + |
| 78 | + |
| 79 | +def _emb_lookup_bwd( |
| 80 | + sparsecore_params: SparsecoreParams, |
| 81 | + res: tuple[Mapping[str, tuple[jax.Array, ...]], tuple[jax.Array, ...]], |
| 82 | + gradients: Nested[jax.Array], |
| 83 | +) -> tuple[Nested[jax.Array], None]: |
| 84 | + """Backward pass for embedding lookup.""" |
| 85 | + (tables, csr_inputs) = res |
| 86 | + |
| 87 | + emb_table_grads = shard_map.shard_map( |
| 88 | + functools.partial( |
| 89 | + embedding.tpu_sparse_dense_matmul_grad, |
| 90 | + feature_specs=sparsecore_params.feature_specs, |
| 91 | + sharding_strategy=sparsecore_params.sharding_strategy, |
| 92 | + ), |
| 93 | + mesh=sparsecore_params.abstract_mesh, |
| 94 | + in_specs=( |
| 95 | + jax.sharding.PartitionSpec(*sparsecore_params.data_axes), |
| 96 | + jax.sharding.PartitionSpec(*sparsecore_params.data_axes), |
| 97 | + jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes), |
| 98 | + ), |
| 99 | + out_specs=jax.sharding.PartitionSpec(*sparsecore_params.data_axes), |
| 100 | + check_rep=False, |
| 101 | + )(gradients, csr_inputs, tables) |
| 102 | + |
| 103 | + # `tpu_sparse_dense_matmul_grad` returns a general mapping (usually a dict). |
| 104 | + # It may not be the same type as the embedding table (e.g. FrozenDict). |
| 105 | + # Here we use flatten / unflatten to ensure the types are the same. |
| 106 | + emb_table_grads = jax.tree.unflatten( |
| 107 | + jax.tree.structure(tables), jax.tree.leaves(emb_table_grads) |
| 108 | + ) |
| 109 | + |
| 110 | + return emb_table_grads, None |
| 111 | + |
| 112 | + |
| 113 | +sparsecore_lookup.defvjp(_emb_lookup_fwd, _emb_lookup_bwd) |
0 commit comments