diff --git a/torch2trt/converters/mul.py b/torch2trt/converters/mul.py index eefd744c..ebbcccdf 100644 --- a/torch2trt/converters/mul.py +++ b/torch2trt/converters/mul.py @@ -3,6 +3,7 @@ @tensorrt_converter('torch.mul') +@tensorrt_converter('torch.Tensor.mul_') @tensorrt_converter('torch.Tensor.__imul__') @tensorrt_converter('torch.Tensor.__mul__') @tensorrt_converter('torch.Tensor.__rmul__') diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index 6a33a9ee..206dcdb8 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -555,10 +555,15 @@ def torch2trt(module, outputs = (outputs,) ctx.mark_outputs(outputs, output_names) - builder.max_workspace_size = max_workspace_size - builder.fp16_mode = fp16_mode builder.max_batch_size = max_batch_size - builder.strict_type_constraints = strict_type_constraints + config = builder.create_builder_config() + config.max_workspace_size = max_workspace_size + + if strict_type_constraints: + config.set_flag(trt.BuilderFlag.STRICT_TYPES) + + if fp16_mode: + config.set_flag(trt.BuilderFlag.FP16) if int8_mode: @@ -566,7 +571,7 @@ def torch2trt(module, if int8_calib_dataset is None: int8_calib_dataset = TensorBatchDataset(inputs_in) - builder.int8_mode = True + config.set_flag(trt.BuilderFlag.INT8) #Making sure not to run calibration with QAT mode on if not 'qat_mode' in kwargs: @@ -575,7 +580,7 @@ def torch2trt(module, inputs, int8_calib_dataset, batch_size=int8_calib_batch_size, algorithm=int8_calib_algorithm ) - engine = builder.build_cuda_engine(network) + engine = builder.build_engine(network, config) module_trt = TRTModule(engine, input_names, output_names)