Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] yolov7 onnx #86

Open
wants to merge 5 commits into
base: DEPLOY
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
-r requirements.txt
gradio
onnx
onnxruntime
pytest
pytest-cov
pre-commit
Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def inference_v7_cfg():
return get_cfg(overrides=["task=inference", "model=v7"])


@pytest.fixture(scope="session")
def inference_v7_onnx_cfg():
return get_cfg(overrides=["task=inference", "model=v7", "task.fast_inference=onnx"])


@pytest.fixture(scope="session")
def device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down
24 changes: 24 additions & 0 deletions tests/test_utils/test_deploy_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest
import torch

from yolo.config.config import Config
from yolo.utils.deploy_utils import FastModelLoader


def test_load_v7_onnx(inference_v7_onnx_cfg: Config):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FastModelLoader(inference_v7_onnx_cfg).load_model(device)
assert hasattr(model, "num_classes")


def test_infer_v7_onnx(inference_v7_onnx_cfg: Config):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FastModelLoader(inference_v7_onnx_cfg).load_model(device)
image_data = torch.zeros(1, 3, 640, 640, dtype=torch.float32)
predict = model(image_data)
assert "Main" in predict
predictions = predict["Main"]
assert len(predictions) == 3
assert predictions[0].shape == (1, 255, 80, 80)
assert predictions[1].shape == (1, 255, 40, 40)
assert predictions[2].shape == (1, 255, 20, 20)
8 changes: 7 additions & 1 deletion yolo/utils/deploy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ def onnx_forward(self: InferenceSession, x: Tensor):
if idx % 3 == 2:
model_outputs.append(layer_output)
layer_output = []
if len(model_outputs) == 6:
if len(model_outputs) == 6: # yolov9
model_outputs = model_outputs[:3]
elif len(model_outputs) == 1: # yolov7
model_outputs = model_outputs[0]
return {"Main": model_outputs}

InferenceSession.__call__ = onnx_forward
Expand All @@ -60,9 +62,13 @@ def onnx_forward(self: InferenceSession, x: Tensor):
try:
ort_session = InferenceSession(self.model_path, providers=providers)
logger.info(":rocket: Using ONNX as MODEL frameworks!")
# required by Anc2Box
ort_session.num_classes = self.class_num
except Exception as e:
logger.warning(f"🈳 Error loading ONNX model: {e}")
ort_session = self._create_onnx_model(providers)
# required by Anc2Box
ort_session.num_classes = self.class_num
return ort_session

def _create_onnx_model(self, providers):
Expand Down