Skip to content

Commit

Permalink
[xla:cpu:xnn] Measure execution time of parallel task to decide the o…
Browse files Browse the repository at this point in the history
…ptimal number of workers

```
```

PiperOrigin-RevId: 716882217
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Jan 20, 2025
1 parent 8b9abb4 commit 52992e6
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 39 deletions.
5 changes: 5 additions & 0 deletions xla/backends/cpu/runtime/xnnpack/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,12 @@ cc_library(
deps = [
"//xla/tsl/concurrency:async_value",
"//xla/tsl/lib/math:math_util",
"//xla/tsl/platform:env",
"//xla/tsl/platform:logging",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/time",
"@eigen_archive//:eigen3",
],
)
Expand All @@ -66,6 +69,7 @@ xla_cc_test(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@eigen_archive//:eigen3",
],
Expand Down Expand Up @@ -203,6 +207,7 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@eigen_archive//:eigen3",
"@pthreadpool",
Expand Down
121 changes: 101 additions & 20 deletions xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@ limitations under the License.
#include <optional>
#include <utility>

#include "absl/base/attributes.h"
#include "absl/base/optimization.h"
#include "absl/log/check.h"
#include "absl/time/time.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/concurrency/chain.h"
#include "xla/tsl/lib/math/math_util.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/logging.h"

#define EIGEN_USE_THREADS
Expand All @@ -50,8 +54,12 @@ static tsl::AsyncValueRef<tsl::Chain> OkDoneEventSingleton() {
return singleton->AsRef();
}

ParallelLoopRunner::ParallelLoopRunner(const Eigen::ThreadPoolDevice* device)
: done_event_(OkDoneEventSingleton()), device_(device) {}
ParallelLoopRunner::ParallelLoopRunner(
const Eigen::ThreadPoolDevice* device,
std::optional<absl::Duration> worker_timeslice)
: done_event_(OkDoneEventSingleton()),
device_(device),
worker_timeslice_(worker_timeslice) {}

