25
25
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
26
27
27
#include < stdint.h>
28
+
28
29
#include < exception>
30
+
29
31
#include " libtorch_utils.h"
30
32
#include " triton/backend/backend_common.h"
31
33
#include " triton/backend/backend_input_collector.h"
53
55
#include < cuda_runtime_api.h>
54
56
#endif // TRITON_ENABLE_GPU
55
57
58
+ // Default forward method to call on PyTorch modules
59
+ const std::string DEFAULT_MODULE_METHOD_NAME = " forward" ;
60
+
56
61
//
57
62
// PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API.
58
63
//
@@ -103,6 +108,7 @@ class ModelState : public BackendModel {
103
108
104
109
bool EnabledWeightSharing () { return enable_weight_sharing_; }
105
110
const std::vector<std::string>& ModelOutputs () { return output_names_; }
111
+ const std::string& ModuleMethodName () { return module_method_name_; }
106
112
107
113
private:
108
114
ModelState (TRITONBACKEND_Model* triton_model);
@@ -145,6 +151,10 @@ class ModelState : public BackendModel {
145
151
// List of all the outputs specified in the output section of model
146
152
// configuration.
147
153
std::vector<std::string> output_names_;
154
+
155
+ // Method to call on PyTorch Module.
156
+ // Defaults to DEFAULT_MODULE_METHOD_NAME.
157
+ std::string module_method_name_;
148
158
};
149
159
150
160
TRITONSERVER_Error*
@@ -180,7 +190,8 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
180
190
enable_weight_sharing_(false ), enable_tensor_fuser_pair_({false , true }),
181
191
enable_jit_profiling_pair_({false , true }),
182
192
enable_jit_executor_pair_({false , true }),
183
- enable_nvfuser_pair_({false , false })
193
+ enable_nvfuser_pair_({false , false }),
194
+ module_method_name_(DEFAULT_MODULE_METHOD_NAME)
184
195
{
185
196
output_names_.clear ();
186
197
@@ -454,6 +465,30 @@ ModelState::ParseParameters()
454
465
" for model instance '" + Name () + " '" )
455
466
.c_str ());
456
467
}
468
+
469
+ // If 'MODULE_METHOD_NAME' is not present in 'parameters' then
470
+ // 'module_method_name_' is set to 'DEFAULT_MODULE_METHOD_NAME' ('forward').
471
+ std::string module_method_name = DEFAULT_MODULE_METHOD_NAME;
472
+ err = GetParameterValue (params, " MODULE_METHOD_NAME" , &module_method_name);
473
+ if (err != nullptr ) {
474
+ if (TRITONSERVER_ErrorCode (err) != TRITONSERVER_ERROR_NOT_FOUND) {
475
+ return err;
476
+ } else {
477
+ LOG_MESSAGE (
478
+ TRITONSERVER_LOG_INFO,
479
+ (std::string (" module_method_name is not specified" ) +
480
+ " for model instance '" + Name () + " '" )
481
+ .c_str ());
482
+ TRITONSERVER_ErrorDelete (err);
483
+ }
484
+ } else {
485
+ module_method_name_ = module_method_name;
486
+ LOG_MESSAGE (
487
+ TRITONSERVER_LOG_INFO,
488
+ (std::string (" module_method_name is " ) + module_method_name_ +
489
+ " for model instance '" + Name () + " '" )
490
+ .c_str ());
491
+ }
457
492
}
458
493
459
494
return nullptr ;
@@ -764,7 +799,8 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
764
799
// configuration specifies only those.
765
800
std::vector<std::string> allowed_inputs;
766
801
767
- const torch::jit::Method& method = torch_model_->get_method (" forward" );
802
+ const torch::jit::Method& method =
803
+ torch_model_->get_method (model_state_->ModuleMethodName ());
768
804
const auto & schema = method.function ().getSchema ();
769
805
const std::vector<c10::Argument>& arguments = schema.arguments ();
770
806
@@ -1312,30 +1348,36 @@ ModelInstanceState::Execute(
1312
1348
torch::jit::overrideCanFuseOnCPU (false );
1313
1349
torch::jit::overrideCanFuseOnGPU (false );
1314
1350
torch::jit::setTensorExprFuserEnabled (false );
1315
- torch::jit::fuser::cuda::setEnabled (true );
1351
+ torch::jit::fuser::cuda::setEnabled (true );
1316
1352
} else {
1317
1353
torch::jit::overrideCanFuseOnCPU (true );
1318
1354
torch::jit::overrideCanFuseOnGPU (true );
1319
1355
torch::jit::setTensorExprFuserEnabled (true );
1320
- torch::jit::fuser::cuda::setEnabled (false );
1356
+ torch::jit::fuser::cuda::setEnabled (false );
1321
1357
}
1322
1358
}
1323
1359
1324
1360
torch::NoGradGuard no_grad;
1325
1361
1326
1362
// If input is a dictionary, prepare dictionary from 'input_tensors'.
1363
+ std::string module_method_name = model_state_->ModuleMethodName ();
1364
+ std::vector<c10::IValue> inputs;
1327
1365
if (is_dict_input_) {
1328
- torch ::Dict<std::string, torch ::Tensor> input_dict ;
1366
+ c10 ::Dict<std::string, at ::Tensor> dict ;
1329
1367
for (auto & input_index : input_index_map_) {
1330
1368
torch::jit::IValue ival = (*input_tensors)[input_index.second ];
1331
- input_dict .insert (input_index.first , ival.toTensor ());
1369
+ dict .insert (input_index.first , ival.toTensor ());
1332
1370
}
1333
- std::vector<torch::jit::IValue> input_dict_ivalue = {input_dict};
1334
- model_outputs_ = torch_model_->forward (input_dict_ivalue);
1371
+ inputs.push_back (dict);
1335
1372
} else {
1336
- model_outputs_ = torch_model_->forward (*input_tensors);
1373
+ for (auto & input_tensor : *input_tensors) {
1374
+ inputs.push_back (input_tensor.toTensor ());
1375
+ }
1337
1376
}
1338
1377
1378
+ // Actually run the method on the model.
1379
+ model_outputs_ = torch_model_->get_method (module_method_name)(inputs);
1380
+
1339
1381
if (model_outputs_.isTuple ()) {
1340
1382
auto model_outputs_tuple = model_outputs_.toTuple ();
1341
1383
size_t op_index = 0 ;
@@ -1761,9 +1803,9 @@ ModelInstanceState::SetInputTensors(
1761
1803
1762
1804
batchn_shape[0 ] += GetElementCount (input_shape, input_dims_count);
1763
1805
}
1764
- }
1765
- else {
1766
- batchn_shape = std::vector<int64_t >(input_shape, input_shape + input_dims_count);
1806
+ } else {
1807
+ batchn_shape =
1808
+ std::vector<int64_t >(input_shape, input_shape + input_dims_count);
1767
1809
if (supports_batching_) {
1768
1810
batchn_shape[0 ] = total_batch_size;
1769
1811
}
@@ -1772,8 +1814,8 @@ ModelInstanceState::SetInputTensors(
1772
1814
// The input must be in contiguous CPU/GPU memory.
1773
1815
std::vector<std::pair<TRITONSERVER_MemoryType, int64_t >> alloc_perference;
1774
1816
if (device_.is_cpu ()) {
1775
- alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0 },
1776
- {TRITONSERVER_MEMORY_CPU, 0 }};
1817
+ alloc_perference = {
1818
+ {TRITONSERVER_MEMORY_CPU_PINNED, 0 }, {TRITONSERVER_MEMORY_CPU, 0 }};
1777
1819
} else {
1778
1820
alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index ()}};
1779
1821
}
@@ -1887,9 +1929,11 @@ ModelInstanceState::ReadOutputTensors(
1887
1929
1888
1930
// Output tensors may not reside on the same device as model
1889
1931
torch::Device tensor_device = output_flat.device ();
1890
- const auto memory_type = (tensor_device.type () == torch::kCPU ) ? TRITONSERVER_MEMORY_CPU
1891
- : TRITONSERVER_MEMORY_GPU;
1892
- const auto memory_id = (tensor_device.type () == torch::kCPU ) ? 0 : tensor_device.index ();
1932
+ const auto memory_type = (tensor_device.type () == torch::kCPU )
1933
+ ? TRITONSERVER_MEMORY_CPU
1934
+ : TRITONSERVER_MEMORY_GPU;
1935
+ const auto memory_id =
1936
+ (tensor_device.type () == torch::kCPU ) ? 0 : tensor_device.index ();
1893
1937
1894
1938
// Batch output doesn't support string data type yet, as it is not trivial
1895
1939
// to parse string output
@@ -1906,16 +1950,16 @@ ModelInstanceState::ReadOutputTensors(
1906
1950
return TRITONSERVER_ErrorNew (
1907
1951
TRITONSERVER_ERROR_INVALID_ARG,
1908
1952
(std::string (" output '" ) + name +
1909
- " ' is a scalar which is not supported." )
1953
+ " ' is a scalar which is not supported." )
1910
1954
.c_str ());
1911
1955
}
1912
1956
1913
1957
responder.ProcessTensor (
1914
- name, output_dtype, batchn_shape, output_buffer,
1915
- memory_type, memory_id);
1958
+ name, output_dtype, batchn_shape, output_buffer, memory_type,
1959
+ memory_id);
1916
1960
} else {
1917
1961
responder.ProcessBatchOutput (
1918
- name, *batch_output, output_buffer, memory_type, memory_id);
1962
+ name, *batch_output, output_buffer, memory_type, memory_id);
1919
1963
}
1920
1964
} else if (output_tensors[op_index].isList ()) {
1921
1965
// Custom handling for string/bytes tensor...
0 commit comments