1
1
import time
2
+ import tensorflow as tf
3
+ physical_devices = tf .config .experimental .list_physical_devices ('GPU' )
4
+ if len (physical_devices ) > 0 :
5
+ tf .config .experimental .set_memory_growth (physical_devices [0 ], True )
2
6
from absl import app , flags , logging
3
7
from absl .flags import FLAGS
4
8
import core .utils as utils
5
- from core .yolov4 import YOLOv4 , YOLOv3 , YOLOv3_tiny , decode
9
+ from core .yolov4 import filter_boxes
10
+ from tensorflow .python .saved_model import tag_constants
6
11
from PIL import Image
7
- from core .config import cfg
8
12
import cv2
9
13
import numpy as np
10
- import tensorflow as tf
14
+ from tensorflow .compat .v1 import ConfigProto
15
+ from tensorflow .compat .v1 import InteractiveSession
11
16
12
- flags .DEFINE_string ('framework' , 'tf' , '(tf, tflite' )
13
- flags .DEFINE_string ('weights' , './data /yolov4.weights ' ,
17
+ flags .DEFINE_string ('framework' , 'tf' , '(tf, tflite, trt ' )
18
+ flags .DEFINE_string ('weights' , './checkpoints /yolov4-416 ' ,
14
19
'path to weights file' )
15
- flags .DEFINE_integer ('size' , 608 , 'resize images to' )
20
+ flags .DEFINE_integer ('size' , 416 , 'resize images to' )
16
21
flags .DEFINE_boolean ('tiny' , False , 'yolo or yolo-tiny' )
17
22
flags .DEFINE_string ('model' , 'yolov4' , 'yolov3 or yolov4' )
18
- flags .DEFINE_string ('video' , './data/road.avi' , 'path to input video' )
23
+ flags .DEFINE_string ('video' , './data/road.mp4' , 'path to input video' )
24
+ flags .DEFINE_float ('iou' , 0.45 , 'iou threshold' )
25
+ flags .DEFINE_float ('score' , 0.25 , 'score threshold' )
19
26
20
27
def main (_argv ):
21
- if FLAGS .tiny :
22
- STRIDES = np .array (cfg .YOLO .STRIDES_TINY )
23
- ANCHORS = utils .get_anchors (cfg .YOLO .ANCHORS_TINY , FLAGS .tiny )
24
- else :
25
- STRIDES = np .array (cfg .YOLO .STRIDES )
26
- if FLAGS .model == 'yolov4' :
27
- ANCHORS = utils .get_anchors (cfg .YOLO .ANCHORS , FLAGS .tiny )
28
- else :
29
- ANCHORS = utils .get_anchors (cfg .YOLO .ANCHORS_V3 , FLAGS .tiny )
30
- NUM_CLASS = len (utils .read_class_names (cfg .YOLO .CLASSES ))
31
- XYSCALE = cfg .YOLO .XYSCALE
28
+ config = ConfigProto ()
29
+ config .gpu_options .allow_growth = True
30
+ session = InteractiveSession (config = config )
31
+ STRIDES , ANCHORS , NUM_CLASS , XYSCALE = utils .load_config (FLAGS )
32
32
input_size = FLAGS .size
33
33
video_path = FLAGS .video
34
34
35
35
print ("Video from: " , video_path )
36
36
vid = cv2 .VideoCapture (video_path )
37
37
38
- if FLAGS .framework == 'tf' :
39
- input_layer = tf .keras .layers .Input ([input_size , input_size , 3 ])
40
- if FLAGS .tiny :
41
- feature_maps = YOLOv3_tiny (input_layer , NUM_CLASS )
42
- bbox_tensors = []
43
- for i , fm in enumerate (feature_maps ):
44
- bbox_tensor = decode (fm , NUM_CLASS , i )
45
- bbox_tensors .append (bbox_tensor )
46
- model = tf .keras .Model (input_layer , bbox_tensors )
47
- utils .load_weights_tiny (model , FLAGS .weights )
48
- else :
49
- if FLAGS .model == 'yolov3' :
50
- feature_maps = YOLOv3 (input_layer , NUM_CLASS )
51
- bbox_tensors = []
52
- for i , fm in enumerate (feature_maps ):
53
- bbox_tensor = decode (fm , NUM_CLASS , i )
54
- bbox_tensors .append (bbox_tensor )
55
- model = tf .keras .Model (input_layer , bbox_tensors )
56
- utils .load_weights_v3 (model , FLAGS .weights )
57
- elif FLAGS .model == 'yolov4' :
58
- feature_maps = YOLOv4 (input_layer , NUM_CLASS )
59
- bbox_tensors = []
60
- for i , fm in enumerate (feature_maps ):
61
- bbox_tensor = decode (fm , NUM_CLASS , i )
62
- bbox_tensors .append (bbox_tensor )
63
- model = tf .keras .Model (input_layer , bbox_tensors )
64
-
65
- if FLAGS .weights .split ("." )[len (FLAGS .weights .split ("." )) - 1 ] == "weights" :
66
- utils .load_weights (model , FLAGS .weights )
67
- else :
68
- model .load_weights (FLAGS .weights ).expect_partial ()
69
-
70
- model .summary ()
71
- else :
72
- # Load TFLite model and allocate tensors.
38
+ if FLAGS .framework == 'tflite' :
73
39
interpreter = tf .lite .Interpreter (model_path = FLAGS .weights )
74
40
interpreter .allocate_tensors ()
75
- # Get input and output tensors.
76
41
input_details = interpreter .get_input_details ()
77
42
output_details = interpreter .get_output_details ()
78
43
print (input_details )
79
44
print (output_details )
45
+ else :
46
+ saved_model_loaded = tf .saved_model .load (FLAGS .weights , tags = [tag_constants .SERVING ])
47
+ infer = saved_model_loaded .signatures ['serving_default' ]
80
48
81
49
while True :
82
50
return_value , frame = vid .read ()
@@ -86,26 +54,37 @@ def main(_argv):
86
54
else :
87
55
raise ValueError ("No image! Try with another video format" )
88
56
frame_size = frame .shape [:2 ]
89
- image_data = utils .image_preprocess (np .copy (frame ), [input_size , input_size ])
57
+ image_data = cv2 .resize (frame , (input_size , input_size ))
58
+ image_data = image_data / 255.
90
59
image_data = image_data [np .newaxis , ...].astype (np .float32 )
91
60
prev_time = time .time ()
92
61
93
- if FLAGS .framework == 'tf' :
94
- pred_bbox = model .predict (image_data )
95
- else :
62
+ if FLAGS .framework == 'tflite' :
96
63
interpreter .set_tensor (input_details [0 ]['index' ], image_data )
97
64
interpreter .invoke ()
98
- pred_bbox = [interpreter .get_tensor (output_details [i ]['index' ]) for i in range (len (output_details ))]
99
-
100
- if FLAGS .model == 'yolov4' :
101
- pred_bbox = utils .postprocess_bbbox (pred_bbox , ANCHORS , STRIDES , XYSCALE )
65
+ pred = [interpreter .get_tensor (output_details [i ]['index' ]) for i in range (len (output_details ))]
66
+ if FLAGS .model == 'yolov4' and FLAGS .tiny == True :
67
+ boxes , pred_conf = filter_boxes (pred [1 ], pred [0 ], score_threshold = 0.25 )
68
+ else :
69
+ boxes , pred_conf = filter_boxes (pred [0 ], pred [1 ], score_threshold = 0.25 )
102
70
else :
103
- pred_bbox = utils .postprocess_bbbox (pred_bbox , ANCHORS , STRIDES )
104
-
105
- bboxes = utils .postprocess_boxes (pred_bbox , frame_size , input_size , 0.25 )
106
- bboxes = utils .nms (bboxes , 0.213 , method = 'nms' )
71
+ batch_data = tf .constant (image_data )
72
+ pred_bbox = infer (batch_data )
73
+ for key , value in pred_bbox .items ():
74
+ boxes = value [:, :, 0 :4 ]
75
+ pred_conf = value [:, :, 4 :]
107
76
108
- image = utils .draw_bbox (frame , bboxes )
77
+ boxes , scores , classes , valid_detections = tf .image .combined_non_max_suppression (
78
+ boxes = tf .reshape (boxes , (tf .shape (boxes )[0 ], - 1 , 1 , 4 )),
79
+ scores = tf .reshape (
80
+ pred_conf , (tf .shape (pred_conf )[0 ], - 1 , tf .shape (pred_conf )[- 1 ])),
81
+ max_output_size_per_class = 50 ,
82
+ max_total_size = 50 ,
83
+ iou_threshold = FLAGS .iou ,
84
+ score_threshold = FLAGS .score
85
+ )
86
+ pred_bbox = [boxes .numpy (), scores .numpy (), classes .numpy (), valid_detections .numpy ()]
87
+ image = utils .draw_bbox (frame , pred_bbox )
109
88
curr_time = time .time ()
110
89
exec_time = curr_time - prev_time
111
90
result = np .asarray (image )
0 commit comments