Skip to content

Commit c6a9e1c

Browse files
committed
Support calling custom method names via METHOD_TO_CALL (fixes triton-inference-server/server#5209)
1 parent 588c6ac commit c6a9e1c

File tree

3 files changed

+80
-22
lines changed

3 files changed

+80
-22
lines changed

README.md

+14
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,20 @@ complex execution modes and dynamic shapes. If not specified, all are enabled by
206206

207207
`ENABLE_TENSOR_FUSER`
208208

209+
* `MODULE_METHOD_NAME`: String flag to specify which method on the PyTorch model is being called.
210+
Default value is `forward`.
211+
212+
The section of model config file specifying this parameter will look like:
213+
214+
```
215+
parameters: {
216+
key: "MODULE_METHOD_NAME"
217+
value: {
218+
string_value:"custom_method"
219+
}
220+
}
221+
```
222+
209223
### Important Note
210224

211225
* The execution of PyTorch model on GPU is asynchronous in nature. See

src/libtorch.cc

+65-21
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

2727
#include <stdint.h>
28+
2829
#include <exception>
30+
2931
#include "libtorch_utils.h"
3032
#include "triton/backend/backend_common.h"
3133
#include "triton/backend/backend_input_collector.h"
@@ -53,6 +55,9 @@
5355
#include <cuda_runtime_api.h>
5456
#endif // TRITON_ENABLE_GPU
5557

58+
// Default forward method to call on PyTorch modules
59+
const std::string DEFAULT_MODULE_METHOD_NAME = "forward";
60+
5661
//
5762
// PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API.
5863
//
@@ -103,6 +108,7 @@ class ModelState : public BackendModel {
103108

104109
bool EnabledWeightSharing() { return enable_weight_sharing_; }
105110
const std::vector<std::string>& ModelOutputs() { return output_names_; }
111+
const std::string& ModuleMethodName() { return module_method_name_; }
106112

107113
private:
108114
ModelState(TRITONBACKEND_Model* triton_model);
@@ -145,6 +151,10 @@ class ModelState : public BackendModel {
145151
// List of all the outputs specified in the output section of model
146152
// configuration.
147153
std::vector<std::string> output_names_;
154+
155+
// Method to call on PyTorch Module.
156+
// Defaults to DEFAULT_MODULE_METHOD_NAME.
157+
std::string module_method_name_;
148158
};
149159

150160
TRITONSERVER_Error*
@@ -180,7 +190,8 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
180190
enable_weight_sharing_(false), enable_tensor_fuser_pair_({false, true}),
181191
enable_jit_profiling_pair_({false, true}),
182192
enable_jit_executor_pair_({false, true}),
183-
enable_nvfuser_pair_({false, false})
193+
enable_nvfuser_pair_({false, false}),
194+
module_method_name_(DEFAULT_MODULE_METHOD_NAME)
184195
{
185196
output_names_.clear();
186197

@@ -454,6 +465,30 @@ ModelState::ParseParameters()
454465
" for model instance '" + Name() + "'")
455466
.c_str());
456467
}
468+
469+
// If 'MODULE_METHOD_NAME' is not present in 'parameters' then
470+
// 'module_method_name_' is set to 'DEFAULT_MODULE_METHOD_NAME' ('forward').
471+
std::string module_method_name = DEFAULT_MODULE_METHOD_NAME;
472+
err = GetParameterValue(params, "MODULE_METHOD_NAME", &module_method_name);
473+
if (err != nullptr) {
474+
if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) {
475+
return err;
476+
} else {
477+
LOG_MESSAGE(
478+
TRITONSERVER_LOG_INFO,
479+
(std::string("module_method_name is not specified") +
480+
" for model instance '" + Name() + "'")
481+
.c_str());
482+
TRITONSERVER_ErrorDelete(err);
483+
}
484+
} else {
485+
module_method_name_ = module_method_name;
486+
LOG_MESSAGE(
487+
TRITONSERVER_LOG_INFO,
488+
(std::string("module_method_name is ") + module_method_name_ +
489+
" for model instance '" + Name() + "'")
490+
.c_str());
491+
}
457492
}
458493

459494
return nullptr;
@@ -764,7 +799,8 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
764799
// configuration specifies only those.
765800
std::vector<std::string> allowed_inputs;
766801

767-
const torch::jit::Method& method = torch_model_->get_method("forward");
802+
const torch::jit::Method& method =
803+
torch_model_->get_method(model_state_->ModuleMethodName());
768804
const auto& schema = method.function().getSchema();
769805
const std::vector<c10::Argument>& arguments = schema.arguments();
770806

@@ -1312,30 +1348,36 @@ ModelInstanceState::Execute(
13121348
torch::jit::overrideCanFuseOnCPU(false);
13131349
torch::jit::overrideCanFuseOnGPU(false);
13141350
torch::jit::setTensorExprFuserEnabled(false);
1315-
torch::jit::fuser::cuda::setEnabled(true);
1351+
torch::jit::fuser::cuda::setEnabled(true);
13161352
} else {
13171353
torch::jit::overrideCanFuseOnCPU(true);
13181354
torch::jit::overrideCanFuseOnGPU(true);
13191355
torch::jit::setTensorExprFuserEnabled(true);
1320-
torch::jit::fuser::cuda::setEnabled(false);
1356+
torch::jit::fuser::cuda::setEnabled(false);
13211357
}
13221358
}
13231359

