@@ -187,7 +187,95 @@ struct LocalSparseCoreTensorGroupingContext {
187187 MatrixXi& kept_unique_ids_per_partition_per_bucket;
188188};
189189
190- inline void GroupAndDeduplicateCooTensorsForLocalSparseCore (
190+ inline void GroupAndDeduplicateCooTensorsForLocalSparseCoreNoBuckets (
191+ LocalSparseCoreTensorGroupingContext context) {
192+ // Unpack context for readability.
193+ const PreprocessSparseDenseMatmulInputOptions& options = context.options ;
194+ const StackedTableMetadata& stacked_table_metadata =
195+ context.stacked_table_metadata ;
196+ const std::vector<CooFormat>& coo_tensors = context.coo_tensors ;
197+ PartitionedCooTensors& grouped_coo_tensors = context.grouped_coo_tensors ;
198+ StatsPerDevice& stats = context.stats ;
199+ MatrixXi& observed_ids = context.ids_per_sc_partition_per_bucket ;
200+ MatrixXi& observed_unique_ids = context.unique_ids_per_partition_per_bucket ;
201+ MatrixXi& kept_ids = context.kept_ids_per_sc_partition_per_bucket ;
202+ MatrixXi& kept_unique_ids = context.kept_unique_ids_per_partition_per_bucket ;
203+
204+ const bool allow_id_dropping = options.allow_id_dropping ;
205+ const uint32_t global_sc_count = options.GetNumScs ();
206+ const int max_ids_per_partition =
207+ stacked_table_metadata.max_ids_per_partition ;
208+ const int max_unique_ids_per_partition =
209+ stacked_table_metadata.max_unique_ids_per_partition ;
210+ uint32_t prev_col_id = std::numeric_limits<uint32_t >::max ();
211+ uint32_t prev_row_id = std::numeric_limits<uint32_t >::max ();
212+ bool dropping_current_unique_col_id = false ;
213+ for (const uint64_t key : context.keys ) {
214+ // Step 1: Unpack key to get tensor coordinates.
215+ const uint32_t index = key & CooFormat::kIndexMask ;
216+ const CooFormat& coo_tensor = coo_tensors[index];
217+ const uint32_t col_id = coo_tensor.col_id ;
218+ const uint32_t global_sc_id = coo_tensor.col_id & (global_sc_count - 1 );
219+ const uint32_t row_id = coo_tensor.row_id ;
220+
221+ // Step 2: Handle duplicates.
222+ // An ID that is a duplicate of a previously non-dropped ID is merged.
223+ // It does not count as a new ID for stats and does not go through dropping
224+ // logic.
225+ if (grouped_coo_tensors.MaybeMerge (/* bucket_id=*/ 0 , coo_tensor)) {
226+ continue ;
227+ }
228+ // If the ID is a duplicate of the last seen ID, it must have been dropped
229+ // (otherwise it would have been merged above), so drop this one too.
230+ if (col_id == prev_col_id && row_id == prev_row_id) {
231+ ++stats.dropped_id_count ;
232+ continue ;
233+ }
234+
235+ // Step 3: Update observed statistics for the new ID.
236+ const bool is_new_col = col_id != prev_col_id;
237+ // Update observed stats. These are never decremented and are used for
238+ // reporting.
239+ observed_ids (global_sc_id, 0 ) += 1 ;
240+ if (is_new_col) {
241+ observed_unique_ids (global_sc_id, 0 ) += 1 ;
242+ dropping_current_unique_col_id =
243+ (kept_unique_ids (global_sc_id, 0 ) + 1 ) >
244+ max_unique_ids_per_partition;
245+ }
246+
247+ // Step 4: Determine if the ID should be dropped based on capacity limits.
248+ // We do NOT drop IDs when minibatching is enabled and we are in the
249+ // first pass (`create_buckets=false`), as we need to detect limit
250+ // overflows to decide if minibatching is required.
251+ const bool can_drop_id =
252+ !options.enable_minibatching ;
253+ const bool exceeds_ids_limit =
254+ (kept_ids (global_sc_id, 0 ) + 1 ) > max_ids_per_partition;
255+
256+ // Step 5: Add ID to result or drop it.
257+ if (can_drop_id && allow_id_dropping &&
258+ (exceeds_ids_limit || dropping_current_unique_col_id)) {
259+ // Dropped id.
260+ ++stats.dropped_id_count ;
261+ } else {
262+ grouped_coo_tensors.Add (context.local_sc_id , /* bucket_id=*/ 0 , coo_tensor);
263+ // Update kept counts.
264+ kept_ids (global_sc_id, 0 ) += 1 ;
265+ if (is_new_col) {
266+ kept_unique_ids (global_sc_id, 0 ) += 1 ;
267+ }
268+ }
269+
270+ // Step 6: Update state for next iteration.
271+ // This must be done regardless of whether the ID was dropped to ensure
272+ // correct stats collection for subsequent IDs.
273+ prev_col_id = col_id;
274+ prev_row_id = row_id;
275+ }
276+ }
277+
278+ inline void GroupAndDeduplicateCooTensorsForLocalSparseCoreWithBuckets (
191279 LocalSparseCoreTensorGroupingContext context) {
192280 // Unpack context for readability.
193281 const PreprocessSparseDenseMatmulInputOptions& options = context.options ;
@@ -298,8 +386,8 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
298386// NOTE: We use output buffers `max_ids_per_sc`, `max_unique_ids_per_sc`, and
299387// `required_buffer_size_per_sc` because we fill values in a loop to a bigger
300388// array.
301- template <typename SplitType>
302- PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice (
389+ template <bool kCreateBuckets , typename SplitType>
390+ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDeviceImpl (
303391 const ExtractedCooTensors& extracted_coo_tensors,
304392 const StackedTableMetadata& stacked_table_metadata,
305393 const PreprocessSparseDenseMatmulInputOptions& options,
@@ -320,20 +408,18 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
320408 // This function can be called in two passes for minibatching. The logic for
321409 // stats collection and ID dropping depends on the pass.
322410 //
323- // Pass 1: Check if minibatching is required (`create_buckets ` is false).
411+ // Pass 1: Check if minibatching is required (`kCreateBuckets ` is false).
324412 // - No IDs are dropped.
325413 // - Stats are collected on all observed IDs to compute splits.
326414 //
327- // Pass 2: Create buckets (`create_buckets ` is true).
415+ // Pass 2: Create buckets (`kCreateBuckets ` is true).
328416 // - A dummy stats object is used (stats are not re-computed).
329417 // - IDs may be dropped if they exceed capacity.
330- const bool create_buckets = options.enable_minibatching &&
331- (std::is_same_v<SplitType, MinibatchingSplit>);
332418
333419 // Partition COO tensors among SparseCores for the local device (based on row
334420 // id).
335421 const int bucket_count =
336- create_buckets ? CooFormat::kMaxMinibatchingBuckets : 1 ;
422+ kCreateBuckets ? CooFormat::kMaxMinibatchingBuckets : 1 ;
337423 PartitionedCooTensors grouped_coo_tensors (
338424 coo_tensors.size (), num_sc_per_device, global_sc_count, bucket_count);
339425
@@ -367,31 +453,51 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
367453 // local_embedding_id(32-num_scs bits), index(26 bits)].
368454 // Note that this assumes `num_scs` is a power of 2.
369455 keys.push_back (coo_tensors[coo_tensor_index].GetGroupingKey (
370- num_sc_bits, coo_tensor_index, create_buckets ,
456+ num_sc_bits, coo_tensor_index, kCreateBuckets ,
371457 options.minibatching_bucketing_hash_fn ));
372458 }
373459
374460 // The expected allocation size may be uninitialized.
375461 DCHECK (expected_keys_size == 0 || keys.size () == expected_keys_size);
376462 hwy::VQSort (keys.data (), keys.size (), hwy::SortAscending ());
377463
378- internal::GroupAndDeduplicateCooTensorsForLocalSparseCore ({
379- .keys = keys,
380- .coo_tensors = coo_tensors,
381- .stacked_table_metadata = stacked_table_metadata,
382- .options = options,
383- .create_buckets = create_buckets,
384- .local_sc_id = local_sc_id,
385- .grouped_coo_tensors = grouped_coo_tensors,
386- .ids_per_sc_partition_per_bucket = ids_per_sc_partition_per_bucket,
387- .unique_ids_per_partition_per_bucket =
388- unique_ids_per_partition_per_bucket,
389- .stats = stats,
390- .kept_ids_per_sc_partition_per_bucket =
391- kept_ids_per_sc_partition_per_bucket,
392- .kept_unique_ids_per_partition_per_bucket =
393- kept_unique_ids_per_partition_per_bucket,
394- });
464+ if constexpr (kCreateBuckets ) {
465+ internal::GroupAndDeduplicateCooTensorsForLocalSparseCoreWithBuckets ({
466+ .keys = keys,
467+ .coo_tensors = coo_tensors,
468+ .stacked_table_metadata = stacked_table_metadata,
469+ .options = options,
470+ .create_buckets = kCreateBuckets ,
471+ .local_sc_id = local_sc_id,
472+ .grouped_coo_tensors = grouped_coo_tensors,
473+ .ids_per_sc_partition_per_bucket = ids_per_sc_partition_per_bucket,
474+ .unique_ids_per_partition_per_bucket =
475+ unique_ids_per_partition_per_bucket,
476+ .stats = stats,
477+ .kept_ids_per_sc_partition_per_bucket =
478+ kept_ids_per_sc_partition_per_bucket,
479+ .kept_unique_ids_per_partition_per_bucket =
480+ kept_unique_ids_per_partition_per_bucket,
481+ });
482+ } else {
483+ internal::GroupAndDeduplicateCooTensorsForLocalSparseCoreNoBuckets ({
484+ .keys = keys,
485+ .coo_tensors = coo_tensors,
486+ .stacked_table_metadata = stacked_table_metadata,
487+ .options = options,
488+ .create_buckets = kCreateBuckets ,
489+ .local_sc_id = local_sc_id,
490+ .grouped_coo_tensors = grouped_coo_tensors,
491+ .ids_per_sc_partition_per_bucket = ids_per_sc_partition_per_bucket,
492+ .unique_ids_per_partition_per_bucket =
493+ unique_ids_per_partition_per_bucket,
494+ .stats = stats,
495+ .kept_ids_per_sc_partition_per_bucket =
496+ kept_ids_per_sc_partition_per_bucket,
497+ .kept_unique_ids_per_partition_per_bucket =
498+ kept_unique_ids_per_partition_per_bucket,
499+ });
500+ }
395501
396502 grouped_coo_tensors.FillRemainingScBuckets ();
397503
@@ -427,7 +533,7 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
427533
428534 // Only validate if creating minibatching buckets or when minibatching is
429535 // disabled, not when checking if minibatching is required.
430- if (!options.enable_minibatching || create_buckets )
536+ if (!options.enable_minibatching || kCreateBuckets )
431537 internal::ValidateMaxIdsOrDie (
432538 observed_max_ids_per_bucket, observed_max_unique_ids_per_bucket,
433539 max_ids_per_partition, max_unique_ids_per_partition,
@@ -437,6 +543,26 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
437543 return grouped_coo_tensors;
438544}
439545
546+ template <typename SplitType>
547+ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice (
548+ const ExtractedCooTensors& extracted_coo_tensors,
549+ const StackedTableMetadata& stacked_table_metadata,
550+ const PreprocessSparseDenseMatmulInputOptions& options,
551+ internal::StatsPerDevice& stats, SplitType& minibatching_split) {
552+ const bool create_buckets =
553+ options.enable_minibatching &&
554+ std::is_same_v<SplitType, MinibatchingSplit>;
555+ if (create_buckets) {
556+ return SortAndGroupCooTensorsPerLocalDeviceImpl<true >(
557+ extracted_coo_tensors, stacked_table_metadata, options, stats,
558+ minibatching_split);
559+ } else {
560+ return SortAndGroupCooTensorsPerLocalDeviceImpl<false >(
561+ extracted_coo_tensors, stacked_table_metadata, options, stats,
562+ minibatching_split);
563+ }
564+ }
565+
440566} // namespace jax_sc_embedding
441567
442568#endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_SORT_AND_GROUP_COO_TENSORS_IMPL_H_
0 commit comments