Skip to content

Commit 1ce8e3b

Browse files
[JAX SC] Parallelize device loop for extraction, sorting and grouping.
* `9.28%` geomean reduction in wall time with `0.61%` CPU time increase and `6.05%` cycles reduction. * Use separate pool to avoid deadlocks. The fixed cost for scheduling should be less than 0.1% * Add default constructible objects for parallelization. PiperOrigin-RevId: 826509091
1 parent 00afdd2 commit 1ce8e3b

File tree

6 files changed

+98
-35
lines changed

6 files changed

+98
-35
lines changed

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ struct TableState {
126126
MinibatchingSplit table_minibatching_split = 0;
127127
std::vector<ExtractedCooTensors> extracted_coo_tensors_per_device;
128128
std::vector<PartitionedCooTensors> partitioned_coo_tensors_per_device;
129+
std::vector<int> dropped_id_count_per_device;
129130

130131
TableState(const std::string& name,
131132
absl::Span<const StackedTableMetadata> metadata,
@@ -147,8 +148,9 @@ struct TableState {
147148
stats_per_host(options.local_device_count, options.GetNumScs(),
148149
options.num_sc_per_device),
149150
batch_size_for_device(0) {
150-
extracted_coo_tensors_per_device.reserve(options.local_device_count);
151-
partitioned_coo_tensors_per_device.reserve(options.local_device_count);
151+
extracted_coo_tensors_per_device.resize(options.local_device_count);
152+
partitioned_coo_tensors_per_device.resize(options.local_device_count);
153+
dropped_id_count_per_device.resize(options.local_device_count, 0);
152154
}
153155
};
154156

@@ -165,26 +167,40 @@ void ExtractSortAndGroupCooTensorsForTable(
165167
state.stacked_table_name);
166168
});
167169

170+
absl::BlockingCounter counter(options.local_device_count);
168171
for (int local_device = 0; local_device < options.local_device_count;
169172
++local_device) {
170-
ExtractedCooTensors extracted_coo_tensors =
171-
internal::ExtractCooTensorsForAllFeaturesPerLocalDevice(
172-
state.stacked_table_metadata, input_batches, local_device, options);
173-
state.extracted_coo_tensors_per_device.push_back(extracted_coo_tensors);
174-
if (local_device == 0)
175-
state.batch_size_for_device = extracted_coo_tensors.batch_size_for_device;
176-
else
177-
CHECK_EQ(state.batch_size_for_device,
178-
extracted_coo_tensors.batch_size_for_device);
179-
180-
internal::StatsPerDevice stats_per_device =
181-
state.stats_per_host.GetStatsPerDevice(local_device);
182-
const PartitionedCooTensors grouped_coo_tensors =
183-
SortAndGroupCooTensorsPerLocalDevice(
184-
extracted_coo_tensors, state.stacked_table_metadata[0], options,
185-
stats_per_device, state.table_minibatching_required);
186-
state.partitioned_coo_tensors_per_device.push_back(grouped_coo_tensors);
187-
state.stats_per_host.dropped_id_count += stats_per_device.dropped_id_count;
173+
DeviceProcessingThreadPool()->Schedule([&, local_device] {
174+
state.extracted_coo_tensors_per_device[local_device] =
175+
internal::ExtractCooTensorsForAllFeaturesPerLocalDevice(
176+
state.stacked_table_metadata, input_batches, local_device,
177+
options);
178+
179+
internal::StatsPerDevice stats_per_device =
180+
state.stats_per_host.GetStatsPerDevice(local_device);
181+
state.partitioned_coo_tensors_per_device[local_device] =
182+
SortAndGroupCooTensorsPerLocalDevice(
183+
state.extracted_coo_tensors_per_device[local_device],
184+
state.stacked_table_metadata[0], options, stats_per_device,
185+
state.table_minibatching_required);
186+
state.dropped_id_count_per_device[local_device] =
187+
stats_per_device.dropped_id_count;
188+
counter.DecrementCount();
189+
});
190+
}
191+
counter.Wait();
192+
193+
// Post-process results after all threads are done.
194+
state.batch_size_for_device =
195+
state.extracted_coo_tensors_per_device[0].batch_size_for_device;
196+
state.stats_per_host.dropped_id_count = 0;
197+
for (int local_device = 0; local_device < options.local_device_count;
198+
++local_device) {
199+
DCHECK_EQ(state.batch_size_for_device,
200+
state.extracted_coo_tensors_per_device[local_device]
201+
.batch_size_for_device);
202+
state.stats_per_host.dropped_id_count +=
203+
state.dropped_id_count_per_device[local_device];
188204
}
189205
}
190206

