Skip to content

Commit 9ab36aa

Browse files
committed
update training code
1 parent 4bd8993 commit 9ab36aa

8 files changed

+171
-259
lines changed

core/dataset.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
class Dataset(object):
1414
"""implement Dataset here"""
1515

16-
def __init__(self, is_training: bool, dataset_type: str = "converted_coco", tiny: bool = False):
16+
def __init__(self, FLAGS, is_training: bool, dataset_type: str = "converted_coco"):
17+
self.tiny = FLAGS.tiny
18+
self.strides, self.anchors, NUM_CLASS, XYSCALE = utils.load_config(FLAGS)
1719
self.dataset_type = dataset_type
1820

1921
self.annot_path = (
@@ -28,12 +30,8 @@ def __init__(self, is_training: bool, dataset_type: str = "converted_coco", tiny
2830
self.data_aug = cfg.TRAIN.DATA_AUG if is_training else cfg.TEST.DATA_AUG
2931

3032
self.train_input_sizes = cfg.TRAIN.INPUT_SIZE
31-
self.strides = (
32-
np.array(cfg.YOLO.STRIDES_TINY) if tiny else np.array(cfg.YOLO.STRIDES)
33-
)
3433
self.classes = utils.read_class_names(cfg.YOLO.CLASSES)
3534
self.num_classes = len(self.classes)
36-
self.anchors = np.array(utils.get_anchors(cfg.YOLO.ANCHORS))
3735
self.anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE
3836
self.max_bbox_per_scale = 150
3937

core/utils.py

+13-71
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,19 @@
55
import tensorflow as tf
66
from core.config import cfg
77

8+
def load_freeze_layer(model='yolov4', tiny=False):
9+
if tiny:
10+
if model == 'yolov3':
11+
freeze_layouts = ['conv2d_9', 'conv2d_12']
12+
else:
13+
freeze_layouts = ['conv2d_17', 'conv2d_20']
14+
else:
15+
if model == 'yolov3':
16+
freeze_layouts = ['conv2d_58', 'conv2d_66', 'conv2d_74']
17+
else:
18+
freeze_layouts = ['conv2d_93', 'conv2d_101', 'conv2d_109']
19+
return freeze_layouts
20+
821
def load_weights(model, weights_file, model_name='yolov4', is_tiny=False):
922
if is_tiny:
1023
if model_name == 'yolov3':
@@ -89,7 +102,6 @@ def get_anchors(anchors_path, tiny=False):
89102
else:
90103
return anchors.reshape(3, 3, 2)
91104

92-
93105
def image_preprocess(image, target_size, gt_boxes=None):
94106

95107
ih, iw = target_size
@@ -112,7 +124,6 @@ def image_preprocess(image, target_size, gt_boxes=None):
112124
gt_boxes[:, [1, 3]] = gt_boxes[:, [1, 3]] * scale + dh
113125
return image_paded, gt_boxes
114126

115-
116127
def draw_bbox(image, bboxes, classes=read_class_names(cfg.YOLO.CLASSES), show_label=True):
117128
num_classes = len(classes)
118129
image_h, image_w, _ = image.shape
@@ -149,10 +160,8 @@ def draw_bbox(image, bboxes, classes=read_class_names(cfg.YOLO.CLASSES), show_la
149160

150161
cv2.putText(image, bbox_mess, (c1[0], np.float32(c1[1] - 2)), cv2.FONT_HERSHEY_SIMPLEX,
151162
fontScale, (0, 0, 0), bbox_thick // 2, lineType=cv2.LINE_AA)
152-
153163
return image
154164

155-
156165
def bbox_iou(bboxes1, bboxes2):
157166
"""
158167
@param bboxes1: (a, b, ..., 4)
@@ -316,7 +325,6 @@ def bbox_ciou(bboxes1, bboxes2):
316325

317326
return ciou
318327

319-
320328
def nms(bboxes, iou_threshold, sigma=0.3, method='nms'):
321329
"""
322330
:param bboxes: (xmin, ymin, xmax, ymax, score, class)
@@ -354,72 +362,6 @@ def nms(bboxes, iou_threshold, sigma=0.3, method='nms'):
354362

355363
return best_bboxes
356364

357-
def diounms_sort(bboxes, iou_threshold, sigma=0.3, method='nms', beta_nms=0.6):
358-
best_bboxes = []
359-
return best_bboxes
360-
def postprocess_bbbox(pred_bbox, ANCHORS, STRIDES, XYSCALE=[1,1,1]):
361-
for i, pred in enumerate(pred_bbox):
362-
conv_shape = pred.shape
363-
output_size = conv_shape[1]
364-
conv_raw_dxdy = pred[:, :, :, :, 0:2]
365-
conv_raw_dwdh = pred[:, :, :, :, 2:4]
366-
xy_grid = np.meshgrid(np.arange(output_size), np.arange(output_size))
367-
xy_grid = np.expand_dims(np.stack(xy_grid, axis=-1), axis=2) # [gx, gy, 1, 2]
368-
369-
xy_grid = np.tile(tf.expand_dims(xy_grid, axis=0), [1, 1, 1, 3, 1])
370-
xy_grid = xy_grid.astype(np.float)
371-
372-
# pred_xy = (tf.sigmoid(conv_raw_dxdy) + xy_grid) * STRIDES[i]
373-
pred_xy = ((tf.sigmoid(conv_raw_dxdy) * XYSCALE[i]) - 0.5 * (XYSCALE[i] - 1) + xy_grid) * STRIDES[i]
374-
# pred_wh = (tf.exp(conv_raw_dwdh) * ANCHORS[i]) * STRIDES[i]
375-
pred_wh = (tf.exp(conv_raw_dwdh) * ANCHORS[i])
376-
pred[:, :, :, :, 0:4] = tf.concat([pred_xy, pred_wh], axis=-1)
377-
378-
pred_bbox = [tf.reshape(x, (-1, tf.shape(x)[-1])) for x in pred_bbox]
379-
pred_bbox = tf.concat(pred_bbox, axis=0)
380-
return pred_bbox
381-
def postprocess_boxes(pred_bbox, org_img_shape, input_size, score_threshold):
382-
383-
valid_scale=[0, np.inf]
384-
pred_bbox = np.array(pred_bbox)
385-
386-
pred_xywh = pred_bbox[:, 0:4]
387-
pred_conf = pred_bbox[:, 4]
388-
pred_prob = pred_bbox[:, 5:]
389-
390-
# # (1) (x, y, w, h) --> (xmin, ymin, xmax, ymax)
391-
pred_coor = np.concatenate([pred_xywh[:, :2] - pred_xywh[:, 2:] * 0.5,
392-
pred_xywh[:, :2] + pred_xywh[:, 2:] * 0.5], axis=-1)
393-
# # (2) (xmin, ymin, xmax, ymax) -> (xmin_org, ymin_org, xmax_org, ymax_org)
394-
org_h, org_w = org_img_shape
395-
resize_ratio = min(input_size / org_w, input_size / org_h)
396-
397-
dw = (input_size - resize_ratio * org_w) / 2
398-
dh = (input_size - resize_ratio * org_h) / 2
399-
400-
pred_coor[:, 0::2] = 1.0 * (pred_coor[:, 0::2] - dw) / resize_ratio
401-
pred_coor[:, 1::2] = 1.0 * (pred_coor[:, 1::2] - dh) / resize_ratio
402-
403-
# # (3) clip some boxes those are out of range
404-
pred_coor = np.concatenate([np.maximum(pred_coor[:, :2], [0, 0]),
405-
np.minimum(pred_coor[:, 2:], [org_w - 1, org_h - 1])], axis=-1)
406-
invalid_mask = np.logical_or((pred_coor[:, 0] > pred_coor[:, 2]), (pred_coor[:, 1] > pred_coor[:, 3]))
407-
pred_coor[invalid_mask] = 0
408-
409-
# # (4) discard some invalid boxes
410-
bboxes_scale = np.sqrt(np.multiply.reduce(pred_coor[:, 2:4] - pred_coor[:, 0:2], axis=-1))
411-
scale_mask = np.logical_and((valid_scale[0] < bboxes_scale), (bboxes_scale < valid_scale[1]))
412-
413-
# # (5) discard some boxes with low scores
414-
classes = np.argmax(pred_prob, axis=-1)
415-
scores = pred_conf * pred_prob[np.arange(len(pred_coor)), classes]
416-
# scores = pred_prob[np.arange(len(pred_coor)), classes]
417-
score_mask = scores > score_threshold
418-
mask = np.logical_and(scale_mask, score_mask)
419-
coors, scores, classes = pred_coor[mask], scores[mask], classes[mask]
420-
421-
return np.concatenate([coors, scores[:, np.newaxis], classes[:, np.newaxis]], axis=-1)
422-
423365
def freeze_all(model, frozen=True):
424366
model.trainable = not frozen
425367
if isinstance(model, tf.keras.Model):

core/yolov4.py

+23
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,29 @@ def decode(conv_output, output_size, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE=[1,
168168
else:
169169
return decode_tf(conv_output, output_size, NUM_CLASS, STRIDES, ANCHORS, i=i, XYSCALE=XYSCALE)
170170

171+
def decode_train(conv_output, output_size, NUM_CLASS, STRIDES, ANCHORS, i=0, XYSCALE=[1, 1, 1]):
172+
conv_output = tf.reshape(conv_output,
173+
(tf.shape(conv_output)[0], output_size, output_size, 3, 5 + NUM_CLASS))
174+
175+
conv_raw_dxdy, conv_raw_dwdh, conv_raw_conf, conv_raw_prob = tf.split(conv_output, (2, 2, 1, NUM_CLASS),
176+
axis=-1)
177+
178+
xy_grid = tf.meshgrid(tf.range(output_size), tf.range(output_size))
179+
xy_grid = tf.expand_dims(tf.stack(xy_grid, axis=-1), axis=2) # [gx, gy, 1, 2]
180+
xy_grid = tf.tile(tf.expand_dims(xy_grid, axis=0), [tf.shape(conv_output)[0], 1, 1, 3, 1])
181+
182+
xy_grid = tf.cast(xy_grid, tf.float32)
183+
184+
pred_xy = ((tf.sigmoid(conv_raw_dxdy) * XYSCALE[i]) - 0.5 * (XYSCALE[i] - 1) + xy_grid) * \
185+
STRIDES[i]
186+
pred_wh = (tf.exp(conv_raw_dwdh) * ANCHORS[i])
187+
pred_xywh = tf.concat([pred_xy, pred_wh], axis=-1)
188+
189+
pred_conf = tf.sigmoid(conv_raw_conf)
190+
pred_prob = tf.sigmoid(conv_raw_prob)
191+
192+
return tf.concat([pred_xywh, pred_conf, pred_prob], axis=-1)
193+
171194
def decode_tf(conv_output, output_size, NUM_CLASS, STRIDES, ANCHORS, i=0, XYSCALE=[1, 1, 1]):
172195
conv_output = tf.reshape(conv_output,
173196
(tf.shape(conv_output)[0], output_size, output_size, 3, 5 + NUM_CLASS))

detect.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import time
2-
31
import tensorflow as tf
42
physical_devices = tf.config.experimental.list_physical_devices('GPU')
53
if len(physical_devices) > 0:
@@ -10,17 +8,16 @@
108
from core.yolov4 import filter_boxes
119
from tensorflow.python.saved_model import tag_constants
1210
from PIL import Image
13-
from core.config import cfg
1411
import cv2
1512
import numpy as np
1613
from tensorflow.compat.v1 import ConfigProto
1714
from tensorflow.compat.v1 import InteractiveSession
1815

1916
flags.DEFINE_string('framework', 'tf', '(tf, tflite, trt')
20-
flags.DEFINE_string('weights', './checkpoints/yolov4-416',
17+
flags.DEFINE_string('weights', './checkpoints/yolov4-tiny-416',
2118
'path to weights file')
2219
flags.DEFINE_integer('size', 416, 'resize images to')
23-
flags.DEFINE_boolean('tiny', False, 'yolo or yolo-tiny')
20+
flags.DEFINE_boolean('tiny', True, 'yolo or yolo-tiny')
2421
flags.DEFINE_string('model', 'yolov4', 'yolov3 or yolov4')
2522
flags.DEFINE_string('image', './data/kite.jpg', 'path to input image')
2623
flags.DEFINE_string('output', 'result.png', 'path to output image')

detectvideo.py

+46-67
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,50 @@
11
import time
2+
import tensorflow as tf
3+
physical_devices = tf.config.experimental.list_physical_devices('GPU')
4+
if len(physical_devices) > 0:
5+
tf.config.experimental.set_memory_growth(physical_devices[0], True)
26
from absl import app, flags, logging
37
from absl.flags import FLAGS
48
import core.utils as utils
5-
from core.yolov4 import YOLOv4, YOLOv3, YOLOv3_tiny, decode
9+
from core.yolov4 import filter_boxes
10+
from tensorflow.python.saved_model import tag_constants
611
from PIL import Image
7-
from core.config import cfg
812
import cv2
913
import numpy as np
10-
import tensorflow as tf
14+
from tensorflow.compat.v1 import ConfigProto
15+
from tensorflow.compat.v1 import InteractiveSession
1116

12-
flags.DEFINE_string('framework', 'tf', '(tf, tflite')
13-
flags.DEFINE_string('weights', './data/yolov4.weights',
17+
flags.DEFINE_string('framework', 'tf', '(tf, tflite, trt')
18+
flags.DEFINE_string('weights', './checkpoints/yolov4-416',
1419
'path to weights file')
15-
flags.DEFINE_integer('size', 608, 'resize images to')
20+
flags.DEFINE_integer('size', 416, 'resize images to')
1621
flags.DEFINE_boolean('tiny', False, 'yolo or yolo-tiny')
1722
flags.DEFINE_string('model', 'yolov4', 'yolov3 or yolov4')
18-
flags.DEFINE_string('video', './data/road.avi', 'path to input video')
23+
flags.DEFINE_string('video', './data/road.mp4', 'path to input video')
24+
flags.DEFINE_float('iou', 0.45, 'iou threshold')
25+
flags.DEFINE_float('score', 0.25, 'score threshold')
1926

2027
def main(_argv):
21-
if FLAGS.tiny:
22-
STRIDES = np.array(cfg.YOLO.STRIDES_TINY)
23-
ANCHORS = utils.get_anchors(cfg.YOLO.ANCHORS_TINY, FLAGS.tiny)
24-
else:
25-
STRIDES = np.array(cfg.YOLO.STRIDES)
26-
if FLAGS.model == 'yolov4':
27-
ANCHORS = utils.get_anchors(cfg.YOLO.ANCHORS, FLAGS.tiny)
28-
else:
29-
ANCHORS = utils.get_anchors(cfg.YOLO.ANCHORS_V3, FLAGS.tiny)
30-
NUM_CLASS = len(utils.read_class_names(cfg.YOLO.CLASSES))
31-
XYSCALE = cfg.YOLO.XYSCALE
28+
config = ConfigProto()
29+
config.gpu_options.allow_growth = True
30+
session = InteractiveSession(config=config)
31+
STRIDES, ANCHORS, NUM_CLASS, XYSCALE = utils.load_config(FLAGS)
3232
input_size = FLAGS.size
3333
video_path = FLAGS.video
3434

3535
print("Video from: ", video_path )
3636
vid = cv2.VideoCapture(video_path)
3737

38-
if FLAGS.framework == 'tf':
39-
input_layer = tf.keras.layers.Input([input_size, input_size, 3])
40-
if FLAGS.tiny:
41-
feature_maps = YOLOv3_tiny(input_layer, NUM_CLASS)
42-
bbox_tensors = []
43-
for i, fm in enumerate(feature_maps):
44-
bbox_tensor = decode(fm, NUM_CLASS, i)
45-
bbox_tensors.append(bbox_tensor)
46-
model = tf.keras.Model(input_layer, bbox_tensors)
47-
utils.load_weights_tiny(model, FLAGS.weights)
48-
else:
49-
if FLAGS.model == 'yolov3':
50-
feature_maps = YOLOv3(input_layer, NUM_CLASS)
51-
bbox_tensors = []
52-
for i, fm in enumerate(feature_maps):
53-
bbox_tensor = decode(fm, NUM_CLASS, i)
54-
bbox_tensors.append(bbox_tensor)
55-
model = tf.keras.Model(input_layer, bbox_tensors)
56-
utils.load_weights_v3(model, FLAGS.weights)
57-
elif FLAGS.model == 'yolov4':
58-
feature_maps = YOLOv4(input_layer, NUM_CLASS)
59-
bbox_tensors = []
60-
for i, fm in enumerate(feature_maps):
61-
bbox_tensor = decode(fm, NUM_CLASS, i)
62-
bbox_tensors.append(bbox_tensor)
63-
model = tf.keras.Model(input_layer, bbox_tensors)
64-
65-
if FLAGS.weights.split(".")[len(FLAGS.weights.split(".")) - 1] == "weights":
66-
utils.load_weights(model, FLAGS.weights)
67-
else:
68-
model.load_weights(FLAGS.weights).expect_partial()
69-
70-
model.summary()
71-
else:
72-
# Load TFLite model and allocate tensors.
38+
if FLAGS.framework == 'tflite':
7339
interpreter = tf.lite.Interpreter(model_path=FLAGS.weights)
7440
interpreter.allocate_tensors()
75-
# Get input and output tensors.
7641
input_details = interpreter.get_input_details()
7742
output_details = interpreter.get_output_details()
7843
print(input_details)
7944
print(output_details)
45+
else:
46+
saved_model_loaded = tf.saved_model.load(FLAGS.weights, tags=[tag_constants.SERVING])
47+
infer = saved_model_loaded.signatures['serving_default']
8048

8149
while True:
8250
return_value, frame = vid.read()
@@ -86,26 +54,37 @@ def main(_argv):
8654
else:
8755
raise ValueError("No image! Try with another video format")
8856
frame_size = frame.shape[:2]
89-
image_data = utils.image_preprocess(np.copy(frame), [input_size, input_size])
57+
image_data = cv2.resize(frame, (input_size, input_size))
58+
image_data = image_data / 255.
9059
image_data = image_data[np.newaxis, ...].astype(np.float32)
9160
prev_time = time.time()
9261

93-
if FLAGS.framework == 'tf':
94-
pred_bbox = model.predict(image_data)
95-
else:
62+
if FLAGS.framework == 'tflite':
9663
interpreter.set_tensor(input_details[0]['index'], image_data)
9764
interpreter.invoke()
98-
pred_bbox = [interpreter.get_tensor(output_details[i]['index']) for i in range(len(output_details))]
99-
100-
if FLAGS.model == 'yolov4':
101-
pred_bbox = utils.postprocess_bbbox(pred_bbox, ANCHORS, STRIDES, XYSCALE)
65+
pred = [interpreter.get_tensor(output_details[i]['index']) for i in range(len(output_details))]
66+
if FLAGS.model == 'yolov4' and FLAGS.tiny == True:
67+
boxes, pred_conf = filter_boxes(pred[1], pred[0], score_threshold=0.25)
68+
else:
69+
boxes, pred_conf = filter_boxes(pred[0], pred[1], score_threshold=0.25)
10270
else:
103-
pred_bbox = utils.postprocess_bbbox(pred_bbox, ANCHORS, STRIDES)
104-
105-
bboxes = utils.postprocess_boxes(pred_bbox, frame_size, input_size, 0.25)
106-
bboxes = utils.nms(bboxes, 0.213, method='nms')
71+
batch_data = tf.constant(image_data)
72+
pred_bbox = infer(batch_data)
73+
for key, value in pred_bbox.items():
74+
boxes = value[:, :, 0:4]
75+
pred_conf = value[:, :, 4:]
10776

108-
image = utils.draw_bbox(frame, bboxes)
77+
boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression(
78+
boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)),
79+
scores=tf.reshape(
80+
pred_conf, (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])),
81+
max_output_size_per_class=50,
82+
max_total_size=50,
83+
iou_threshold=FLAGS.iou,
84+
score_threshold=FLAGS.score
85+
)
86+
pred_bbox = [boxes.numpy(), scores.numpy(), classes.numpy(), valid_detections.numpy()]
87+
image = utils.draw_bbox(frame, pred_bbox)
10988
curr_time = time.time()
11089
exec_time = curr_time - prev_time
11190
result = np.asarray(image)

0 commit comments

Comments
 (0)