Skip to content

Commit 75f87cd

Browse files
committed
Support calling custom method names via METHOD_TO_CALL (fixes triton-inference-server/server#5209)
1 parent 588c6ac commit 75f87cd

File tree

2 files changed

+55
-5
lines changed

2 files changed

+55
-5
lines changed

README.md

+14
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,20 @@ complex execution modes and dynamic shapes. If not specified, all are enabled by
206206

207207
`ENABLE_TENSOR_FUSER`
208208

209+
* `METHOD_TO_CALL`: String flag to specify which method on the PyTorch model is being called.
210+
Default value is `forward`.
211+
212+
The section of model config file specifying this parameter will look like:
213+
214+
```
215+
parameters: {
216+
key: "METHOD_TO_CALL"
217+
value: {
218+
string_value:"true"
219+
}
220+
}
221+
```
222+
209223
### Important Note
210224

211225
* The execution of PyTorch model on GPU is asynchronous in nature. See

src/libtorch.cc

+41-5
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ class ModelState : public BackendModel {
103103

104104
bool EnabledWeightSharing() { return enable_weight_sharing_; }
105105
const std::vector<std::string>& ModelOutputs() { return output_names_; }
106+
const std::string& MethodToCall() { return method_to_call_; }
106107

107108
private:
108109
ModelState(TRITONBACKEND_Model* triton_model);
@@ -145,6 +146,10 @@ class ModelState : public BackendModel {
145146
// List of all the outputs specified in the output section of model
146147
// configuration.
147148
std::vector<std::string> output_names_;
149+
150+
// Method to call on PyTorch Module.
151+
// Defaults to "forward".
152+
std::string method_to_call_;
148153
};
149154

150155
TRITONSERVER_Error*
@@ -180,7 +185,8 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
180185
enable_weight_sharing_(false), enable_tensor_fuser_pair_({false, true}),
181186
enable_jit_profiling_pair_({false, true}),
182187
enable_jit_executor_pair_({false, true}),
183-
enable_nvfuser_pair_({false, false})
188+
enable_nvfuser_pair_({false, false}),
189+
method_to_call_("forward")
184190
{
185191
output_names_.clear();
186192

@@ -454,6 +460,29 @@ ModelState::ParseParameters()
454460
" for model instance '" + Name() + "'")
455461
.c_str());
456462
}
463+
464+
// If 'ENABLE_NVFUSER' is not present in 'parameters' then no
465+
// update is made to 'enable_nvfuser'.
466+
std::string method_to_call = "forward";
467+
err = GetParameterValue(params, "METHOD_TO_CALL", &method_to_call);
468+
if (err != nullptr) {
469+
if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) {
470+
return err;
471+
} else {
472+
LOG_MESSAGE(
473+
TRITONSERVER_LOG_INFO, (std::string("method_to_call is not specified") +
474+
" for model instance '" + Name() + "'")
475+
.c_str());
476+
TRITONSERVER_ErrorDelete(err);
477+
}
478+
} else {
479+
method_to_call_ = std::string("forward");
480+
LOG_MESSAGE(
481+
TRITONSERVER_LOG_INFO, (std::string("method_to_call is ") +
482+
method_to_call_ +
483+
" for model instance '" + Name() + "'")
484+
.c_str());
485+
}
457486
}
458487

459488
return nullptr;
@@ -764,7 +793,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
764793
// configuration specifies only those.
765794
std::vector<std::string> allowed_inputs;
766795

767-
const torch::jit::Method& method = torch_model_->get_method("forward");
796+
const torch::jit::Method& method = torch_model_->get_method(model_state_->MethodToCall());
768797
const auto& schema = method.function().getSchema();
769798
const std::vector<c10::Argument>& arguments = schema.arguments();
770799

@@ -1324,16 +1353,23 @@ ModelInstanceState::Execute(
13241353
torch::NoGradGuard no_grad;
13251354

13261355
// If input is a dictionary, prepare dictionary from 'input_tensors'.
1356+
std::string method_to_call = model_state_->MethodToCall();
13271357
if (is_dict_input_) {
13281358
torch::Dict<std::string, torch::Tensor> input_dict;
13291359
for (auto& input_index : input_index_map_) {
13301360
torch::jit::IValue ival = (*input_tensors)[input_index.second];
13311361
input_dict.insert(input_index.first, ival.toTensor());
13321362
}
1333-
std::vector<torch::jit::IValue> input_dict_ivalue = {input_dict};
1334-
model_outputs_ = torch_model_->forward(input_dict_ivalue);
1363+
auto typ = c10::DictType::create(c10::StringType::get(), c10::TensorType::get());
1364+
auto inp = c10::impl::GenericList(typ);
1365+
inp.emplace_back(input_dict);
1366+
model_outputs_ = torch_model_->run_method(method_to_call, inp);
13351367
} else {
1336-
model_outputs_ = torch_model_->forward(*input_tensors);
1368+
auto inp = c10::impl::GenericList(c10::TensorType::get());
1369+
for (auto& input_tensor : *input_tensors) {
1370+
inp.emplace_back(input_tensor.toTensor());
1371+
}
1372+
model_outputs_ = torch_model_->run_method(method_to_call, inp);
13371373
}
13381374

13391375
if (model_outputs_.isTuple()) {

0 commit comments

Comments
 (0)