|
56 | 56 | #include <cuda_runtime_api.h>
|
57 | 57 | #endif // TRITON_ENABLE_GPU
|
58 | 58 |
|
| 59 | +// for thread control |
| 60 | +// https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html#runtime-api |
| 61 | +// https://github.com/pytorch/pytorch/blob/v2.2.1-rc3/aten/src/ATen/Parallel.h#L133 |
| 62 | +#include <ATen/Parallel.h> |
| 63 | + |
| 64 | + |
59 | 65 | //
|
60 | 66 | // PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API.
|
61 | 67 | //
|
@@ -465,6 +471,76 @@ ModelState::ParseParameters()
|
465 | 471 | " for model instance '" + Name() + "'")
|
466 | 472 | .c_str());
|
467 | 473 | }
|
| 474 | + |
| 475 | + // If "INTRA_OP_THREAD_COUNT" is not present in 'parameters' then no update |
| 476 | + // is made to 'intra_op_thread_count', which by default will take all |
| 477 | + // threads |
| 478 | + int intra_op_thread_count = -1; |
| 479 | + err = ParseParameterInt( |
| 480 | + params, "INTRA_OP_THREAD_COUNT", &intra_op_thread_count); |
| 481 | + if (err != nullptr) { |
| 482 | + if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { |
| 483 | + return err; |
| 484 | + } else { |
| 485 | + TRITONSERVER_ErrorDelete(err); |
| 486 | + } |
| 487 | + } else { |
| 488 | + if (intra_op_thread_count > 0) { |
| 489 | + try { |
| 490 | + at::set_num_threads(intra_op_thread_count); |
| 491 | + LOG_MESSAGE( |
| 492 | + TRITONSERVER_LOG_INFO, |
| 493 | + (std::string("Intra op thread count is set to ") + |
| 494 | + std::to_string(intra_op_thread_count) + " for model instance '" + |
| 495 | + Name() + "'") |
| 496 | + .c_str()); |
| 497 | + } catch (c10::Error &e) { |
| 498 | + LOG_MESSAGE( |
| 499 | + TRITONSERVER_LOG_WARN, |
| 500 | + (std::string("Could not set intra op thread count is set to ") + |
| 501 | + std::to_string(intra_op_thread_count) + " for model instance '" + |
| 502 | + Name() + "'. Using value: " + std::to_string(at::get_num_threads())) |
| 503 | + .c_str()); |
| 504 | + |
| 505 | + |
| 506 | + } |
| 507 | + } |
| 508 | + } |
| 509 | + |
| 510 | + // If "INTER_OP_THREAD_COUNT" is not present in 'parameters' then no update |
| 511 | + // is made to 'inter_op_thread_count', which by default will take all |
| 512 | + // threads |
| 513 | + int inter_op_thread_count = -1; |
| 514 | + err = ParseParameterInt( |
| 515 | + params, "INTER_OP_THREAD_COUNT", &inter_op_thread_count); |
| 516 | + if (err != nullptr) { |
| 517 | + if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { |
| 518 | + return err; |
| 519 | + } else { |
| 520 | + TRITONSERVER_ErrorDelete(err); |
| 521 | + } |
| 522 | + } else { |
| 523 | + if (inter_op_thread_count > 0) { |
| 524 | + try { |
| 525 | + at::set_num_interop_threads(inter_op_thread_count); |
| 526 | + LOG_MESSAGE( |
| 527 | + TRITONSERVER_LOG_INFO, |
| 528 | + (std::string("Inter op thread count is set to ") + |
| 529 | + std::to_string(inter_op_thread_count) + " for model instance '" + |
| 530 | + Name() + "'") |
| 531 | + .c_str()); |
| 532 | + } catch (c10::Error &e) { |
| 533 | + LOG_MESSAGE( |
| 534 | + TRITONSERVER_LOG_WARN, |
| 535 | + (std::string("Could not set intra op thread count is set to ") + |
| 536 | + std::to_string(intra_op_thread_count) + " for model instance '" + |
| 537 | + Name() + "'. Using value: " + std::to_string(at::get_num_interop_threads())) |
| 538 | + .c_str()); |
| 539 | + |
| 540 | + |
| 541 | + } |
| 542 | + } |
| 543 | + } |
468 | 544 | }
|
469 | 545 |
|
470 | 546 | return nullptr;
|
|
0 commit comments