Skip to content

Commit 7b63f0f

Browse files
tanmayv25iceychris
andauthored
Support calling custom method names via MODULE_METHOD_NAME (fixes triton-inference-server/server#5209) (#127)
Signed-off-by: Christian Bruckdorfer <[email protected]> Co-authored-by: Christian Bruckdorfer <[email protected]>
1 parent 5c97507 commit 7b63f0f

File tree

2 files changed

+72
-7
lines changed

2 files changed

+72
-7
lines changed

Diff for: README.md

+14
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,20 @@ key: "INTRA_OP_THREAD_COUNT"
217217
}
218218
```
219219

220+
* `MODULE_METHOD_NAME`:
221+
222+
String flag to specify which method on the PyTorch model is being called.
223+
Default value is `forward`.
224+
225+
```
226+
parameters: {
227+
key: "MODULE_METHOD_NAME"
228+
value: {
229+
string_value:"custom_method"
230+
}
231+
}
232+
```
233+
220234
* Additional Optimizations: Three additional boolean parameters are available to disable
221235
certain Torch optimizations that can sometimes cause latency regressions in models with
222236
complex execution modes and dynamic shapes. If not specified, all are enabled by default.

Diff for: src/libtorch.cc

+58-7
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
// https://github.com/pytorch/pytorch/blob/v2.2.1-rc3/aten/src/ATen/Parallel.h#L133
6262
#include <ATen/Parallel.h>
6363

64+
// Default forward method to call on PyTorch modules
65+
const std::string DEFAULT_MODULE_METHOD_NAME = "forward";
6466

6567
//
6668
// PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API.
@@ -111,6 +113,7 @@ class ModelState : public BackendModel {
111113
{
112114
return model_outputs_;
113115
}
116+
const std::string& ModuleMethodName() { return module_method_name_; }
114117

115118
private:
116119
ModelState(TRITONBACKEND_Model* triton_model);
@@ -153,6 +156,10 @@ class ModelState : public BackendModel {
153156
// is specified both in the output section and state section, it indicates
154157
// that the backend must return the output state to the client too.
155158
std::map<std::string, std::pair<int64_t, int64_t>> model_outputs_;
159+
160+
// Method to call on PyTorch Module.
161+
// Defaults to DEFAULT_MODULE_METHOD_NAME.
162+
std::string module_method_name_;
156163
};
157164

158165
TRITONSERVER_Error*
@@ -230,7 +237,8 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
230237
enable_inference_mode_(true), enable_cache_cleaning_(false),
231238
enable_weight_sharing_(false), enable_tensor_fuser_pair_({false, true}),
232239
enable_jit_profiling_pair_({false, true}),
233-
enable_jit_executor_pair_({false, true})
240+
enable_jit_executor_pair_({false, true}),
241+
module_method_name_(DEFAULT_MODULE_METHOD_NAME)
234242
{
235243
}
236244

@@ -519,6 +527,30 @@ ModelState::ParseParameters()
519527
.c_str());
520528
}
521529
}
530+
531+
// If 'MODULE_METHOD_NAME' is not present in 'parameters' then
532+
// 'module_method_name_' is set to 'DEFAULT_MODULE_METHOD_NAME' ('forward').
533+
std::string module_method_name = DEFAULT_MODULE_METHOD_NAME;
534+
err = GetParameterValue(params, "MODULE_METHOD_NAME", &module_method_name);
535+
if (err != nullptr) {
536+
if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) {
537+
return err;
538+
} else {
539+
LOG_MESSAGE(
540+
TRITONSERVER_LOG_INFO,
541+
(std::string("module_method_name is not specified") +
542+
" for model instance '" + Name() + "'")
543+
.c_str());
544+
TRITONSERVER_ErrorDelete(err);
545+
}
546+
} else {
547+
module_method_name_ = module_method_name;
548+
LOG_MESSAGE(
549+
TRITONSERVER_LOG_INFO,
550+
(std::string("module_method_name is ") + module_method_name_ +
551+
" for model instance '" + Name() + "'")
552+
.c_str());
553+
}
522554
}
523555

524556
return nullptr;
@@ -940,7 +972,20 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
940972
// configuration specifies only those.
941973
std::vector<std::string> allowed_inputs;
942974

943-
const torch::jit::Method& method = torch_model_->get_method("forward");
975+
// First check if method exists in the model and throw an error if absent
976+
const auto methodNameToExecute = model_state_->ModuleMethodName();
977+
const auto optionalMethodHandle =
978+
torch_model_->find_method(methodNameToExecute);
979+
if (!optionalMethodHandle.has_value()) {
980+
return TRITONSERVER_ErrorNew(
981+
TRITONSERVER_ERROR_INVALID_ARG,
982+
(std::string("unable to find method '") + methodNameToExecute +
983+
"' in model '" + model_path_ + "'")
984+
.c_str());
985+
}
986+
987+
// Get the method schema and validate the inputs
988+
const torch::jit::Method& method = optionalMethodHandle.value();
944989
const auto& schema = method.function().getSchema();
945990
const std::vector<c10::Argument>& arguments = schema.arguments();
946991

@@ -1583,18 +1628,24 @@ ModelInstanceState::Execute(
15831628
torch::NoGradGuard no_grad;
15841629

15851630
// If input is a dictionary, prepare dictionary from 'input_tensors'.
1631+
std::string module_method_name = model_state_->ModuleMethodName();
1632+
std::vector<c10::IValue> inputs;
15861633
if (is_dict_input_) {
1587-
torch::Dict<std::string, torch::Tensor> input_dict;
1634+
c10::Dict<std::string, at::Tensor> dict;
15881635
for (auto& input_index : input_index_map_) {
15891636
torch::jit::IValue ival = (*input_tensors)[input_index.second];
1590-
input_dict.insert(input_index.first, ival.toTensor());
1637+
dict.insert(input_index.first, ival.toTensor());
15911638
}
1592-
std::vector<torch::jit::IValue> input_dict_ivalue = {input_dict};
1593-
model_outputs_ = torch_model_->forward(input_dict_ivalue);
1639+
inputs.push_back(dict);
15941640
} else {
1595-
model_outputs_ = torch_model_->forward(*input_tensors);
1641+
for (auto& input_tensor : *input_tensors) {
1642+
inputs.push_back(input_tensor.toTensor());
1643+
}
15961644
}
15971645

1646+
// Actually run the method on the model.
1647+
model_outputs_ = torch_model_->get_method(module_method_name)(inputs);
1648+
15981649
if (model_outputs_.isTuple()) {
15991650
auto model_outputs_tuple = model_outputs_.toTuple();
16001651
size_t op_index = 0;

0 commit comments

Comments
 (0)