Skip to content

Commit c6acbb3

Browse files
[JAX SC] Raise error on updating FDO params after stacking.
PiperOrigin-RevId: 826192844
1 parent 3fca158 commit c6acbb3

File tree

3 files changed

+54
-10
lines changed

3 files changed

+54
-10
lines changed

jax_tpu_embedding/sparsecore/lib/core/pybind_input_preprocessing_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2923,10 +2923,12 @@ def test_combiner(
29232923
)
29242924

29252925
# Compute expected by re-adjusting the weights using a "sum" combiner.
2926-
table_spec.combiner = "sum"
2927-
table_spec.stacked_table_spec = None
2926+
table_spec_for_expectation = dataclasses.replace(table_spec, combiner="sum")
2927+
feature_spec_for_expectation = dataclasses.replace(
2928+
feature_spec, table_spec=table_spec_for_expectation
2929+
)
29282930
embedding.prepare_feature_specs_for_training(
2929-
feature_spec,
2931+
feature_spec_for_expectation,
29302932
global_device_count=1,
29312933
num_sc_per_device=4,
29322934
)
@@ -2941,7 +2943,7 @@ def test_combiner(
29412943
) = pybind_input_preprocessing.PreprocessSparseDenseMatmulInput(
29422944
[input_features],
29432945
[input_weights],
2944-
[feature_spec],
2946+
[feature_spec_for_expectation],
29452947
local_device_count=1,
29462948
global_device_count=1,
29472949
num_sc_per_device=4,

jax_tpu_embedding/sparsecore/lib/core/sparse_coo_input_preprocessing_test.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ def test_sparse_tensor_input(self, has_leading_dimension):
101101
indices_tensor = [sparse_tensor.indices]
102102
values_tensor = [sparse_tensor.values]
103103
dense_shape_tensor = [sparse_tensor.dense_shape]
104-
self.feature_spec.table_spec.suggested_coo_buffer_size_per_device = 64
105104
batch_number = 42
106105
(
107106
row_pointers_sparse,
@@ -148,7 +147,6 @@ def test_sparse_tensor_input(self, has_leading_dimension):
148147
dtype=np.int32,
149148
)
150149

151-
self.feature_spec.table_spec.suggested_coo_buffer_size_per_device = 64
152150
batch_number = 42
153151
local_device_count = 4
154152
global_device_count = 4
@@ -246,7 +244,6 @@ def test_sparse_tensor_input_with_empty_rows(self, has_leading_dimension):
246244
numpy_features = np.array(numpy_features, dtype=object)
247245
numpy_weights = np.array(numpy_weights, dtype=object)
248246

249-
self.feature_spec.table_spec.suggested_coo_buffer_size_per_device = 64
250247
batch_number = 42
251248
local_device_count = 4
252249
global_device_count = 4
@@ -1374,7 +1371,6 @@ def test_sparse_tensor_input(self, has_leading_dimension):
13741371
indices_tensor = [sparse_tensor.indices]
13751372
values_tensor = [sparse_tensor.values]
13761373
dense_shape_tensor = [sparse_tensor.dense_shape]
1377-
self.feature_spec.table_spec.suggested_coo_buffer_size_per_device = 64
13781374
batch_number = 42
13791375
(
13801376
row_pointers_sparse,
@@ -1420,7 +1416,6 @@ def test_sparse_tensor_input(self, has_leading_dimension):
14201416
],
14211417
dtype=np.int32,
14221418
)
1423-
self.feature_spec.table_spec.suggested_coo_buffer_size_per_device = 64
14241419
batch_number = 42
14251420
local_device_count = 4
14261421
global_device_count = 4

jax_tpu_embedding/sparsecore/lib/nn/embedding_spec.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import collections
2020
import dataclasses
2121
import inspect
22-
from typing import Callable, Sequence, TypeAlias
22+
from typing import Any, Callable, Sequence, TypeAlias
2323

2424
from flax import struct
2525
import jax
@@ -673,6 +673,52 @@ class TableSpec:
673673
"""Quantization config (min, max, num_buckets) which represent the float
674674
range and number of discrete integer buckets to use for quantization."""
675675

676+
_initialized: bool = dataclasses.field(
677+
init=False, default=False, compare=False
678+
)
679+
680+
def __setattr__(self, name: str, value: Any):
681+
# These attributes are allowed to be modified after stacking because
682+
# they are part of the stacking process itself (which sets
683+
# `stacked_table_spec` and `setting_in_stack`), or are used to update
684+
# stacking-related parameters like buffer sizes via replacing the
685+
# StackedTableSpec.
686+
allowed_after_stacking = (
687+
"_stacked_table_spec",
688+
"_setting_in_stack",
689+
"stacked_table_spec",
690+
"setting_in_stack",
691+
)
692+
fdo_params = (
693+
"max_ids_per_partition",
694+
"max_unique_ids_per_partition",
695+
"suggested_coo_buffer_size_per_device",
696+
)
697+
# Check if __post_init__ has run and if this is a stacked table
698+
if (
699+
name not in allowed_after_stacking
700+
and self._initialized
701+
and self.is_stacked()
702+
):
703+
if name in fdo_params:
704+
raise AttributeError(
705+
f"Cannot update parameter '{name}' on TableSpec '{self.name}' "
706+
"after it has been stacked. If you are trying to update FDO "
707+
"parameters like 'max_ids_per_partition', use "
708+
"embedding.update_preprocessing_parameters() to update the "
709+
"StackedTableSpec instead."
710+
)
711+
else:
712+
raise AttributeError(
713+
f"Cannot update parameter '{name}' on TableSpec '{self.name}' "
714+
"after it has been stacked. If you need to modify non-FDO "
715+
"parameters for a stacked table, create a new table spec using "
716+
"`dataclasses.replace(table_spec, ...)` and feature spec using "
717+
"`dataclasses.replace(feature_spec, table_spec=new_table_spec)` "
718+
"and then re-prepare feature specs for training."
719+
)
720+
super().__setattr__(name, value)
721+
676722
# This points to the stacked table spec which this table belongs to.
677723
# If this is None, this table is the top-most table.
678724
_stacked_table_spec: StackedTableSpec | None = dataclasses.field(
@@ -721,6 +767,7 @@ def __post_init__(
721767
row_offset_in_shard=0,
722768
shard_rotation=0,
723769
)
770+
object.__setattr__(self, "_initialized", True)
724771

725772

726773
@dataclasses.dataclass(eq=True, unsafe_hash=True, kw_only=True)

0 commit comments

Comments
 (0)