2323#include < utility>
2424#include < vector>
2525
26+ #include " absl/algorithm/container.h" // from @com_google_absl
2627#include " absl/base/thread_annotations.h" // from @com_google_absl
2728#include " absl/container/flat_hash_map.h" // from @com_google_absl
2829#include " absl/log/check.h" // from @com_google_absl
@@ -126,6 +127,7 @@ struct TableState {
126127 MinibatchingSplit table_minibatching_split = 0 ;
127128 std::vector<ExtractedCooTensors> extracted_coo_tensors_per_device;
128129 std::vector<PartitionedCooTensors> partitioned_coo_tensors_per_device;
130+ std::vector<int > dropped_id_count_per_device;
129131
130132 TableState (const std::string& name,
131133 absl::Span<const StackedTableMetadata> metadata,
@@ -147,8 +149,9 @@ struct TableState {
147149 stats_per_host(options.local_device_count, options.GetNumScs(),
148150 options.num_sc_per_device),
149151 batch_size_for_device(0 ) {
150- extracted_coo_tensors_per_device.reserve (options.local_device_count );
151- partitioned_coo_tensors_per_device.reserve (options.local_device_count );
152+ extracted_coo_tensors_per_device.resize (options.local_device_count );
153+ partitioned_coo_tensors_per_device.resize (options.local_device_count );
154+ dropped_id_count_per_device.resize (options.local_device_count , 0 );
152155 }
153156};
154157
@@ -159,32 +162,42 @@ struct TableState {
159162void ExtractSortAndGroupCooTensorsForTable (
160163 TableState& state,
161164 absl::Span<std::unique_ptr<AbstractInputBatch>> input_batches,
162- const PreprocessSparseDenseMatmulInputOptions& options) {
165+ const PreprocessSparseDenseMatmulInputOptions& options,
166+ absl::BlockingCounter* counter) {
163167 tsl::profiler::TraceMe traceme ([&] {
164168 return absl::StrCat (" InputPreprocessingTable-ExtractSortGroup-" ,
165169 state.stacked_table_name );
166170 });
167-
168171 for (int local_device = 0 ; local_device < options.local_device_count ;
169172 ++local_device) {
170- ExtractedCooTensors extracted_coo_tensors =
171- internal::ExtractCooTensorsForAllFeaturesPerLocalDevice (
172- state.stacked_table_metadata , input_batches, local_device, options);
173- state.extracted_coo_tensors_per_device .push_back (extracted_coo_tensors);
174- if (local_device == 0 )
175- state.batch_size_for_device = extracted_coo_tensors.batch_size_for_device ;
176- else
177- CHECK_EQ (state.batch_size_for_device ,
178- extracted_coo_tensors.batch_size_for_device );
179-
180- internal::StatsPerDevice stats_per_device =
181- state.stats_per_host .GetStatsPerDevice (local_device);
182- const PartitionedCooTensors grouped_coo_tensors =
183- SortAndGroupCooTensorsPerLocalDevice (
184- extracted_coo_tensors, state.stacked_table_metadata [0 ], options,
185- stats_per_device, state.table_minibatching_required );
186- state.partitioned_coo_tensors_per_device .push_back (grouped_coo_tensors);
187- state.stats_per_host .dropped_id_count += stats_per_device.dropped_id_count ;
173+ PreprocessingThreadPool ()->Schedule ([&, local_device, &state = state] {
174+ state.extracted_coo_tensors_per_device [local_device] =
175+ internal::ExtractCooTensorsForAllFeaturesPerLocalDevice (
176+ state.stacked_table_metadata , input_batches, local_device,
177+ options);
178+
179+ internal::StatsPerDevice stats_per_device =
180+ state.stats_per_host .GetStatsPerDevice (local_device);
181+ state.partitioned_coo_tensors_per_device [local_device] =
182+ SortAndGroupCooTensorsPerLocalDevice (
183+ state.extracted_coo_tensors_per_device [local_device],
184+ state.stacked_table_metadata [0 ], options, stats_per_device,
185+ state.table_minibatching_required );
186+ state.dropped_id_count_per_device [local_device] =
187+ stats_per_device.dropped_id_count ;
188+ counter->DecrementCount ();
189+ });
190+ }
191+ }
192+
193+ void PostProcessTableState (TableState& state) {
194+ state.stats_per_host .dropped_id_count =
195+ absl::c_accumulate (state.dropped_id_count_per_device , 0LL );
196+
197+ state.batch_size_for_device =
198+ state.extracted_coo_tensors_per_device [0 ].batch_size_for_device ;
199+ for (const auto & extracted_coo : state.extracted_coo_tensors_per_device ) {
200+ DCHECK_EQ (state.batch_size_for_device , extracted_coo.batch_size_for_device );
188201 }
189202}
190203
@@ -194,29 +207,34 @@ void ExtractSortAndGroupCooTensorsForTable(
194207// `state`: The TableState holding the COO tensors and statistics.
195208// `options`: Preprocessing options.
196209void CreateMinibatchingBucketsForTable (
197- TableState& state, const PreprocessSparseDenseMatmulInputOptions& options) {
210+ TableState& state, const PreprocessSparseDenseMatmulInputOptions& options,
211+ absl::BlockingCounter* counter) {
198212 tsl::profiler::TraceMe traceme ([&] {
199213 return absl::StrCat (" InputPreprocessingTable-CreateMinibatchingBuckets-" ,
200214 state.stacked_table_name );
201215 });
202216 state.stats_per_host .dropped_id_count = 0 ;
203217 for (int local_device = 0 ; local_device < options.local_device_count ;
204218 ++local_device) {
205- // Note: We create a dummy stats object here because we don't want to
206- // overwrite the stats from the first pass, which are authoritative.
207- // The only stat we care about from this second pass is the number of
208- // dropped IDs.
209- StatsPerHost dummy_stats_host (
210- /* local_device_count=*/ 1 , options.GetNumScs (),
211- options.num_sc_per_device );
212- internal::StatsPerDevice dummy_stats =
213- dummy_stats_host.GetStatsPerDevice (0 );
214- state.partitioned_coo_tensors_per_device [local_device] =
215- SortAndGroupCooTensorsPerLocalDevice (
216- state.extracted_coo_tensors_per_device [local_device],
217- state.stacked_table_metadata [0 ], options, dummy_stats,
218- state.table_minibatching_split );
219- state.stats_per_host .dropped_id_count += dummy_stats.dropped_id_count ;
219+ PreprocessingThreadPool ()->Schedule ([&, local_device, &state = state] {
220+ // Note: We create a dummy stats object here because we don't want to
221+ // overwrite the stats from the first pass, which are authoritative.
222+ // The only stat we care about from this second pass is the number of
223+ // dropped IDs.
224+ StatsPerHost dummy_stats_host (
225+ /* local_device_count=*/ 1 , options.GetNumScs (),
226+ options.num_sc_per_device );
227+ internal::StatsPerDevice dummy_stats =
228+ dummy_stats_host.GetStatsPerDevice (0 );
229+ state.partitioned_coo_tensors_per_device [local_device] =
230+ SortAndGroupCooTensorsPerLocalDevice (
231+ state.extracted_coo_tensors_per_device [local_device],
232+ state.stacked_table_metadata [0 ], options, dummy_stats,
233+ state.table_minibatching_split );
234+ state.dropped_id_count_per_device [local_device] =
235+ dummy_stats.dropped_id_count ;
236+ counter->DecrementCount ();
237+ });
220238 }
221239}
222240
@@ -433,44 +451,38 @@ void PopulateOutput(TableState& state, PreprocessSparseDenseMatmulOutput& out,
433451// `out`: The output structure to be populated with CSR arrays and stats.
434452// `output_mutex`: Mutex to protect access to `out`.
435453void FillDeviceBuffersForTable (
436- TableState& state, const PreprocessSparseDenseMatmulInputOptions& options ,
437- bool global_minibatching_required ,
438- MinibatchingSplit global_minibatching_split ,
439- int row_pointers_size_per_bucket, PreprocessSparseDenseMatmulOutput& out ,
440- absl::Mutex& output_mutex ) {
454+ TableState& state, PreprocessSparseDenseMatmulOutput& out ,
455+ const PreprocessSparseDenseMatmulInputOptions& options ,
456+ int row_pointers_size_per_bucket, bool global_minibatching_required ,
457+ MinibatchingSplit global_minibatching_split, absl::Mutex& output_mutex ,
458+ absl::BlockingCounter* counter ) {
441459 tsl::profiler::TraceMe traceme ([&] {
442460 return absl::StrCat (" InputPreprocessingTable-FillBuffer-" ,
443461 state.stacked_table_name );
444462 });
445- int table_dropped_ids = 0 ;
446463 for (int local_device = 0 ; local_device < options.local_device_count ;
447464 ++local_device) {
448- PartitionedCooTensors& grouped_coo_tensors =
449- state.partitioned_coo_tensors_per_device [local_device];
450- if (options.enable_minibatching && global_minibatching_required) {
451- grouped_coo_tensors.Merge (global_minibatching_split);
452- }
465+ PreprocessingThreadPool ()->Schedule ([&, local_device, &state = state] {
466+ PartitionedCooTensors& grouped_coo_tensors =
467+ state.partitioned_coo_tensors_per_device [local_device];
468+ if (options.enable_minibatching && global_minibatching_required) {
469+ grouped_coo_tensors.Merge (global_minibatching_split);
470+ }
453471
454- const int batch_size_per_sc = xla::CeilOfRatio (state.batch_size_for_device ,
455- options.num_sc_per_device );
456- const int coo_buffer_size_per_sc =
457- state.coo_buffer_size_per_device / options.num_sc_per_device ;
458- internal::CsrArraysPerDevice csr_arrays_per_device =
459- state.csr_arrays_per_host .GetCsrArraysPerDevice (local_device);
460- FillLocalDeviceBuffer (grouped_coo_tensors, row_pointers_size_per_bucket,
461- coo_buffer_size_per_sc, batch_size_per_sc, options,
462- csr_arrays_per_device, table_dropped_ids);
463- state.stats_per_host .dropped_id_count += table_dropped_ids;
472+ const int batch_size_per_sc = xla::CeilOfRatio (
473+ state.batch_size_for_device , options.num_sc_per_device );
474+ const int coo_buffer_size_per_sc =
475+ state.coo_buffer_size_per_device / options.num_sc_per_device ;
476+ internal::CsrArraysPerDevice csr_arrays_per_device =
477+ state.csr_arrays_per_host .GetCsrArraysPerDevice (local_device);
478+ int table_dropped_ids = 0 ;
479+ FillLocalDeviceBuffer (grouped_coo_tensors, row_pointers_size_per_bucket,
480+ coo_buffer_size_per_sc, batch_size_per_sc, options,
481+ csr_arrays_per_device, table_dropped_ids);
482+ state.dropped_id_count_per_device [local_device] = table_dropped_ids;
483+ counter->DecrementCount ();
484+ });
464485 }
465- // NOMUTANTS -- Informational.
466- CheckBufferUsage (
467- /* max_required_buffer_size_per_device= */
468- state.stats_per_host .required_buffer_size .maxCoeff () *
469- options.num_sc_per_device ,
470- state.coo_buffer_size_per_device , state.stacked_table_name ,
471- options.batch_number );
472-
473- PopulateOutput (state, out, output_mutex);
474486}
475487
476488} // namespace
@@ -518,14 +530,18 @@ PreprocessSparseDenseMatmulInput(
518530 // Stage 1: COO Extraction and Initial Sort/Group
519531 {
520532 tsl::profiler::TraceMe traceme (" ExtractSortAndGroupCooTensors" );
521- absl::BlockingCounter counter (stacked_tables.size ());
533+ absl::BlockingCounter counter (table_states.size () *
534+ options.local_device_count );
522535 for (auto & state : table_states) {
523- PreprocessingThreadPool ()->Schedule ([&, &state = state] {
524- ExtractSortAndGroupCooTensorsForTable (state, input_batches, options);
525- counter.DecrementCount ();
526- });
536+ ExtractSortAndGroupCooTensorsForTable (state, input_batches, options,
537+ &counter);
527538 }
528539 counter.Wait ();
540+
541+ // Post-process results after all threads are done.
542+ for (auto & state : table_states) {
543+ PostProcessTableState (state);
544+ }
529545 }
530546 TF_ASSIGN_OR_RETURN (bool global_minibatching_required,
531547 SyncMinibatchingRequired (options, table_states));
@@ -536,14 +552,15 @@ PreprocessSparseDenseMatmulInput(
536552 if (options.enable_minibatching && global_minibatching_required) {
537553 {
538554 tsl::profiler::TraceMe traceme (" CreateMinibatchingBuckets" );
539- absl::BlockingCounter counter (stacked_tables.size ());
555+ absl::BlockingCounter counter (table_states.size () *
556+ options.local_device_count );
540557 for (auto & state : table_states) {
541- PreprocessingThreadPool ()->Schedule ([&, &state = state] {
542- CreateMinibatchingBucketsForTable (state, options);
543- counter.DecrementCount ();
544- });
558+ CreateMinibatchingBucketsForTable (state, options, &counter);
545559 }
546560 counter.Wait ();
561+ for (auto & state : table_states) {
562+ PostProcessTableState (state);
563+ }
547564 }
548565
549566 TF_ASSIGN_OR_RETURN (global_minibatching_split,
@@ -553,19 +570,28 @@ PreprocessSparseDenseMatmulInput(
553570 // Stage 3: Fill Device Buffers
554571 {
555572 tsl::profiler::TraceMe traceme (" FillDeviceBuffers" );
556- absl::BlockingCounter counter (stacked_tables.size ());
573+ absl::BlockingCounter counter (table_states.size () *
574+ options.local_device_count );
557575 for (auto & state : table_states) {
558- PreprocessingThreadPool ()->Schedule ([&, &state = state,
559- global_minibatching_required,
560- global_minibatching_split] {
561- FillDeviceBuffersForTable (state, options, global_minibatching_required,
562- global_minibatching_split,
563- row_pointers_size_per_bucket, out,
564- output_mutex);
565- counter.DecrementCount ();
566- });
576+ FillDeviceBuffersForTable (
577+ state, out, options, row_pointers_size_per_bucket,
578+ global_minibatching_required, global_minibatching_split, output_mutex,
579+ &counter);
567580 }
568581 counter.Wait ();
582+ for (auto & state : table_states) {
583+ state.stats_per_host .dropped_id_count +=
584+ absl::c_accumulate (state.dropped_id_count_per_device , 0LL );
585+ // NOMUTANTS -- Informational.
586+ CheckBufferUsage (
587+ /* max_required_buffer_size_per_device= */
588+ state.stats_per_host .required_buffer_size .maxCoeff () *
589+ options.num_sc_per_device ,
590+ state.coo_buffer_size_per_device , state.stacked_table_name ,
591+ options.batch_number );
592+
593+ PopulateOutput (state, out, output_mutex);
594+ }
569595 }
570596
571597 out.num_minibatches = global_minibatching_split.count () + 1 ;
0 commit comments