Skip to content

[feat] Save/load serialized engine to/from disk #950

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 143 additions & 116 deletions torch2trt/torch2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import numpy as np
import io
import os
from collections import defaultdict
import importlib

Expand Down Expand Up @@ -538,6 +539,7 @@ def torch2trt(module,
onnx_opset=None,
max_batch_size=None,
avg_timing_iterations=None,
engine_file_path = './trt_engine/engine.plan',
**kwargs):

# capture arguments to provide to context
Expand All @@ -559,151 +561,176 @@ def torch2trt(module,
input_flattener = Flattener.from_value(inputs)
output_flattener = Flattener.from_value(outputs)

# infer default parameters from dataset

if min_shapes == None:
min_shapes_flat = [tuple(t) for t in dataset.min_shapes(flat=True)]
else:
min_shapes_flat = input_flattener.flatten(min_shapes)

if max_shapes == None:
max_shapes_flat = [tuple(t) for t in dataset.max_shapes(flat=True)]
else:
max_shapes_flat = input_flattener.flatten(max_shapes)

if opt_shapes == None:
opt_shapes_flat = [tuple(t) for t in dataset.median_numel_shapes(flat=True)]
else:
opt_shapes_flat = input_flattener.flatten(opt_shapes)

# handle legacy max_batch_size
if max_batch_size is not None:
min_shapes_flat = [(1,) + s[1:] for s in min_shapes_flat]
max_shapes_flat = [(max_batch_size,) + s[1:] for s in max_shapes_flat]

dynamic_axes_flat = infer_dynamic_axes(min_shapes_flat, max_shapes_flat)

if default_device_type == trt.DeviceType.DLA:
for value in dynamic_axes_flat:
if len(value) > 0:
raise ValueError('Dataset cannot have multiple shapes when using DLA')

logger = trt.Logger(log_level)
builder = trt.Builder(logger)
config = builder.create_builder_config()

if input_names is None:
input_names = default_input_names(input_flattener.size)
if output_names is None:
output_names = default_output_names(output_flattener.size)

if use_onnx:
import onnx_graphsurgeon as gs
import onnx

module_flat = Flatten(module, input_flattener, output_flattener)
inputs_flat = input_flattener.flatten(inputs)

f = io.BytesIO()
torch.onnx.export(
module_flat,
inputs_flat,
f,
input_names=input_names,
output_names=output_names,
dynamic_axes={
name: {int(axis): f'input_{index}_axis_{axis}' for axis in dynamic_axes_flat[index]}
for index, name in enumerate(input_names)
},
opset_version=onnx_opset
)
f.seek(0)
def build_engine():
# infer default parameters from dataset

if min_shapes == None:
min_shapes_flat = [tuple(t) for t in dataset.min_shapes(flat=True)]
else:
min_shapes_flat = input_flattener.flatten(min_shapes)

if max_shapes == None:
max_shapes_flat = [tuple(t) for t in dataset.max_shapes(flat=True)]
else:
max_shapes_flat = input_flattener.flatten(max_shapes)

onnx_graph = gs.import_onnx(onnx.load(f))
onnx_graph.fold_constants().cleanup()
if opt_shapes == None:
opt_shapes_flat = [tuple(t) for t in dataset.median_numel_shapes(flat=True)]
else:
opt_shapes_flat = input_flattener.flatten(opt_shapes)

# handle legacy max_batch_size
if max_batch_size is not None:
min_shapes_flat = [(1,) + s[1:] for s in min_shapes_flat]
max_shapes_flat = [(max_batch_size,) + s[1:] for s in max_shapes_flat]

f = io.BytesIO()
onnx.save(gs.export_onnx(onnx_graph), f)
f.seek(0)
dynamic_axes_flat = infer_dynamic_axes(min_shapes_flat, max_shapes_flat)

onnx_bytes = f.read()
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
parser.parse(onnx_bytes)
if default_device_type == trt.DeviceType.DLA:
for value in dynamic_axes_flat:
if len(value) > 0:
raise ValueError('Dataset cannot have multiple shapes when using DLA')

else:
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
with ConversionContext(network, torch2trt_kwargs=kwargs, builder_config=config, logger=logger) as ctx:

logger = trt.Logger(log_level)
builder = trt.Builder(logger)
config = builder.create_builder_config()

if use_onnx:
import onnx_graphsurgeon as gs
import onnx

module_flat = Flatten(module, input_flattener, output_flattener)
inputs_flat = input_flattener.flatten(inputs)

ctx.add_inputs(inputs_flat, input_names, dynamic_axes=dynamic_axes_flat)
f = io.BytesIO()
torch.onnx.export(
module_flat,
inputs_flat,
f,
input_names=input_names,
output_names=output_names,
dynamic_axes={
name: {int(axis): f'input_{index}_axis_{axis}' for axis in dynamic_axes_flat[index]}
for index, name in enumerate(input_names)
},
opset_version=onnx_opset
)
f.seek(0)

outputs = module(*inputs)
onnx_graph = gs.import_onnx(onnx.load(f))
onnx_graph.fold_constants().cleanup()

outputs_flat = output_flattener.flatten(outputs)
ctx.mark_outputs(outputs_flat, output_names)

# set max workspace size
if trt_version() < "10.0":
config.max_workspace_size = max_workspace_size

f = io.BytesIO()
onnx.save(gs.export_onnx(onnx_graph), f)
f.seek(0)

# set number of avg timing itrs.
if avg_timing_iterations is not None:
config.avg_timing_iterations = avg_timing_iterations
onnx_bytes = f.read()
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
parser.parse(onnx_bytes)

if fp16_mode:
config.set_flag(trt.BuilderFlag.FP16)
else:
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
with ConversionContext(network, torch2trt_kwargs=kwargs, builder_config=config, logger=logger) as ctx:

config.default_device_type = default_device_type
if gpu_fallback:
config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
config.DLA_core = dla_core

if strict_type_constraints:
config.set_flag(trt.BuilderFlag.STRICT_TYPES)
inputs_flat = input_flattener.flatten(inputs)

ctx.add_inputs(inputs_flat, input_names, dynamic_axes=dynamic_axes_flat)

outputs = module(*inputs)

outputs_flat = output_flattener.flatten(outputs)
ctx.mark_outputs(outputs_flat, output_names)

# set max workspace size
if trt_version() < "10.0":
config.max_workspace_size = max_workspace_size

# set number of avg timing itrs.
if avg_timing_iterations is not None:
config.avg_timing_iterations = avg_timing_iterations

if fp16_mode:
config.set_flag(trt.BuilderFlag.FP16)

if int8_mode:
config.default_device_type = default_device_type
if gpu_fallback:
config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
config.DLA_core = dla_core

# default to use input tensors for calibration
if int8_calib_dataset is None:
int8_calib_dataset = dataset
if strict_type_constraints:
config.set_flag(trt.BuilderFlag.STRICT_TYPES)

config.set_flag(trt.BuilderFlag.INT8)
if int8_mode:

#Making sure not to run calibration with QAT mode on
if not 'qat_mode' in kwargs:
calibrator = DatasetCalibrator(
int8_calib_dataset, algorithm=int8_calib_algorithm
# default to use input tensors for calibration
if int8_calib_dataset is None:
int8_calib_dataset = dataset

config.set_flag(trt.BuilderFlag.INT8)

#Making sure not to run calibration with QAT mode on
if not 'qat_mode' in kwargs:
calibrator = DatasetCalibrator(
int8_calib_dataset, algorithm=int8_calib_algorithm
)
config.int8_calibrator = calibrator

# OPTIMIZATION PROFILE
profile = builder.create_optimization_profile()
for index, name in enumerate(input_names):
profile.set_shape(
name,
min_shapes_flat[index],
opt_shapes_flat[index],
max_shapes_flat[index]
)
config.int8_calibrator = calibrator

# OPTIMIZATION PROFILE
profile = builder.create_optimization_profile()
for index, name in enumerate(input_names):
profile.set_shape(
name,
min_shapes_flat[index],
opt_shapes_flat[index],
max_shapes_flat[index]
)
config.add_optimization_profile(profile)
config.add_optimization_profile(profile)

if int8_mode:
config.set_calibration_profile(profile)
if int8_mode:
config.set_calibration_profile(profile)

# BUILD ENGINE
# BUILD ENGINE

if trt_version() < "10.0":
engine = builder.build_engine(network, config)
else:
engine = builder.build_serialized_network(network, config)
if trt_version() < "10.0":
engine = builder.build_engine(network, config)
else:
engine = builder.build_serialized_network(network, config)

# SAVE ENGINE
os.makedirs(os.path.dirname(engine_file_path), exist_ok=True)
if trt_version() < "10.0":
with open(engine_file_path, "wb") as f:
f.write(engine.serialize())
else:
with open(engine_file_path, "wb") as f:
f.write(engine)
return engine, network

load_engine = False
if os.path.exists(engine_file_path):
# If a serialized engine exists, use it instead of building an engine.
print("Reading engine from file {}.".format(engine_file_path))
try:
with open(engine_file_path,
"rb") as f, trt.Logger(log_level) as logger, trt.Runtime(logger) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
if engine is not None:
load_engine = True
except:
print("Failed to load engine from file {}. run build_engine().".format(engine_file_path))
if not load_engine:
engine, network = build_engine()

module_trt = TRTModule(engine, input_names, output_names, input_flattener=input_flattener, output_flattener=output_flattener)

if keep_network:
if not load_engine and keep_network:
module_trt.network = network

return module_trt
Expand Down