Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for MPS on MacOS with ARM chips (Apple Silicon) #214

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
*.pyc
*__pycache__*
*core.*
_ext
tmp
*.o*
*~
*.idea
*.mp4
*.h5
*.pth
*.egg-info

/build
/dist
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ We also provide a set of Face Detector for edge device in [here](https://github.
| Pytorch (original image scale) | 90.70% | 88.16% | 73.82% |
| Mxnet | 88.72% | 86.97% | 79.19% |
| Mxnet(original image scale) | 89.58% | 87.11% | 69.12% |
<p align="center"><img src="curve/Widerface.jpg" width="640"\></p>
<p align="center"><img src="retinaface/curve/Widerface.jpg" width="640"\></p>

## FDDB Performance.
| FDDB(pytorch) | performance |
|:-|:-:|
| Mobilenet0.25 | 98.64% |
| Resnet50 | 99.22% |
<p align="center"><img src="curve/FDDB.png" width="640"\></p>
<p align="center"><img src="retinaface/curve/FDDB.png" width="640"\></p>

### Contents
- [Installation](#installation)
Expand Down Expand Up @@ -112,7 +112,7 @@ python test_fddb.py --trained_model weight_file --network mobile0.25 or resnet50

3. Download [eval_tool](https://bitbucket.org/marcopede/face-eval) to evaluate the performance.

<p align="center"><img src="curve/1.jpg" width="640"\></p>
<p align="center"><img src="retinaface/curve/1.jpg" width="640"\></p>

## TensorRT
-[TensorRT](https://github.com/wang-xinyu/tensorrtx/tree/master/retinaface)
Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[build-system]
requires = [
"setuptools>=64.0.0",
"wheel",
]
build-backend = "setuptools.build_meta"

1 change: 1 addition & 0 deletions retinaface/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .inference_framework import RetinaFaceDetector
52 changes: 34 additions & 18 deletions convert_to_onnx.py → retinaface/convert_to_onnx.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,18 @@
from __future__ import print_function
import os

import argparse
import torch
import torch.backends.cudnn as cudnn
import numpy as np

from data import cfg_mnet, cfg_re50
from layers.functions.prior_box import PriorBox
from utils.nms.py_cpu_nms import py_cpu_nms
import cv2
from models.retinaface import RetinaFace
from utils.box_utils import decode, decode_landm
from utils.timer import Timer


parser = argparse.ArgumentParser(description='Test')
parser.add_argument('-m', '--trained_model', default='./weights/mobilenet0.25_Final.pth',
type=str, help='Trained state_dict file path to open')
parser.add_argument('--network', default='mobile0.25', help='Backbone network mobile0.25 or resnet50')
parser.add_argument('--long_side', default=640, help='when origin_size is false, long_side is scaled size(320 or 640 for long side)')
parser.add_argument('--cpu', action="store_true", default=True, help='Use cpu inference')
parser.add_argument('--cpu', action="store_true", default=False, help='Use cpu inference')

args = parser.parse_args()

Expand All @@ -43,13 +37,17 @@ def remove_prefix(state_dict, prefix):
return {f(key): value for key, value in state_dict.items()}


def load_model(model, pretrained_path, load_to_cpu):
def load_model(model, pretrained_path, device):
print('Loading pretrained model from {}'.format(pretrained_path))
if load_to_cpu:
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
else:
if 'cuda' in device or device=='gpu':
device = torch.cuda.current_device()
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
elif device=='mps':
device = torch.device('mps')
pretrained_dict = torch.load(pretrained_path, map_location=device)
else:
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)

if "state_dict" in pretrained_dict.keys():
pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
else:
Expand All @@ -66,13 +64,31 @@ def load_model(model, pretrained_path, load_to_cpu):
cfg = cfg_mnet
elif args.network == "resnet50":
cfg = cfg_re50

if args.cpu:
print('--> load model and config files to CPU')
device = "cpu"
elif torch.cuda.is_available():
print('--> load model and config files to GPU')
device = "cuda"
elif torch.mps.is_available():
print('--> load model and config files to MPS')
device = "mps"
else:
raise RuntimeError('No GPU or MPS found. Please use "--cpu"')

# net and model
net = RetinaFace(cfg=cfg, phase = 'test')
net = load_model(net, args.trained_model, args.cpu)
net = load_model(net, args.trained_model, device=device)
net.eval()
print('Finished loading model!')
print(net)
device = torch.device("cpu" if args.cpu else "cuda")
print('--> Finished loading model!')
# print(net)

if device == "cuda" and torch.cuda.is_available:
cudnn.benchmark = True

# device = torch.device("cpu" if args.cpu else "cuda")
device = torch.device(device)
net = net.to(device)

# ------------------------ export -----------------------------
Expand All @@ -82,7 +98,7 @@ def load_model(model, pretrained_path, load_to_cpu):
output_names = ["output0"]
inputs = torch.randn(1, 3, args.long_side, args.long_side).to(device)

torch_out = torch.onnx._export(net, inputs, output_onnx, export_params=True, verbose=False,
torch_out = torch.onnx.export(net, inputs, output_onnx, export_params=True, verbose=False,
input_names=input_names, output_names=output_names)


File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion data/data_augment.py → retinaface/data/data_augment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import cv2
import numpy as np
import random
from utils.box_utils import matrix_iof
from retinaface.utils.box_utils import matrix_iof


def _crop(image, boxes, labels, landm, img_dim):
Expand Down
3 changes: 0 additions & 3 deletions data/wider_face.py → retinaface/data/wider_face.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import os
import os.path
import sys
import torch
import torch.utils.data as data
import cv2
Expand Down
124 changes: 87 additions & 37 deletions detect.py → retinaface/detect.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,18 @@
from __future__ import print_function
import os

import argparse
import torch
import torch.backends.cudnn as cudnn
import numpy as np
from data import cfg_mnet, cfg_re50
from layers.functions.prior_box import PriorBox
from utils.nms.py_cpu_nms import py_cpu_nms
import cv2
from models.retinaface import RetinaFace
from utils.box_utils import decode, decode_landm
import time
import cv2

parser = argparse.ArgumentParser(description='Retinaface')
from retinaface.data import cfg_mnet, cfg_re50
from retinaface.layers.functions.prior_box import PriorBox
from retinaface.utils.nms.py_cpu_nms import py_cpu_nms
from retinaface.models.retinaface import RetinaFace
from retinaface.utils.box_utils import decode, decode_landm

parser.add_argument('-m', '--trained_model', default='./weights/Resnet50_Final.pth',
type=str, help='Trained state_dict file path to open')
parser.add_argument('--network', default='resnet50', help='Backbone network mobile0.25 or resnet50')
parser.add_argument('--cpu', action="store_true", default=False, help='Use cpu inference')
parser.add_argument('--confidence_threshold', default=0.02, type=float, help='confidence_threshold')
parser.add_argument('--top_k', default=5000, type=int, help='top_k')
parser.add_argument('--nms_threshold', default=0.4, type=float, help='nms_threshold')
parser.add_argument('--keep_top_k', default=750, type=int, help='keep_top_k')
parser.add_argument('-s', '--save_image', action="store_true", default=True, help='show detection results')
parser.add_argument('--vis_thres', default=0.6, type=float, help='visualization_threshold')
args = parser.parse_args()


def check_keys(model, pretrained_state_dict):
Expand All @@ -33,58 +21,114 @@ def check_keys(model, pretrained_state_dict):
used_pretrained_keys = model_keys & ckpt_keys
unused_pretrained_keys = ckpt_keys - model_keys
missing_keys = model_keys - ckpt_keys
print('Missing keys:{}'.format(len(missing_keys)))
print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
print('Used keys:{}'.format(len(used_pretrained_keys)))
assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
return True


def remove_prefix(state_dict, prefix):
''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
print('remove prefix \'{}\''.format(prefix))
f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
return {f(key): value for key, value in state_dict.items()}


def load_model(model, pretrained_path, load_to_cpu):
def load_model(model, pretrained_path, device, url_file_name=None):
print('Loading pretrained model from {}'.format(pretrained_path))
if load_to_cpu:
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
else:

url_flag = False
if pretrained_path[:8] == 'https://':
url_flag = True

if 'cuda' in device or device=='gpu':
device = torch.cuda.current_device()
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
if url_flag:
pretrained_dict = torch.hub.load_state_dict_from_url(pretrained_path,
map_location=lambda storage, loc: storage.cuda(device),
file_name=url_file_name)
else:
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
elif device=='mps':
device = torch.device('mps')
if url_flag:
pretrained_dict = torch.hub.load_state_dict_from_url(pretrained_path,
map_location=device,
file_name=url_file_name)
else:
pretrained_dict = torch.load(pretrained_path, map_location=device)
else:
if url_flag:
pretrained_dict = torch.hub.load_state_dict_from_url(pretrained_path,
map_location=lambda storage, loc: storage,
file_name=url_file_name)
else:
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)

if "state_dict" in pretrained_dict.keys():
pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
else:
pretrained_dict = remove_prefix(pretrained_dict, 'module.')

check_keys(model, pretrained_dict)
model.load_state_dict(pretrained_dict, strict=False)

return model


if __name__ == '__main__':

parser = argparse.ArgumentParser(description='Retinaface')

parser.add_argument('-m', '--trained_model', default='./weights/Resnet50_Final.pth',
type=str, help='Trained state_dict file path to open')
parser.add_argument('--network', default='resnet50', help='Backbone network mobile0.25 or resnet50')
parser.add_argument('--cpu', action="store_true", default=False, help='Use cpu inference')
parser.add_argument('--confidence_threshold', default=0.02, type=float, help='confidence_threshold')
parser.add_argument('--top_k', default=5000, type=int, help='top_k')
parser.add_argument('--nms_threshold', default=0.4, type=float, help='nms_threshold')
parser.add_argument('--keep_top_k', default=750, type=int, help='keep_top_k')
parser.add_argument('-s', '--save_image', action="store_true", default=False, help='show detection results')
parser.add_argument('--vis_thres', default=0.6, type=float, help='visualization_threshold')
args = parser.parse_args()

torch.set_grad_enabled(False)
cfg = None
if args.network == "mobile0.25":
cfg = cfg_mnet
elif args.network == "resnet50":
cfg = cfg_re50

if args.cpu:
print('--> load model and config files to CPU')
device = "cpu"
elif torch.cuda.is_available():
print('--> load model and config files to GPU')
device = "cuda"
elif torch.mps.is_available():
print('--> load model and config files to MPS')
device = "mps"
else:
raise RuntimeError('No GPU or MPS found. Please use "--cpu"')

# net and model
net = RetinaFace(cfg=cfg, phase = 'test')
net = load_model(net, args.trained_model, args.cpu)
net = load_model(net, args.trained_model, device=device)
net.eval()
print('Finished loading model!')
print(net)
cudnn.benchmark = True
device = torch.device("cpu" if args.cpu else "cuda")
print('--> Finished loading model!')
# print(net)

if device == "cuda" and torch.cuda.is_available:
cudnn.benchmark = True

# device = torch.device("cpu" if args.cpu else "cuda")
device = torch.device(device)
net = net.to(device)

resize = 1

total_time = 0
n_loops = 100

# testing begin
for i in range(100):
image_path = "./curve/test.jpg"
for i in range(n_loops):
image_path = "curve/test.jpg"
img_raw = cv2.imread(image_path, cv2.IMREAD_COLOR)

img = np.float32(img_raw)
Expand All @@ -99,7 +143,11 @@ def load_model(model, pretrained_path, load_to_cpu):

tic = time.time()
loc, conf, landms = net(img) # forward pass
print('net forward time: {:.4f}'.format(time.time() - tic))

time_det = time.time() - tic
# print('net forward time: {:.4f}'.format(time_det))

total_time += time_det

priorbox = PriorBox(cfg, image_size=(im_height, im_width))
priors = priorbox.forward()
Expand Down Expand Up @@ -166,3 +214,5 @@ def load_model(model, pretrained_path, load_to_cpu):
name = "test.jpg"
cv2.imwrite(name, img_raw)

avg_time = total_time / n_loops
print(f'--> Average time: {avg_time:.4f} for {n_loops} loops')
Loading