Skip to content

Commit 7a83464

Browse files
Refactor input preprocessing to optimize non-minibatching path. Remove bucketetization overhead when minibatching is disabled or in pass 1.
PiperOrigin-RevId: 826579723
1 parent 00afdd2 commit 7a83464

File tree

1 file changed

+153
-27
lines changed

1 file changed

+153
-27
lines changed

jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h

Lines changed: 153 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,95 @@ struct LocalSparseCoreTensorGroupingContext {
187187
MatrixXi& kept_unique_ids_per_partition_per_bucket;
188188
};
189189

190-
inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
190+
inline void GroupAndDeduplicateCooTensorsForLocalSparseCoreNoBuckets(
191+
LocalSparseCoreTensorGroupingContext context) {
192+
// Unpack context for readability.
193+
const PreprocessSparseDenseMatmulInputOptions& options = context.options;
194+
const StackedTableMetadata& stacked_table_metadata =
195+
context.stacked_table_metadata;
196+
const std::vector<CooFormat>& coo_tensors = context.coo_tensors;
197+
PartitionedCooTensors& grouped_coo_tensors = context.grouped_coo_tensors;
198+
StatsPerDevice& stats = context.stats;
199+
MatrixXi& observed_ids = context.ids_per_sc_partition_per_bucket;
200+
MatrixXi& observed_unique_ids = context.unique_ids_per_partition_per_bucket;
201+
MatrixXi& kept_ids = context.kept_ids_per_sc_partition_per_bucket;
202+
MatrixXi& kept_unique_ids = context.kept_unique_ids_per_partition_per_bucket;
203+
204+
const bool allow_id_dropping = options.allow_id_dropping;
205+
const uint32_t global_sc_count = options.GetNumScs();
206+
const int max_ids_per_partition =
207+
stacked_table_metadata.max_ids_per_partition;
208+
const int max_unique_ids_per_partition =
209+
stacked_table_metadata.max_unique_ids_per_partition;
210+
uint32_t prev_col_id = std::numeric_limits<uint32_t>::max();
211+
uint32_t prev_row_id = std::numeric_limits<uint32_t>::max();
212+
bool dropping_current_unique_col_id = false;
213+
for (const uint64_t key : context.keys) {
214+
// Step 1: Unpack key to get tensor coordinates.
215+
const uint32_t index = key & CooFormat::kIndexMask;
216+
const CooFormat& coo_tensor = coo_tensors[index];
217+
const uint32_t col_id = coo_tensor.col_id;
218+
const uint32_t global_sc_id = coo_tensor.col_id & (global_sc_count - 1);
219+
const uint32_t row_id = coo_tensor.row_id;
220+
221+
// Step 2: Handle duplicates.
222+
// An ID that is a duplicate of a previously non-dropped ID is merged.
223+
// It does not count as a new ID for stats and does not go through dropping
224+
// logic.
225+
if (grouped_coo_tensors.MaybeMerge(/*bucket_id=*/0, coo_tensor)) {
226+
continue;
227+
}
228+
// If the ID is a duplicate of the last seen ID, it must have been dropped
229+
// (otherwise it would have been merged above), so drop this one too.
230+
if (col_id == prev_col_id && row_id == prev_row_id) {
231+
++stats.dropped_id_count;
232+
continue;
233+
}
234+
235+
// Step 3: Update observed statistics for the new ID.
236+
const bool is_new_col = col_id != prev_col_id;
237+
// Update observed stats. These are never decremented and are used for
238+
// reporting.
239+
observed_ids(global_sc_id, 0) += 1;
240+
if (is_new_col) {
241+
observed_unique_ids(global_sc_id, 0) += 1;
242+
dropping_current_unique_col_id =
243+
(kept_unique_ids(global_sc_id, 0) + 1) >
244+
max_unique_ids_per_partition;
245+
}
246+
247+
// Step 4: Determine if the ID should be dropped based on capacity limits.
248+
// We do NOT drop IDs when minibatching is enabled and we are in the
249+
// first pass (`create_buckets=false`), as we need to detect limit
250+
// overflows to decide if minibatching is required.
251+
const bool can_drop_id =
252+
!options.enable_minibatching;
253+
const bool exceeds_ids_limit =
254+
(kept_ids(global_sc_id, 0) + 1) > max_ids_per_partition;
255+
256+
// Step 5: Add ID to result or drop it.
257+
if (can_drop_id && allow_id_dropping &&
258+
(exceeds_ids_limit || dropping_current_unique_col_id)) {
259+
// Dropped id.
260+
++stats.dropped_id_count;
261+
} else {
262+
grouped_coo_tensors.Add(context.local_sc_id, /*bucket_id=*/0, coo_tensor);
263+
// Update kept counts.
264+
kept_ids(global_sc_id, 0) += 1;
265+
if (is_new_col) {
266+
kept_unique_ids(global_sc_id, 0) += 1;
267+
}
268+
}
269+
270+
// Step 6: Update state for next iteration.
271+
// This must be done regardless of whether the ID was dropped to ensure
272+
// correct stats collection for subsequent IDs.
273+
prev_col_id = col_id;
274+
prev_row_id = row_id;
275+
}
276+
}
277+
278+
inline void GroupAndDeduplicateCooTensorsForLocalSparseCoreWithBuckets(
191279
LocalSparseCoreTensorGroupingContext context) {
192280
// Unpack context for readability.
193281
const PreprocessSparseDenseMatmulInputOptions& options = context.options;
@@ -298,8 +386,8 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
298386
// NOTE: We use output buffers `max_ids_per_sc`, `max_unique_ids_per_sc`, and
299387
// `required_buffer_size_per_sc` because we fill values in a loop to a bigger
300388
// array.
301-
template <typename SplitType>
302-
PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
389+
template <bool kCreateBuckets, typename SplitType>
390+
PartitionedCooTensors SortAndGroupCooTensorsPerLocalDeviceImpl(
303391
const ExtractedCooTensors& extracted_coo_tensors,
304392
const StackedTableMetadata& stacked_table_metadata,
305393
const PreprocessSparseDenseMatmulInputOptions& options,
@@ -320,20 +408,18 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
320408
// This function can be called in two passes for minibatching. The logic for
321409
// stats collection and ID dropping depends on the pass.
322410
//
323-
// Pass 1: Check if minibatching is required (`create_buckets` is false).
411+
// Pass 1: Check if minibatching is required (`kCreateBuckets` is false).
324412
// - No IDs are dropped.
325413
// - Stats are collected on all observed IDs to compute splits.
326414
//
327-
// Pass 2: Create buckets (`create_buckets` is true).
415+
// Pass 2: Create buckets (`kCreateBuckets` is true).
328416
// - A dummy stats object is used (stats are not re-computed).
329417
// - IDs may be dropped if they exceed capacity.
330-
const bool create_buckets = options.enable_minibatching &&
331-
(std::is_same_v<SplitType, MinibatchingSplit>);
332418

333419
// Partition COO tensors among SparseCores for the local device (based on row
334420
// id).
335421
const int bucket_count =
336-
create_buckets ? CooFormat::kMaxMinibatchingBuckets : 1;
422+
kCreateBuckets ? CooFormat::kMaxMinibatchingBuckets : 1;
337423
PartitionedCooTensors grouped_coo_tensors(
338424
coo_tensors.size(), num_sc_per_device, global_sc_count, bucket_count);
339425

@@ -367,31 +453,51 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
367453
// local_embedding_id(32-num_scs bits), index(26 bits)].
368454
// Note that this assumes `num_scs` is a power of 2.
369455
keys.push_back(coo_tensors[coo_tensor_index].GetGroupingKey(
370-
num_sc_bits, coo_tensor_index, create_buckets,
456+
num_sc_bits, coo_tensor_index, kCreateBuckets,
371457
options.minibatching_bucketing_hash_fn));
372458
}
373459

374460
// The expected allocation size may be uninitialized.
375461
DCHECK(expected_keys_size == 0 || keys.size() == expected_keys_size);
376462
hwy::VQSort(keys.data(), keys.size(), hwy::SortAscending());
377463

378-
internal::GroupAndDeduplicateCooTensorsForLocalSparseCore({
379-
.keys = keys,
380-
.coo_tensors = coo_tensors,
381-
.stacked_table_metadata = stacked_table_metadata,
382-
.options = options,
383-
.create_buckets = create_buckets,
384-
.local_sc_id = local_sc_id,
385-
.grouped_coo_tensors = grouped_coo_tensors,
386-
.ids_per_sc_partition_per_bucket = ids_per_sc_partition_per_bucket,
387-
.unique_ids_per_partition_per_bucket =
388-
unique_ids_per_partition_per_bucket,
389-
.stats = stats,
390-
.kept_ids_per_sc_partition_per_bucket =
391-
kept_ids_per_sc_partition_per_bucket,
392-
.kept_unique_ids_per_partition_per_bucket =
393-
kept_unique_ids_per_partition_per_bucket,
394-
});
464+
if constexpr (kCreateBuckets) {
465+
internal::GroupAndDeduplicateCooTensorsForLocalSparseCoreWithBuckets({
466+
.keys = keys,
467+
.coo_tensors = coo_tensors,
468+
.stacked_table_metadata = stacked_table_metadata,
469+
.options = options,
470+
.create_buckets = kCreateBuckets,
471+
.local_sc_id = local_sc_id,
472+
.grouped_coo_tensors = grouped_coo_tensors,
473+
.ids_per_sc_partition_per_bucket = ids_per_sc_partition_per_bucket,
474+
.unique_ids_per_partition_per_bucket =
475+
unique_ids_per_partition_per_bucket,
476+
.stats = stats,
477+
.kept_ids_per_sc_partition_per_bucket =
478+
kept_ids_per_sc_partition_per_bucket,
479+
.kept_unique_ids_per_partition_per_bucket =
480+
kept_unique_ids_per_partition_per_bucket,
481+
});
482+
} else {
483+
internal::GroupAndDeduplicateCooTensorsForLocalSparseCoreNoBuckets({
484+
.keys = keys,
485+
.coo_tensors = coo_tensors,
486+
.stacked_table_metadata = stacked_table_metadata,
487+
.options = options,
488+
.create_buckets = kCreateBuckets,
489+
.local_sc_id = local_sc_id,
490+
.grouped_coo_tensors = grouped_coo_tensors,
491+
.ids_per_sc_partition_per_bucket = ids_per_sc_partition_per_bucket,
492+
.unique_ids_per_partition_per_bucket =
493+
unique_ids_per_partition_per_bucket,
494+
.stats = stats,
495+
.kept_ids_per_sc_partition_per_bucket =
496+
kept_ids_per_sc_partition_per_bucket,
497+
.kept_unique_ids_per_partition_per_bucket =
498+
kept_unique_ids_per_partition_per_bucket,
499+
});
500+
}
395501

396502
grouped_coo_tensors.FillRemainingScBuckets();
397503

@@ -427,7 +533,7 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
427533

428534
// Only validate if creating minibatching buckets or when minibatching is
429535
// disabled, not when checking if minibatching is required.
430-
if (!options.enable_minibatching || create_buckets)
536+
if (!options.enable_minibatching || kCreateBuckets)
431537
internal::ValidateMaxIdsOrDie(
432538
observed_max_ids_per_bucket, observed_max_unique_ids_per_bucket,
433539
max_ids_per_partition, max_unique_ids_per_partition,
@@ -437,6 +543,26 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
437543
return grouped_coo_tensors;
438544
}
439545

546+
template <typename SplitType>
547+
PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
548+
const ExtractedCooTensors& extracted_coo_tensors,
549+
const StackedTableMetadata& stacked_table_metadata,
550+
const PreprocessSparseDenseMatmulInputOptions& options,
551+
internal::StatsPerDevice& stats, SplitType& minibatching_split) {
552+
const bool create_buckets =
553+
options.enable_minibatching &&
554+
std::is_same_v<SplitType, MinibatchingSplit>;
555+
if (create_buckets) {
556+
return SortAndGroupCooTensorsPerLocalDeviceImpl<true>(
557+
extracted_coo_tensors, stacked_table_metadata, options, stats,
558+
minibatching_split);
559+
} else {
560+
return SortAndGroupCooTensorsPerLocalDeviceImpl<false>(
561+
extracted_coo_tensors, stacked_table_metadata, options, stats,
562+
minibatching_split);
563+
}
564+
}
565+
440566
} // namespace jax_sc_embedding
441567

442568
#endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_SORT_AND_GROUP_COO_TENSORS_IMPL_H_

0 commit comments

Comments
 (0)