13241360
torch::NoGradGuard no_grad;
13251361

13261362
// If input is a dictionary, prepare dictionary from 'input_tensors'.
1363+
std::string module_method_name = model_state_->ModuleMethodName();
1364+
std::vector<c10::IValue> inputs;
13271365
if (is_dict_input_) {
1328-
torch::Dict<std::string, torch::Tensor> input_dict;
1366+
c10::Dict<std::string, at::Tensor> dict;
13291367
for (auto& input_index : input_index_map_) {
13301368
torch::jit::IValue ival = (*input_tensors)[input_index.second];
1331-
input_dict.insert(input_index.first, ival.toTensor());
1369+
dict.insert(input_index.first, ival.toTensor());
13321370
}
1333-
std::vector<torch::jit::IValue> input_dict_ivalue = {input_dict};
1334-
model_outputs_ = torch_model_->forward(input_dict_ivalue);
1371+
inputs.push_back(dict);
13351372
} else {
1336-
model_outputs_ = torch_model_->forward(*input_tensors);
1373+
for (auto& input_tensor : *input_tensors) {
1374+
inputs.push_back(input_tensor.toTensor());
1375+
}
13371376
}
13381377

1378+
// Actually run the method on the model.
1379+
model_outputs_ = torch_model_->get_method(module_method_name)(inputs);
1380+
13391381
if (model_outputs_.isTuple()) {
13401382
auto model_outputs_tuple = model_outputs_.toTuple();
13411383
size_t op_index = 0;
@@ -1761,9 +1803,9 @@ ModelInstanceState::SetInputTensors(
17611803

17621804
batchn_shape[0] += GetElementCount(input_shape, input_dims_count);
17631805
}
1764-
}
1765-
else {
1766-
batchn_shape = std::vector<int64_t>(input_shape, input_shape + input_dims_count);
1806+
} else {
1807+
batchn_shape =
1808+
std::vector<int64_t>(input_shape, input_shape + input_dims_count);
17671809
if (supports_batching_) {
17681810
batchn_shape[0] = total_batch_size;
17691811
}
@@ -1772,8 +1814,8 @@ ModelInstanceState::SetInputTensors(
17721814
// The input must be in contiguous CPU/GPU memory.
17731815
std::vector<std::pair<TRITONSERVER_MemoryType, int64_t>> alloc_perference;
17741816
if (device_.is_cpu()) {
1775-
alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0},
1776-
{TRITONSERVER_MEMORY_CPU, 0}};
1817+
alloc_perference = {
1818+
{TRITONSERVER_MEMORY_CPU_PINNED, 0}, {TRITONSERVER_MEMORY_CPU, 0}};
17771819
} else {
17781820
alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index()}};
17791821
}
@@ -1887,9 +1929,11 @@ ModelInstanceState::ReadOutputTensors(
18871929

18881930
// Output tensors may not reside on the same device as model
18891931
torch::Device tensor_device = output_flat.device();
1890-
const auto memory_type = (tensor_device.type() == torch::kCPU) ? TRITONSERVER_MEMORY_CPU
1891-
: TRITONSERVER_MEMORY_GPU;
1892-
const auto memory_id = (tensor_device.type() == torch::kCPU) ? 0 : tensor_device.index();
1932+
const auto memory_type = (tensor_device.type() == torch::kCPU)
1933+
? TRITONSERVER_MEMORY_CPU
1934+
: TRITONSERVER_MEMORY_GPU;
1935+
const auto memory_id =
1936+
(tensor_device.type() == torch::kCPU) ? 0 : tensor_device.index();
18931937

18941938
// Batch output doesn't support string data type yet, as it is not trivial
18951939
// to parse string output
@@ -1906,16 +1950,16 @@ ModelInstanceState::ReadOutputTensors(
19061950
return TRITONSERVER_ErrorNew(
19071951
TRITONSERVER_ERROR_INVALID_ARG,
19081952
(std::string("output '") + name +
1909-
"' is a scalar which is not supported.")
1953+
"' is a scalar which is not supported.")
19101954
.c_str());
19111955
}
19121956

19131957
responder.ProcessTensor(
1914-
name, output_dtype, batchn_shape, output_buffer,
1915-
memory_type, memory_id);
1958+
name, output_dtype, batchn_shape, output_buffer, memory_type,
1959+
memory_id);
19161960
} else {
19171961
responder.ProcessBatchOutput(
1918-
name, *batch_output, output_buffer, memory_type, memory_id);
1962+
name, *batch_output, output_buffer, memory_type, memory_id);
19191963
}
19201964
} else if (output_tensors[op_index].isList()) {
19211965
// Custom handling for string/bytes tensor...

src/libtorch_utils.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ ParseParameter(
152152
#ifdef TRITON_ENABLE_GPU
153153
TRITONSERVER_Error*
154154
ConvertCUDAStatusToTritonError(
155-
cudaError_t cuda_error,TRITONSERVER_Error_Code code, const char* msg)
155+
cudaError_t cuda_error, TRITONSERVER_Error_Code code, const char* msg)
156156
{
157157
if (cuda_error != cudaSuccess) {
158158
return TRITONSERVER_ErrorNew(

0 commit comments

Comments
 (0)