|
56 | 56 | #include <cuda_runtime_api.h>
|
57 | 57 | #endif // TRITON_ENABLE_GPU
|
58 | 58 |
|
| 59 | +// Default forward method to call on PyTorch modules |
| 60 | +const std::string DEFAULT_MODULE_METHOD_NAME = "forward"; |
| 61 | + |
59 | 62 | //
|
60 | 63 | // PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API.
|
61 | 64 | //
|
@@ -105,6 +108,7 @@ class ModelState : public BackendModel {
|
105 | 108 | {
|
106 | 109 | return model_outputs_;
|
107 | 110 | }
|
| 111 | + const std::string& ModuleMethodName() { return module_method_name_; } |
108 | 112 |
|
109 | 113 | private:
|
110 | 114 | ModelState(TRITONBACKEND_Model* triton_model);
|
@@ -147,6 +151,10 @@ class ModelState : public BackendModel {
|
147 | 151 | // is specified both in the output section and state section, it indicates
|
148 | 152 | // that the backend must return the output state to the client too.
|
149 | 153 | std::map<std::string, std::pair<int64_t, int64_t>> model_outputs_;
|
| 154 | + |
| 155 | + // Method to call on PyTorch Module. |
| 156 | + // Defaults to DEFAULT_MODULE_METHOD_NAME. |
| 157 | + std::string module_method_name_; |
150 | 158 | };
|
151 | 159 |
|
152 | 160 | TRITONSERVER_Error*
|
@@ -224,7 +232,8 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
|
224 | 232 | enable_inference_mode_(true), enable_cache_cleaning_(false),
|
225 | 233 | enable_weight_sharing_(false), enable_tensor_fuser_pair_({false, true}),
|
226 | 234 | enable_jit_profiling_pair_({false, true}),
|
227 |
| - enable_jit_executor_pair_({false, true}) |
| 235 | + enable_jit_executor_pair_({false, true}), |
| 236 | + module_method_name_(DEFAULT_MODULE_METHOD_NAME) |
228 | 237 | {
|
229 | 238 | }
|
230 | 239 |
|
@@ -465,6 +474,30 @@ ModelState::ParseParameters()
|
465 | 474 | " for model instance '" + Name() + "'")
|
466 | 475 | .c_str());
|
467 | 476 | }
|
| 477 | + |
| 478 | + // If 'MODULE_METHOD_NAME' is not present in 'parameters' then |
| 479 | + // 'module_method_name_' is set to 'DEFAULT_MODULE_METHOD_NAME' ('forward'). |
| 480 | + std::string module_method_name = DEFAULT_MODULE_METHOD_NAME; |
| 481 | + err = GetParameterValue(params, "MODULE_METHOD_NAME", &module_method_name); |
| 482 | + if (err != nullptr) { |
| 483 | + if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { |
| 484 | + return err; |
| 485 | + } else { |
| 486 | + LOG_MESSAGE( |
| 487 | + TRITONSERVER_LOG_INFO, |
| 488 | + (std::string("module_method_name is not specified") + |
| 489 | + " for model instance '" + Name() + "'") |
| 490 | + .c_str()); |
| 491 | + TRITONSERVER_ErrorDelete(err); |
| 492 | + } |
| 493 | + } else { |
| 494 | + module_method_name_ = module_method_name; |
| 495 | + LOG_MESSAGE( |
| 496 | + TRITONSERVER_LOG_INFO, |
| 497 | + (std::string("module_method_name is ") + module_method_name_ + |
| 498 | + " for model instance '" + Name() + "'") |
| 499 | + .c_str()); |
| 500 | + } |
468 | 501 | }
|
469 | 502 |
|
470 | 503 | return nullptr;
|
@@ -886,7 +919,19 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
|
886 | 919 | // configuration specifies only those.
|
887 | 920 | std::vector<std::string> allowed_inputs;
|
888 | 921 |
|
889 |
| - const torch::jit::Method& method = torch_model_->get_method("forward"); |
| 922 | + // First check if method exists in the model and throw an error if absent |
| 923 | + const auto methodNameToExecute = model_state_->ModuleMethodName(); |
| 924 | + const auto optionalMethodHandle = torch_model_->find_method(methodNameToExecute); |
| 925 | + if (!optionalMethodHandle.has_value()) { |
| 926 | + return TRITONSERVER_ErrorNew( |
| 927 | + TRITONSERVER_ERROR_INVALID_ARG, |
| 928 | + (std::string("unable to find method '") + |
| 929 | + methodNameToExecute + "' in model '" + model_path_ + "'") |
| 930 | + .c_str()); |
| 931 | + } |
| 932 | + |
| 933 | + // Get the method schema and validate the inputs |
| 934 | + const torch::jit::Method& method = optionalMethodHandle.value(); |
890 | 935 | const auto& schema = method.function().getSchema();
|
891 | 936 | const std::vector<c10::Argument>& arguments = schema.arguments();
|
892 | 937 |
|
@@ -1529,18 +1574,24 @@ ModelInstanceState::Execute(
|
1529 | 1574 | torch::NoGradGuard no_grad;
|
1530 | 1575 |
|
1531 | 1576 | // If input is a dictionary, prepare dictionary from 'input_tensors'.
|
| 1577 | + std::string module_method_name = model_state_->ModuleMethodName(); |
| 1578 | + std::vector<c10::IValue> inputs; |
1532 | 1579 | if (is_dict_input_) {
|
1533 |
| - torch::Dict<std::string, torch::Tensor> input_dict; |
| 1580 | + c10::Dict<std::string, at::Tensor> dict; |
1534 | 1581 | for (auto& input_index : input_index_map_) {
|
1535 | 1582 | torch::jit::IValue ival = (*input_tensors)[input_index.second];
|
1536 |
| - input_dict.insert(input_index.first, ival.toTensor()); |
| 1583 | + dict.insert(input_index.first, ival.toTensor()); |
1537 | 1584 | }
|
1538 |
| - std::vector<torch::jit::IValue> input_dict_ivalue = {input_dict}; |
1539 |
| - model_outputs_ = torch_model_->forward(input_dict_ivalue); |
| 1585 | + inputs.push_back(dict); |
1540 | 1586 | } else {
|
1541 |
| - model_outputs_ = torch_model_->forward(*input_tensors); |
| 1587 | + for (auto& input_tensor : *input_tensors) { |
| 1588 | + inputs.push_back(input_tensor.toTensor()); |
| 1589 | + } |
1542 | 1590 | }
|
1543 | 1591 |
|
| 1592 | + // Actually run the method on the model. |
| 1593 | + model_outputs_ = torch_model_->get_method(module_method_name)(inputs); |
| 1594 | + |
1544 | 1595 | if (model_outputs_.isTuple()) {
|
1545 | 1596 | auto model_outputs_tuple = model_outputs_.toTuple();
|
1546 | 1597 | size_t op_index = 0;
|
|
0 commit comments