@@ -126,6 +126,7 @@ struct TableState {
126126 MinibatchingSplit table_minibatching_split = 0 ;
127127 std::vector<ExtractedCooTensors> extracted_coo_tensors_per_device;
128128 std::vector<PartitionedCooTensors> partitioned_coo_tensors_per_device;
129+ std::vector<int > dropped_id_count_per_device;
129130
130131 TableState (const std::string& name,
131132 absl::Span<const StackedTableMetadata> metadata,
@@ -147,8 +148,9 @@ struct TableState {
147148 stats_per_host(options.local_device_count, options.GetNumScs(),
148149 options.num_sc_per_device),
149150 batch_size_for_device(0 ) {
150- extracted_coo_tensors_per_device.reserve (options.local_device_count );
151- partitioned_coo_tensors_per_device.reserve (options.local_device_count );
151+ extracted_coo_tensors_per_device.resize (options.local_device_count );
152+ partitioned_coo_tensors_per_device.resize (options.local_device_count );
153+ dropped_id_count_per_device.resize (options.local_device_count , 0 );
152154 }
153155};
154156
@@ -165,26 +167,40 @@ void ExtractSortAndGroupCooTensorsForTable(
165167 state.stacked_table_name );
166168 });
167169
170+ absl::BlockingCounter counter (options.local_device_count );
168171 for (int local_device = 0 ; local_device < options.local_device_count ;
169172 ++local_device) {
170- ExtractedCooTensors extracted_coo_tensors =
171- internal::ExtractCooTensorsForAllFeaturesPerLocalDevice (
172- state.stacked_table_metadata , input_batches, local_device, options);
173- state.extracted_coo_tensors_per_device .push_back (extracted_coo_tensors);
174- if (local_device == 0 )
175- state.batch_size_for_device = extracted_coo_tensors.batch_size_for_device ;
176- else
177- CHECK_EQ (state.batch_size_for_device ,
178- extracted_coo_tensors.batch_size_for_device );
179-
180- internal::StatsPerDevice stats_per_device =
181- state.stats_per_host .GetStatsPerDevice (local_device);
182- const PartitionedCooTensors grouped_coo_tensors =
183- SortAndGroupCooTensorsPerLocalDevice (
184- extracted_coo_tensors, state.stacked_table_metadata [0 ], options,
185- stats_per_device, state.table_minibatching_required );
186- state.partitioned_coo_tensors_per_device .push_back (grouped_coo_tensors);
187- state.stats_per_host .dropped_id_count += stats_per_device.dropped_id_count ;
173+ DeviceProcessingThreadPool ()->Schedule ([&, local_device] {
174+ state.extracted_coo_tensors_per_device [local_device] =
175+ internal::ExtractCooTensorsForAllFeaturesPerLocalDevice (
176+ state.stacked_table_metadata , input_batches, local_device,
177+ options);
178+
179+ internal::StatsPerDevice stats_per_device =
180+ state.stats_per_host .GetStatsPerDevice (local_device);
181+ state.partitioned_coo_tensors_per_device [local_device] =
182+ SortAndGroupCooTensorsPerLocalDevice (
183+ state.extracted_coo_tensors_per_device [local_device],
184+ state.stacked_table_metadata [0 ], options, stats_per_device,
185+ state.table_minibatching_required );
186+ state.dropped_id_count_per_device [local_device] =
187+ stats_per_device.dropped_id_count ;
188+ counter.DecrementCount ();
189+ });
190+ }
191+ counter.Wait ();
192+
193+ // Post-process results after all threads are done.
194+ state.batch_size_for_device =
195+ state.extracted_coo_tensors_per_device [0 ].batch_size_for_device ;
196+ state.stats_per_host .dropped_id_count = 0 ;
197+ for (int local_device = 0 ; local_device < options.local_device_count ;
198+ ++local_device) {
199+ DCHECK_EQ (state.batch_size_for_device ,
200+ state.extracted_coo_tensors_per_device [local_device]
201+ .batch_size_for_device );
202+ state.stats_per_host .dropped_id_count +=
203+ state.dropped_id_count_per_device [local_device];
188204 }
189205}
190206
@@ -518,9 +534,9 @@ PreprocessSparseDenseMatmulInput(
518534 // Stage 1: COO Extraction and Initial Sort/Group
519535 {
520536 tsl::profiler::TraceMe traceme (" ExtractSortAndGroupCooTensors" );
521- absl::BlockingCounter counter (stacked_tables .size ());
537+ absl::BlockingCounter counter (table_states .size ());
522538 for (auto & state : table_states) {
523- PreprocessingThreadPool ()->Schedule ([&, &state = state] {
539+ TableProcessingThreadPool ()->Schedule ([&, &state = state] {
524540 ExtractSortAndGroupCooTensorsForTable (state, input_batches, options);
525541 counter.DecrementCount ();
526542 });
@@ -536,9 +552,9 @@ PreprocessSparseDenseMatmulInput(
536552 if (options.enable_minibatching && global_minibatching_required) {
537553 {
538554 tsl::profiler::TraceMe traceme (" CreateMinibatchingBuckets" );
539- absl::BlockingCounter counter (stacked_tables .size ());
555+ absl::BlockingCounter counter (table_states .size ());
540556 for (auto & state : table_states) {
541- PreprocessingThreadPool ()->Schedule ([&, &state = state] {
557+ TableProcessingThreadPool ()->Schedule ([&, &state = state] {
542558 CreateMinibatchingBucketsForTable (state, options);
543559 counter.DecrementCount ();
544560 });
@@ -553,9 +569,9 @@ PreprocessSparseDenseMatmulInput(
553569 // Stage 3: Fill Device Buffers
554570 {
555571 tsl::profiler::TraceMe traceme (" FillDeviceBuffers" );
556- absl::BlockingCounter counter (stacked_tables .size ());
572+ absl::BlockingCounter counter (table_states .size ());
557573 for (auto & state : table_states) {
558- PreprocessingThreadPool ()->Schedule ([&, &state = state,
574+ TableProcessingThreadPool ()->Schedule ([&, &state = state,
559575 global_minibatching_required,
560576 global_minibatching_split] {
561577 FillDeviceBuffersForTable (state, options, global_minibatching_required,
0 commit comments