Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pytorch to onnx to mxnet #176

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
__pycache__
*.onnx
*.pt
*.pth
*.tar
61 changes: 39 additions & 22 deletions convert_to_onnx.py → export.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,18 @@
from __future__ import print_function
import os
import argparse
import torch
import torch.backends.cudnn as cudnn
import numpy as np
from data import cfg_mnet, cfg_re50
from layers.functions.prior_box import PriorBox
from utils.nms.py_cpu_nms import py_cpu_nms
import cv2
from models.retinaface import RetinaFace
from utils.box_utils import decode, decode_landm
from utils.timer import Timer


parser = argparse.ArgumentParser(description='Test')
parser.add_argument('-m', '--trained_model', default='./weights/mobilenet0.25_Final.pth',
type=str, help='Trained state_dict file path to open')
parser.add_argument('--network', default='mobile0.25', help='Backbone network mobile0.25 or resnet50')
parser.add_argument('--long_side', default=640, help='when origin_size is false, long_side is scaled size(320 or 640 for long side)')
parser.add_argument('--cpu', action="store_true", default=True, help='Use cpu inference')
parser.add_argument('--hw', default="cpu", help='device for inference')
parser.add_argument('--mode', default='onnx', help='export to onnx or torch script')
parser.add_argument('output')

args = parser.parse_args()

Expand Down Expand Up @@ -60,29 +54,52 @@ def load_model(model, pretrained_path, load_to_cpu):


if __name__ == '__main__':
args = parser.parse_args()
torch.set_grad_enabled(False)
cfg = None
if args.network == "mobile0.25":
cfg = cfg_mnet
elif args.network == "resnet50":
cfg = cfg_re50
# net and model
net = RetinaFace(cfg=cfg, phase = 'test')
net = load_model(net, args.trained_model, args.cpu)
net = RetinaFace(cfg=cfg, phase='test')
load_to_cpu = args.hw == 'cpu' or args.hw == 'eia'
net = load_model(net, args.trained_model, load_to_cpu)
net.eval()
print('Finished loading model!')
print(net)
device = torch.device("cpu" if args.cpu else "cuda")
net = net.to(device)
if args.mode == 'onnx':
device = torch.device("cpu" if args.hw == 'cpu' else "cuda")
net = net.to(device)

# ------------------------ export -----------------------------
output_onnx = 'FaceDetector.onnx'
print("==> Exporting model to ONNX format at '{}'".format(output_onnx))
input_names = ["input0"]
output_names = ["output0"]
inputs = torch.randn(1, 3, args.long_side, args.long_side).to(device)

torch_out = torch.onnx._export(net, inputs, output_onnx, export_params=True, verbose=False,
input_names=input_names, output_names=output_names)

if args.mode == 'onnx':
output_onnx = args.output
print("==> Exporting model to ONNX format at '{}'".format(output_onnx))
input_names = ["input0"]
output_names = ["output0"]
inputs = torch.randn(1, 3, args.long_side, args.long_side).to(device)

torch_out = torch.onnx._export(net, inputs, output_onnx,
export_params=True, verbose=False,
input_names=input_names,
output_names=output_names,
opset_version=11)
elif args.mode == 'tscript':
output_mod = args.output
# scripted_model = torch.jit.script(net)
# torch.jit.save(scripted_model, output_mod)
inputs = torch.randn(1, 3, args.long_side, args.long_side)
check_inputs = [torch.randn(1, 3, args.long_side, args.long_side),
torch.randn(2, 3, args.long_side, args.long_side)]
if args.hw == 'cpu' or args.hw == 'gpu':
device = torch.device("cpu" if args.hw == 'cpu' else "cuda")
net = net.to(device)
traced_model = torch.jit.trace(net, inputs,
check_inputs=check_inputs)
elif args.hw == 'eia':
with torch.jit.optimized_execution(True,
{'target_device': 'eia:0'}):
traced_model = torch.jit.trace(net, inputs,
check_inputs=check_inputs)
torch.jit.save(traced_model, output_mod)
7 changes: 7 additions & 0 deletions export_models.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/bash

python3 export.py -m weights/mobilenet0.25_Final.pth --network mobile0.25 --mode tscript mnet.0.25.pt
python3 export.py -m weights/Resnet50_Final.pth --network resnet50 --mode tscript resnet50.pt

python3 export.py -m weights/mobilenet0.25_Final.pth --network mobile0.25 mnet.0.25.onnx
python3 export.py -m weights/Resnet50_Final.pth --network resnet50 resnet50.onnx
4 changes: 4 additions & 0 deletions weights/download_all_pretrain.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash
./google_download.sh 14KX6VqF69MdSPk3Tr9PlDYbq7ArpdNUW Resnet50_Final.pth
./google_download.sh 1q36RaTZnpHVl4vRuNypoEMVWiiwCqhuD mobilenetV1X0.25_pretrain.tar
./google_download.sh 15zP8BP-5IvWXWZoYTNdvUJUiBqZ1hxu1 mobilenet0.25_Final.pth
8 changes: 8 additions & 0 deletions weights/google_download.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash
set -euox pipefail
FILEID=$1
FILENAME=$2
cookie=$(FILEID=$1 ; wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies \
--no-check-certificate "https://docs.google.com/uc?export=download&id=$FILEID" -O- | \
sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$cookie&id=$FILEID" -O $FILENAME && rm -rf /tmp/cookies.txt