@@ -518,9 +534,9 @@ PreprocessSparseDenseMatmulInput(
518534
// Stage 1: COO Extraction and Initial Sort/Group
519535
{
520536
tsl::profiler::TraceMe traceme("ExtractSortAndGroupCooTensors");
521-
absl::BlockingCounter counter(stacked_tables.size());
537+
absl::BlockingCounter counter(table_states.size());
522538
for (auto& state : table_states) {
523-
PreprocessingThreadPool()->Schedule([&, &state = state] {
539+
TableProcessingThreadPool()->Schedule([&, &state = state] {
524540
ExtractSortAndGroupCooTensorsForTable(state, input_batches, options);
525541
counter.DecrementCount();
526542
});
@@ -536,9 +552,9 @@ PreprocessSparseDenseMatmulInput(
536552
if (options.enable_minibatching && global_minibatching_required) {
537553
{
538554
tsl::profiler::TraceMe traceme("CreateMinibatchingBuckets");
539-
absl::BlockingCounter counter(stacked_tables.size());
555+
absl::BlockingCounter counter(table_states.size());
540556
for (auto& state : table_states) {
541-
PreprocessingThreadPool()->Schedule([&, &state = state] {
557+
TableProcessingThreadPool()->Schedule([&, &state = state] {
542558
CreateMinibatchingBucketsForTable(state, options);
543559
counter.DecrementCount();
544560
});
@@ -553,9 +569,9 @@ PreprocessSparseDenseMatmulInput(
553569
// Stage 3: Fill Device Buffers
554570
{
555571
tsl::profiler::TraceMe traceme("FillDeviceBuffers");
556-
absl::BlockingCounter counter(stacked_tables.size());
572+
absl::BlockingCounter counter(table_states.size());
557573
for (auto& state : table_states) {
558-
PreprocessingThreadPool()->Schedule([&, &state = state,
574+
TableProcessingThreadPool()->Schedule([&, &state = state,
559575
global_minibatching_required,
560576
global_minibatching_split] {
561577
FillDeviceBuffersForTable(state, options, global_minibatching_required,

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.cc

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,14 @@ namespace jax_sc_embedding {
2828
namespace {
2929

3030
constexpr char kScEnv[] = "SPARSECORE_INPUT_PREPROCESSING_THREADS";
31-
constexpr char kScPool[] = "SparseCoreInputPreprocessingThreadPool";
31+
constexpr char kDevicePool[] = "SparseCoreDeviceProcessingThreadPool";
32+
constexpr char kTablePool[] = "SparseCoreTableProcessingThreadPool";
3233

3334
// Returns at least one but the minimum of NumSchedulableCPUs() and the value
3435
// specified by the environment variable
3536
// `SPARSECORE_INPUT_PREPROCESSING_THREADS`.
37+
// NOTE: This size applies to *each* thread pool (Device and Table). If the env
38+
// var is set to N, 2*N threads may be created in total.
3639
int GetThreadPoolSize() {
3740
int num_threads = tsl::port::NumSchedulableCPUs();
3841
if (const char* env = std::getenv(kScEnv); env != nullptr) {
@@ -46,14 +49,30 @@ int GetThreadPoolSize() {
4649

4750
} // namespace
4851

49-
tsl::thread::ThreadPool* PreprocessingThreadPool() {
52+
tsl::thread::ThreadPool* DeviceProcessingThreadPool() {
5053
static tsl::thread::ThreadPool* pool = []() {
5154
const int num_threads = GetThreadPoolSize();
5255
DCHECK_GE(num_threads, 1);
53-
LOG(INFO) << "Creating thread pool for SparseCore input preprocessing: "
56+
LOG(INFO) << "Creating device processing thread pool for SparseCore input "
57+
"preprocessing: "
5458
<< num_threads << " threads";
5559
auto thread_pool = new tsl::thread::ThreadPool(
56-
tsl::Env::Default(), tsl::ThreadOptions(), kScPool, num_threads,
60+
tsl::Env::Default(), tsl::ThreadOptions(), kDevicePool, num_threads,
61+
/*low_latency_hint=*/false);
62+
return thread_pool;
63+
}();
64+
return pool;
65+
}
66+
67+
tsl::thread::ThreadPool* TableProcessingThreadPool() {
68+
static tsl::thread::ThreadPool* pool = []() {
69+
const int num_threads = GetThreadPoolSize();
70+
DCHECK_GE(num_threads, 1);
71+
LOG(INFO) << "Creating table processing thread pool for SparseCore input "
72+
"preprocessing: "
73+
<< num_threads << " threads";
74+
auto thread_pool = new tsl::thread::ThreadPool(
75+
tsl::Env::Default(), tsl::ThreadOptions(), kTablePool, num_threads,
5776
/*low_latency_hint=*/false);
5877
return thread_pool;
5978
}();

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,20 @@
1818

1919
namespace jax_sc_embedding {
2020

21-
// Global thread pool for all computations done by input preprocessing.
22-
tsl::thread::ThreadPool* PreprocessingThreadPool();
21+
// We use two separate thread pools to handle nested parallelism in input
22+
// preprocessing. Table-level tasks are scheduled onto TableProcessingThreadPool,
23+
// and each of these tasks may schedule multiple device-level tasks onto
24+
// DeviceProcessingThreadPool.
25+
// If a single pool were used, it could lead to deadlock: if all threads in the
26+
// pool were occupied by table-level tasks blocked waiting for device-level
27+
// tasks to complete, no threads would be available to run the device-level
28+
// tasks, and the system would hang. Using separate pools prevents this issue.
29+
30+
// Thread pool for device-level computations in input preprocessing.
31+
tsl::thread::ThreadPool* DeviceProcessingThreadPool();
32+
33+
// Thread pool for table-level computations in input preprocessing.
34+
tsl::thread::ThreadPool* TableProcessingThreadPool();
2335

2436
} // namespace jax_sc_embedding
2537

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ struct ExtractedCooTensors {
222222
// grouping them. Might be lower after deduplication.
223223
std::vector<int> coo_tensors_per_sc;
224224

225+
ExtractedCooTensors() : ExtractedCooTensors(0, 0) {}
225226
ExtractedCooTensors(int num_sc_per_device, int batch_size_for_device)
226227
: batch_size_for_device(batch_size_for_device),
227228
coo_tensors_per_sc(num_sc_per_device, 0) {}

jax_tpu_embedding/sparsecore/lib/core/partitioned_coo_tensors.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ namespace jax_sc_embedding {
3232

3333
class PartitionedCooTensors {
3434
public:
35+
PartitionedCooTensors() : PartitionedCooTensors(0, 0, 0, 1) {}
3536
PartitionedCooTensors(int reserve_count, int num_sc_per_device,
3637
uint32_t global_sc_count, int bucket_count_per_sc = 1)
3738
: coo_tensors_(),

jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_input_benchmarks.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,31 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""Simple benchmarks for preprocessing input for sparse-dense matmul.
14+
r"""Simple benchmarks for preprocessing input for sparse-dense matmul.
1515
1616
Example usage:
1717
1818
On perflab comparing against HEAD:
19-
benchy --perflab --runs=10 --reference=srcfs --benchmark_filter=all
19+
benchy --perflab --runs=10 --reference=srcfs --benchmark_filter=all \
2020
:preprocess_input_benchmarks
2121
2222
Or locally:
23-
bazel run -c opt --dynamic_mode=off --copt=-gmlt :preprocess_input_benchmarks --
23+
bazel run -c opt --dynamic_mode=off --copt=-gmlt :preprocess_input_benchmarks -- \
2424
--benchmark_filter=all --cpu_profile=/tmp/preprocess.prof
25+
26+
The --benchmark_filter flag uses a regex to select benchmarks. For parameterized
27+
benchmarks, the name is typically formatted as:
28+
`[benchmark_name]/[param1]:[value1]/[param2]:[value2]`.
29+
Boolean parameters are often represented as 0 for False and 1 for True.
30+
31+
For example, to run only the `sparse_coo` benchmarks:
32+
`--benchmark_filter=preprocess_input_benchmark_sparse_coo`
33+
34+
To run only the `sparse_coo` benchmark where `has_leading_dimension` is `False`:
35+
`--benchmark_filter='preprocess_input_benchmark_sparse_coo/has_leading_dimension:0'`
36+
37+
To run all benchmarks across all suites where `has_leading_dimension` is `False`:
38+
`--benchmark_filter='/has_leading_dimension:0'`
2539
"""
2640

2741
import concurrent

0 commit comments

Comments
 (0)