diff --git a/requirements-dev.txt b/requirements-dev.txt index 968d2e6..6a4fd00 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,7 @@ -r requirements.txt gradio +onnx +onnxruntime pytest pytest-cov pre-commit diff --git a/tests/conftest.py b/tests/conftest.py index 540d007..9c0b42c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -48,6 +48,11 @@ def inference_v7_cfg(): return get_cfg(overrides=["task=inference", "model=v7"]) +@pytest.fixture(scope="session") +def inference_v7_onnx_cfg(): + return get_cfg(overrides=["task=inference", "model=v7", "task.fast_inference=onnx"]) + + @pytest.fixture(scope="session") def device(): return torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/tests/test_utils/test_deploy_utils.py b/tests/test_utils/test_deploy_utils.py new file mode 100644 index 0000000..5f1a0f9 --- /dev/null +++ b/tests/test_utils/test_deploy_utils.py @@ -0,0 +1,24 @@ +import pytest +import torch + +from yolo.config.config import Config +from yolo.utils.deploy_utils import FastModelLoader + + +def test_load_v7_onnx(inference_v7_onnx_cfg: Config): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = FastModelLoader(inference_v7_onnx_cfg).load_model(device) + assert hasattr(model, "num_classes") + + +def test_infer_v7_onnx(inference_v7_onnx_cfg: Config): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = FastModelLoader(inference_v7_onnx_cfg).load_model(device) + image_data = torch.zeros(1, 3, 640, 640, dtype=torch.float32) + predict = model(image_data) + assert "Main" in predict + predictions = predict["Main"] + assert len(predictions) == 3 + assert predictions[0].shape == (1, 255, 80, 80) + assert predictions[1].shape == (1, 255, 40, 40) + assert predictions[2].shape == (1, 255, 20, 20) diff --git a/yolo/utils/deploy_utils.py b/yolo/utils/deploy_utils.py index 4a0db99..5d9c5c0 100644 --- a/yolo/utils/deploy_utils.py +++ b/yolo/utils/deploy_utils.py @@ -47,8 +47,10 @@ def onnx_forward(self: InferenceSession, x: Tensor): if idx % 3 == 2: model_outputs.append(layer_output) layer_output = [] - if len(model_outputs) == 6: + if len(model_outputs) == 6: # yolov9 model_outputs = model_outputs[:3] + elif len(model_outputs) == 1: # yolov7 + model_outputs = model_outputs[0] return {"Main": model_outputs} InferenceSession.__call__ = onnx_forward @@ -60,9 +62,13 @@ def onnx_forward(self: InferenceSession, x: Tensor): try: ort_session = InferenceSession(self.model_path, providers=providers) logger.info(":rocket: Using ONNX as MODEL frameworks!") + # required by Anc2Box + ort_session.num_classes = self.class_num except Exception as e: logger.warning(f"🈳 Error loading ONNX model: {e}") ort_session = self._create_onnx_model(providers) + # required by Anc2Box + ort_session.num_classes = self.class_num return ort_session def _create_onnx_model(self, providers):