@@ -66,6 +66,7 @@ limitations under the License.
66
66
#include " tensorflow/core/lib/strings/str_util.h"
67
67
#include " tensorflow/core/lib/strings/strcat.h"
68
68
#include " tensorflow/core/platform/byte_order.h"
69
+ #include " tensorflow/core/platform/cpu_info.h"
69
70
#include " tensorflow/core/platform/logging.h"
70
71
#include " tensorflow/core/platform/mutex.h"
71
72
#include " tensorflow/core/platform/tracing.h"
@@ -251,8 +252,38 @@ std::atomic_int_fast64_t DirectSession::step_id_counter_(1);
251
252
252
253
static RunHandlerPool* GetOrCreateRunHandlerPool (
253
254
const SessionOptions& options) {
255
+ int num_inter_threads = 0 ;
256
+ int num_intra_threads = 0 ;
257
+ static const int env_num_inter_threads = NumInterOpThreadsFromEnvironment ();
258
+ static const int env_num_intra_threads = NumIntraOpThreadsFromEnvironment ();
259
+ if (env_num_inter_threads > 0 ) {
260
+ num_inter_threads = env_num_inter_threads;
261
+ }
262
+ if (env_num_intra_threads > 0 ) {
263
+ num_intra_threads = env_num_intra_threads;
264
+ }
265
+
266
+ if (num_inter_threads == 0 ) {
267
+ if (options.config .session_inter_op_thread_pool_size () > 0 ) {
268
+ // Note due to ShouldUseRunHandler we are guaranteed that
269
+ // run_options.inter_op_thread_pool() == 0
270
+ num_inter_threads =
271
+ options.config .session_inter_op_thread_pool (0 ).num_threads ();
272
+ }
273
+ if (num_inter_threads == 0 ) {
274
+ num_inter_threads = NumInterOpThreadsFromSessionOptions (options);
275
+ }
276
+ }
277
+
278
+ if (num_intra_threads == 0 ) {
279
+ num_intra_threads = options.config .intra_op_parallelism_threads ();
280
+ if (num_intra_threads == 0 ) {
281
+ num_intra_threads = port::NumSchedulableCPUs ();
282
+ }
283
+ }
284
+
254
285
static RunHandlerPool* pool =
255
- new RunHandlerPool (NumInterOpThreadsFromSessionOptions (options) );
286
+ new RunHandlerPool (num_inter_threads, num_intra_threads );
256
287
return pool;
257
288
}
258
289
@@ -630,7 +661,7 @@ Status DirectSession::RunInternal(
630
661
if (ShouldUseRunHandlerPool (run_options) &&
631
662
run_options.experimental ().use_run_handler_pool ()) {
632
663
VLOG (1 ) << " Using RunHandler to scheduler inter-op closures." ;
633
- handler = GetOrCreateRunHandlerPool (options_)->Get ();
664
+ handler = GetOrCreateRunHandlerPool (options_)->Get (step_id );
634
665
}
635
666
auto * handler_ptr = handler.get ();
636
667
@@ -663,6 +694,10 @@ Status DirectSession::RunInternal(
663
694
device_thread_pool->Schedule (std::move (c));
664
695
};
665
696
}
697
+ if (handler != nullptr ) {
698
+ args.user_intra_op_threadpool = handler->AsIntraThreadPoolInterface ();
699
+ }
700
+
666
701
item.executor ->RunAsync (args, barrier->Get ());
667
702
}
668
703
0 commit comments