Skip to content

Commit bc4d1a2

Browse files
committed
Add tests about v7 onnx
1 parent 2190148 commit bc4d1a2

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ def inference_v7_cfg():
4747
return get_cfg(overrides=["task=inference", "model=v7"])
4848

4949

50+
@pytest.fixture(scope="session")
51+
def inference_v7_onnx_cfg():
52+
return get_cfg(overrides=["task=inference", "model=v7", "task.fast_inference=onnx"])
53+
54+
5055
@pytest.fixture(scope="session")
5156
def device():
5257
return torch.device("cuda" if torch.cuda.is_available() else "cpu")

tests/test_utils/test_deploy_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import pytest
2+
import torch
3+
4+
from yolo.config.config import Config
5+
from yolo.utils.deploy_utils import FastModelLoader
6+
7+
8+
def test_load_v7_onnx(inference_v7_onnx_cfg: Config):
9+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10+
model = FastModelLoader(inference_v7_onnx_cfg).load_model(device)
11+
assert hasattr(model, "num_classes")
12+
13+
14+
def test_infer_v7_onnx(inference_v7_onnx_cfg: Config):
15+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16+
model = FastModelLoader(inference_v7_onnx_cfg).load_model(device)
17+
image_data = torch.zeros(1, 3, 640, 640, dtype=torch.float32)
18+
predict = model(image_data)
19+
assert "Main" in predict
20+
predictions = predict["Main"]
21+
assert len(predictions) == 3
22+
assert predictions[0].shape == (1, 255, 80, 80)
23+
assert predictions[1].shape == (1, 255, 40, 40)
24+
assert predictions[2].shape == (1, 255, 20, 20)

0 commit comments

Comments
 (0)