Skip to content

Commit 07f69b1

Browse files
Bycobmergify[bot]
authored andcommitted
feat: add onnx export for torchvision models
1 parent 5bd9bdc commit 07f69b1

8 files changed

+165
-16
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ build/*
66
model/
77
models/
88
tools/build-cpp-netlib/
9+
__pycache__

ci/devel-trt.Dockerfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ RUN for url in \
112112
; do curl -L -s -o /tmp/p.deb $url && dpkg -i /tmp/p.deb && rm -rf /tmp/p.deb; done
113113

114114
RUN python3 -m pip install --upgrade pip
115-
RUN python3 -m pip install torch
115+
RUN python3 -m pip install torch torchvision
116116

117117
RUN apt clean -y
118118
ADD ci/gitconfig /etc/gitconfig

ci/devel.Dockerfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ RUN for url in \
115115
; do curl -L -s -o /tmp/p.deb $url && dpkg -i /tmp/p.deb && rm -rf /tmp/p.deb; done
116116

117117
RUN python3 -m pip install --upgrade pip
118-
RUN python3 -m pip install torch
118+
RUN python3 -m pip install torch torchvision
119119

120120
RUN apt clean -y
121121
ADD ci/gitconfig /etc/gitconfig

tests/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -483,3 +483,8 @@ if (USE_SIMSEARCH)
483483
endif()
484484

485485
endif()
486+
487+
# Python tests
488+
add_test(NAME ut_python
489+
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}/tests"
490+
COMMAND python3 -m unittest ut_python -v)

tests/ut_python/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
import unittest
2+
from .ut_tools_torch import *

tests/ut_python/temp/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*

tests/ut_python/ut_tools_torch.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import unittest
2+
import os
3+
import subprocess
4+
import torch
5+
import torchvision
6+
7+
_test_dir = os.path.dirname(__file__)
8+
_temp_dir = os.path.join(_test_dir, "temp")
9+
10+
def get_detection_input(batch_size=1):
11+
"""
12+
Sample input for detection models, usable for tracing or testing
13+
"""
14+
return (
15+
torch.rand(batch_size, 3, 224, 224),
16+
torch.full((batch_size,), 0).long(),
17+
torch.Tensor([1, 1, 200, 200]).repeat((batch_size, 1)),
18+
torch.full((batch_size,), 1).long(),
19+
)
20+
21+
class TestTorchvisionExport(unittest.TestCase):
22+
23+
def setUp(self):
24+
os.chdir(os.path.join(_test_dir, "../../tools/torch"))
25+
26+
def test_resnet50_export(self):
27+
# Export model (not pretrained because we don't have permission for the cache)
28+
subprocess.run(["python3", "trace_torchvision.py", "-vp", "resnet50", "-o", _temp_dir])
29+
model_file = os.path.join(_temp_dir, "resnet50.pt")
30+
self.assertTrue(os.path.exists(model_file), model_file)
31+
32+
# Export to onnx
33+
subprocess.run(["python3", "trace_torchvision.py", "-vp", "resnet50", "-o", _temp_dir, "--to-onnx", "--weights", model_file])
34+
onnx_file = os.path.join(_temp_dir, "resnet50.onnx")
35+
self.assertTrue(os.path.exists(onnx_file), onnx_file)
36+
37+
def test_fasterrcnn_export(self):
38+
# Export model (not pretrained because we don't have permission for the cache)
39+
subprocess.run(["python3", "trace_torchvision.py", "-vp", "fasterrcnn_resnet50_fpn", "-o", _temp_dir])
40+
model_file = os.path.join(_temp_dir, "fasterrcnn_resnet50_fpn-cls91.pt")
41+
self.assertTrue(os.path.exists(model_file), model_file)
42+
43+
# Test inference
44+
rfcnn = torch.jit.load(model_file)
45+
rfcnn.train()
46+
model_loss, model_preds = rfcnn(*get_detection_input())
47+
self.assertTrue(model_loss > 0)
48+
49+
rfcnn.eval()
50+
model_loss, model_preds = rfcnn(torch.rand(1, 3, 224, 224))
51+
self.assertTrue("boxes" in model_preds[0])
52+
53+
# Export to onnx
54+
subprocess.run(["python3", "trace_torchvision.py", "-vp", "fasterrcnn_resnet50_fpn", "-o", _temp_dir, "--to-onnx", "--weights", model_file])
55+
onnx_file = os.path.join(_temp_dir, "fasterrcnn_resnet50_fpn-cls91.onnx")
56+
self.assertTrue(os.path.exists(onnx_file), onnx_file)
57+
58+
def tearDown(self):
59+
print("Removing all files in %s" % _temp_dir)
60+
ignore=[".gitignore"]
61+
for f in os.listdir(_temp_dir):
62+
removed = os.path.join(_temp_dir, f)
63+
if f in ignore:
64+
print("Ignore %s" % removed)
65+
else:
66+
print("Remove %s" % removed)
67+
os.remove(removed)
68+
69+
if __name__ == '__main__':
70+
unittest.main()

tools/torch/trace_torchvision.py

+84-14
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
parser.add_argument('--backbone', type=str, help="Backbone for detection models")
3737
parser.add_argument('--print-models', action='store_true', help="Print all the available models names and exit")
3838
parser.add_argument('--to-dd-native', action='store_true', help="Prepare the model so that the weights can be loaded on native model with dede")
39+
parser.add_argument('--to-onnx', action="store_true", help="If specified, export to onnx instead of jit.")
40+
parser.add_argument('--weights', type=str, help="If not None, these weights will be embedded in the model before exporting")
3941
parser.add_argument('-a', "--all", action='store_true', help="Export all available models")
4042
parser.add_argument('-v', "--verbose", action='store_true', help="Set logging level to INFO")
4143
parser.add_argument('-o', "--output-dir", default=".", type=str, help="Output directory for traced models")
@@ -44,6 +46,9 @@
4446
parser.add_argument('--cpu', action='store_true', help="Force models to be exported for CPU device")
4547
parser.add_argument('--num_classes', type=int, help="Number of classes")
4648
parser.add_argument('--trace', action='store_true', help="Whether to trace model instead of scripting")
49+
parser.add_argument('--batch_size', type=int, default=1, help="When exporting with fixed batch size, this will be the batch size of the model")
50+
parser.add_argument('--img_width', type=int, default=224, help="Width of the image when exporting with fixed image size")
51+
parser.add_argument('--img_height', type=int, default=224, help="Height of the image when exporting with fixed image size")
4752

4853
args = parser.parse_args()
4954

@@ -112,15 +117,43 @@ def forward(self, x, ids = None, bboxes = None, labels = None):
112117

113118
return loss, predictions
114119

115-
def get_detection_input():
120+
121+
class DetectionModel_PredictOnly(torch.nn.Module):
122+
"""
123+
Adapt input and output of the model to make it exportable to
124+
ONNX
125+
"""
126+
def __init__(self, model):
127+
super(DetectionModel_PredictOnly, self).__init__()
128+
self.model = model
129+
130+
def forward(self, x):
131+
l_x = [x[i] for i in range(x.shape[0])]
132+
predictions = self.model(l_x)
133+
# To dede format
134+
pred_list = list()
135+
for i in range(x.shape[0]):
136+
pred_list.append(
137+
torch.cat((
138+
torch.full(predictions[i]["labels"].shape, i, dtype=float).unsqueeze(1),
139+
predictions[i]["labels"].unsqueeze(1).float(),
140+
predictions[i]["scores"].unsqueeze(1),
141+
predictions[i]["boxes"]), dim=1))
142+
143+
return torch.cat(pred_list)
144+
145+
def get_image_input(batch_size=1, img_width=224, img_height=224):
146+
return torch.rand(batch_size, 3, img_width, img_height)
147+
148+
def get_detection_input(batch_size=1, img_width=224, img_height=224):
116149
"""
117150
Sample input for detection models, usable for tracing or testing
118151
"""
119152
return (
120-
torch.rand(1, 3, 224, 224),
121-
torch.full((1,), 0).long(),
122-
torch.Tensor([1, 1, 200, 200]).unsqueeze(0),
123-
torch.full((1,), 1).long(),
153+
torch.rand(batch_size, 3, img_width, img_height),
154+
torch.arange(0, batch_size).long(),
155+
torch.Tensor([1, 1, 200, 200]).repeat((batch_size, 1)),
156+
torch.full((batch_size,), 1).long(),
124157
)
125158

126159
model_classes = {
@@ -230,7 +263,7 @@ def get_detection_input():
230263
else:
231264
if args.backbone:
232265
raise RuntimeError("--backbone is only supported with models \"fasterrcnn\" or \"retinanet\".")
233-
model = model_classes[mname](pretrained=args.pretrained, progress=args.verbose)
266+
model = model_classes[mname](pretrained=args.pretrained, pretrained_backbone=args.pretrained, progress=args.verbose)
234267

235268
if args.num_classes:
236269
logging.info("Using num_classes = %d" % args.num_classes)
@@ -246,9 +279,17 @@ def get_detection_input():
246279
# replace pretrained head
247280
model.head = M.detection.retinanet.RetinaNetHead(in_channels, num_anchors, args.num_classes)
248281

249-
detect_model = DetectionModel(model)
250-
detect_model.train()
251-
script_module = torch.jit.script(detect_model)
282+
if args.to_onnx:
283+
model = DetectionModel_PredictOnly(model)
284+
model.eval()
285+
else:
286+
model = DetectionModel(model)
287+
model.train()
288+
script_module = torch.jit.script(model)
289+
290+
if args.num_classes is None:
291+
# TODO dont hard code this
292+
args.num_classes = 91
252293

253294
else:
254295
kwargs = {}
@@ -264,16 +305,45 @@ def get_detection_input():
264305

265306
model.eval()
266307

267-
268308
# tracing or scripting model (default)
269309
if args.trace:
270-
example = torch.rand(1, 3, 224, 224)
310+
example = get_image_input(args.batch_size, args.img_width, args.img_height)
271311
script_module = torch.jit.trace(model, example)
272312
else:
273313
script_module = torch.jit.script(model)
314+
315+
filename = os.path.join(
316+
args.output_dir,
317+
mname
318+
+ ("-pretrained" if args.pretrained else "")
319+
+ ("-" + args.backbone if args.backbone else "")
320+
+ ("-cls" + str(args.num_classes) if args.num_classes else "")
321+
+ ".pt")
322+
323+
if args.weights:
324+
# load weights
325+
weights = torch.jit.load(args.weights).state_dict()
274326

275-
filename = os.path.join(args.output_dir, mname + ("-pretrained" if args.pretrained else "") + ("-" + args.backbone if args.backbone else "") + "-cls" + str(args.num_classes) + ".pt")
276-
logging.info("Saving to %s", filename)
277-
script_module.save(filename)
327+
if args.to_onnx:
328+
logging.info("Apply weights from %s to the onnx model" % args.weights)
329+
model.load_state_dict(weights, strict=True)
330+
else:
331+
logging.info("Apply weights from %s to the jit model" % args.weights)
332+
script_module.load_state_dict(weights, strict=True)
333+
334+
if args.to_onnx:
335+
logging.info("Export model to onnx (%s)" % filename)
336+
# remove extension
337+
filename = filename[:-3] + ".onnx"
338+
example = get_image_input(args.batch_size, args.img_width, args.img_height)
339+
torch.onnx.export(
340+
model, example, filename,
341+
export_params=True, verbose=args.verbose,
342+
opset_version=11, do_constant_folding=True,
343+
input_names=["input"], output_names=["output"])
344+
# dynamic_axes={"input":{0:"batch_size"},"output":{0:"batch_size"}}
345+
else:
346+
logging.info("Saving to %s", filename)
347+
script_module.save(filename)
278348

279349
logging.info("Done")

0 commit comments

Comments
 (0)