Skip to content

Commit

Permalink
Add src
Browse files Browse the repository at this point in the history
  • Loading branch information
ozora-ogino committed Dec 1, 2021
1 parent 306aa03 commit 42cf415
Show file tree
Hide file tree
Showing 9 changed files with 680 additions and 1 deletion.
7 changes: 6 additions & 1 deletion data/video2img.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
#!/usr/bin/env python3
#
# Copyright 2021.
# ozora-ogino

import os
from argparse import ArgumentParser

Expand All @@ -24,7 +29,7 @@ def _main(video: str, save_dir: str, limit_frames: int) -> None:

# Video to images.
while success and count < limit_frames:
cv2.imwrite(os.path.join(save_dir, f"frame{count}.jpg"), image)
cv2.imwrite(os.path.join(save_dir, f"frame{str(count).zfill(5)}.jpg"), image)
success, image = vidcap.read()
print("Read a new frame: ", success)
count += 1
Expand Down
2 changes: 2 additions & 0 deletions models/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
**.tflite
!.gitignore
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ protected-access,
unnecessary-lambda,
logging-format-interpolation,
broad-except
bare-except
"""
enable = "C0303" # Trailing whitespace
87 changes: 87 additions & 0 deletions src/detect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#!/usr/bin/env python3
#
# Copyright 2021.
# ozora-ogino

from typing import Tuple

import cv2
import numpy as np
from tflite_runtime.interpreter import Interpreter


class Detect(object):
"""YOLOv5 tflite detect model."""

def __init__(
self,
model_file: str,
conf_thr: float,
):
# Load model to memory.
self.interpreter = Interpreter(model_file)
self.interpreter.allocate_tensors()

self.input_details = self.interpreter
_, self.width, self.height, _ = self.interpreter.get_input_details()[0]["shape"]
self.output_details = self.interpreter.get_output_details()
self.conf_thr = conf_thr

def detect(self, img: np.ndarray, box_type="xywh") -> Tuple[np.ndarray]:
"""Detect objects.
Returns:
Tuple[np.ndarray]: The shape of each element is (25500, 4) (25500,) (25500,).
"""
img = self.preprocess(img)
output_data = self._detect(img)
boxes, scores, class_idx = self.postprocess(output_data, box_type)
return boxes, scores, class_idx

def _detect(self, img: np.ndarray):
"""Inference."""
self.interpreter.set_tensor(0, img)
self.interpreter.invoke()
output_data = self.interpreter.get_tensor(self.output_details[0]["index"]) # get tensor x(1, 25200, 7)
return output_data

def preprocess(self, img: np.ndarray) -> np.ndarray:
"""Preprocess."""
# Resize
img = cv2.resize(img, (self.height, self.width))
# BGR -> RGB
img = img[:, :, [2, 1, 0]]
# Normalize
img = img / 255.0
img = np.expand_dims(img, axis=0)
return img.astype(np.float32)

def postprocess(self, output_data, box_type: str):
"""Postprocess."""
output_data = output_data[0]
# xywh
boxes = output_data[..., :4]
conf = output_data[..., 4:5]
cls = np.argmax(output_data[..., 5:], axis=1).astype(np.float32).reshape(-1, 1)

conf = np.squeeze(conf, axis=1)
cls = np.squeeze(cls, axis=1)

if box_type == "xyxy":
# xywh -> xyxyx
boxes = self.to_xyxy(box_type)

# Filter by confidence threshold.
idxs = np.where(conf > self.conf_thr)
boxes = boxes[idxs]
cls = cls[idxs]
conf = conf[idxs]

return boxes, conf, cls

def to_xyxy(self, boxes: np.ndarray) -> np.ndarray:
"""Covert xywh to xyxy."""
x, y, w, h = boxes[..., 0], boxes[..., 1], boxes[..., 2], boxes[..., 3]
boxes = np.array([x - w / 2, y - h / 2, x + w / 2, y + h / 2])
# [4, n] -> [n, 4]
boxes = boxes.transpose((1, 0))
return boxes
130 changes: 130 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#!/usr/bin/env python3
#
# Copyright 2021.
# ozora-ogino

import argparse
import os
import time

import cv2
import numpy as np

from detect import Detect
from streams import VideoStream
from tracker import Tracker


def _detect_person(
detect: Detect,
frame: np.ndarray,
confidence: float,
iou_threshold: float,
) -> np.ndarray:
"""Detect person objects in a frame.
Returns:
np.ndarray: Array like [xyxy, score].
"""

# Detect objects in the frame.
boxes, scores, class_idx = detect.detect(frame)

# NMS
idx = cv2.dnn.NMSBoxes(boxes, scores, confidence, iou_threshold)
boxes = boxes[idx]
scores = scores[idx]
class_idx = class_idx[idx]

# Filter only person object (class index = 0).
person_idx = np.where(class_idx == 0)[0]
boxes = boxes[person_idx]
scores = scores[person_idx]

# Scale boxes by frame size.
H, W = frame.shape[:2]
boxes = detect.to_xyxy(boxes) * np.array([W, H, W, H])

# dets: [xmin, ymin, xmax, ymax, score]
dets = np.concatenate([boxes.astype(int), scores.reshape(-1, 1)], axis=1)
return dets


def main(
src: str,
dest: str,
model: str,
video_fmt: str,
confidence: float,
iou_threshold: float,
):
"""Track human objects and count the number of human.
Args:
src (str): Source video.
dest (str): Directory to save results.
model (str): Path to tflite weight.
confidence (float): Confidence threshold.
iou_threshold (float): IoU threshold for NMS.
"""
if not os.path.exists(dest):
os.mkdir(dest)

# The line to count.
border = [(0, 500), (1920, 500)]
tracker = Tracker(border)
detect = Detect(model, confidence)
stream = VideoStream(src)
writer = None

total_frames = len(stream)
if total_frames:
print(f"Total frames: {len(stream)}")

while True:
# Read the next frame from stream.
is_finish, frame = stream.next()

if not is_finish:
break

start = time.time()
dets = _detect_person(detect, frame, confidence, iou_threshold)
end = time.time()

# Update tracker and draw bounding boxes in frame.
frame = tracker.update(frame, dets)

# Executed only first time.
if writer is None:
# Initialize video writer.
codecs = {"mp4": "MP4V", "avi": "MJPG"}
output_video = os.path.join(dest, f"result.{video_fmt}")
fourcc = cv2.VideoWriter_fourcc(*codecs[video_fmt])
writer = cv2.VideoWriter(output_video, fourcc, 30, (frame.shape[1], frame.shape[0]), True)

# Estimate total time.
second_per_frame = end - start
print(f"Computation time per a frame: {second_per_frame:.4f} seconds")
print(f"Estimated total time: {second_per_frame * total_frames:.4f}")

# Save frame as an image and video.
cv2.imwrite(os.path.join(dest, "detect.jpg"), frame)
writer.write(frame)

writer.release()
stream.release()
print("Done!")


if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument("--src", help="Path to video source.", default="./data/TownCentreXVID.mp4")
parser.add_argument("--dest", help="Path to output directory", default="./outputs/")
parser.add_argument("--model", help="Path to YOLOv5 tflite file", default="./models/yolov5n6-fp16.tflite")
parser.add_argument("--video-fmt", help="Format of output video file.", choices=["mp4", "avi"], default="avi")
parser.add_argument("--confidence", type=float, default=0.2, help="Confidence threshold.")
parser.add_argument("--iou-threshold", type=float, default=0.2, help="IoU threshold for NMS.")
args = vars(parser.parse_args())
main(**args)
Loading

0 comments on commit 42cf415

Please sign in to comment.