36
36
parser .add_argument ('--backbone' , type = str , help = "Backbone for detection models" )
37
37
parser .add_argument ('--print-models' , action = 'store_true' , help = "Print all the available models names and exit" )
38
38
parser .add_argument ('--to-dd-native' , action = 'store_true' , help = "Prepare the model so that the weights can be loaded on native model with dede" )
39
+ parser .add_argument ('--to-onnx' , action = "store_true" , help = "If specified, export to onnx instead of jit." )
40
+ parser .add_argument ('--weights' , type = str , help = "If not None, these weights will be embedded in the model before exporting" )
39
41
parser .add_argument ('-a' , "--all" , action = 'store_true' , help = "Export all available models" )
40
42
parser .add_argument ('-v' , "--verbose" , action = 'store_true' , help = "Set logging level to INFO" )
41
43
parser .add_argument ('-o' , "--output-dir" , default = "." , type = str , help = "Output directory for traced models" )
44
46
parser .add_argument ('--cpu' , action = 'store_true' , help = "Force models to be exported for CPU device" )
45
47
parser .add_argument ('--num_classes' , type = int , help = "Number of classes" )
46
48
parser .add_argument ('--trace' , action = 'store_true' , help = "Whether to trace model instead of scripting" )
49
+ parser .add_argument ('--batch_size' , type = int , default = 1 , help = "When exporting with fixed batch size, this will be the batch size of the model" )
50
+ parser .add_argument ('--img_width' , type = int , default = 224 , help = "Width of the image when exporting with fixed image size" )
51
+ parser .add_argument ('--img_height' , type = int , default = 224 , help = "Height of the image when exporting with fixed image size" )
47
52
48
53
args = parser .parse_args ()
49
54
@@ -112,15 +117,43 @@ def forward(self, x, ids = None, bboxes = None, labels = None):
112
117
113
118
return loss , predictions
114
119
115
- def get_detection_input ():
120
+
121
+ class DetectionModel_PredictOnly (torch .nn .Module ):
122
+ """
123
+ Adapt input and output of the model to make it exportable to
124
+ ONNX
125
+ """
126
+ def __init__ (self , model ):
127
+ super (DetectionModel_PredictOnly , self ).__init__ ()
128
+ self .model = model
129
+
130
+ def forward (self , x ):
131
+ l_x = [x [i ] for i in range (x .shape [0 ])]
132
+ predictions = self .model (l_x )
133
+ # To dede format
134
+ pred_list = list ()
135
+ for i in range (x .shape [0 ]):
136
+ pred_list .append (
137
+ torch .cat ((
138
+ torch .full (predictions [i ]["labels" ].shape , i , dtype = float ).unsqueeze (1 ),
139
+ predictions [i ]["labels" ].unsqueeze (1 ).float (),
140
+ predictions [i ]["scores" ].unsqueeze (1 ),
141
+ predictions [i ]["boxes" ]), dim = 1 ))
142
+
143
+ return torch .cat (pred_list )
144
+
145
+ def get_image_input (batch_size = 1 , img_width = 224 , img_height = 224 ):
146
+ return torch .rand (batch_size , 3 , img_width , img_height )
147
+
148
+ def get_detection_input (batch_size = 1 , img_width = 224 , img_height = 224 ):
116
149
"""
117
150
Sample input for detection models, usable for tracing or testing
118
151
"""
119
152
return (
120
- torch .rand (1 , 3 , 224 , 224 ),
121
- torch .full (( 1 ,), 0 ).long (),
122
- torch .Tensor ([1 , 1 , 200 , 200 ]).unsqueeze ( 0 ),
123
- torch .full ((1 ,), 1 ).long (),
153
+ torch .rand (batch_size , 3 , img_width , img_height ),
154
+ torch .arange ( 0 , batch_size ).long (),
155
+ torch .Tensor ([1 , 1 , 200 , 200 ]).repeat (( batch_size , 1 ) ),
156
+ torch .full ((batch_size ,), 1 ).long (),
124
157
)
125
158
126
159
model_classes = {
@@ -230,7 +263,7 @@ def get_detection_input():
230
263
else :
231
264
if args .backbone :
232
265
raise RuntimeError ("--backbone is only supported with models \" fasterrcnn\" or \" retinanet\" ." )
233
- model = model_classes [mname ](pretrained = args .pretrained , progress = args .verbose )
266
+ model = model_classes [mname ](pretrained = args .pretrained , pretrained_backbone = args . pretrained , progress = args .verbose )
234
267
235
268
if args .num_classes :
236
269
logging .info ("Using num_classes = %d" % args .num_classes )
@@ -246,9 +279,17 @@ def get_detection_input():
246
279
# replace pretrained head
247
280
model .head = M .detection .retinanet .RetinaNetHead (in_channels , num_anchors , args .num_classes )
248
281
249
- detect_model = DetectionModel (model )
250
- detect_model .train ()
251
- script_module = torch .jit .script (detect_model )
282
+ if args .to_onnx :
283
+ model = DetectionModel_PredictOnly (model )
284
+ model .eval ()
285
+ else :
286
+ model = DetectionModel (model )
287
+ model .train ()
288
+ script_module = torch .jit .script (model )
289
+
290
+ if args .num_classes is None :
291
+ # TODO dont hard code this
292
+ args .num_classes = 91
252
293
253
294
else :
254
295
kwargs = {}
@@ -264,16 +305,45 @@ def get_detection_input():
264
305
265
306
model .eval ()
266
307
267
-
268
308
# tracing or scripting model (default)
269
309
if args .trace :
270
- example = torch . rand ( 1 , 3 , 224 , 224 )
310
+ example = get_image_input ( args . batch_size , args . img_width , args . img_height )
271
311
script_module = torch .jit .trace (model , example )
272
312
else :
273
313
script_module = torch .jit .script (model )
314
+
315
+ filename = os .path .join (
316
+ args .output_dir ,
317
+ mname
318
+ + ("-pretrained" if args .pretrained else "" )
319
+ + ("-" + args .backbone if args .backbone else "" )
320
+ + ("-cls" + str (args .num_classes ) if args .num_classes else "" )
321
+ + ".pt" )
322
+
323
+ if args .weights :
324
+ # load weights
325
+ weights = torch .jit .load (args .weights ).state_dict ()
274
326
275
- filename = os .path .join (args .output_dir , mname + ("-pretrained" if args .pretrained else "" ) + ("-" + args .backbone if args .backbone else "" ) + "-cls" + str (args .num_classes ) + ".pt" )
276
- logging .info ("Saving to %s" , filename )
277
- script_module .save (filename )
327
+ if args .to_onnx :
328
+ logging .info ("Apply weights from %s to the onnx model" % args .weights )
329
+ model .load_state_dict (weights , strict = True )
330
+ else :
331
+ logging .info ("Apply weights from %s to the jit model" % args .weights )
332
+ script_module .load_state_dict (weights , strict = True )
333
+
334
+ if args .to_onnx :
335
+ logging .info ("Export model to onnx (%s)" % filename )
336
+ # remove extension
337
+ filename = filename [:- 3 ] + ".onnx"
338
+ example = get_image_input (args .batch_size , args .img_width , args .img_height )
339
+ torch .onnx .export (
340
+ model , example , filename ,
341
+ export_params = True , verbose = args .verbose ,
342
+ opset_version = 11 , do_constant_folding = True ,
343
+ input_names = ["input" ], output_names = ["output" ])
344
+ # dynamic_axes={"input":{0:"batch_size"},"output":{0:"batch_size"}}
345
+ else :
346
+ logging .info ("Saving to %s" , filename )
347
+ script_module .save (filename )
278
348
279
349
logging .info ("Done" )
0 commit comments