-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
306aa03
commit 42cf415
Showing
9 changed files
with
680 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
**.tflite | ||
!.gitignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
#!/usr/bin/env python3 | ||
# | ||
# Copyright 2021. | ||
# ozora-ogino | ||
|
||
from typing import Tuple | ||
|
||
import cv2 | ||
import numpy as np | ||
from tflite_runtime.interpreter import Interpreter | ||
|
||
|
||
class Detect(object): | ||
"""YOLOv5 tflite detect model.""" | ||
|
||
def __init__( | ||
self, | ||
model_file: str, | ||
conf_thr: float, | ||
): | ||
# Load model to memory. | ||
self.interpreter = Interpreter(model_file) | ||
self.interpreter.allocate_tensors() | ||
|
||
self.input_details = self.interpreter | ||
_, self.width, self.height, _ = self.interpreter.get_input_details()[0]["shape"] | ||
self.output_details = self.interpreter.get_output_details() | ||
self.conf_thr = conf_thr | ||
|
||
def detect(self, img: np.ndarray, box_type="xywh") -> Tuple[np.ndarray]: | ||
"""Detect objects. | ||
Returns: | ||
Tuple[np.ndarray]: The shape of each element is (25500, 4) (25500,) (25500,). | ||
""" | ||
img = self.preprocess(img) | ||
output_data = self._detect(img) | ||
boxes, scores, class_idx = self.postprocess(output_data, box_type) | ||
return boxes, scores, class_idx | ||
|
||
def _detect(self, img: np.ndarray): | ||
"""Inference.""" | ||
self.interpreter.set_tensor(0, img) | ||
self.interpreter.invoke() | ||
output_data = self.interpreter.get_tensor(self.output_details[0]["index"]) # get tensor x(1, 25200, 7) | ||
return output_data | ||
|
||
def preprocess(self, img: np.ndarray) -> np.ndarray: | ||
"""Preprocess.""" | ||
# Resize | ||
img = cv2.resize(img, (self.height, self.width)) | ||
# BGR -> RGB | ||
img = img[:, :, [2, 1, 0]] | ||
# Normalize | ||
img = img / 255.0 | ||
img = np.expand_dims(img, axis=0) | ||
return img.astype(np.float32) | ||
|
||
def postprocess(self, output_data, box_type: str): | ||
"""Postprocess.""" | ||
output_data = output_data[0] | ||
# xywh | ||
boxes = output_data[..., :4] | ||
conf = output_data[..., 4:5] | ||
cls = np.argmax(output_data[..., 5:], axis=1).astype(np.float32).reshape(-1, 1) | ||
|
||
conf = np.squeeze(conf, axis=1) | ||
cls = np.squeeze(cls, axis=1) | ||
|
||
if box_type == "xyxy": | ||
# xywh -> xyxyx | ||
boxes = self.to_xyxy(box_type) | ||
|
||
# Filter by confidence threshold. | ||
idxs = np.where(conf > self.conf_thr) | ||
boxes = boxes[idxs] | ||
cls = cls[idxs] | ||
conf = conf[idxs] | ||
|
||
return boxes, conf, cls | ||
|
||
def to_xyxy(self, boxes: np.ndarray) -> np.ndarray: | ||
"""Covert xywh to xyxy.""" | ||
x, y, w, h = boxes[..., 0], boxes[..., 1], boxes[..., 2], boxes[..., 3] | ||
boxes = np.array([x - w / 2, y - h / 2, x + w / 2, y + h / 2]) | ||
# [4, n] -> [n, 4] | ||
boxes = boxes.transpose((1, 0)) | ||
return boxes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
#!/usr/bin/env python3 | ||
# | ||
# Copyright 2021. | ||
# ozora-ogino | ||
|
||
import argparse | ||
import os | ||
import time | ||
|
||
import cv2 | ||
import numpy as np | ||
|
||
from detect import Detect | ||
from streams import VideoStream | ||
from tracker import Tracker | ||
|
||
|
||
def _detect_person( | ||
detect: Detect, | ||
frame: np.ndarray, | ||
confidence: float, | ||
iou_threshold: float, | ||
) -> np.ndarray: | ||
"""Detect person objects in a frame. | ||
Returns: | ||
np.ndarray: Array like [xyxy, score]. | ||
""" | ||
|
||
# Detect objects in the frame. | ||
boxes, scores, class_idx = detect.detect(frame) | ||
|
||
# NMS | ||
idx = cv2.dnn.NMSBoxes(boxes, scores, confidence, iou_threshold) | ||
boxes = boxes[idx] | ||
scores = scores[idx] | ||
class_idx = class_idx[idx] | ||
|
||
# Filter only person object (class index = 0). | ||
person_idx = np.where(class_idx == 0)[0] | ||
boxes = boxes[person_idx] | ||
scores = scores[person_idx] | ||
|
||
# Scale boxes by frame size. | ||
H, W = frame.shape[:2] | ||
boxes = detect.to_xyxy(boxes) * np.array([W, H, W, H]) | ||
|
||
# dets: [xmin, ymin, xmax, ymax, score] | ||
dets = np.concatenate([boxes.astype(int), scores.reshape(-1, 1)], axis=1) | ||
return dets | ||
|
||
|
||
def main( | ||
src: str, | ||
dest: str, | ||
model: str, | ||
video_fmt: str, | ||
confidence: float, | ||
iou_threshold: float, | ||
): | ||
"""Track human objects and count the number of human. | ||
Args: | ||
src (str): Source video. | ||
dest (str): Directory to save results. | ||
model (str): Path to tflite weight. | ||
confidence (float): Confidence threshold. | ||
iou_threshold (float): IoU threshold for NMS. | ||
""" | ||
if not os.path.exists(dest): | ||
os.mkdir(dest) | ||
|
||
# The line to count. | ||
border = [(0, 500), (1920, 500)] | ||
tracker = Tracker(border) | ||
detect = Detect(model, confidence) | ||
stream = VideoStream(src) | ||
writer = None | ||
|
||
total_frames = len(stream) | ||
if total_frames: | ||
print(f"Total frames: {len(stream)}") | ||
|
||
while True: | ||
# Read the next frame from stream. | ||
is_finish, frame = stream.next() | ||
|
||
if not is_finish: | ||
break | ||
|
||
start = time.time() | ||
dets = _detect_person(detect, frame, confidence, iou_threshold) | ||
end = time.time() | ||
|
||
# Update tracker and draw bounding boxes in frame. | ||
frame = tracker.update(frame, dets) | ||
|
||
# Executed only first time. | ||
if writer is None: | ||
# Initialize video writer. | ||
codecs = {"mp4": "MP4V", "avi": "MJPG"} | ||
output_video = os.path.join(dest, f"result.{video_fmt}") | ||
fourcc = cv2.VideoWriter_fourcc(*codecs[video_fmt]) | ||
writer = cv2.VideoWriter(output_video, fourcc, 30, (frame.shape[1], frame.shape[0]), True) | ||
|
||
# Estimate total time. | ||
second_per_frame = end - start | ||
print(f"Computation time per a frame: {second_per_frame:.4f} seconds") | ||
print(f"Estimated total time: {second_per_frame * total_frames:.4f}") | ||
|
||
# Save frame as an image and video. | ||
cv2.imwrite(os.path.join(dest, "detect.jpg"), frame) | ||
writer.write(frame) | ||
|
||
writer.release() | ||
stream.release() | ||
print("Done!") | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--src", help="Path to video source.", default="./data/TownCentreXVID.mp4") | ||
parser.add_argument("--dest", help="Path to output directory", default="./outputs/") | ||
parser.add_argument("--model", help="Path to YOLOv5 tflite file", default="./models/yolov5n6-fp16.tflite") | ||
parser.add_argument("--video-fmt", help="Format of output video file.", choices=["mp4", "avi"], default="avi") | ||
parser.add_argument("--confidence", type=float, default=0.2, help="Confidence threshold.") | ||
parser.add_argument("--iou-threshold", type=float, default=0.2, help="IoU threshold for NMS.") | ||
args = vars(parser.parse_args()) | ||
main(**args) |
Oops, something went wrong.