Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xla:cpu:xnn] Measure execution time of parallel task to decide the optimal number of workers #21580

Merged
merged 1 commit into from
Jan 20, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions xla/backends/cpu/runtime/xnnpack/BUILD
Original file line number Diff line number Diff line change
@@ -46,9 +46,12 @@ cc_library(
deps = [
"//xla/tsl/concurrency:async_value",
"//xla/tsl/lib/math:math_util",
"//xla/tsl/platform:env",
"//xla/tsl/platform:logging",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/time",
"@eigen_archive//:eigen3",
],
)
@@ -66,6 +69,7 @@ xla_cc_test(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@eigen_archive//:eigen3",
],
@@ -203,6 +207,7 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@eigen_archive//:eigen3",
"@pthreadpool",
121 changes: 101 additions & 20 deletions xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc
Original file line number Diff line number Diff line change
@@ -24,10 +24,14 @@ limitations under the License.
#include <optional>
#include <utility>

#include "absl/base/attributes.h"
#include "absl/base/optimization.h"
#include "absl/log/check.h"
#include "absl/time/time.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/concurrency/chain.h"
#include "xla/tsl/lib/math/math_util.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/logging.h"

#define EIGEN_USE_THREADS
@@ -50,8 +54,12 @@ static tsl::AsyncValueRef<tsl::Chain> OkDoneEventSingleton() {
return singleton->AsRef();
}

ParallelLoopRunner::ParallelLoopRunner(const Eigen::ThreadPoolDevice* device)
: done_event_(OkDoneEventSingleton()), device_(device) {}
ParallelLoopRunner::ParallelLoopRunner(
const Eigen::ThreadPoolDevice* device,
std::optional<absl::Duration> worker_timeslice)
: done_event_(OkDoneEventSingleton()),
device_(device),
worker_timeslice_(worker_timeslice) {}

