Skip to content

Commit d8f34cb

Browse files
[JAX SC] Parallelize device loop for extraction, sorting, grouping, bucketing and buffer filling.
* `9.81%` geomean reduction (`~11%` with FDO) in wall time with `0.97%` CPU time increase and `5.31%` cycles reduction. * Use separate pool to avoid deadlocks. The fixed cost for scheduling should be less than 0.1% * Add default constructible objects for parallelization. PiperOrigin-RevId: 826509091
1 parent 3fca158 commit d8f34cb

File tree

6 files changed

+136
-92
lines changed

6 files changed

+136
-92
lines changed

jax_tpu_embedding/sparsecore/lib/core/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ cc_library(
134134
":input_preprocessing_util",
135135
":partitioned_coo_tensors",
136136
":sort_and_group_coo_tensors_impl",
137+
"@com_google_absl//absl/algorithm:container",
137138
"@com_google_absl//absl/base:core_headers",
138139
"@com_google_absl//absl/container:flat_hash_map",
139140
"@com_google_absl//absl/log",

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc

Lines changed: 114 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <utility>
2424
#include <vector>
2525

26+
#include "absl/algorithm/container.h" // from @com_google_absl
2627
#include "absl/base/thread_annotations.h" // from @com_google_absl
2728
#include "absl/container/flat_hash_map.h" // from @com_google_absl
2829
#include "absl/log/check.h" // from @com_google_absl
@@ -126,6 +127,7 @@ struct TableState {
126127
MinibatchingSplit table_minibatching_split = 0;
127128
std::vector<ExtractedCooTensors> extracted_coo_tensors_per_device;
128129
std::vector<PartitionedCooTensors> partitioned_coo_tensors_per_device;
130+
std::vector<int> dropped_id_count_per_device;
129131

130132
TableState(const std::string& name,
131133
absl::Span<const StackedTableMetadata> metadata,
@@ -147,8 +149,9 @@ struct TableState {
147149
stats_per_host(options.local_device_count, options.GetNumScs(),
148150
options.num_sc_per_device),
149151
batch_size_for_device(0) {
150-
extracted_coo_tensors_per_device.reserve(options.local_device_count);
151-
partitioned_coo_tensors_per_device.reserve(options.local_device_count);
152+
extracted_coo_tensors_per_device.resize(options.local_device_count);
153+
partitioned_coo_tensors_per_device.resize(options.local_device_count);
154+
dropped_id_count_per_device.resize(options.local_device_count, 0);
152155
}
153156
};
154157

@@ -159,32 +162,42 @@ struct TableState {
159162
void ExtractSortAndGroupCooTensorsForTable(
160163
TableState& state,
161164
absl::Span<std::unique_ptr<AbstractInputBatch>> input_batches,
162-
const PreprocessSparseDenseMatmulInputOptions& options) {
165+
const PreprocessSparseDenseMatmulInputOptions& options,
166+
absl::BlockingCounter* counter) {
163167
tsl::profiler::TraceMe traceme([&] {
164168
return absl::StrCat("InputPreprocessingTable-ExtractSortGroup-",
165169
state.stacked_table_name);
166170
});
167-
168171
for (int local_device = 0; local_device < options.local_device_count;
169172
++local_device) {
170-
ExtractedCooTensors extracted_coo_tensors =
171-
internal::ExtractCooTensorsForAllFeaturesPerLocalDevice(
172-
state.stacked_table_metadata, input_batches, local_device, options);
173-
state.extracted_coo_tensors_per_device.push_back(extracted_coo_tensors);
174-
if (local_device == 0)
175-
state.batch_size_for_device = extracted_coo_tensors.batch_size_for_device;
176-
else
177-
CHECK_EQ(state.batch_size_for_device,
178-
extracted_coo_tensors.batch_size_for_device);
179-
180-
internal::StatsPerDevice stats_per_device =
181-
state.stats_per_host.GetStatsPerDevice(local_device);
182-
const PartitionedCooTensors grouped_coo_tensors =
183-
SortAndGroupCooTensorsPerLocalDevice(
184-
extracted_coo_tensors, state.stacked_table_metadata[0], options,
185-
stats_per_device, state.table_minibatching_required);
186-
state.partitioned_coo_tensors_per_device.push_back(grouped_coo_tensors);
187-
state.stats_per_host.dropped_id_count += stats_per_device.dropped_id_count;
173+
PreprocessingThreadPool()->Schedule([&, local_device, &state = state] {
174+
state.extracted_coo_tensors_per_device[local_device] =
175+
internal::ExtractCooTensorsForAllFeaturesPerLocalDevice(
176+
state.stacked_table_metadata, input_batches, local_device,
177+
options);
178+
179+
internal::StatsPerDevice stats_per_device =
180+
state.stats_per_host.GetStatsPerDevice(local_device);
181+
state.partitioned_coo_tensors_per_device[local_device] =
182+
SortAndGroupCooTensorsPerLocalDevice(
183+
state.extracted_coo_tensors_per_device[local_device],
184+
state.stacked_table_metadata[0], options, stats_per_device,
185+
state.table_minibatching_required);
186+
state.dropped_id_count_per_device[local_device] =
187+
stats_per_device.dropped_id_count;
188+
counter->DecrementCount();
189+
});
190+
}
191+
}
192+
193+
void PostProcessTableState(TableState& state) {
194+
state.stats_per_host.dropped_id_count =
195+
absl::c_accumulate(state.dropped_id_count_per_device, 0LL);
196+
197+
state.batch_size_for_device =
198+
state.extracted_coo_tensors_per_device[0].batch_size_for_device;
199+
for (const auto& extracted_coo : state.extracted_coo_tensors_per_device) {
200+
DCHECK_EQ(state.batch_size_for_device, extracted_coo.batch_size_for_device);
188201
}
189202
}
190203

@@ -194,29 +207,34 @@ void ExtractSortAndGroupCooTensorsForTable(
194207
// `state`: The TableState holding the COO tensors and statistics.
195208
// `options`: Preprocessing options.
196209
void CreateMinibatchingBucketsForTable(
197-
TableState& state, const PreprocessSparseDenseMatmulInputOptions& options) {
210+
TableState& state, const PreprocessSparseDenseMatmulInputOptions& options,
211+
absl::BlockingCounter* counter) {
198212
tsl::profiler::TraceMe traceme([&] {
199213
return absl::StrCat("InputPreprocessingTable-CreateMinibatchingBuckets-",
200214
state.stacked_table_name);
201215
});
202216
state.stats_per_host.dropped_id_count = 0;
203217
for (int local_device = 0; local_device < options.local_device_count;
204218
++local_device) {
205-
// Note: We create a dummy stats object here because we don't want to
206-
// overwrite the stats from the first pass, which are authoritative.
207-
// The only stat we care about from this second pass is the number of
208-
// dropped IDs.
209-
StatsPerHost dummy_stats_host(
210-
/*local_device_count=*/1, options.GetNumScs(),
211-
options.num_sc_per_device);
212-
internal::StatsPerDevice dummy_stats =
213-
dummy_stats_host.GetStatsPerDevice(0);
214-
state.partitioned_coo_tensors_per_device[local_device] =
215-
SortAndGroupCooTensorsPerLocalDevice(
216-
state.extracted_coo_tensors_per_device[local_device],
217-
state.stacked_table_metadata[0], options, dummy_stats,
218-
state.table_minibatching_split);
219-
state.stats_per_host.dropped_id_count += dummy_stats.dropped_id_count;
219+
PreprocessingThreadPool()->Schedule([&, local_device, &state = state] {
220+
// Note: We create a dummy stats object here because we don't want to
221+
// overwrite the stats from the first pass, which are authoritative.
222+
// The only stat we care about from this second pass is the number of
223+
// dropped IDs.
224+
StatsPerHost dummy_stats_host(
225+
/*local_device_count=*/1, options.GetNumScs(),
226+
options.num_sc_per_device);
227+
internal::StatsPerDevice dummy_stats =
228+
dummy_stats_host.GetStatsPerDevice(0);
229+
state.partitioned_coo_tensors_per_device[local_device] =
230+
SortAndGroupCooTensorsPerLocalDevice(
231+
state.extracted_coo_tensors_per_device[local_device],
232+
state.stacked_table_metadata[0], options, dummy_stats,
233+
state.table_minibatching_split);
234+
state.dropped_id_count_per_device[local_device] =
235+
dummy_stats.dropped_id_count;
236+
counter->DecrementCount();
237+
});
220238
}
221239
}
222240

@@ -433,44 +451,38 @@ void PopulateOutput(TableState& state, PreprocessSparseDenseMatmulOutput& out,
433451
// `out`: The output structure to be populated with CSR arrays and stats.
434452
// `output_mutex`: Mutex to protect access to `out`.
435453
void FillDeviceBuffersForTable(
436-
TableState& state, const PreprocessSparseDenseMatmulInputOptions& options,
437-
bool global_minibatching_required,
438-
MinibatchingSplit global_minibatching_split,
439-
int row_pointers_size_per_bucket, PreprocessSparseDenseMatmulOutput& out,
440-
absl::Mutex& output_mutex) {
454+
TableState& state, PreprocessSparseDenseMatmulOutput& out,
455+
const PreprocessSparseDenseMatmulInputOptions& options,
456+
int row_pointers_size_per_bucket, bool global_minibatching_required,
457+
MinibatchingSplit global_minibatching_split, absl::Mutex& output_mutex,
458+
absl::BlockingCounter* counter) {
441459
tsl::profiler::TraceMe traceme([&] {
442460
return absl::StrCat("InputPreprocessingTable-FillBuffer-",
443461
state.stacked_table_name);
444462
});
445-
int table_dropped_ids = 0;
446463
for (int local_device = 0; local_device < options.local_device_count;
447464
++local_device) {
448-
PartitionedCooTensors& grouped_coo_tensors =
449-
state.partitioned_coo_tensors_per_device[local_device];
450-
if (options.enable_minibatching && global_minibatching_required) {
451-
grouped_coo_tensors.Merge(global_minibatching_split);
452-
}
465+
PreprocessingThreadPool()->Schedule([&, local_device, &state = state] {
466+
PartitionedCooTensors& grouped_coo_tensors =
467+
state.partitioned_coo_tensors_per_device[local_device];
468+
if (options.enable_minibatching && global_minibatching_required) {
469+
grouped_coo_tensors.Merge(global_minibatching_split);
470+
}
453471

454-
const int batch_size_per_sc = xla::CeilOfRatio(state.batch_size_for_device,
455-
options.num_sc_per_device);
456-
const int coo_buffer_size_per_sc =
457-
state.coo_buffer_size_per_device / options.num_sc_per_device;
458-
internal::CsrArraysPerDevice csr_arrays_per_device =
459-
state.csr_arrays_per_host.GetCsrArraysPerDevice(local_device);
460-
FillLocalDeviceBuffer(grouped_coo_tensors, row_pointers_size_per_bucket,
461-
coo_buffer_size_per_sc, batch_size_per_sc, options,
462-
csr_arrays_per_device, table_dropped_ids);
463-
state.stats_per_host.dropped_id_count += table_dropped_ids;
472+
const int batch_size_per_sc = xla::CeilOfRatio(
473+
state.batch_size_for_device, options.num_sc_per_device);
474+
const int coo_buffer_size_per_sc =
475+
state.coo_buffer_size_per_device / options.num_sc_per_device;
476+
internal::CsrArraysPerDevice csr_arrays_per_device =
477+
state.csr_arrays_per_host.GetCsrArraysPerDevice(local_device);
478+
int table_dropped_ids = 0;
479+
FillLocalDeviceBuffer(grouped_coo_tensors, row_pointers_size_per_bucket,
480+
coo_buffer_size_per_sc, batch_size_per_sc, options,
481+
csr_arrays_per_device, table_dropped_ids);
482+
state.dropped_id_count_per_device[local_device] = table_dropped_ids;
483+
counter->DecrementCount();
484+
});
464485
}
465-
// NOMUTANTS -- Informational.
466-
CheckBufferUsage(
467-
/* max_required_buffer_size_per_device= */
468-
state.stats_per_host.required_buffer_size.maxCoeff() *
469-
options.num_sc_per_device,
470-
state.coo_buffer_size_per_device, state.stacked_table_name,
471-
options.batch_number);
472-
473-
PopulateOutput(state, out, output_mutex);
474486
}
475487

476488
} // namespace
@@ -518,14 +530,18 @@ PreprocessSparseDenseMatmulInput(
518530
// Stage 1: COO Extraction and Initial Sort/Group
519531
{
520532
tsl::profiler::TraceMe traceme("ExtractSortAndGroupCooTensors");
521-
absl::BlockingCounter counter(stacked_tables.size());
533+
absl::BlockingCounter counter(table_states.size() *
534+
options.local_device_count);
522535
for (auto& state : table_states) {
523-
PreprocessingThreadPool()->Schedule([&, &state = state] {
524-
ExtractSortAndGroupCooTensorsForTable(state, input_batches, options);
525-
counter.DecrementCount();
526-
});
536+
ExtractSortAndGroupCooTensorsForTable(state, input_batches, options,
537+
&counter);
527538
}
528539
counter.Wait();
540+
541+
// Post-process results after all threads are done.
542+
for (auto& state : table_states) {
543+
PostProcessTableState(state);
544+
}
529545
}
530546
TF_ASSIGN_OR_RETURN(bool global_minibatching_required,
531547
SyncMinibatchingRequired(options, table_states));
@@ -536,14 +552,15 @@ PreprocessSparseDenseMatmulInput(
536552
if (options.enable_minibatching && global_minibatching_required) {
537553
{
538554
tsl::profiler::TraceMe traceme("CreateMinibatchingBuckets");
539-
absl::BlockingCounter counter(stacked_tables.size());
555+
absl::BlockingCounter counter(table_states.size() *
556+
options.local_device_count);
540557
for (auto& state : table_states) {
541-
PreprocessingThreadPool()->Schedule([&, &state = state] {
542-
CreateMinibatchingBucketsForTable(state, options);
543-
counter.DecrementCount();
544-
});
558+
CreateMinibatchingBucketsForTable(state, options, &counter);
545559
}
546560
counter.Wait();
561+
for (auto& state : table_states) {
562+
PostProcessTableState(state);
563+
}
547564
}
548565

549566
TF_ASSIGN_OR_RETURN(global_minibatching_split,
@@ -553,19 +570,28 @@ PreprocessSparseDenseMatmulInput(
553570
// Stage 3: Fill Device Buffers
554571
{
555572
tsl::profiler::TraceMe traceme("FillDeviceBuffers");
556-
absl::BlockingCounter counter(stacked_tables.size());
573+
absl::BlockingCounter counter(table_states.size() *
574+
options.local_device_count);
557575
for (auto& state : table_states) {
558-
PreprocessingThreadPool()->Schedule([&, &state = state,
559-
global_minibatching_required,
560-
global_minibatching_split] {
561-
FillDeviceBuffersForTable(state, options, global_minibatching_required,
562-
global_minibatching_split,
563-
row_pointers_size_per_bucket, out,
564-
output_mutex);
565-
counter.DecrementCount();
566-
});
576+
FillDeviceBuffersForTable(
577+
state, out, options, row_pointers_size_per_bucket,
578+
global_minibatching_required, global_minibatching_split, output_mutex,
579+
&counter);
567580
}
568581
counter.Wait();
582+
for (auto& state : table_states) {
583+
state.stats_per_host.dropped_id_count +=
584+
absl::c_accumulate(state.dropped_id_count_per_device, 0LL);
585+
// NOMUTANTS -- Informational.
586+
CheckBufferUsage(
587+
/* max_required_buffer_size_per_device= */
588+
state.stats_per_host.required_buffer_size.maxCoeff() *
589+
options.num_sc_per_device,
590+
state.coo_buffer_size_per_device, state.stacked_table_name,
591+
options.batch_number);
592+
593+
PopulateOutput(state, out, output_mutex);
594+
}
569595
}
570596

571597
out.num_minibatches = global_minibatching_split.count() + 1;

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ tsl::thread::ThreadPool* PreprocessingThreadPool() {
5050
static tsl::thread::ThreadPool* pool = []() {
5151
const int num_threads = GetThreadPoolSize();
5252
DCHECK_GE(num_threads, 1);
53-
LOG(INFO) << "Creating thread pool for SparseCore input preprocessing: "
53+
LOG(INFO) << "Creating thread pool for SparseCore input "
54+
"preprocessing: "
5455
<< num_threads << " threads";
5556
auto thread_pool = new tsl::thread::ThreadPool(
5657
tsl::Env::Default(), tsl::ThreadOptions(), kScPool, num_threads,

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ struct ExtractedCooTensors {
222222
// grouping them. Might be lower after deduplication.
223223
std::vector<int> coo_tensors_per_sc;
224224

225+
ExtractedCooTensors() : ExtractedCooTensors(0, 0) {}
225226
ExtractedCooTensors(int num_sc_per_device, int batch_size_for_device)
226227
: batch_size_for_device(batch_size_for_device),
227228
coo_tensors_per_sc(num_sc_per_device, 0) {}

jax_tpu_embedding/sparsecore/lib/core/partitioned_coo_tensors.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ namespace jax_sc_embedding {
3232

3333
class PartitionedCooTensors {
3434
public:
35+
PartitionedCooTensors() : PartitionedCooTensors(0, 0, 0, 1) {}
3536
PartitionedCooTensors(int reserve_count, int num_sc_per_device,
3637
uint32_t global_sc_count, int bucket_count_per_sc = 1)
3738
: coo_tensors_(),

jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_input_benchmarks.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,31 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""Simple benchmarks for preprocessing input for sparse-dense matmul.
14+
r"""Simple benchmarks for preprocessing input for sparse-dense matmul.
1515
1616
Example usage:
1717
1818
On perflab comparing against HEAD:
19-
benchy --perflab --runs=10 --reference=srcfs --benchmark_filter=all
19+
benchy --perflab --runs=10 --reference=srcfs --benchmark_filter=all \
2020
:preprocess_input_benchmarks
2121
2222
Or locally:
23-
bazel run -c opt --dynamic_mode=off --copt=-gmlt :preprocess_input_benchmarks --
23+
bazel run -c opt --dynamic_mode=off --copt=-gmlt :preprocess_input_benchmarks -- \
2424
--benchmark_filter=all --cpu_profile=/tmp/preprocess.prof
25+
26+
The --benchmark_filter flag uses a regex to select benchmarks. For parameterized
27+
benchmarks, the name is typically formatted as:
28+
`[benchmark_name]/[param1]:[value1]/[param2]:[value2]`.
29+
Boolean parameters are often represented as 0 for False and 1 for True.
30+
31+
For example, to run only the `sparse_coo` benchmarks:
32+
`--benchmark_filter=preprocess_input_benchmark_sparse_coo`
33+
34+
To run only the `sparse_coo` benchmark where `has_leading_dimension` is `False`:
35+
`--benchmark_filter='preprocess_input_benchmark_sparse_coo/has_leading_dimension:0'`
36+
37+
To run all benchmarks across all suites where `has_leading_dimension` is `False`:
38+
`--benchmark_filter='/has_leading_dimension:0'`
2539
"""
2640

2741
import concurrent

0 commit comments

Comments
 (0)