Skip to content

Commit 89df3ff

Browse files
Added modified pull request to handle setting number of threads for PyTorch
1 parent 4fa7daa commit 89df3ff

File tree

3 files changed

+96
-0
lines changed

3 files changed

+96
-0
lines changed

src/libtorch.cc

+76
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@
5656
#include <cuda_runtime_api.h>
5757
#endif // TRITON_ENABLE_GPU
5858

59+
// for thread control
60+
// https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html#runtime-api
61+
// https://github.com/pytorch/pytorch/blob/v2.2.1-rc3/aten/src/ATen/Parallel.h#L133
62+
#include <ATen/Parallel.h>
63+
64+
5965
//
6066
// PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API.
6167
//
@@ -465,6 +471,76 @@ ModelState::ParseParameters()
465471
" for model instance '" + Name() + "'")
466472
.c_str());
467473
}
474+
475+
// If "INTRA_OP_THREAD_COUNT" is not present in 'parameters' then no update
476+
// is made to 'intra_op_thread_count', which by default will take all
477+
// threads
478+
int intra_op_thread_count = -1;
479+
err = ParseParameterInt(
480+
params, "INTRA_OP_THREAD_COUNT", &intra_op_thread_count);
481+
if (err != nullptr) {
482+
if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) {
483+
return err;
484+
} else {
485+
TRITONSERVER_ErrorDelete(err);
486+
}
487+
} else {
488+
if (intra_op_thread_count > 0) {
489+
try {
490+
at::set_num_threads(intra_op_thread_count);
491+
LOG_MESSAGE(
492+
TRITONSERVER_LOG_INFO,
493+
(std::string("Intra op thread count is set to ") +
494+
std::to_string(intra_op_thread_count) + " for model instance '" +
495+
Name() + "'")
496+
.c_str());
497+
} catch (c10::Error &e) {
498+
LOG_MESSAGE(
499+
TRITONSERVER_LOG_WARN,
500+
(std::string("Could not set intra op thread count is set to ") +
501+
std::to_string(intra_op_thread_count) + " for model instance '" +
502+
Name() + "'. Using value: " + std::to_string(at::get_num_threads()))
503+
.c_str());
504+
505+
506+
}
507+
}
508+
}
509+
510+
// If "INTER_OP_THREAD_COUNT" is not present in 'parameters' then no update
511+
// is made to 'inter_op_thread_count', which by default will take all
512+
// threads
513+
int inter_op_thread_count = -1;
514+
err = ParseParameterInt(
515+
params, "INTER_OP_THREAD_COUNT", &inter_op_thread_count);
516+
if (err != nullptr) {
517+
if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) {
518+
return err;
519+
} else {
520+
TRITONSERVER_ErrorDelete(err);
521+
}
522+
} else {
523+
if (inter_op_thread_count > 0) {
524+
try {
525+
at::set_num_interop_threads(inter_op_thread_count);
526+
LOG_MESSAGE(
527+
TRITONSERVER_LOG_INFO,
528+
(std::string("Inter op thread count is set to ") +
529+
std::to_string(inter_op_thread_count) + " for model instance '" +
530+
Name() + "'")
531+
.c_str());
532+
} catch (c10::Error &e) {
533+
LOG_MESSAGE(
534+
TRITONSERVER_LOG_WARN,
535+
(std::string("Could not set intra op thread count is set to ") +
536+
std::to_string(intra_op_thread_count) + " for model instance '" +
537+
Name() + "'. Using value: " + std::to_string(at::get_num_interop_threads()))
538+
.c_str());
539+
540+
541+
}
542+
}
543+
}
468544
}
469545

470546
return nullptr;

src/libtorch_utils.cc

+13
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,19 @@ ParseParameter(
149149
return nullptr;
150150
}
151151

152+
TRITONSERVER_Error*
153+
ParseParameterInt(
154+
triton::common::TritonJson::Value& params, const std::string& mkey,
155+
int* value)
156+
{
157+
std::string value_str;
158+
RETURN_IF_ERROR(GetParameterValue(params, mkey, &value_str));
159+
RETURN_IF_ERROR(ParseIntValue(value_str, value));
160+
161+
return nullptr;
162+
}
163+
164+
152165
#ifdef TRITON_ENABLE_GPU
153166
TRITONSERVER_Error*
154167
ConvertCUDAStatusToTritonError(

src/libtorch_utils.h

+7
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,11 @@ TRITONSERVER_Error* ParseParameter(
6262
triton::common::TritonJson::Value& params, const std::string& mkey,
6363
bool* value);
6464

65+
// If the key 'mkey' is present in 'params' then update 'value' with the
66+
// value associated with that key. If 'mkey' is not present in 'params' then
67+
// 'value' is set to 'default_value'.
68+
TRITONSERVER_Error* ParseParameterInt(
69+
triton::common::TritonJson::Value& params, const std::string& mkey,
70+
int* value);
71+
6572
}}} // namespace triton::backend::pytorch

0 commit comments

Comments
 (0)