1
1
import cv2
2
2
import numpy as np
3
- import onnxruntime as ort
4
3
5
4
from fdfat .utils import box_utils
6
5
7
6
# ONNX_BACKENDS = ['CoreMLExecutionProvider', 'CPUExecutionProvider']
8
7
# ONNX_BACKENDS = ['CoreMLExecutionProvider']
9
8
ONNX_BACKENDS = ['CPUExecutionProvider' ]
10
9
11
- class ONNXModel :
10
+ class InferModelBackend :
12
11
13
- def __init__ (self , model_path , channel_first = True ):
12
+ ONNX = 0
13
+ TFLITE = 1
14
+
15
+ class InferModel :
16
+
17
+ def __init__ (self , model_path , channel_first = True , backend = InferModelBackend .ONNX ):
14
18
15
19
self .model_path = model_path
16
- # self.input_width, self.input_height = input_size
20
+ self .backend = backend
21
+ self .channel_first = channel_first
22
+
23
+ if self .backend == InferModelBackend .ONNX :
24
+
25
+ import onnxruntime as ort
26
+
27
+ sess_options = ort .SessionOptions ()
28
+ sess_options .intra_op_num_threads = 1
29
+ sess_options .inter_op_num_threads = 1
30
+ sess_options .execution_mode = ort .ExecutionMode .ORT_SEQUENTIAL
31
+ sess_options .graph_optimization_level = ort .GraphOptimizationLevel .ORT_ENABLE_ALL
32
+ self .session = ort .InferenceSession (self .model_path , sess_options , providers = ONNX_BACKENDS )
33
+
34
+ if channel_first :
35
+ _ , _ , self .input_height , self .input_width = self .session .get_inputs ()[0 ].shape
36
+ else :
37
+ _ , self .input_height , self .input_width , _ = self .session .get_inputs ()[0 ].shape
38
+
39
+ elif self .backend == InferModelBackend .TFLITE :
40
+ try :
41
+ import tflite_runtime .interpreter as tflite
42
+ except :
43
+ import tensorflow .lite as tflite
44
+
45
+
46
+ self .interpreter = tflite .Interpreter (model_path = self .model_path )
47
+ self .interpreter .allocate_tensors ()
17
48
18
- sess_options = ort .SessionOptions ()
19
- sess_options .intra_op_num_threads = 1
20
- sess_options .inter_op_num_threads = 1
21
- sess_options .execution_mode = ort .ExecutionMode .ORT_SEQUENTIAL
22
- sess_options .graph_optimization_level = ort .GraphOptimizationLevel .ORT_ENABLE_ALL
23
- self .session = ort .InferenceSession (self .model_path , sess_options , providers = ONNX_BACKENDS )
49
+ # Get input and output tensors.
50
+ self .input_details = self .interpreter .get_input_details ()
51
+ self .output_details = self .interpreter .get_output_details ()
24
52
25
- if channel_first :
26
- _ , _ , self .input_height , self . input_width = self . session . get_inputs ()[ 0 ]. shape
53
+ _ , self . input_height , self . input_width , _ = self . input_details [ 0 ][ "shape" ]
54
+ self .channel_first = False
27
55
else :
28
- _ , self . input_height , self .input_width , _ = self . session . get_inputs ()[ 0 ]. shape
56
+ raise AttributeError ( f"Backend ( { self .backend } ) is not supported" )
29
57
30
58
def preprocess (self , img ):
31
59
32
60
img = cv2 .resize (img , (self .input_width , self .input_height ))
33
61
img_mean = np .array ([127 , 127 , 127 ])
34
62
img = (img - img_mean ) / 128
35
63
36
- img = np .transpose (img , [2 , 0 , 1 ])
64
+ if self .channel_first :
65
+ img = np .transpose (img , [2 , 0 , 1 ])
66
+
37
67
img = np .expand_dims (img , axis = 0 )
38
68
img = img .astype (np .float32 )
39
69
40
70
return img
41
71
42
- class FaceDetector (ONNXModel ):
72
+ def _predict (self , img ):
73
+ pass
74
+
75
+ class FaceDetector (InferModel ):
76
+
77
+ def __init__ (self , model_path , channel_first = True , backend = InferModelBackend .ONNX ):
78
+ if backend != InferModelBackend .ONNX :
79
+ raise AttributeError (f"Backend ({ backend } ) is not supported for FaceDetector" )
80
+ else :
81
+ super ().__init__ (model_path , channel_first , backend )
43
82
44
83
def postprocess (self , width , height , confidences , boxes , prob_threshold , iou_threshold = 0.3 , top_k = - 1 ):
45
84
boxes = boxes [0 ]
@@ -82,13 +121,21 @@ def predict(self, ori_img, threshold=0.5):
82
121
83
122
return boxes , probs
84
123
85
- class LandmarkAligner (ONNXModel ):
86
-
124
+ class LandmarkAligner (InferModel ):
125
+
126
+ def _predict (self , img ):
127
+ if self .backend == InferModelBackend .ONNX :
128
+ return self .session .run ([], {"input" : img })[0 ]
129
+ elif self .backend == InferModelBackend .TFLITE :
130
+ self .interpreter .set_tensor (self .input_details [0 ]['index' ], img )
131
+ self .interpreter .invoke ()
132
+ return self .interpreter .get_tensor (self .output_details [0 ]['index' ])
133
+
87
134
def predict (self , ori_img , have_face_cls = False ):
88
135
height , width , _ = ori_img .shape
89
136
90
137
img = self .preprocess (ori_img )
91
- lmk = self .session . run ([], { 'input' : img })[ 0 ]
138
+ lmk = self ._predict ( img )
92
139
93
140
if have_face_cls :
94
141
lmk , face_cls = lmk [0 ][:70 * 2 ].reshape ((70 ,2 )), lmk [0 ][70 * 2 ]
@@ -103,7 +150,7 @@ def predict(self, ori_img, have_face_cls=False):
103
150
return lmk , face_cls
104
151
105
152
return lmk
106
-
153
+
107
154
def predict_frame (self , frame , bbox , have_face_cls = False ):
108
155
109
156
fheight , fwidth , _ = frame .shape
@@ -112,7 +159,6 @@ def predict_frame(self, frame, bbox, have_face_cls=False):
112
159
if bw * bh == 0 :
113
160
return np .zeros ((70 ,2 )), 0
114
161
115
-
116
162
face_img = frame [lmk_box [1 ]:lmk_box [3 ], lmk_box [0 ]:lmk_box [2 ], :]
117
163
lmk = self .predict (face_img , have_face_cls = have_face_cls )
118
164
0 commit comments