tsl::AsyncValueRef<tsl::Chain> ParallelLoopRunner::ResetDoneEvent() {
auto done_event = std::move(done_event_);
Expand Down Expand Up @@ -169,7 +177,7 @@ static void Parallelize(ParallelizeContext* ctx, uint16_t start_index,
}

template <typename ParallelTask>
void ParallelLoopRunner::Parallelize(
ABSL_ATTRIBUTE_ALWAYS_INLINE void ParallelLoopRunner::Parallelize(
tsl::CountDownAsyncValueRef<tsl::Chain> count_down, size_t num_workers,
size_t num_tasks, ParallelTask&& parallel_task) {
DCHECK_EQ(count_down.count(), num_workers)
Expand Down Expand Up @@ -213,7 +221,7 @@ void ParallelLoopRunner::Parallelize(
}

template <typename Task>
void ParallelLoopRunner::ScheduleOne(Task&& task) {
ABSL_ATTRIBUTE_ALWAYS_INLINE void ParallelLoopRunner::ScheduleOne(Task&& task) {
auto event = tsl::MakeConstructedAsyncValueRef<tsl::Chain>();
done_event_.AndThen([event, task = std::forward<Task>(task)] {
task();
Expand All @@ -222,23 +230,103 @@ void ParallelLoopRunner::ScheduleOne(Task&& task) {
done_event_ = std::move(event);
}

// Compute the number of workers that should be used for parallel operation, by
// executing the first task, measuring the compute time and estimating how many
// workers are needed, so that each worker will handle `worker_timeslice` amount
// of compute.
template <typename ParallelTask>
void ParallelLoopRunner::ScheduleAll(size_t num_tasks,
ParallelTask&& parallel_task) {
// We use at most `num_threads()` workers as we can't run more parallel
// workers than the number of threads in the thread pool.
ABSL_ATTRIBUTE_ALWAYS_INLINE size_t
ComputeOptimalNumWorkers(absl::Duration worker_timeslice, size_t num_threads,
size_t num_tasks, ParallelTask& parallel_task) {
// Run first task in the caller thread, to estimate the number of parallel
// workers that should be used for parallel operation.
uint64_t start_ns = tsl::Env::Default()->NowNanos();
parallel_task(0);
uint64_t end_ns = tsl::Env::Default()->NowNanos();

// We assume that all tasks take roughly the same amount of compute and we
// can estimate the total workload duration by multiplying the number of
// remaining tasks by the duration of a single task.
size_t workload_ns = (num_tasks - 1) * (end_ns - start_ns);
size_t timeslice_ns = absl::ToInt64Nanoseconds(worker_timeslice);

// Get the number of workers, so that each worker will take roughly
// `worker_timeslice` amount of compute. Don't create more workers than
// the number of threads in the thread pool or the number of tasks.
size_t num_workers =
std::min(std::min(num_tasks - 1, num_threads),
tsl::MathUtil::CeilOfRatio(workload_ns, timeslice_ns));
return std::min(num_workers, size_t{std::numeric_limits<uint16_t>::max()});
}

template <typename ParallelTask>
ABSL_ATTRIBUTE_ALWAYS_INLINE void ParallelLoopRunner::ScheduleAll(
size_t num_tasks, ParallelTask&& parallel_task) {
DCHECK_GT(num_tasks, 1) << "Expected at least two task";

// If done event is already available and we have a worker timeslice, we can
// compute the optimal number of workers for the parallel operation and
// potentially avoid allocating count down counter altogether.
if (ABSL_PREDICT_TRUE(done_event_.IsConcrete() && worker_timeslice_)) {
size_t optimal_num_workers = ComputeOptimalNumWorkers(
*worker_timeslice_, num_threads(), num_tasks, parallel_task);

// Execute remaining tasks in the caller thread if we have a single worker.
if (ABSL_PREDICT_TRUE(optimal_num_workers == 1)) {
for (size_t i = 1; i < num_tasks; ++i) {
parallel_task(i);
}
return;
}

tsl::CountDownAsyncValueRef<tsl::Chain> count_down(optimal_num_workers);
done_event_ = count_down.AsRef();

// Parallelize the remaining tasks (skip the first task that was executed
// when we were computing the number of workers).
Parallelize(std::move(count_down), optimal_num_workers, num_tasks - 1,
[parallel_task = std::forward<ParallelTask>(parallel_task)](
size_t task_index) { parallel_task(task_index + 1); });
return;
}

// If `done_event_` is not available, we start with at most `num_threads()`
// workers as we can't run more parallel workers than the number of threads in
// the thread pool. Later we might adjust the number of workers when it's safe
// to execute the first task to measure the execution time.
size_t num_workers = std::min(std::min(num_tasks, num_threads()),
size_t{std::numeric_limits<uint16_t>::max()});

tsl::CountDownAsyncValueRef<tsl::Chain> count_down(num_workers);
auto count_down_done = count_down.AsRef();

done_event_.AndThen(
auto schedule_all =
[this, num_workers, num_tasks, count_down = std::move(count_down),
parallel_task = std::forward<ParallelTask>(parallel_task)] {
Parallelize(std::move(count_down), num_workers, num_tasks,
std::move(parallel_task));
});
parallel_task = std::forward<ParallelTask>(parallel_task)]() mutable {
// If we don't have a worker timeslice, we can parallelize the task
// immediately using pre-computed number of workers.
if (ABSL_PREDICT_FALSE(!worker_timeslice_)) {
Parallelize(std::move(count_down), num_workers, num_tasks,
std::move(parallel_task));
return;
}

// Compute the optimal number of workers by executing the first task.
size_t optimal_num_workers = ComputeOptimalNumWorkers(
*worker_timeslice_, num_threads(), num_tasks, parallel_task);
DCHECK_LE(optimal_num_workers, num_workers);

// Count down for the workers that we don't need.
count_down.CountDown(num_workers - optimal_num_workers);

// Parallelize the remaining tasks (skip the first task that was
// executed when we were computing the number of workers).
Parallelize(std::move(count_down), optimal_num_workers, num_tasks - 1,
[parallel_task = std::move(parallel_task)](
size_t task_index) { parallel_task(task_index + 1); });
};

done_event_.AndThen(std::move(schedule_all));
done_event_ = std::move(count_down_done);
}

Expand Down Expand Up @@ -376,8 +464,6 @@ void ParallelLoopRunner::Parallelize(size_t range, size_t tile,

// Fast path for the degenerate parallel loop with single task.
if (ABSL_PREDICT_TRUE(num_tasks == 1)) {
DCHECK_EQ(range, tile) << "Expected range to be equal to tile";

// Execute task in the caller thread if done event is already available.
if (ABSL_PREDICT_TRUE(done_event_.IsConcrete())) {
task(0, range);
Expand Down Expand Up @@ -405,8 +491,6 @@ void ParallelLoopRunner::Parallelize(size_t range_i, size_t range_j,

// Fast path for the degenerate parallel loop with single task.
if (ABSL_PREDICT_TRUE(num_tasks == 1)) {
DCHECK_EQ(range_j, tile_j) << "Expected range to be equal to tile";

// Execute task in the caller thread if done event is already available.
if (ABSL_PREDICT_TRUE(done_event_.IsConcrete())) {
task(0, 0, range_j);
Expand Down Expand Up @@ -435,9 +519,6 @@ void ParallelLoopRunner::Parallelize(size_t range_i, size_t range_j,

// Fast path for the degenerate parallel loop with single task.
if (ABSL_PREDICT_TRUE(num_tasks == 1)) {
DCHECK_EQ(range_j, tile_j) << "Expected range to be equal to tile";
DCHECK_EQ(range_k, tile_k) << "Expected range to be equal to tile";

// Execute task in the caller thread if done event is already available.
if (ABSL_PREDICT_TRUE(done_event_.IsConcrete())) {
task(0, 0, 0, range_j, range_k);
Expand Down
15 changes: 14 additions & 1 deletion xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include <optional>

#include "absl/container/fixed_array.h"
#include "absl/time/time.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/concurrency/chain.h"

Expand All @@ -43,6 +44,12 @@ namespace xla::cpu {
// Parallel loop runner is an implementation of the `pthreadpool` API adaptor
// for XLA:CPU runtime.
//
// Parallel loop runner can be configured by the `worker_timeslice` parameter,
// that defines the approximate amount of compute (in terms of wall time) that
// each persistent worker will handle. We rely on this parameter to avoid
// scheduling too many workers into the thread pool, because for tiny tasks the
// overheads can be prohibitively expensive.
//
// WARNING: ParallelLoopRunner is not thread-safe, and must be externally
// synchronized by the user.
class ParallelLoopRunner {
Expand All @@ -56,7 +63,9 @@ class ParallelLoopRunner {
#endif

public:
explicit ParallelLoopRunner(const Eigen::ThreadPoolDevice* device);
explicit ParallelLoopRunner(
const Eigen::ThreadPoolDevice* device,
std::optional<absl::Duration> worker_timeslice = std::nullopt);

// Takes ownership of the runner and returns a done event. After the done
// event is transferred to the caller, it is illegal to schedule more parallel
Expand Down Expand Up @@ -202,6 +211,10 @@ class ParallelLoopRunner {
// pools for different NUMA nodes, and we have to be able to switch between
// them from run to run.
std::atomic<const Eigen::ThreadPoolDevice*> device_;

// The approximate amount of compute (in terms of wall time) that each
// persistent worker should handle.
std::optional<absl::Duration> worker_timeslice_;
};

} // namespace xla::cpu
Expand Down
Loading

0 comments on commit 52992e6

Please sign in to comment.