|
61 | 61 | // https://github.com/pytorch/pytorch/blob/v2.2.1-rc3/aten/src/ATen/Parallel.h#L133
|
62 | 62 | #include <ATen/Parallel.h>
|
63 | 63 |
|
| 64 | +// Default forward method to call on PyTorch modules |
| 65 | +const std::string DEFAULT_MODULE_METHOD_NAME = "forward"; |
64 | 66 |
|
65 | 67 | //
|
66 | 68 | // PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API.
|
@@ -111,6 +113,7 @@ class ModelState : public BackendModel {
|
111 | 113 | {
|
112 | 114 | return model_outputs_;
|
113 | 115 | }
|
| 116 | + const std::string& ModuleMethodName() { return module_method_name_; } |
114 | 117 |
|
115 | 118 | private:
|
116 | 119 | ModelState(TRITONBACKEND_Model* triton_model);
|
@@ -153,6 +156,10 @@ class ModelState : public BackendModel {
|
153 | 156 | // is specified both in the output section and state section, it indicates
|
154 | 157 | // that the backend must return the output state to the client too.
|
155 | 158 | std::map<std::string, std::pair<int64_t, int64_t>> model_outputs_;
|
| 159 | + |
| 160 | + // Method to call on PyTorch Module. |
| 161 | + // Defaults to DEFAULT_MODULE_METHOD_NAME. |
| 162 | + std::string module_method_name_; |
156 | 163 | };
|
157 | 164 |
|
158 | 165 | TRITONSERVER_Error*
|
@@ -230,7 +237,8 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
|
230 | 237 | enable_inference_mode_(true), enable_cache_cleaning_(false),
|
231 | 238 | enable_weight_sharing_(false), enable_tensor_fuser_pair_({false, true}),
|
232 | 239 | enable_jit_profiling_pair_({false, true}),
|
233 |
| - enable_jit_executor_pair_({false, true}) |
| 240 | + enable_jit_executor_pair_({false, true}), |
| 241 | + module_method_name_(DEFAULT_MODULE_METHOD_NAME) |
234 | 242 | {
|
235 | 243 | }
|
236 | 244 |
|
@@ -519,6 +527,30 @@ ModelState::ParseParameters()
|
519 | 527 | .c_str());
|
520 | 528 | }
|
521 | 529 | }
|
| 530 | + |
| 531 | + // If 'MODULE_METHOD_NAME' is not present in 'parameters' then |
| 532 | + // 'module_method_name_' is set to 'DEFAULT_MODULE_METHOD_NAME' ('forward'). |
| 533 | + std::string module_method_name = DEFAULT_MODULE_METHOD_NAME; |
| 534 | + err = GetParameterValue(params, "MODULE_METHOD_NAME", &module_method_name); |
| 535 | + if (err != nullptr) { |
| 536 | + if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { |
| 537 | + return err; |
| 538 | + } else { |
| 539 | + LOG_MESSAGE( |
| 540 | + TRITONSERVER_LOG_INFO, |
| 541 | + (std::string("module_method_name is not specified") + |
| 542 | + " for model instance '" + Name() + "'") |
| 543 | + .c_str()); |
| 544 | + TRITONSERVER_ErrorDelete(err); |
| 545 | + } |
| 546 | + } else { |
| 547 | + module_method_name_ = module_method_name; |
| 548 | + LOG_MESSAGE( |
| 549 | + TRITONSERVER_LOG_INFO, |
| 550 | + (std::string("module_method_name is ") + module_method_name_ + |
| 551 | + " for model instance '" + Name() + "'") |
| 552 | + .c_str()); |
| 553 | + } |
522 | 554 | }
|
523 | 555 |
|
524 | 556 | return nullptr;
|
@@ -940,7 +972,20 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
|
940 | 972 | // configuration specifies only those.
|
941 | 973 | std::vector<std::string> allowed_inputs;
|
942 | 974 |
|
943 |
| - const torch::jit::Method& method = torch_model_->get_method("forward"); |
| 975 | + // First check if method exists in the model and throw an error if absent |
| 976 | + const auto methodNameToExecute = model_state_->ModuleMethodName(); |
| 977 | + const auto optionalMethodHandle = |
| 978 | + torch_model_->find_method(methodNameToExecute); |
| 979 | + if (!optionalMethodHandle.has_value()) { |
| 980 | + return TRITONSERVER_ErrorNew( |
| 981 | + TRITONSERVER_ERROR_INVALID_ARG, |
| 982 | + (std::string("unable to find method '") + methodNameToExecute + |
| 983 | + "' in model '" + model_path_ + "'") |
| 984 | + .c_str()); |
| 985 | + } |
| 986 | + |
| 987 | + // Get the method schema and validate the inputs |
| 988 | + const torch::jit::Method& method = optionalMethodHandle.value(); |
944 | 989 | const auto& schema = method.function().getSchema();
|
945 | 990 | const std::vector<c10::Argument>& arguments = schema.arguments();
|
946 | 991 |
|
@@ -1583,18 +1628,24 @@ ModelInstanceState::Execute(
|
1583 | 1628 | torch::NoGradGuard no_grad;
|
1584 | 1629 |
|
1585 | 1630 | // If input is a dictionary, prepare dictionary from 'input_tensors'.
|
| 1631 | + std::string module_method_name = model_state_->ModuleMethodName(); |
| 1632 | + std::vector<c10::IValue> inputs; |
1586 | 1633 | if (is_dict_input_) {
|
1587 |
| - torch::Dict<std::string, torch::Tensor> input_dict; |
| 1634 | + c10::Dict<std::string, at::Tensor> dict; |
1588 | 1635 | for (auto& input_index : input_index_map_) {
|
1589 | 1636 | torch::jit::IValue ival = (*input_tensors)[input_index.second];
|
1590 |
| - input_dict.insert(input_index.first, ival.toTensor()); |
| 1637 | + dict.insert(input_index.first, ival.toTensor()); |
1591 | 1638 | }
|
1592 |
| - std::vector<torch::jit::IValue> input_dict_ivalue = {input_dict}; |
1593 |
| - model_outputs_ = torch_model_->forward(input_dict_ivalue); |
| 1639 | + inputs.push_back(dict); |
1594 | 1640 | } else {
|
1595 |
| - model_outputs_ = torch_model_->forward(*input_tensors); |
| 1641 | + for (auto& input_tensor : *input_tensors) { |
| 1642 | + inputs.push_back(input_tensor.toTensor()); |
| 1643 | + } |
1596 | 1644 | }
|
1597 | 1645 |
|
| 1646 | + // Actually run the method on the model. |
| 1647 | + model_outputs_ = torch_model_->get_method(module_method_name)(inputs); |
| 1648 | + |
1598 | 1649 | if (model_outputs_.isTuple()) {
|
1599 | 1650 | auto model_outputs_tuple = model_outputs_.toTuple();
|
1600 | 1651 | size_t op_index = 0;
|
|
0 commit comments