Skip to content

Commit b71a4c0

Browse files
Switch RunHandler to use an inference friendly thread pool.
Each inference has a dedicated work queue. All threads steal the work in the priority order of the request (currently arrival time). Note that there is one pool for both intra and inter work. However to avoid there are some thread the are not allowed to steal inter work, which can be blocking. PiperOrigin-RevId: 254257458
1 parent 7ed84ad commit b71a4c0

File tree

5 files changed

+528
-96
lines changed

5 files changed

+528
-96
lines changed

Diff for: tensorflow/core/BUILD

+17
Original file line numberDiff line numberDiff line change
@@ -4743,6 +4743,23 @@ tf_cc_test(
47434743
],
47444744
)
47454745

4746+
tf_cc_test(
4747+
name = "framework_run_handler_test",
4748+
size = "small",
4749+
srcs = ["framework/run_handler_test.cc"],
4750+
linkstatic = tf_kernel_tests_linkstatic(),
4751+
deps = [
4752+
":framework_internal",
4753+
":lib",
4754+
":lib_internal",
4755+
":test",
4756+
":test_main",
4757+
"//third_party/eigen3",
4758+
"@com_google_absl//absl/memory",
4759+
"@com_google_absl//absl/synchronization",
4760+
],
4761+
)
4762+
47464763
tf_cc_test(
47474764
name = "common_runtime_partitioning_utils_test",
47484765
size = "small",

Diff for: tensorflow/core/common_runtime/direct_session.cc

+37-2
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ limitations under the License.
6666
#include "tensorflow/core/lib/strings/str_util.h"
6767
#include "tensorflow/core/lib/strings/strcat.h"
6868
#include "tensorflow/core/platform/byte_order.h"
69+
#include "tensorflow/core/platform/cpu_info.h"
6970
#include "tensorflow/core/platform/logging.h"
7071
#include "tensorflow/core/platform/mutex.h"
7172
#include "tensorflow/core/platform/tracing.h"
@@ -251,8 +252,38 @@ std::atomic_int_fast64_t DirectSession::step_id_counter_(1);
251252

252253
static RunHandlerPool* GetOrCreateRunHandlerPool(
253254
const SessionOptions& options) {
255+
int num_inter_threads = 0;
256+
int num_intra_threads = 0;
257+
static const int env_num_inter_threads = NumInterOpThreadsFromEnvironment();
258+
static const int env_num_intra_threads = NumIntraOpThreadsFromEnvironment();
259+
if (env_num_inter_threads > 0) {
260+
num_inter_threads = env_num_inter_threads;
261+
}
262+
if (env_num_intra_threads > 0) {
263+
num_intra_threads = env_num_intra_threads;
264+
}
265+
266+
if (num_inter_threads == 0) {
267+
if (options.config.session_inter_op_thread_pool_size() > 0) {
268+
// Note due to ShouldUseRunHandler we are guaranteed that
269+
// run_options.inter_op_thread_pool() == 0
270+
num_inter_threads =
271+
options.config.session_inter_op_thread_pool(0).num_threads();
272+
}
273+
if (num_inter_threads == 0) {
274+
num_inter_threads = NumInterOpThreadsFromSessionOptions(options);
275+
}
276+
}
277+
278+
if (num_intra_threads == 0) {
279+
num_intra_threads = options.config.intra_op_parallelism_threads();
280+
if (num_intra_threads == 0) {
281+
num_intra_threads = port::NumSchedulableCPUs();
282+
}
283+
}
284+
254285
static RunHandlerPool* pool =
255-
new RunHandlerPool(NumInterOpThreadsFromSessionOptions(options));
286+
new RunHandlerPool(num_inter_threads, num_intra_threads);
256287
return pool;
257288
}
258289

@@ -630,7 +661,7 @@ Status DirectSession::RunInternal(
630661
if (ShouldUseRunHandlerPool(run_options) &&
631662
run_options.experimental().use_run_handler_pool()) {
632663
VLOG(1) << "Using RunHandler to scheduler inter-op closures.";
633-
handler = GetOrCreateRunHandlerPool(options_)->Get();
664+
handler = GetOrCreateRunHandlerPool(options_)->Get(step_id);
634665
}
635666
auto* handler_ptr = handler.get();
636667

@@ -663,6 +694,10 @@ Status DirectSession::RunInternal(
663694
device_thread_pool->Schedule(std::move(c));
664695
};
665696
}
697+
if (handler != nullptr) {
698+
args.user_intra_op_threadpool = handler->AsIntraThreadPoolInterface();
699+
}
700+
666701
item.executor->RunAsync(args, barrier->Get());
667702
}
668703

0 commit comments

Comments
 (0)