Skip to content

Commit e4cc7a6

Browse files
committed
Support calling custom method names via MODULE_METHOD_NAME (fixes triton-inference-server/server#5209)
Signed-off-by: Christian Bruckdorfer <[email protected]>
1 parent 0931f9d commit e4cc7a6

File tree

2 files changed

+72
-7
lines changed

2 files changed

+72
-7
lines changed

README.md

+14
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,20 @@ key: "ENABLE_CACHE_CLEANING"
176176
}
177177
```
178178

179+
* `MODULE_METHOD_NAME`: String flag to specify which method on the PyTorch model is being called.
180+
Default value is `forward`.
181+
182+
The section of model config file specifying this parameter will look like:
183+
184+
```
185+
parameters: {
186+
key: "MODULE_METHOD_NAME"
187+
value: {
188+
string_value:"custom_method"
189+
}
190+
}
191+
```
192+
179193
* Additional Optimizations: Three additional boolean parameters are available to disable
180194
certain Torch optimizations that can sometimes cause latency regressions in models with
181195
complex execution modes and dynamic shapes. If not specified, all are enabled by default.

src/libtorch.cc

+58-7
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@
5656
#include <cuda_runtime_api.h>
5757
#endif // TRITON_ENABLE_GPU
5858

59+
// Default forward method to call on PyTorch modules
60+
const std::string DEFAULT_MODULE_METHOD_NAME = "forward";
61+
5962
//
6063
// PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API.
6164
//
@@ -105,6 +108,7 @@ class ModelState : public BackendModel {
105108
{
106109
return model_outputs_;
107110
}
111+
const std::string& ModuleMethodName() { return module_method_name_; }
108112

109113
private:
110114
ModelState(TRITONBACKEND_Model* triton_model);
@@ -147,6 +151,10 @@ class ModelState : public BackendModel {
147151
// is specified both in the output section and state section, it indicates
148152
// that the backend must return the output state to the client too.
149153
std::map<std::string, std::pair<int64_t, int64_t>> model_outputs_;
154+
155+
// Method to call on PyTorch Module.
156+
// Defaults to DEFAULT_MODULE_METHOD_NAME.
157+
std::string module_method_name_;
150158
};
151159

152160
TRITONSERVER_Error*
@@ -224,7 +232,8 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
224232
enable_inference_mode_(true), enable_cache_cleaning_(false),
225233
enable_weight_sharing_(false), enable_tensor_fuser_pair_({false, true}),
226234
enable_jit_profiling_pair_({false, true}),
227-
enable_jit_executor_pair_({false, true})
235+
enable_jit_executor_pair_({false, true}),
236+
module_method_name_(DEFAULT_MODULE_METHOD_NAME)
228237
{
229238
}
230239

@@ -465,6 +474,30 @@ ModelState::ParseParameters()
465474
" for model instance '" + Name() + "'")
466475
.c_str());
467476
}
477+
478+
// If 'MODULE_METHOD_NAME' is not present in 'parameters' then
479+
// 'module_method_name_' is set to 'DEFAULT_MODULE_METHOD_NAME' ('forward').
480+
std::string module_method_name = DEFAULT_MODULE_METHOD_NAME;
481+
err = GetParameterValue(params, "MODULE_METHOD_NAME", &module_method_name);
482+
if (err != nullptr) {
483+
if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) {
484+
return err;
485+
} else {
486+
LOG_MESSAGE(
487+
TRITONSERVER_LOG_INFO,
488+
(std::string("module_method_name is not specified") +
489+
" for model instance '" + Name() + "'")
490+
.c_str());
491+
TRITONSERVER_ErrorDelete(err);
492+
}
493+
} else {
494+
module_method_name_ = module_method_name;
495+
LOG_MESSAGE(
496+
TRITONSERVER_LOG_INFO,
497+
(std::string("module_method_name is ") + module_method_name_ +
498+
" for model instance '" + Name() + "'")
499+
.c_str());
500+
}
468501
}
469502

470503
return nullptr;
@@ -886,7 +919,19 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
886919
// configuration specifies only those.
887920
std::vector<std::string> allowed_inputs;
888921

889-
const torch::jit::Method& method = torch_model_->get_method("forward");
922+
// First check if method exists in the model and throw an error if absent
923+
const auto methodNameToExecute = model_state_->ModuleMethodName();
924+
const auto optionalMethodHandle = torch_model_->find_method(methodNameToExecute);
925+
if (!optionalMethodHandle.has_value()) {
926+
return TRITONSERVER_ErrorNew(
927+
TRITONSERVER_ERROR_INVALID_ARG,
928+
(std::string("unable to find method '") +
929+
methodNameToExecute + "' in model '" + model_path_ + "'")
930+
.c_str());
931+
}
932+
933+
// Get the method schema and validate the inputs
934+
const torch::jit::Method& method = optionalMethodHandle.value();
890935
const auto& schema = method.function().getSchema();
891936
const std::vector<c10::Argument>& arguments = schema.arguments();
892937

@@ -1529,18 +1574,24 @@ ModelInstanceState::Execute(
15291574
torch::NoGradGuard no_grad;
15301575

15311576
// If input is a dictionary, prepare dictionary from 'input_tensors'.
1577+
std::string module_method_name = model_state_->ModuleMethodName();
1578+
std::vector<c10::IValue> inputs;
15321579
if (is_dict_input_) {
1533-
torch::Dict<std::string, torch::Tensor> input_dict;
1580+
c10::Dict<std::string, at::Tensor> dict;
15341581
for (auto& input_index : input_index_map_) {
15351582
torch::jit::IValue ival = (*input_tensors)[input_index.second];
1536-
input_dict.insert(input_index.first, ival.toTensor());
1583+
dict.insert(input_index.first, ival.toTensor());
15371584
}
1538-
std::vector<torch::jit::IValue> input_dict_ivalue = {input_dict};
1539-
model_outputs_ = torch_model_->forward(input_dict_ivalue);
1585+
inputs.push_back(dict);
15401586
} else {
1541-
model_outputs_ = torch_model_->forward(*input_tensors);
1587+
for (auto& input_tensor : *input_tensors) {
1588+
inputs.push_back(input_tensor.toTensor());
1589+
}
15421590
}
15431591

1592+
// Actually run the method on the model.
1593+
model_outputs_ = torch_model_->get_method(module_method_name)(inputs);
1594+
15441595
if (model_outputs_.isTuple()) {
15451596
auto model_outputs_tuple = model_outputs_.toTuple();
15461597
size_t op_index = 0;

0 commit comments

Comments
 (0)