-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
619d56f
commit 423c00f
Showing
4 changed files
with
94 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from src.detect import Detect | ||
|
||
|
||
class TestDetect: | ||
@pytest.fixture(autouse=True) | ||
def _setup(self): | ||
self.detect = Detect("./models/yolov5n6-fp16.tflite", conf_thr=0.4) | ||
|
||
def test_detect(self): | ||
dummy_img = np.random.randn(300, 300, 3) | ||
boxes, scores, class_idx = self.detect.detect(dummy_img) | ||
assert boxes.shape[1] == 4 | ||
assert isinstance(boxes, np.ndarray) and isinstance(scores, np.ndarray) and isinstance(class_idx, np.ndarray) | ||
|
||
def test_preprocess(self): | ||
dummy_img = np.random.randn(300, 300, 3) | ||
result = self.detect.preprocess(dummy_img) | ||
assert result.shape == (1, self.detect.height, self.detect.width, 3) | ||
assert result.dtype == np.float32 | ||
|
||
def test_to_xyxy(self): | ||
xywh = np.array([[100, 100, 200, 200]]) | ||
expect = np.array([[0, 0, 200, 200]]) | ||
result = self.detect.to_xyxy(xywh) | ||
assert (result == expect).all() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from src.utils import check_direction, direction_config, is_intersect | ||
|
||
# pylint:disable=unexpected-keyword-arg | ||
|
||
|
||
class TestCheckDirection: | ||
def test_true(self): | ||
"""Test true case.""" | ||
directions = { | ||
"right": {"prev_center": [0, 0], "current_center": [20, 0], "expect": True}, | ||
"left": {"prev_center": [10, 0], "current_center": [0, 0], "expect": True}, | ||
"top": {"prev_center": [0, 10], "current_center": [0, 0], "expect": True}, | ||
"bottom": {"prev_center": [0, 0], "current_center": [0, 10], "expect": True}, | ||
} | ||
for direction_str, args in directions.items(): | ||
expect = args.pop("expect") | ||
result = check_direction(**args, direction=direction_config[direction_str]) | ||
assert result == expect | ||
|
||
def test_false(self): | ||
"""Test false case.""" | ||
directions = { | ||
"right": {"prev_center": [0, 0], "current_center": [0, 0], "expect": False}, | ||
# This is right. | ||
"left": {"prev_center": [0, 0], "current_center": [10, 0], "expect": False}, | ||
# This is bottom. | ||
"top": {"prev_center": [0, 0], "current_center": [0, 10], "expect": False}, | ||
# This is top. | ||
"bottom": {"prev_center": [0, 10], "current_center": [0, 0], "expect": False}, | ||
} | ||
for direction_str, args in directions.items(): | ||
expect = args.pop("expect") | ||
result = check_direction(**args, direction=direction_config[direction_str]) | ||
assert result == expect | ||
|
||
def test_direction_none(self): | ||
"""Check if always return true when direction is set None.""" | ||
args = [ | ||
{"prev_center": [0, 0], "current_center": [0, 0]}, # No movement. | ||
{"prev_center": [0, 0], "current_center": [10, 0]}, # Right | ||
{"prev_center": [10, 0], "current_center": [0, 0]}, # Left. | ||
{"prev_center": [0, 10], "current_center": [0, 0]}, # Top. | ||
{"prev_center": [0, 0], "current_center": [0, 10]}, # Bottom. | ||
] | ||
for arg in args: | ||
# If the direction is None, always return True. | ||
result = check_direction(**arg, direction=None) | ||
assert result == True | ||
|
||
|
||
class TestIsIntersect: | ||
def test_true(self): | ||
"""Test true case.""" | ||
args = {"A": [10, 0], "B": [10, 30], "C": [0, 10], "D": [30, 0]} | ||
result = is_intersect(**args) | ||
assert result == True | ||
|
||
def test_false(self): | ||
"""Test false case.""" | ||
args = {"A": [10, 0], "B": [10, 30], "C": [0, 10], "D": [0, 0]} | ||
result = is_intersect(**args) | ||
assert result == False |