Skip to content

Commit 70bf245

Browse files
committed
Support TFLite inference
1 parent b7aa4ef commit 70bf245

File tree

3 files changed

+74
-23
lines changed

3 files changed

+74
-23
lines changed

fdfat/engine/base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -90,4 +90,5 @@ def load_checkpoint(self, checkpoint, map_location='cpu', epoch_info=True):
9090
self.start_epoch = checkpoint['epoch']
9191
self.best_epoch_loss = checkpoint['best_fit']
9292
self.best_epoch_no = checkpoint['best_epoch']
93-
LOGGER.info(f"Loaded checkpoint epoch {checkpoint['epoch']}")
93+
LOGGER.info(f"Loaded checkpoint epoch {checkpoint['epoch']}")
94+

fdfat/nn/onnx.py fdfat/nn/infer.py

+66-20
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,84 @@
11
import cv2
22
import numpy as np
3-
import onnxruntime as ort
43

54
from fdfat.utils import box_utils
65

76
# ONNX_BACKENDS = ['CoreMLExecutionProvider', 'CPUExecutionProvider']
87
# ONNX_BACKENDS = ['CoreMLExecutionProvider']
98
ONNX_BACKENDS = ['CPUExecutionProvider']
109

11-
class ONNXModel:
10+
class InferModelBackend:
1211

13-
def __init__(self, model_path, channel_first=True):
12+
ONNX = 0
13+
TFLITE = 1
14+
15+
class InferModel:
16+
17+
def __init__(self, model_path, channel_first=True, backend=InferModelBackend.ONNX):
1418

1519
self.model_path = model_path
16-
# self.input_width, self.input_height = input_size
20+
self.backend = backend
21+
self.channel_first = channel_first
22+
23+
if self.backend == InferModelBackend.ONNX:
24+
25+
import onnxruntime as ort
26+
27+
sess_options = ort.SessionOptions()
28+
sess_options.intra_op_num_threads = 1
29+
sess_options.inter_op_num_threads = 1
30+
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
31+
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
32+
self.session = ort.InferenceSession(self.model_path, sess_options, providers=ONNX_BACKENDS)
33+
34+
if channel_first:
35+
_, _, self.input_height, self.input_width = self.session.get_inputs()[0].shape
36+
else:
37+
_, self.input_height, self.input_width, _ = self.session.get_inputs()[0].shape
38+
39+
elif self.backend == InferModelBackend.TFLITE:
40+
try:
41+
import tflite_runtime.interpreter as tflite
42+
except:
43+
import tensorflow.lite as tflite
44+
45+
46+
self.interpreter = tflite.Interpreter(model_path=self.model_path)
47+
self.interpreter.allocate_tensors()
1748

18-
sess_options = ort.SessionOptions()
19-
sess_options.intra_op_num_threads = 1
20-
sess_options.inter_op_num_threads = 1
21-
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
22-
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
23-
self.session = ort.InferenceSession(self.model_path, sess_options, providers=ONNX_BACKENDS)
49+
# Get input and output tensors.
50+
self.input_details = self.interpreter.get_input_details()
51+
self.output_details = self.interpreter.get_output_details()
2452

25-
if channel_first:
26-
_, _, self.input_height, self.input_width = self.session.get_inputs()[0].shape
53+
_, self.input_height, self.input_width, _ = self.input_details[0]["shape"]
54+
self.channel_first = False
2755
else:
28-
_, self.input_height, self.input_width, _ = self.session.get_inputs()[0].shape
56+
raise AttributeError(f"Backend ({self.backend}) is not supported")
2957

3058
def preprocess(self, img):
3159

3260
img = cv2.resize(img, (self.input_width, self.input_height))
3361
img_mean = np.array([127, 127, 127])
3462
img = (img - img_mean) / 128
3563

36-
img = np.transpose(img, [2, 0, 1])
64+
if self.channel_first:
65+
img = np.transpose(img, [2, 0, 1])
66+
3767
img = np.expand_dims(img, axis=0)
3868
img = img.astype(np.float32)
3969

4070
return img
4171

42-
class FaceDetector(ONNXModel):
72+
def _predict(self, img):
73+
pass
74+
75+
class FaceDetector(InferModel):
76+
77+
def __init__(self, model_path, channel_first=True, backend=InferModelBackend.ONNX):
78+
if backend != InferModelBackend.ONNX:
79+
raise AttributeError(f"Backend ({backend}) is not supported for FaceDetector")
80+
else:
81+
super().__init__(model_path, channel_first, backend)
4382

4483
def postprocess(self, width, height, confidences, boxes, prob_threshold, iou_threshold=0.3, top_k=-1):
4584
boxes = boxes[0]
@@ -82,13 +121,21 @@ def predict(self, ori_img, threshold=0.5):
82121

83122
return boxes, probs
84123

85-
class LandmarkAligner(ONNXModel):
86-
124+
class LandmarkAligner(InferModel):
125+
126+
def _predict(self, img):
127+
if self.backend == InferModelBackend.ONNX:
128+
return self.session.run([], {"input": img})[0]
129+
elif self.backend == InferModelBackend.TFLITE:
130+
self.interpreter.set_tensor(self.input_details[0]['index'], img)
131+
self.interpreter.invoke()
132+
return self.interpreter.get_tensor(self.output_details[0]['index'])
133+
87134
def predict(self, ori_img, have_face_cls=False):
88135
height, width, _ = ori_img.shape
89136

90137
img = self.preprocess(ori_img)
91-
lmk = self.session.run([], {'input': img})[0]
138+
lmk = self._predict(img)
92139

93140
if have_face_cls:
94141
lmk, face_cls = lmk[0][:70*2].reshape((70,2)), lmk[0][70*2]
@@ -103,7 +150,7 @@ def predict(self, ori_img, have_face_cls=False):
103150
return lmk, face_cls
104151

105152
return lmk
106-
153+
107154
def predict_frame(self, frame, bbox, have_face_cls=False):
108155

109156
fheight, fwidth, _ = frame.shape
@@ -112,7 +159,6 @@ def predict_frame(self, frame, bbox, have_face_cls=False):
112159
if bw*bh == 0:
113160
return np.zeros((70,2)), 0
114161

115-
116162
face_img = frame[lmk_box[1]:lmk_box[3], lmk_box[0]:lmk_box[2], :]
117163
lmk = self.predict(face_img, have_face_cls=have_face_cls)
118164

fdfat/tracking/facial_sort.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import zmq
66
import zmq.decorators as zmqd
77

8-
from fdfat.nn.onnx import LandmarkAligner, FaceDetector
8+
from fdfat.nn.infer import LandmarkAligner, FaceDetector, InferModelBackend
99
from fdfat.tracking.sort import SORT
1010
from fdfat.utils import box_utils
1111
from fdfat.utils import profiler
@@ -53,7 +53,11 @@ def __init__(self, args):
5353
self.frame_id = 0
5454
self.current_faces = [] # to visualize
5555

56-
self.landmark = LandmarkAligner(self.args.track_landmark)
56+
if self.args.track_landmark.endswith(".onnx"):
57+
self.infer_backend = InferModelBackend.ONNX
58+
elif self.args.track_landmark.endswith(".tflite"):
59+
self.infer_backend = InferModelBackend.TFLITE
60+
self.landmark = LandmarkAligner(self.args.track_landmark, backend=self.infer_backend)
5761

5862
def run(self):
5963
self._run()

0 commit comments

Comments
 (0)