tsl::AsyncValueRef<tsl::Chain> ParallelLoopRunner::ResetDoneEvent() {
auto done_event = std::move(done_event_);
@@ -169,7 +177,7 @@ static void Parallelize(ParallelizeContext* ctx, uint16_t start_index,
}

template <typename ParallelTask>
void ParallelLoopRunner::Parallelize(
ABSL_ATTRIBUTE_ALWAYS_INLINE void ParallelLoopRunner::Parallelize(
tsl::CountDownAsyncValueRef<tsl::Chain> count_down, size_t num_workers,
size_t num_tasks, ParallelTask&& parallel_task) {
DCHECK_EQ(count_down.count(), num_workers)
@@ -213,7 +221,7 @@ void ParallelLoopRunner::Parallelize(
}

template <typename Task>
void ParallelLoopRunner::ScheduleOne(Task&& task) {
ABSL_ATTRIBUTE_ALWAYS_INLINE void ParallelLoopRunner::ScheduleOne(Task&& task) {
auto event = tsl::MakeConstructedAsyncValueRef<tsl::Chain>();
done_event_.AndThen([event, task = std::forward<Task>(task)] {
task();
@@ -222,23 +230,103 @@ void ParallelLoopRunner::ScheduleOne(Task&& task) {
done_event_ = std::move(event);
}

// Compute the number of workers that should be used for parallel operation, by
// executing the first task, measuring the compute time and estimating how many
// workers are needed, so that each worker will handle `worker_timeslice` amount
// of compute.
template <typename ParallelTask>
void ParallelLoopRunner::ScheduleAll(size_t num_tasks,
ParallelTask&& parallel_task) {
// We use at most `num_threads()` workers as we can't run more parallel
// workers than the number of threads in the thread pool.
ABSL_ATTRIBUTE_ALWAYS_INLINE size_t
ComputeOptimalNumWorkers(absl::Duration worker_timeslice, size_t num_threads,
size_t num_tasks, ParallelTask& parallel_task) {
// Run first task in the caller thread, to estimate the number of parallel
// workers that should be used for parallel operation.
uint64_t start_ns = tsl::Env::Default()->NowNanos();
parallel_task(0);
uint64_t end_ns = tsl::Env::Default()->NowNanos();

// We assume that all tasks take roughly the same amount of compute and we
// can estimate the total workload duration by multiplying the number of
// remaining tasks by the duration of a single task.
size_t workload_ns = (num_tasks - 1) * (end_ns - start_ns);
size_t timeslice_ns = absl::ToInt64Nanoseconds(worker_timeslice);

// Get the number of workers, so that each worker will take roughly
// `worker_timeslice` amount of compute. Don't create more workers than
// the number of threads in the thread pool or the number of tasks.
size_t num_workers =
std::min(std::min(num_tasks - 1, num_threads),
tsl::MathUtil::CeilOfRatio(workload_ns, timeslice_ns));
return std::min(num_workers, size_t{std::numeric_limits<uint16_t>::max()});
}

template <typename ParallelTask>
ABSL_ATTRIBUTE_ALWAYS_INLINE void ParallelLoopRunner::ScheduleAll(
size_t num_tasks, ParallelTask&& parallel_task) {
DCHECK_GT(num_tasks, 1) << "Expected at least two task";

// If done event is already available and we have a worker timeslice, we can
// compute the optimal number of workers for the parallel operation and
// potentially avoid allocating count down counter altogether.
if (ABSL_PREDICT_TRUE(done_event_.IsConcrete() && worker_timeslice_)) {
size_t optimal_num_workers = ComputeOptimalNumWorkers(
*worker_timeslice_, num_threads(), num_tasks, parallel_task);

// Execute remaining tasks in the caller thread if we have a single worker.
if (ABSL_PREDICT_TRUE(optimal_num_workers == 1)) {
for (size_t i = 1; i < num_tasks; ++i) {
parallel_task(i);
}
return;
}

tsl::CountDownAsyncValueRef<tsl::Chain> count_down(optimal_num_workers);
done_event_ = count_down.AsRef();

// Parallelize the remaining tasks (skip the first task that was executed
// when we were computing the number of workers).
Parallelize(std::move(count_down), optimal_num_workers, num_tasks - 1,
[parallel_task = std::forward<ParallelTask>(parallel_task)](
size_t task_index) { parallel_task(task_index + 1); });
return;
}

// If `done_event_` is not available, we start with at most `num_threads()`
// workers as we can't run more parallel workers than the number of threads in
// the thread pool. Later we might adjust the number of workers when it's safe
// to execute the first task to measure the execution time.
size_t num_workers = std::min(std::min(num_tasks, num_threads()),
size_t{std::numeric_limits<uint16_t>::max()});

tsl::CountDownAsyncValueRef<tsl::Chain> count_down(num_workers);
auto count_down_done = count_down.AsRef();

done_event_.AndThen(
auto schedule_all =
[this, num_workers, num_tasks, count_down = std::move(count_down),
parallel_task = std::forward<ParallelTask>(parallel_task)] {
Parallelize(std::move(count_down), num_workers, num_tasks,
std::move(parallel_task));
});
parallel_task = std::forward<ParallelTask>(parallel_task)]() mutable {
// If we don't have a worker timeslice, we can parallelize the task
// immediately using pre-computed number of workers.
if (ABSL_PREDICT_FALSE(!worker_timeslice_)) {
Parallelize(std::move(count_down), num_workers, num_tasks,
std::move(parallel_task));
return;
}

// Compute the optimal number of workers by executing the first task.
size_t optimal_num_workers = ComputeOptimalNumWorkers(
*worker_timeslice_, num_threads(), num_tasks, parallel_task);
DCHECK_LE(optimal_num_workers, num_workers);

// Count down for the workers that we don't need.
count_down.CountDown(num_workers - optimal_num_workers);

// Parallelize the remaining tasks (skip the first task that was
// executed when we were computing the number of workers).
Parallelize(std::move(count_down), optimal_num_workers, num_tasks - 1,
[parallel_task = std::move(parallel_task)](
size_t task_index) { parallel_task(task_index + 1); });
};

done_event_.AndThen(std::move(schedule_all));
done_event_ = std::move(count_down_done);
}

@@ -376,8 +464,6 @@ void ParallelLoopRunner::Parallelize(size_t range, size_t tile,

// Fast path for the degenerate parallel loop with single task.
if (ABSL_PREDICT_TRUE(num_tasks == 1)) {
DCHECK_EQ(range, tile) << "Expected range to be equal to tile";

// Execute task in the caller thread if done event is already available.
if (ABSL_PREDICT_TRUE(done_event_.IsConcrete())) {
task(0, range);
@@ -405,8 +491,6 @@ void ParallelLoopRunner::Parallelize(size_t range_i, size_t range_j,

// Fast path for the degenerate parallel loop with single task.
if (ABSL_PREDICT_TRUE(num_tasks == 1)) {
DCHECK_EQ(range_j, tile_j) << "Expected range to be equal to tile";

// Execute task in the caller thread if done event is already available.
if (ABSL_PREDICT_TRUE(done_event_.IsConcrete())) {
task(0, 0, range_j);
@@ -435,9 +519,6 @@ void ParallelLoopRunner::Parallelize(size_t range_i, size_t range_j,

// Fast path for the degenerate parallel loop with single task.
if (ABSL_PREDICT_TRUE(num_tasks == 1)) {
DCHECK_EQ(range_j, tile_j) << "Expected range to be equal to tile";
DCHECK_EQ(range_k, tile_k) << "Expected range to be equal to tile";

// Execute task in the caller thread if done event is already available.
if (ABSL_PREDICT_TRUE(done_event_.IsConcrete())) {
task(0, 0, 0, range_j, range_k);
15 changes: 14 additions & 1 deletion xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ limitations under the License.
#include <optional>

#include "absl/container/fixed_array.h"
#include "absl/time/time.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/concurrency/chain.h"

@@ -43,6 +44,12 @@ namespace xla::cpu {
// Parallel loop runner is an implementation of the `pthreadpool` API adaptor
// for XLA:CPU runtime.
//
// Parallel loop runner can be configured by the `worker_timeslice` parameter,
// that defines the approximate amount of compute (in terms of wall time) that
// each persistent worker will handle. We rely on this parameter to avoid
// scheduling too many workers into the thread pool, because for tiny tasks the
// overheads can be prohibitively expensive.
//
// WARNING: ParallelLoopRunner is not thread-safe, and must be externally
// synchronized by the user.
class ParallelLoopRunner {
@@ -56,7 +63,9 @@ class ParallelLoopRunner {
#endif

public:
explicit ParallelLoopRunner(const Eigen::ThreadPoolDevice* device);
explicit ParallelLoopRunner(
const Eigen::ThreadPoolDevice* device,
std::optional<absl::Duration> worker_timeslice = std::nullopt);

// Takes ownership of the runner and returns a done event. After the done
// event is transferred to the caller, it is illegal to schedule more parallel
@@ -202,6 +211,10 @@ class ParallelLoopRunner {
// pools for different NUMA nodes, and we have to be able to switch between
// them from run to run.
std::atomic<const Eigen::ThreadPoolDevice*> device_;

// The approximate amount of compute (in terms of wall time) that each
// persistent worker should handle.
std::optional<absl::Duration> worker_timeslice_;
};

} // namespace xla::cpu
53 changes: 36 additions & 17 deletions xla/backends/cpu/runtime/xnnpack/parallel_loop_runner_test.cc
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/cleanup/cleanup.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/env.h"
@@ -38,7 +39,7 @@ limitations under the License.
namespace xla::cpu {
namespace {

TEST(ParallelLoopRunnerTest, WorkQueueSimple) {
TEST(ParallelLoopRunnerWorkerTest, WorkQueueSimple) {
ParallelLoopRunner::WorkQueue queue(20, 10);

EXPECT_EQ(queue.Pop(0), std::make_optional(0));
@@ -48,7 +49,7 @@ TEST(ParallelLoopRunnerTest, WorkQueueSimple) {
EXPECT_EQ(queue.Pop(1), std::make_optional(2));
}

TEST(ParallelLoopRunnerTest, WorkQueueEmptyPartitions) {
TEST(ParallelLoopRunnerWorkerTest, WorkQueueEmptyPartitions) {
ParallelLoopRunner::WorkQueue queue(1, 10);

EXPECT_EQ(queue.Pop(0), std::make_optional(0));
@@ -59,7 +60,7 @@ TEST(ParallelLoopRunnerTest, WorkQueueEmptyPartitions) {
}
}

TEST(ParallelLoopRunnerTest, WorkQueue) {
TEST(ParallelLoopRunnerWorkerTest, WorkQueue) {
for (size_t size : {1, 2, 4, 8, 16, 32, 64}) {
for (size_t num_partitions : {1, 2, 3, 4, 5, 6, 7, 8}) {
ParallelLoopRunner::WorkQueue queue(size, num_partitions);
@@ -79,7 +80,7 @@ TEST(ParallelLoopRunnerTest, WorkQueue) {
}
}

TEST(ParallelLoopRunnerTest, Worker) {
TEST(ParallelLoopRunnerWorkerTest, Worker) {
for (size_t size : {1, 2, 4, 8, 16, 32, 64}) {
for (size_t num_partitions : {1, 2, 3, 4, 5, 6, 7, 8}) {
// We check that no matter what is the initial partition, the worker
@@ -103,7 +104,7 @@ TEST(ParallelLoopRunnerTest, Worker) {
}
}

TEST(ParallelLoopRunnerTest, WorkerConcurrency) {
TEST(ParallelLoopRunnerWorkerTest, WorkerConcurrency) {
tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8);

size_t size = 1024;
@@ -129,11 +130,14 @@ TEST(ParallelLoopRunnerTest, WorkerConcurrency) {
EXPECT_EQ(num_tasks.load(), size);
}

TEST(ParallelLoopRunnerTest, Parallelize1D) {
class ParallelLoopRunnerTest
: public testing::TestWithParam<std::optional<absl::Duration>> {};

TEST_P(ParallelLoopRunnerTest, Parallelize1D) {
tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8);
Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(),
threads.NumThreads());
ParallelLoopRunner runner(&device);
ParallelLoopRunner runner(&device, GetParam());

constexpr int32_t d0 = 128;

@@ -153,11 +157,11 @@ TEST(ParallelLoopRunnerTest, Parallelize1D) {
[](int32_t value) { return value == 5; }));
}

TEST(ParallelLoopRunnerTest, Parallelize1DTile1D) {
TEST_P(ParallelLoopRunnerTest, Parallelize1DTile1D) {
tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8);
Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(),
threads.NumThreads());
ParallelLoopRunner runner(&device);
ParallelLoopRunner runner(&device, GetParam());

constexpr int32_t d0 = 128;

@@ -181,11 +185,11 @@ TEST(ParallelLoopRunnerTest, Parallelize1DTile1D) {
[](int32_t value) { return value == 5; }));
}

TEST(ParallelLoopRunnerTest, Parallelize2DTile1D) {
TEST_P(ParallelLoopRunnerTest, Parallelize2DTile1D) {
tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8);
Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(),
threads.NumThreads());
ParallelLoopRunner runner(&device);
ParallelLoopRunner runner(&device, GetParam());

constexpr int32_t d0 = 4;
constexpr int32_t d1 = 39;
@@ -210,11 +214,11 @@ TEST(ParallelLoopRunnerTest, Parallelize2DTile1D) {
[](int32_t value) { return value == 5; }));
}

TEST(ParallelLoopRunnerTest, Parallelize3DTile2D) {
TEST_P(ParallelLoopRunnerTest, Parallelize3DTile2D) {
tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8);
Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(),
threads.NumThreads());
ParallelLoopRunner runner(&device);
ParallelLoopRunner runner(&device, GetParam());

constexpr int32_t d0 = 4;
constexpr int32_t d1 = 39;
@@ -243,6 +247,13 @@ TEST(ParallelLoopRunnerTest, Parallelize3DTile2D) {
[](int32_t value) { return value == 5; }));
}

INSTANTIATE_TEST_SUITE_P(ParallelLoopRunner, ParallelLoopRunnerTest,
testing::Values(std::nullopt, absl::Nanoseconds(100),
absl::Nanoseconds(500),
absl::Microseconds(1),
absl::Microseconds(10),
absl::Milliseconds(1)));

//===----------------------------------------------------------------------===//
// Performance benchmarks.
//===----------------------------------------------------------------------===//
@@ -265,7 +276,11 @@ static void BM_Parallelize2DTile1D(benchmark::State& state) {
tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8);
Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(),
threads.NumThreads());
ParallelLoopRunner runner(&device);

size_t timeslice = state.range(0);
ParallelLoopRunner runner(
&device, timeslice ? std::make_optional(absl::Nanoseconds(timeslice))
: std::nullopt);

size_t range = 4;
size_t tile = 1;
@@ -276,13 +291,17 @@ static void BM_Parallelize2DTile1D(benchmark::State& state) {
}
}

BENCHMARK(BM_Parallelize2DTile1D);
BENCHMARK(BM_Parallelize2DTile1D)->Arg(0)->Arg(100)->Arg(10000);

static void BM_Parallelize3DTile2D(benchmark::State& state) {
tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8);
Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(),
threads.NumThreads());
ParallelLoopRunner runner(&device);

size_t timeslice = state.range(0);
ParallelLoopRunner runner(
&device, timeslice ? std::make_optional(absl::Nanoseconds(timeslice))
: std::nullopt);

size_t range = 4;
size_t tile = 1;
@@ -294,7 +313,7 @@ static void BM_Parallelize3DTile2D(benchmark::State& state) {
}
}

BENCHMARK(BM_Parallelize3DTile2D);
BENCHMARK(BM_Parallelize3DTile2D)->Arg(0)->Arg(100)->Arg(10000);

} // namespace
} // namespace xla::cpu
4 changes: 3 additions & 1 deletion xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk.cc
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "pthreadpool.h"
#include "xla/backends/cpu/runtime/thunk.h"
@@ -137,7 +138,8 @@ absl::StatusOr<XnnFusionThunk::XnnRuntime> XnnFusionThunk::CreateXnnRuntime(

// If XLA is compiled with custom pthreadpool, use it in XNNPACK runtime,
// otherwise we'll run all XNNPACK operations in the default pthreadpool.
runtime.runner = std::make_unique<ParallelLoopRunner>(device);
runtime.runner = std::make_unique<ParallelLoopRunner>(
device, /*worker_timeslice=*/absl::Microseconds(100));
if (use_custom_threadpool) {
runtime.threadpool = CreateCustomPthreadpool(runtime.runner.get());
} else {