From c11433b049bb1ef4f6f50e3f947cf64a47f98dd5 Mon Sep 17 00:00:00 2001 From: Ramon Date: Tue, 4 Feb 2025 08:53:59 +0100 Subject: [PATCH 01/15] WIP exports --- yolo/config/task/export.yaml | 3 +++ yolo/lazy.py | 5 ++++- yolo/tools/solver.py | 22 +++++++++++++++++- yolo/utils/export_utils.py | 43 ++++++++++++++++++++++++++++++++++++ 4 files changed, 71 insertions(+), 2 deletions(-) create mode 100644 yolo/config/task/export.yaml create mode 100644 yolo/utils/export_utils.py diff --git a/yolo/config/task/export.yaml b/yolo/config/task/export.yaml new file mode 100644 index 00000000..0256be30 --- /dev/null +++ b/yolo/config/task/export.yaml @@ -0,0 +1,3 @@ +task: export + +format: onnx diff --git a/yolo/lazy.py b/yolo/lazy.py index 0f1cc55b..00c90921 100644 --- a/yolo/lazy.py +++ b/yolo/lazy.py @@ -8,7 +8,7 @@ sys.path.append(str(project_root)) from yolo.config.config import Config -from yolo.tools.solver import InferenceModel, TrainModel, ValidateModel +from yolo.tools.solver import ExportModel, InferenceModel, TrainModel, ValidateModel from yolo.utils.logging_utils import setup @@ -39,6 +39,9 @@ def main(cfg: Config): if cfg.task.task == "inference": model = InferenceModel(cfg) trainer.predict(model) + if cfg.task.task == "export": + model = ExportModel(cfg) + model.export() if __name__ == "__main__": diff --git a/yolo/tools/solver.py b/yolo/tools/solver.py index 8246a66d..9a34b7da 100644 --- a/yolo/tools/solver.py +++ b/yolo/tools/solver.py @@ -10,8 +10,9 @@ from yolo.tools.drawer import draw_bboxes from yolo.tools.loss_functions import create_loss_function from yolo.utils.bounding_box_utils import create_converter, to_metrics_format +from yolo.utils.export_utils import ModelExporter from yolo.utils.model_utils import PostProcess, create_optimizer, create_scheduler - +from yolo.utils.logger import logger class BaseModel(LightningModule): def __init__(self, cfg: Config): @@ -139,3 +140,22 @@ def _save_image(self, img, batch_idx): save_image_path = Path(self.trainer.default_root_dir) / f"frame{batch_idx:03d}.png" img.save(save_image_path) print(f"💾 Saved visualize image at {save_image_path}") + + +class ExportModel(BaseModel): + def __init__(self, cfg: Config): + super().__init__(cfg) + self.cfg = cfg + self.format = cfg.task.format + self.model_exporter = ModelExporter(self.cfg, self.model) + + def export(self): + if self.format == 'onnx': + self.model_exporter.export_onnx() + if self.format == 'tflite': + self.model_exporter.export_flite() + if self.format == 'coreml': + self.model_exporter.export_coreml() + + + diff --git a/yolo/utils/export_utils.py b/yolo/utils/export_utils.py new file mode 100644 index 00000000..ef71edd7 --- /dev/null +++ b/yolo/utils/export_utils.py @@ -0,0 +1,43 @@ + +from yolo.config.config import Config +from yolo.model.yolo import YOLO +from yolo.utils.logger import logger +from pathlib import Path + +class ModelExporter(): + def __init__(self, cfg: Config, model: YOLO): + self.model = model + self.cfg = cfg + self.class_num = cfg.dataset.class_num + self.format = self.cfg.task.format + if cfg.weight == True: + cfg.weight = Path("weights") / f"{cfg.model.name}.pt" + self.model_path = f"{Path(self.cfg.weight).stem}.{self.format}" + + def export_onnx(self): + logger.info(f":package: Exporting model to onnx format") + import torch + dummy_input = torch.ones((1, 3, *self.cfg.image_size)) + + # TODO move duplicated export code also used in fast inference to a separate file + torch.onnx.export( + self.model, + dummy_input, + self.model_path, + input_names=["input"], + output_names=["output"], + dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, + ) + + def export_flite(self): + logger.info(f":package: Exporting model to tflite format") + logger.info(f":construction: Not implemented yet") + + def export_coreml(self): + logger.info(f":package: Exporting model to coreml format") + logger.info(f":construction: Not implemented yet") + import torch + dummy_input = torch.ones((1, 3, *self.cfg.image_size)) + traced_model = torch.jit.trace(self.model, dummy_input) + out = traced_model(dummy_input) + \ No newline at end of file From 6554034cdece6d19aa66dc24ec415a5c6efb3a81 Mon Sep 17 00:00:00 2001 From: Ramon Date: Mon, 17 Feb 2025 09:07:18 +0100 Subject: [PATCH 02/15] =?UTF-8?q?=F0=9F=8D=8F=20Add=20CoreMl=20Export?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- yolo/model/yolo.py | 17 +++++++++++++---- yolo/tools/solver.py | 7 ++++--- yolo/utils/export_utils.py | 24 +++++++++++++++++++----- 3 files changed, 36 insertions(+), 12 deletions(-) diff --git a/yolo/model/yolo.py b/yolo/model/yolo.py index cc9ce20b..918a3eb1 100644 --- a/yolo/model/yolo.py +++ b/yolo/model/yolo.py @@ -21,8 +21,9 @@ class YOLO(nn.Module): parameters, and any other relevant configuration details. """ - def __init__(self, model_cfg: ModelConfig, class_num: int = 80): + def __init__(self, model_cfg: ModelConfig, class_num: int = 80, export : bool =False): super(YOLO, self).__init__() + self.export = export self.num_classes = class_num self.layer_map = get_layer_map() # Get the map Dict[str: Module] self.model: List[YOLOLayer] = nn.ModuleList() @@ -80,8 +81,16 @@ def forward(self, x): y[-1] = x if layer.usable: y[index] = x - if layer.output: + + # On export we want to trace the model with torch.jit.trace + # Dicts and tuples are not supported by torch.jit.trace + # This is possible because we have one output tag on export (Main) + if layer.output and self.export == True: + output = x + output = [tensor for tpl in output for tensor in tpl] + elif layer.output: output[layer.tags] = x + return output def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]): @@ -152,7 +161,7 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]): self.model.load_state_dict(model_state_dict) -def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO: +def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80, export : bool = False) -> YOLO: """Constructs and returns a model from a Dictionary configuration file. Args: @@ -162,7 +171,7 @@ def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, YOLO: An instance of the model defined by the given configuration. """ OmegaConf.set_struct(model_cfg, False) - model = YOLO(model_cfg, class_num) + model = YOLO(model_cfg, class_num, export=export) if weight_path: if weight_path == True: weight_path = Path("weights") / f"{model_cfg.name}.pt" diff --git a/yolo/tools/solver.py b/yolo/tools/solver.py index 9a34b7da..0fbacae6 100644 --- a/yolo/tools/solver.py +++ b/yolo/tools/solver.py @@ -15,9 +15,9 @@ from yolo.utils.logger import logger class BaseModel(LightningModule): - def __init__(self, cfg: Config): + def __init__(self, cfg: Config, export: bool = False): super().__init__() - self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight) + self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight, export=export) def forward(self, x): return self.model(x) @@ -144,7 +144,8 @@ def _save_image(self, img, batch_idx): class ExportModel(BaseModel): def __init__(self, cfg: Config): - super().__init__(cfg) + cfg.model.model.auxiliary = {} + super().__init__(cfg, export=True) self.cfg = cfg self.format = cfg.task.format self.model_exporter = ModelExporter(self.cfg, self.model) diff --git a/yolo/utils/export_utils.py b/yolo/utils/export_utils.py index ef71edd7..84fc1c7d 100644 --- a/yolo/utils/export_utils.py +++ b/yolo/utils/export_utils.py @@ -19,25 +19,39 @@ def export_onnx(self): import torch dummy_input = torch.ones((1, 3, *self.cfg.image_size)) + onnx_model_path = f"{Path(self.cfg.weight).stem}.onnx" + # TODO move duplicated export code also used in fast inference to a separate file torch.onnx.export( self.model, dummy_input, - self.model_path, + onnx_model_path, input_names=["input"], output_names=["output"], - dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, + dynamic_axes=None #{"input": {0: "batch_size"}, "output": {0: "batch_size"}}, ) + return onnx_model_path def export_flite(self): logger.info(f":package: Exporting model to tflite format") logger.info(f":construction: Not implemented yet") def export_coreml(self): logger.info(f":package: Exporting model to coreml format") - logger.info(f":construction: Not implemented yet") + import torch dummy_input = torch.ones((1, 3, *self.cfg.image_size)) + + self.model.eval() traced_model = torch.jit.trace(self.model, dummy_input) - out = traced_model(dummy_input) - \ No newline at end of file + + import coremltools as ct + import logging + logging.getLogger("coremltools").disabled = True + model_from_trace = ct.convert( + traced_model, + inputs=[ct.TensorType(shape=dummy_input.shape)], + convert_to="neuralnetwork", + ) + model_from_trace.save(f"{Path(self.cfg.weight).stem}.mlmodel") + logger.info(f":white_check_mark: Model exported to coreml format") From 61895c28e4be267e330cce032eec95065c7ba8b4 Mon Sep 17 00:00:00 2001 From: Ramon Date: Wed, 19 Feb 2025 09:49:20 +0100 Subject: [PATCH 03/15] =?UTF-8?q?=F0=9F=8D=8F=20Add=20scripts=20for=20Core?= =?UTF-8?q?ML=20and=20ONNX=20export,=20refactor=20YOLO=20model=20forward?= =?UTF-8?q?=20loop,=20undo=20export=20param,=20use=20FastModelLoader=20in?= =?UTF-8?q?=20InferenceModel?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- yolo/model/yolo.py | 25 +++++++------- yolo/tools/solver.py | 20 ++++++++--- yolo/utils/deploy_utils.py | 68 ++++++++++++++++++++++++++++---------- yolo/utils/export_utils.py | 46 ++++++++++++++++---------- 4 files changed, 107 insertions(+), 52 deletions(-) diff --git a/yolo/model/yolo.py b/yolo/model/yolo.py index 918a3eb1..088216ec 100644 --- a/yolo/model/yolo.py +++ b/yolo/model/yolo.py @@ -21,9 +21,8 @@ class YOLO(nn.Module): parameters, and any other relevant configuration details. """ - def __init__(self, model_cfg: ModelConfig, class_num: int = 80, export : bool =False): + def __init__(self, model_cfg: ModelConfig, class_num: int = 80): super(YOLO, self).__init__() - self.export = export self.num_classes = class_num self.layer_map = get_layer_map() # Get the map Dict[str: Module] self.model: List[YOLOLayer] = nn.ModuleList() @@ -72,25 +71,27 @@ def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]): def forward(self, x): y = {0: x} output = dict() - for index, layer in enumerate(self.model, start=1): + + # Use a simple loop instead of enumerate() + # Needed for torch export compatibility + index = 1 + for layer in self.model: if isinstance(layer.source, list): model_input = [y[idx] for idx in layer.source] else: model_input = y[layer.source] + x = layer(model_input) y[-1] = x + if layer.usable: y[index] = x - # On export we want to trace the model with torch.jit.trace - # Dicts and tuples are not supported by torch.jit.trace - # This is possible because we have one output tag on export (Main) - if layer.output and self.export == True: - output = x - output = [tensor for tpl in output for tensor in tpl] - elif layer.output: + if layer.output: output[layer.tags] = x + index += 1 + return output def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]): @@ -161,7 +162,7 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]): self.model.load_state_dict(model_state_dict) -def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80, export : bool = False) -> YOLO: +def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO: """Constructs and returns a model from a Dictionary configuration file. Args: @@ -171,7 +172,7 @@ def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, YOLO: An instance of the model defined by the given configuration. """ OmegaConf.set_struct(model_cfg, False) - model = YOLO(model_cfg, class_num, export=export) + model = YOLO(model_cfg, class_num) if weight_path: if weight_path == True: weight_path = Path("weights") / f"{model_cfg.name}.pt" diff --git a/yolo/tools/solver.py b/yolo/tools/solver.py index 0fbacae6..0a48f61c 100644 --- a/yolo/tools/solver.py +++ b/yolo/tools/solver.py @@ -10,6 +10,7 @@ from yolo.tools.drawer import draw_bboxes from yolo.tools.loss_functions import create_loss_function from yolo.utils.bounding_box_utils import create_converter, to_metrics_format +from yolo.utils.deploy_utils import FastModelLoader from yolo.utils.export_utils import ModelExporter from yolo.utils.model_utils import PostProcess, create_optimizer, create_scheduler from yolo.utils.logger import logger @@ -17,7 +18,7 @@ class BaseModel(LightningModule): def __init__(self, cfg: Config, export: bool = False): super().__init__() - self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight, export=export) + self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight) def forward(self, x): return self.model(x) @@ -110,15 +111,20 @@ def configure_optimizers(self): class InferenceModel(BaseModel): def __init__(self, cfg: Config): + cfg.model.model.auxiliary = {} super().__init__(cfg) + # super().__init__(cfg) self.cfg = cfg - # TODO: Add FastModel self.predict_loader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task) def setup(self, stage): self.vec2box = create_converter( self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device ) + + if self.cfg.task.fast_inference: + self.fast_model = FastModelLoader(self.cfg, self.model).load_model(self.device) + self.post_process = PostProcess(self.vec2box, self.cfg.task.nms) def predict_dataloader(self): @@ -126,7 +132,11 @@ def predict_dataloader(self): def predict_step(self, batch, batch_idx): images, rev_tensor, origin_frame = batch - predicts = self.post_process(self(images), rev_tensor=rev_tensor) + if self.fast_model: + predictions = self.fast_model(images) + else: + predictions = self(images) + predicts = self.post_process(predictions, rev_tensor=rev_tensor) img = draw_bboxes(origin_frame, predicts, idx2label=self.cfg.dataset.class_list) if getattr(self.predict_loader, "is_stream", None): fps = self._display_stream(img) @@ -145,10 +155,10 @@ def _save_image(self, img, batch_idx): class ExportModel(BaseModel): def __init__(self, cfg: Config): cfg.model.model.auxiliary = {} - super().__init__(cfg, export=True) + super().__init__(cfg) self.cfg = cfg self.format = cfg.task.format - self.model_exporter = ModelExporter(self.cfg, self.model) + self.model_exporter = ModelExporter(self.cfg, self.model, format=self.format) def export(self): if self.format == 'onnx': diff --git a/yolo/utils/deploy_utils.py b/yolo/utils/deploy_utils.py index 4a0db991..bb35006b 100644 --- a/yolo/utils/deploy_utils.py +++ b/yolo/utils/deploy_utils.py @@ -4,23 +4,30 @@ from torch import Tensor from yolo.config.config import Config -from yolo.model.yolo import create_model +from yolo.model.yolo import YOLO, create_model +from yolo.utils.export_utils import ModelExporter from yolo.utils.logger import logger class FastModelLoader: - def __init__(self, cfg: Config): + def __init__(self, cfg: Config, model: YOLO): self.cfg = cfg - self.compiler = cfg.task.fast_inference + self.model = model + self.compiler : str = cfg.task.fast_inference self.class_num = cfg.dataset.class_num self._validate_compiler() if cfg.weight == True: cfg.weight = Path("weights") / f"{cfg.model.name}.pt" - self.model_path = f"{Path(cfg.weight).stem}.{self.compiler}" + + extension : str = self.compiler + if self.compiler == "coreml": + extension = "mlpackage" + + self.model_path = f"{Path(cfg.weight).stem}.{extension}" def _validate_compiler(self): - if self.compiler not in ["onnx", "trt", "deploy"]: + if self.compiler not in ["onnx", "trt", "deploy", "coreml"]: logger.warning(f":warning: Compiler '{self.compiler}' is not supported. Using original model.") self.compiler = None if self.cfg.device == "mps" and self.compiler == "trt": @@ -30,6 +37,8 @@ def _validate_compiler(self): def load_model(self, device): if self.compiler == "onnx": return self._load_onnx_model(device) + elif self.compiler == "coreml": + return self._load_coreml_model(device) elif self.compiler == "trt": return self._load_trt_model().to(device) elif self.compiler == "deploy": @@ -37,6 +46,9 @@ def load_model(self, device): return create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).to(device) def _load_onnx_model(self, device): + + # TODO install onnxruntime or onnxruntime-gpu if needed + from onnxruntime import InferenceSession def onnx_forward(self: InferenceSession, x: Tensor): @@ -55,6 +67,8 @@ def onnx_forward(self: InferenceSession, x: Tensor): if device == "cpu": providers = ["CPUExecutionProvider"] + elif device == "coreml": + providers = ["CoreMLExecutionProvider"] else: providers = ["CUDAExecutionProvider"] try: @@ -67,21 +81,39 @@ def onnx_forward(self: InferenceSession, x: Tensor): def _create_onnx_model(self, providers): from onnxruntime import InferenceSession - from torch.onnx import export - - model = create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).eval() - dummy_input = torch.ones((1, 3, *self.cfg.image_size)) - export( - model, - dummy_input, - self.model_path, - input_names=["input"], - output_names=["output"], - dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, - ) - logger.info(f":inbox_tray: ONNX model saved to {self.model_path}") + model_exporter = ModelExporter(self.cfg, self.model, format='onnx', model_path=self.model_path) + model_exporter.export_onnx(dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}) return InferenceSession(self.model_path, providers=providers) + def _load_coreml_model(self, device): + from coremltools import models + + def coreml_forward(self, x: Tensor): + x = x.cpu().numpy() + model_outputs, layer_output = [], [] + predictions = self.predict({"x": x}) + for idx, key in enumerate(sorted(predictions.keys())): + layer_output.append(torch.from_numpy(predictions[key]).to(device)) + if idx % 3 == 2: + model_outputs.append(layer_output) + layer_output = [] + return {"Main": model_outputs} + + models.MLModel.__call__ = coreml_forward + + try: + model_coreml = models.MLModel(self.model_path) + logger.info(":rocket: Using CoreML as MODEL frameworks!") + except FileNotFoundError: + logger.warning(f"🈳 No found model weight at {self.model_path}") + model_coreml = self._create_coreml_model() + + return model_coreml + + def _create_coreml_model(self): + model_exporter = ModelExporter(self.cfg, self.model, format='coreml', model_path=self.model_path) + model_exporter.export_coreml() + def _load_trt_model(self): from torch2trt import TRTModule diff --git a/yolo/utils/export_utils.py b/yolo/utils/export_utils.py index 84fc1c7d..0516f1d3 100644 --- a/yolo/utils/export_utils.py +++ b/yolo/utils/export_utils.py @@ -1,36 +1,41 @@ +from typing import Dict, List, Optional from yolo.config.config import Config from yolo.model.yolo import YOLO from yolo.utils.logger import logger from pathlib import Path class ModelExporter(): - def __init__(self, cfg: Config, model: YOLO): + def __init__(self, cfg: Config, model: YOLO, format: str, model_path: Optional[str] = None): self.model = model self.cfg = cfg self.class_num = cfg.dataset.class_num - self.format = self.cfg.task.format + self.format = format if cfg.weight == True: cfg.weight = Path("weights") / f"{cfg.model.name}.pt" - self.model_path = f"{Path(self.cfg.weight).stem}.{self.format}" + if model_path: + self.model_path = model_path + else: + self.model_path = f"{Path(self.cfg.weight).stem}.{self.format}" - def export_onnx(self): + def export_onnx(self, dynamic_axes : Optional[Dict[str, Dict[int, str]]] = None): logger.info(f":package: Exporting model to onnx format") import torch dummy_input = torch.ones((1, 3, *self.cfg.image_size)) onnx_model_path = f"{Path(self.cfg.weight).stem}.onnx" - # TODO move duplicated export code also used in fast inference to a separate file torch.onnx.export( self.model, dummy_input, onnx_model_path, input_names=["input"], output_names=["output"], - dynamic_axes=None #{"input": {0: "batch_size"}, "output": {0: "batch_size"}}, + dynamic_axes=dynamic_axes, ) + logger.info(f":inbox_tray: ONNX model saved to {self.model_path}") + return onnx_model_path def export_flite(self): logger.info(f":package: Exporting model to tflite format") @@ -40,18 +45,25 @@ def export_coreml(self): logger.info(f":package: Exporting model to coreml format") import torch - dummy_input = torch.ones((1, 3, *self.cfg.image_size)) self.model.eval() - traced_model = torch.jit.trace(self.model, dummy_input) - - import coremltools as ct + example_inputs = (torch.rand(1, 3, *self.cfg.image_size),) + exported_program = torch.export.export(self.model, example_inputs) + import logging + import coremltools as ct + + # Convert to Core ML program using the Unified Conversion API. logging.getLogger("coremltools").disabled = True - model_from_trace = ct.convert( - traced_model, - inputs=[ct.TensorType(shape=dummy_input.shape)], - convert_to="neuralnetwork", - ) - model_from_trace.save(f"{Path(self.cfg.weight).stem}.mlmodel") - logger.info(f":white_check_mark: Model exported to coreml format") + + output_names : List[str] = [ + "1_class_scores_small", "2_box_features_small", "3_bbox_deltas_small", + "4_class_scores_medium", "5_box_features_medium", "6_bbox_deltas_medium", + "7_class_scores_large", "8_box_features_large", "9_bbox_deltas_large" + ] + + model_from_export = ct.convert(exported_program, + outputs=[ct.TensorType(name=name) for name in output_names]) + + model_from_export.save(f"{Path(self.cfg.weight).stem}.mlpackage") + logger.info(f":white_check_mark: Model exported to coreml format") \ No newline at end of file From 50a62797c9c96f8074ab36c8c7c4151f9757fb82 Mon Sep 17 00:00:00 2001 From: Ramon Date: Thu, 20 Feb 2025 07:51:06 +0100 Subject: [PATCH 04/15] =?UTF-8?q?=F0=9F=A7=A0=20Add=20TFLite=20export=20fu?= =?UTF-8?q?nctionality=20and=20update=20model=20loading=20logic?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- yolo/tools/solver.py | 2 +- yolo/utils/deploy_utils.py | 56 ++++++++++++++++++++++++++++++++++++-- yolo/utils/export_utils.py | 34 +++++++++++++++++++---- 3 files changed, 83 insertions(+), 9 deletions(-) diff --git a/yolo/tools/solver.py b/yolo/tools/solver.py index 0a48f61c..f3f53d75 100644 --- a/yolo/tools/solver.py +++ b/yolo/tools/solver.py @@ -164,7 +164,7 @@ def export(self): if self.format == 'onnx': self.model_exporter.export_onnx() if self.format == 'tflite': - self.model_exporter.export_flite() + self.model_exporter.export_tflite() if self.format == 'coreml': self.model_exporter.export_coreml() diff --git a/yolo/utils/deploy_utils.py b/yolo/utils/deploy_utils.py index bb35006b..685d1c35 100644 --- a/yolo/utils/deploy_utils.py +++ b/yolo/utils/deploy_utils.py @@ -27,7 +27,7 @@ def __init__(self, cfg: Config, model: YOLO): self.model_path = f"{Path(cfg.weight).stem}.{extension}" def _validate_compiler(self): - if self.compiler not in ["onnx", "trt", "deploy", "coreml"]: + if self.compiler not in ["onnx", "trt", "deploy", "coreml", "tflite"]: logger.warning(f":warning: Compiler '{self.compiler}' is not supported. Using original model.") self.compiler = None if self.cfg.device == "mps" and self.compiler == "trt": @@ -37,6 +37,8 @@ def _validate_compiler(self): def load_model(self, device): if self.compiler == "onnx": return self._load_onnx_model(device) + if self.compiler == "tflite": + return self._load_tflite_model(device) elif self.compiler == "coreml": return self._load_coreml_model(device) elif self.compiler == "trt": @@ -45,6 +47,49 @@ def load_model(self, device): self.cfg.model.model.auxiliary = {} return create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).to(device) + def _load_tflite_model(self, device): + + if not Path(self.model_path).exists(): + self._create_tflite_model() + + from ai_edge_litert.interpreter import Interpreter + + try: + interpreter = Interpreter(model_path=self.model_path) + interpreter.allocate_tensors() + logger.info(":rocket: Using TensorFlow Lite as MODEL framework!") + except Exception as e: + logger.warning(f"🈳 Error loading TFLite model: {e}") + interpreter = self._create_tflite_model() + + def tflite_forward(self: Interpreter, x: Tensor): + + # Get input & output tensor details + input_details = self.get_input_details() + output_details = sorted(self.get_output_details(), key=lambda d: d['name']) # Sort by 'name' + + # Convert input tensor to NumPy and assign it to the model + x_numpy = x.cpu().numpy() + self.set_tensor(input_details[0]['index'], x_numpy) + + model_outputs, layer_output = [], [] + x_numpy = x.cpu().numpy() + self.set_tensor(input_details[0]['index'], x_numpy) + self.invoke() + for idx, output_detail in enumerate(output_details): + predict = self.get_tensor(output_detail['index']) + layer_output.append(torch.from_numpy(predict).to(device)) + if idx % 3 == 2: + model_outputs.append(layer_output) + layer_output = [] + if len(model_outputs) == 6: + model_outputs = model_outputs[:3] + return {"Main": model_outputs} + + Interpreter.__call__ = tflite_forward + + return interpreter + def _load_onnx_model(self, device): # TODO install onnxruntime or onnxruntime-gpu if needed @@ -101,15 +146,22 @@ def coreml_forward(self, x: Tensor): models.MLModel.__call__ = coreml_forward + if not Path(self.model_path).exists(): + self._create_coreml_model() + try: model_coreml = models.MLModel(self.model_path) logger.info(":rocket: Using CoreML as MODEL frameworks!") except FileNotFoundError: logger.warning(f"🈳 No found model weight at {self.model_path}") - model_coreml = self._create_coreml_model() + return None return model_coreml + def _create_tflite_model(self): + model_exporter = ModelExporter(self.cfg, self.model, format='tflite', model_path=self.model_path) + model_exporter.export_tflite() + def _create_coreml_model(self): model_exporter = ModelExporter(self.cfg, self.model, format='coreml', model_path=self.model_path) model_exporter.export_coreml() diff --git a/yolo/utils/export_utils.py b/yolo/utils/export_utils.py index 0516f1d3..a83635d9 100644 --- a/yolo/utils/export_utils.py +++ b/yolo/utils/export_utils.py @@ -18,28 +18,50 @@ def __init__(self, cfg: Config, model: YOLO, format: str, model_path: Optional[s else: self.model_path = f"{Path(self.cfg.weight).stem}.{self.format}" - def export_onnx(self, dynamic_axes : Optional[Dict[str, Dict[int, str]]] = None): + def export_onnx(self, dynamic_axes : Optional[Dict[str, Dict[int, str]]] = None, model_path: Optional[str] = None): logger.info(f":package: Exporting model to onnx format") import torch dummy_input = torch.ones((1, 3, *self.cfg.image_size)) - onnx_model_path = f"{Path(self.cfg.weight).stem}.onnx" + if model_path: + onnx_model_path = model_path + else: + onnx_model_path = self.model_path + + # onnx_model_path = f"{Path(self.cfg.weight).stem}.onnx" + output_names : List[str] = [ + "1_class_scores_small", "2_box_features_small", "3_bbox_deltas_small", + "4_class_scores_medium", "5_box_features_medium", "6_bbox_deltas_medium", + "7_class_scores_large", "8_box_features_large", "9_bbox_deltas_large" + ] torch.onnx.export( self.model, dummy_input, onnx_model_path, input_names=["input"], - output_names=["output"], + output_names=output_names, dynamic_axes=dynamic_axes, ) - logger.info(f":inbox_tray: ONNX model saved to {self.model_path}") + + + logger.info(f":inbox_tray: ONNX model saved to {onnx_model_path}") return onnx_model_path - def export_flite(self): + + def export_tflite(self): logger.info(f":package: Exporting model to tflite format") - logger.info(f":construction: Not implemented yet") + + import torch + self.model.eval() + example_inputs = (torch.rand(1, 3, *self.cfg.image_size),) + + import ai_edge_torch + edge_model = ai_edge_torch.convert(self.model, example_inputs) + edge_model.export(self.model_path) + + logger.info(f":white_check_mark: Model exported to tflite format") def export_coreml(self): logger.info(f":package: Exporting model to coreml format") From 5246d42f46bb7561160144db9527b341f15733f7 Mon Sep 17 00:00:00 2001 From: Ramon Date: Thu, 20 Feb 2025 08:36:45 +0100 Subject: [PATCH 05/15] =?UTF-8?q?=F0=9F=90=9B=20Bugfix=20pipeline?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- yolo/tools/solver.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/yolo/tools/solver.py b/yolo/tools/solver.py index f3f53d75..73b6c5aa 100644 --- a/yolo/tools/solver.py +++ b/yolo/tools/solver.py @@ -111,7 +111,8 @@ def configure_optimizers(self): class InferenceModel(BaseModel): def __init__(self, cfg: Config): - cfg.model.model.auxiliary = {} + if not hasattr(cfg.model.model, 'auxiliary'): + cfg.model.model.auxiliary = {} super().__init__(cfg) # super().__init__(cfg) self.cfg = cfg @@ -154,7 +155,8 @@ def _save_image(self, img, batch_idx): class ExportModel(BaseModel): def __init__(self, cfg: Config): - cfg.model.model.auxiliary = {} + if not hasattr(cfg.model.model, 'auxiliary'): + cfg.model.model.auxiliary = {} super().__init__(cfg) self.cfg = cfg self.format = cfg.task.format From a5cbf063eeef842f26617748703d43cdbf4ab037 Mon Sep 17 00:00:00 2001 From: Ramon Date: Thu, 20 Feb 2025 09:58:34 +0100 Subject: [PATCH 06/15] =?UTF-8?q?=F0=9F=90=9B=20Bugfix=20pipeline?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- yolo/tools/solver.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/yolo/tools/solver.py b/yolo/tools/solver.py index 73b6c5aa..2083c245 100644 --- a/yolo/tools/solver.py +++ b/yolo/tools/solver.py @@ -111,7 +111,7 @@ def configure_optimizers(self): class InferenceModel(BaseModel): def __init__(self, cfg: Config): - if not hasattr(cfg.model.model, 'auxiliary'): + if hasattr(cfg.model.model, 'auxiliary'): cfg.model.model.auxiliary = {} super().__init__(cfg) # super().__init__(cfg) @@ -133,7 +133,7 @@ def predict_dataloader(self): def predict_step(self, batch, batch_idx): images, rev_tensor, origin_frame = batch - if self.fast_model: + if hasattr(self, 'fast_model') and self.fast_model: predictions = self.fast_model(images) else: predictions = self(images) @@ -155,7 +155,7 @@ def _save_image(self, img, batch_idx): class ExportModel(BaseModel): def __init__(self, cfg: Config): - if not hasattr(cfg.model.model, 'auxiliary'): + if hasattr(cfg.model.model, 'auxiliary'): cfg.model.model.auxiliary = {} super().__init__(cfg) self.cfg = cfg From 3d4c28e3e28ada18524891418f36ec0dfb9df21c Mon Sep 17 00:00:00 2001 From: Ramon Date: Thu, 20 Feb 2025 14:49:50 +0100 Subject: [PATCH 07/15] =?UTF-8?q?=F0=9F=93=9C=20Sort=20imports?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- yolo/tools/solver.py | 2 +- yolo/utils/export_utils.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/yolo/tools/solver.py b/yolo/tools/solver.py index 38bd2f93..a1903689 100644 --- a/yolo/tools/solver.py +++ b/yolo/tools/solver.py @@ -12,8 +12,8 @@ from yolo.utils.bounding_box_utils import create_converter, to_metrics_format from yolo.utils.deploy_utils import FastModelLoader from yolo.utils.export_utils import ModelExporter -from yolo.utils.model_utils import PostProcess, create_optimizer, create_scheduler from yolo.utils.logger import logger +from yolo.utils.model_utils import PostProcess, create_optimizer, create_scheduler class BaseModel(LightningModule): def __init__(self, cfg: Config, export: bool = False): diff --git a/yolo/utils/export_utils.py b/yolo/utils/export_utils.py index a83635d9..ceb6982e 100644 --- a/yolo/utils/export_utils.py +++ b/yolo/utils/export_utils.py @@ -1,9 +1,8 @@ - +from pathlib import Path from typing import Dict, List, Optional from yolo.config.config import Config from yolo.model.yolo import YOLO from yolo.utils.logger import logger -from pathlib import Path class ModelExporter(): def __init__(self, cfg: Config, model: YOLO, format: str, model_path: Optional[str] = None): From 7f7d286afc57a291077fa43d59a756316ed833ab Mon Sep 17 00:00:00 2001 From: Ramon Date: Thu, 20 Feb 2025 14:54:09 +0100 Subject: [PATCH 08/15] =?UTF-8?q?=F0=9F=93=9C=20Sort=20imports?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- yolo/utils/export_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yolo/utils/export_utils.py b/yolo/utils/export_utils.py index ceb6982e..b0d54f2f 100644 --- a/yolo/utils/export_utils.py +++ b/yolo/utils/export_utils.py @@ -1,5 +1,5 @@ -from pathlib import Path from typing import Dict, List, Optional +from pathlib import Path from yolo.config.config import Config from yolo.model.yolo import YOLO from yolo.utils.logger import logger From f13d1df24ca9813bb2a375172faad0ae9cd1d969 Mon Sep 17 00:00:00 2001 From: Ramon Date: Thu, 20 Feb 2025 15:00:18 +0100 Subject: [PATCH 09/15] =?UTF-8?q?=F0=9F=A7=B9=20Clean=20code=20formats?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- yolo/tools/solver.py | 16 +++++------ yolo/utils/deploy_utils.py | 25 +++++++++-------- yolo/utils/export_utils.py | 57 ++++++++++++++++++++++++-------------- 3 files changed, 56 insertions(+), 42 deletions(-) diff --git a/yolo/tools/solver.py b/yolo/tools/solver.py index a1903689..83dd5228 100644 --- a/yolo/tools/solver.py +++ b/yolo/tools/solver.py @@ -15,6 +15,7 @@ from yolo.utils.logger import logger from yolo.utils.model_utils import PostProcess, create_optimizer, create_scheduler + class BaseModel(LightningModule): def __init__(self, cfg: Config, export: bool = False): super().__init__() @@ -111,7 +112,7 @@ def configure_optimizers(self): class InferenceModel(BaseModel): def __init__(self, cfg: Config): - if hasattr(cfg.model.model, 'auxiliary'): + if hasattr(cfg.model.model, "auxiliary"): cfg.model.model.auxiliary = {} super().__init__(cfg) # super().__init__(cfg) @@ -133,7 +134,7 @@ def predict_dataloader(self): def predict_step(self, batch, batch_idx): images, rev_tensor, origin_frame = batch - if hasattr(self, 'fast_model') and self.fast_model: + if hasattr(self, "fast_model") and self.fast_model: predictions = self.fast_model(images) else: predictions = self(images) @@ -155,7 +156,7 @@ def _save_image(self, img, batch_idx): class ExportModel(BaseModel): def __init__(self, cfg: Config): - if hasattr(cfg.model.model, 'auxiliary'): + if hasattr(cfg.model.model, "auxiliary"): cfg.model.model.auxiliary = {} super().__init__(cfg) self.cfg = cfg @@ -163,12 +164,9 @@ def __init__(self, cfg: Config): self.model_exporter = ModelExporter(self.cfg, self.model, format=self.format) def export(self): - if self.format == 'onnx': + if self.format == "onnx": self.model_exporter.export_onnx() - if self.format == 'tflite': + if self.format == "tflite": self.model_exporter.export_tflite() - if self.format == 'coreml': + if self.format == "coreml": self.model_exporter.export_coreml() - - - diff --git a/yolo/utils/deploy_utils.py b/yolo/utils/deploy_utils.py index 685d1c35..6a3c138f 100644 --- a/yolo/utils/deploy_utils.py +++ b/yolo/utils/deploy_utils.py @@ -13,14 +13,14 @@ class FastModelLoader: def __init__(self, cfg: Config, model: YOLO): self.cfg = cfg self.model = model - self.compiler : str = cfg.task.fast_inference + self.compiler: str = cfg.task.fast_inference self.class_num = cfg.dataset.class_num self._validate_compiler() if cfg.weight == True: cfg.weight = Path("weights") / f"{cfg.model.name}.pt" - extension : str = self.compiler + extension: str = self.compiler if self.compiler == "coreml": extension = "mlpackage" @@ -48,10 +48,10 @@ def load_model(self, device): return create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).to(device) def _load_tflite_model(self, device): - + if not Path(self.model_path).exists(): self._create_tflite_model() - + from ai_edge_litert.interpreter import Interpreter try: @@ -66,18 +66,18 @@ def tflite_forward(self: Interpreter, x: Tensor): # Get input & output tensor details input_details = self.get_input_details() - output_details = sorted(self.get_output_details(), key=lambda d: d['name']) # Sort by 'name' + output_details = sorted(self.get_output_details(), key=lambda d: d["name"]) # Sort by 'name' # Convert input tensor to NumPy and assign it to the model x_numpy = x.cpu().numpy() - self.set_tensor(input_details[0]['index'], x_numpy) + self.set_tensor(input_details[0]["index"], x_numpy) model_outputs, layer_output = [], [] x_numpy = x.cpu().numpy() - self.set_tensor(input_details[0]['index'], x_numpy) + self.set_tensor(input_details[0]["index"], x_numpy) self.invoke() for idx, output_detail in enumerate(output_details): - predict = self.get_tensor(output_detail['index']) + predict = self.get_tensor(output_detail["index"]) layer_output.append(torch.from_numpy(predict).to(device)) if idx % 3 == 2: model_outputs.append(layer_output) @@ -126,7 +126,8 @@ def onnx_forward(self: InferenceSession, x: Tensor): def _create_onnx_model(self, providers): from onnxruntime import InferenceSession - model_exporter = ModelExporter(self.cfg, self.model, format='onnx', model_path=self.model_path) + + model_exporter = ModelExporter(self.cfg, self.model, format="onnx", model_path=self.model_path) model_exporter.export_onnx(dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}) return InferenceSession(self.model_path, providers=providers) @@ -157,13 +158,13 @@ def coreml_forward(self, x: Tensor): return None return model_coreml - + def _create_tflite_model(self): - model_exporter = ModelExporter(self.cfg, self.model, format='tflite', model_path=self.model_path) + model_exporter = ModelExporter(self.cfg, self.model, format="tflite", model_path=self.model_path) model_exporter.export_tflite() def _create_coreml_model(self): - model_exporter = ModelExporter(self.cfg, self.model, format='coreml', model_path=self.model_path) + model_exporter = ModelExporter(self.cfg, self.model, format="coreml", model_path=self.model_path) model_exporter.export_coreml() def _load_trt_model(self): diff --git a/yolo/utils/export_utils.py b/yolo/utils/export_utils.py index b0d54f2f..06004262 100644 --- a/yolo/utils/export_utils.py +++ b/yolo/utils/export_utils.py @@ -1,10 +1,12 @@ -from typing import Dict, List, Optional from pathlib import Path +from typing import Dict, List, Optional + from yolo.config.config import Config from yolo.model.yolo import YOLO from yolo.utils.logger import logger -class ModelExporter(): + +class ModelExporter: def __init__(self, cfg: Config, model: YOLO, format: str, model_path: Optional[str] = None): self.model = model self.cfg = cfg @@ -17,9 +19,10 @@ def __init__(self, cfg: Config, model: YOLO, format: str, model_path: Optional[s else: self.model_path = f"{Path(self.cfg.weight).stem}.{self.format}" - def export_onnx(self, dynamic_axes : Optional[Dict[str, Dict[int, str]]] = None, model_path: Optional[str] = None): + def export_onnx(self, dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None, model_path: Optional[str] = None): logger.info(f":package: Exporting model to onnx format") import torch + dummy_input = torch.ones((1, 3, *self.cfg.image_size)) if model_path: @@ -28,10 +31,16 @@ def export_onnx(self, dynamic_axes : Optional[Dict[str, Dict[int, str]]] = None, onnx_model_path = self.model_path # onnx_model_path = f"{Path(self.cfg.weight).stem}.onnx" - output_names : List[str] = [ - "1_class_scores_small", "2_box_features_small", "3_bbox_deltas_small", - "4_class_scores_medium", "5_box_features_medium", "6_bbox_deltas_medium", - "7_class_scores_large", "8_box_features_large", "9_bbox_deltas_large" + output_names: List[str] = [ + "1_class_scores_small", + "2_box_features_small", + "3_bbox_deltas_small", + "4_class_scores_medium", + "5_box_features_medium", + "6_bbox_deltas_medium", + "7_class_scores_large", + "8_box_features_large", + "9_bbox_deltas_large", ] torch.onnx.export( @@ -43,20 +52,20 @@ def export_onnx(self, dynamic_axes : Optional[Dict[str, Dict[int, str]]] = None, dynamic_axes=dynamic_axes, ) - - logger.info(f":inbox_tray: ONNX model saved to {onnx_model_path}") return onnx_model_path - + def export_tflite(self): logger.info(f":package: Exporting model to tflite format") import torch + self.model.eval() example_inputs = (torch.rand(1, 3, *self.cfg.image_size),) import ai_edge_torch + edge_model = ai_edge_torch.convert(self.model, example_inputs) edge_model.export(self.model_path) @@ -66,25 +75,31 @@ def export_coreml(self): logger.info(f":package: Exporting model to coreml format") import torch - + self.model.eval() example_inputs = (torch.rand(1, 3, *self.cfg.image_size),) exported_program = torch.export.export(self.model, example_inputs) import logging + import coremltools as ct # Convert to Core ML program using the Unified Conversion API. logging.getLogger("coremltools").disabled = True - - output_names : List[str] = [ - "1_class_scores_small", "2_box_features_small", "3_bbox_deltas_small", - "4_class_scores_medium", "5_box_features_medium", "6_bbox_deltas_medium", - "7_class_scores_large", "8_box_features_large", "9_bbox_deltas_large" + + output_names: List[str] = [ + "1_class_scores_small", + "2_box_features_small", + "3_bbox_deltas_small", + "4_class_scores_medium", + "5_box_features_medium", + "6_bbox_deltas_medium", + "7_class_scores_large", + "8_box_features_large", + "9_bbox_deltas_large", ] - - model_from_export = ct.convert(exported_program, - outputs=[ct.TensorType(name=name) for name in output_names]) - + + model_from_export = ct.convert(exported_program, outputs=[ct.TensorType(name=name) for name in output_names]) + model_from_export.save(f"{Path(self.cfg.weight).stem}.mlpackage") - logger.info(f":white_check_mark: Model exported to coreml format") \ No newline at end of file + logger.info(f":white_check_mark: Model exported to coreml format") From 718d28c377b178ac7306f831a0b2fe50a5c4c582 Mon Sep 17 00:00:00 2001 From: Ramon Date: Sun, 23 Feb 2025 14:58:45 +0100 Subject: [PATCH 10/15] =?UTF-8?q?=F0=9F=8D=8F=20Integrate=20post=20process?= =?UTF-8?q?ing=20in=20model=20for=20coreml=20exports?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- yolo/model/yolo.py | 60 ++++++++++++++++++++++++++++++-- yolo/tools/solver.py | 25 +++++++++---- yolo/utils/bounding_box_utils.py | 7 ++-- yolo/utils/deploy_utils.py | 24 ++++++------- yolo/utils/export_utils.py | 57 ++++++++++++++---------------- yolo/utils/model_utils.py | 10 ++++-- 6 files changed, 126 insertions(+), 57 deletions(-) diff --git a/yolo/model/yolo.py b/yolo/model/yolo.py index 634f9e5a..8567fb11 100644 --- a/yolo/model/yolo.py +++ b/yolo/model/yolo.py @@ -21,12 +21,13 @@ class YOLO(nn.Module): parameters, and any other relevant configuration details. """ - def __init__(self, model_cfg: ModelConfig, class_num: int = 80): + def __init__(self, model_cfg: ModelConfig, class_num: int = 80, export_mode : bool =False): super(YOLO, self).__init__() self.num_classes = class_num self.layer_map = get_layer_map() # Get the map Dict[str: Module] self.model: List[YOLOLayer] = nn.ModuleList() self.reg_max = getattr(model_cfg.anchor, "reg_max", 16) + self.export_mode = export_mode self.build_model(model_cfg.model) def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]): @@ -68,7 +69,37 @@ def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]): setattr(layer, "out_c", out_channels) layer_idx += 1 + def generate_anchors(self, image_size: List[int], strides: List[int]): + W, H = image_size + anchors = [] + scaler = [] + for stride in strides: + anchor_num = W // stride * H // stride + scaler.append(torch.full((anchor_num,), stride)) + shift = stride // 2 + h = torch.arange(0, H, stride) + shift + w = torch.arange(0, W, stride) + shift + if torch.__version__ >= "2.3.0": + anchor_h, anchor_w = torch.meshgrid(h, w, indexing="ij") + else: + anchor_h, anchor_w = torch.meshgrid(h, w) + anchor = torch.stack([anchor_w.flatten(), anchor_h.flatten()], dim=-1) + anchors.append(anchor) + all_anchors = torch.cat(anchors, dim=0) + all_scalers = torch.cat(scaler, dim=0) + return all_anchors, all_scalers + + def get_strides(self, output, input_width) -> List[int]: + W = input_width + strides = [] + for predict_head in output: + _, _, *anchor_num = predict_head[2].shape + strides.append(W // anchor_num[1]) + + return strides + def forward(self, x, external: Optional[Dict] = None, shortcut: Optional[str] = None): + input_width, input_height = x.shape[-2:] y = {0: x, **(external or {})} output = dict() @@ -96,6 +127,29 @@ def forward(self, x, external: Optional[Dict] = None, shortcut: Optional[str] = index += 1 + if self.export_mode: + + preds_cls, preds_anc, preds_box = [], [], [] + for layer_output in output['Main']: + pred_cls, pred_anc, pred_box = layer_output + preds_cls.append(pred_cls.permute(0, 2, 3, 1).reshape(pred_cls.shape[0], -1, pred_cls.shape[1])) + preds_anc.append(pred_anc.permute(0, 3, 4, 1, 2).reshape(pred_anc.shape[0], -1, pred_anc.shape[2], pred_anc.shape[1])) + preds_box.append(pred_box.permute(0, 2, 3, 1).reshape(pred_box.shape[0], -1, pred_box.shape[1])) + + preds_cls = torch.concat(preds_cls, dim=1).to(x[0][0].device) + preds_anc = torch.concat(preds_anc, dim=1).to(x[0][0].device) + preds_box = torch.concat(preds_box, dim=1).to(x[0][0].device) + + strides = self.get_strides(output['Main'], input_width) + anchor_grid, scaler = self.generate_anchors([input_width,input_height], strides) # + anchor_grid = anchor_grid.to(x[0][0].device) + scaler = scaler.to(x[0][0].device) + pred_LTRB = preds_box * scaler.view(1, -1, 1) + lt, rb = pred_LTRB.chunk(2, dim=-1) + preds_box = torch.cat([anchor_grid - lt, anchor_grid + rb], dim=-1) + + return preds_cls, preds_anc, preds_box + return output def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]): @@ -167,7 +221,7 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]): self.model.load_state_dict(model_state_dict) -def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO: +def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80, export_mode: bool = False) -> YOLO: """Constructs and returns a model from a Dictionary configuration file. Args: @@ -177,7 +231,7 @@ def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, YOLO: An instance of the model defined by the given configuration. """ OmegaConf.set_struct(model_cfg, False) - model = YOLO(model_cfg, class_num) + model = YOLO(model_cfg, class_num, export_mode=export_mode) if weight_path: if weight_path == True: weight_path = Path("weights") / f"{model_cfg.name}.pt" diff --git a/yolo/tools/solver.py b/yolo/tools/solver.py index 83dd5228..a75ec9e2 100644 --- a/yolo/tools/solver.py +++ b/yolo/tools/solver.py @@ -17,9 +17,9 @@ class BaseModel(LightningModule): - def __init__(self, cfg: Config, export: bool = False): + def __init__(self, cfg: Config, export_mode: bool = False): super().__init__() - self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight) + self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight, export_mode=export_mode) def forward(self, x): return self.model(x) @@ -114,8 +114,14 @@ class InferenceModel(BaseModel): def __init__(self, cfg: Config): if hasattr(cfg.model.model, "auxiliary"): cfg.model.model.auxiliary = {} - super().__init__(cfg) - # super().__init__(cfg) + + export_mode = False + fast_inference = cfg.task.fast_inference + # TODO check if we can use export mode for all formats + if fast_inference == "coreml": + export_mode = True + + super().__init__(cfg, export_mode=export_mode) self.cfg = cfg self.predict_loader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task) @@ -158,9 +164,16 @@ class ExportModel(BaseModel): def __init__(self, cfg: Config): if hasattr(cfg.model.model, "auxiliary"): cfg.model.model.auxiliary = {} - super().__init__(cfg) + + export_mode = False + format = cfg.task.format + # TODO check if we can use export mode for all formats + if self.format == "coreml": + export_mode = True + + super().__init__(cfg, export_mode=export_mode) self.cfg = cfg - self.format = cfg.task.format + self.format = format self.model_exporter = ModelExporter(self.cfg, self.model, format=self.format) def export(self): diff --git a/yolo/utils/bounding_box_utils.py b/yolo/utils/bounding_box_utils.py index 0357bfdc..335aa59c 100644 --- a/yolo/utils/bounding_box_utils.py +++ b/yolo/utils/bounding_box_utils.py @@ -342,13 +342,14 @@ def __init__(self, model: YOLO, anchor_cfg: AnchorConfig, image_size, device): if hasattr(anchor_cfg, "strides"): logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}") self.strides = anchor_cfg.strides - else: + elif not model.export_mode: logger.info(":teddy_bear: Found no stride of model, performed a dummy test for auto-anchor size") self.strides = self.create_auto_anchor(model, image_size) - anchor_grid, scaler = generate_anchors(image_size, self.strides) self.image_size = image_size - self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device) + if not model.export_mode: + anchor_grid, scaler = generate_anchors(image_size, self.strides) + self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device) def create_auto_anchor(self, model: YOLO, image_size): W, H = image_size diff --git a/yolo/utils/deploy_utils.py b/yolo/utils/deploy_utils.py index 6a3c138f..d0b66f7a 100644 --- a/yolo/utils/deploy_utils.py +++ b/yolo/utils/deploy_utils.py @@ -20,11 +20,11 @@ def __init__(self, cfg: Config, model: YOLO): if cfg.weight == True: cfg.weight = Path("weights") / f"{cfg.model.name}.pt" - extension: str = self.compiler - if self.compiler == "coreml": - extension = "mlpackage" + extention = self.compiler + if self.compiler == 'coreml': + extention = 'mlpackage' - self.model_path = f"{Path(cfg.weight).stem}.{extension}" + self.model_path = f"{Path(cfg.weight).stem}.{extention}" def _validate_compiler(self): if self.compiler not in ["onnx", "trt", "deploy", "coreml", "tflite"]: @@ -136,18 +136,18 @@ def _load_coreml_model(self, device): def coreml_forward(self, x: Tensor): x = x.cpu().numpy() - model_outputs, layer_output = [], [] + model_outputs = [] predictions = self.predict({"x": x}) - for idx, key in enumerate(sorted(predictions.keys())): - layer_output.append(torch.from_numpy(predictions[key]).to(device)) - if idx % 3 == 2: - model_outputs.append(layer_output) - layer_output = [] - return {"Main": model_outputs} + + output_keys = ['preds_cls', 'preds_anc', 'preds_box'] + for key in output_keys: + model_outputs.append(torch.from_numpy(predictions[key]).to(device)) + + return model_outputs models.MLModel.__call__ = coreml_forward - if not Path(self.model_path).exists(): + if True or not Path(self.model_path).exists(): self._create_coreml_model() try: diff --git a/yolo/utils/export_utils.py b/yolo/utils/export_utils.py index 06004262..2a10414d 100644 --- a/yolo/utils/export_utils.py +++ b/yolo/utils/export_utils.py @@ -14,24 +14,17 @@ def __init__(self, cfg: Config, model: YOLO, format: str, model_path: Optional[s self.format = format if cfg.weight == True: cfg.weight = Path("weights") / f"{cfg.model.name}.pt" - if model_path: - self.model_path = model_path - else: - self.model_path = f"{Path(self.cfg.weight).stem}.{self.format}" - - def export_onnx(self, dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None, model_path: Optional[str] = None): - logger.info(f":package: Exporting model to onnx format") - import torch - - dummy_input = torch.ones((1, 3, *self.cfg.image_size)) if model_path: - onnx_model_path = model_path + self.model_path = model_path else: - onnx_model_path = self.model_path + extention = self.format + if self.format == 'coreml': + extention = 'mlpackage' + + self.model_path = f"{Path(self.cfg.weight).stem}.{extention}" - # onnx_model_path = f"{Path(self.cfg.weight).stem}.onnx" - output_names: List[str] = [ + self.output_names: List[str] = [ "1_class_scores_small", "2_box_features_small", "3_bbox_deltas_small", @@ -43,12 +36,27 @@ def export_onnx(self, dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None, "9_bbox_deltas_large", ] + self.output_names: List[str] = [ + "preds_cls", "preds_anc", "preds_box" + ] + + def export_onnx(self, dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None, model_path: Optional[str] = None): + logger.info(f":package: Exporting model to onnx format") + import torch + + dummy_input = torch.ones((1, 3, *self.cfg.image_size)) + + if model_path: + onnx_model_path = model_path + else: + onnx_model_path = self.model_path + torch.onnx.export( self.model, dummy_input, onnx_model_path, input_names=["input"], - output_names=output_names, + output_names=self.output_names, dynamic_axes=dynamic_axes, ) @@ -81,25 +89,12 @@ def export_coreml(self): exported_program = torch.export.export(self.model, example_inputs) import logging - import coremltools as ct # Convert to Core ML program using the Unified Conversion API. logging.getLogger("coremltools").disabled = True - output_names: List[str] = [ - "1_class_scores_small", - "2_box_features_small", - "3_bbox_deltas_small", - "4_class_scores_medium", - "5_box_features_medium", - "6_bbox_deltas_medium", - "7_class_scores_large", - "8_box_features_large", - "9_bbox_deltas_large", - ] - - model_from_export = ct.convert(exported_program, outputs=[ct.TensorType(name=name) for name in output_names]) + model_from_export = ct.convert(exported_program, outputs=[ct.TensorType(name=name) for name in self.output_names], convert_to="mlprogram") - model_from_export.save(f"{Path(self.cfg.weight).stem}.mlpackage") - logger.info(f":white_check_mark: Model exported to coreml format") + model_from_export.save(self.model_path) + logger.info(f":white_check_mark: Model exported to coreml format {self.model_path}") diff --git a/yolo/utils/model_utils.py b/yolo/utils/model_utils.py index 9d6c0ce5..6c08259c 100644 --- a/yolo/utils/model_utils.py +++ b/yolo/utils/model_utils.py @@ -173,9 +173,15 @@ def __call__( ) -> List[Tensor]: if image_size is not None: self.converter.update(image_size) - prediction = self.converter(predict["Main"]) - pred_class, _, pred_bbox = prediction[:3] + + if isinstance(predict, dict): + prediction = self.converter(predict["Main"]) + else: + prediction = predict + + pred_class, _, pred_bbox = predict[:3] pred_conf = prediction[3] if len(prediction) == 4 else None + if rev_tensor is not None: pred_bbox = (pred_bbox - rev_tensor[:, None, 1:]) / rev_tensor[:, 0:1, None] pred_bbox = bbox_nms(pred_class, pred_bbox, self.nms, pred_conf) From 89ea875e267bcdd8758b1f724fb181d0cd70a458 Mon Sep 17 00:00:00 2001 From: Ramon Date: Sun, 23 Feb 2025 15:00:20 +0100 Subject: [PATCH 11/15] =?UTF-8?q?=F0=9F=8D=8F=20Integrate=20post=20process?= =?UTF-8?q?ing=20in=20model=20for=20coreml=20exports?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- yolo/model/yolo.py | 26 +++++++++++++++----------- yolo/tools/solver.py | 4 +++- yolo/utils/deploy_utils.py | 8 ++++---- yolo/utils/export_utils.py | 15 ++++++++------- yolo/utils/model_utils.py | 4 ++-- 5 files changed, 32 insertions(+), 25 deletions(-) diff --git a/yolo/model/yolo.py b/yolo/model/yolo.py index 8567fb11..b3977a07 100644 --- a/yolo/model/yolo.py +++ b/yolo/model/yolo.py @@ -21,7 +21,7 @@ class YOLO(nn.Module): parameters, and any other relevant configuration details. """ - def __init__(self, model_cfg: ModelConfig, class_num: int = 80, export_mode : bool =False): + def __init__(self, model_cfg: ModelConfig, class_num: int = 80, export_mode: bool = False): super(YOLO, self).__init__() self.num_classes = class_num self.layer_map = get_layer_map() # Get the map Dict[str: Module] @@ -88,14 +88,14 @@ def generate_anchors(self, image_size: List[int], strides: List[int]): all_anchors = torch.cat(anchors, dim=0) all_scalers = torch.cat(scaler, dim=0) return all_anchors, all_scalers - + def get_strides(self, output, input_width) -> List[int]: W = input_width strides = [] for predict_head in output: _, _, *anchor_num = predict_head[2].shape strides.append(W // anchor_num[1]) - + return strides def forward(self, x, external: Optional[Dict] = None, shortcut: Optional[str] = None): @@ -130,24 +130,26 @@ def forward(self, x, external: Optional[Dict] = None, shortcut: Optional[str] = if self.export_mode: preds_cls, preds_anc, preds_box = [], [], [] - for layer_output in output['Main']: + for layer_output in output["Main"]: pred_cls, pred_anc, pred_box = layer_output preds_cls.append(pred_cls.permute(0, 2, 3, 1).reshape(pred_cls.shape[0], -1, pred_cls.shape[1])) - preds_anc.append(pred_anc.permute(0, 3, 4, 1, 2).reshape(pred_anc.shape[0], -1, pred_anc.shape[2], pred_anc.shape[1])) + preds_anc.append( + pred_anc.permute(0, 3, 4, 1, 2).reshape(pred_anc.shape[0], -1, pred_anc.shape[2], pred_anc.shape[1]) + ) preds_box.append(pred_box.permute(0, 2, 3, 1).reshape(pred_box.shape[0], -1, pred_box.shape[1])) - + preds_cls = torch.concat(preds_cls, dim=1).to(x[0][0].device) preds_anc = torch.concat(preds_anc, dim=1).to(x[0][0].device) preds_box = torch.concat(preds_box, dim=1).to(x[0][0].device) - - strides = self.get_strides(output['Main'], input_width) - anchor_grid, scaler = self.generate_anchors([input_width,input_height], strides) # + + strides = self.get_strides(output["Main"], input_width) + anchor_grid, scaler = self.generate_anchors([input_width, input_height], strides) # anchor_grid = anchor_grid.to(x[0][0].device) scaler = scaler.to(x[0][0].device) pred_LTRB = preds_box * scaler.view(1, -1, 1) lt, rb = pred_LTRB.chunk(2, dim=-1) preds_box = torch.cat([anchor_grid - lt, anchor_grid + rb], dim=-1) - + return preds_cls, preds_anc, preds_box return output @@ -221,7 +223,9 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]): self.model.load_state_dict(model_state_dict) -def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80, export_mode: bool = False) -> YOLO: +def create_model( + model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80, export_mode: bool = False +) -> YOLO: """Constructs and returns a model from a Dictionary configuration file. Args: diff --git a/yolo/tools/solver.py b/yolo/tools/solver.py index a75ec9e2..03bbcd27 100644 --- a/yolo/tools/solver.py +++ b/yolo/tools/solver.py @@ -19,7 +19,9 @@ class BaseModel(LightningModule): def __init__(self, cfg: Config, export_mode: bool = False): super().__init__() - self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight, export_mode=export_mode) + self.model = create_model( + cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight, export_mode=export_mode + ) def forward(self, x): return self.model(x) diff --git a/yolo/utils/deploy_utils.py b/yolo/utils/deploy_utils.py index d0b66f7a..6d4c507b 100644 --- a/yolo/utils/deploy_utils.py +++ b/yolo/utils/deploy_utils.py @@ -21,8 +21,8 @@ def __init__(self, cfg: Config, model: YOLO): cfg.weight = Path("weights") / f"{cfg.model.name}.pt" extention = self.compiler - if self.compiler == 'coreml': - extention = 'mlpackage' + if self.compiler == "coreml": + extention = "mlpackage" self.model_path = f"{Path(cfg.weight).stem}.{extention}" @@ -139,10 +139,10 @@ def coreml_forward(self, x: Tensor): model_outputs = [] predictions = self.predict({"x": x}) - output_keys = ['preds_cls', 'preds_anc', 'preds_box'] + output_keys = ["preds_cls", "preds_anc", "preds_box"] for key in output_keys: model_outputs.append(torch.from_numpy(predictions[key]).to(device)) - + return model_outputs models.MLModel.__call__ = coreml_forward diff --git a/yolo/utils/export_utils.py b/yolo/utils/export_utils.py index 2a10414d..42f2aaa3 100644 --- a/yolo/utils/export_utils.py +++ b/yolo/utils/export_utils.py @@ -19,9 +19,9 @@ def __init__(self, cfg: Config, model: YOLO, format: str, model_path: Optional[s self.model_path = model_path else: extention = self.format - if self.format == 'coreml': - extention = 'mlpackage' - + if self.format == "coreml": + extention = "mlpackage" + self.model_path = f"{Path(self.cfg.weight).stem}.{extention}" self.output_names: List[str] = [ @@ -36,9 +36,7 @@ def __init__(self, cfg: Config, model: YOLO, format: str, model_path: Optional[s "9_bbox_deltas_large", ] - self.output_names: List[str] = [ - "preds_cls", "preds_anc", "preds_box" - ] + self.output_names: List[str] = ["preds_cls", "preds_anc", "preds_box"] def export_onnx(self, dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None, model_path: Optional[str] = None): logger.info(f":package: Exporting model to onnx format") @@ -89,12 +87,15 @@ def export_coreml(self): exported_program = torch.export.export(self.model, example_inputs) import logging + import coremltools as ct # Convert to Core ML program using the Unified Conversion API. logging.getLogger("coremltools").disabled = True - model_from_export = ct.convert(exported_program, outputs=[ct.TensorType(name=name) for name in self.output_names], convert_to="mlprogram") + model_from_export = ct.convert( + exported_program, outputs=[ct.TensorType(name=name) for name in self.output_names], convert_to="mlprogram" + ) model_from_export.save(self.model_path) logger.info(f":white_check_mark: Model exported to coreml format {self.model_path}") diff --git a/yolo/utils/model_utils.py b/yolo/utils/model_utils.py index 6c08259c..194447cd 100644 --- a/yolo/utils/model_utils.py +++ b/yolo/utils/model_utils.py @@ -178,10 +178,10 @@ def __call__( prediction = self.converter(predict["Main"]) else: prediction = predict - + pred_class, _, pred_bbox = predict[:3] pred_conf = prediction[3] if len(prediction) == 4 else None - + if rev_tensor is not None: pred_bbox = (pred_bbox - rev_tensor[:, None, 1:]) / rev_tensor[:, 0:1, None] pred_bbox = bbox_nms(pred_class, pred_bbox, self.nms, pred_conf) From 3bfa80437b3bbc8e428191d67148f4eff9648915 Mon Sep 17 00:00:00 2001 From: Ramon Date: Sun, 23 Feb 2025 15:51:59 +0100 Subject: [PATCH 12/15] =?UTF-8?q?=F0=9F=90=9B=20Fix=20pipeline?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- yolo/utils/model_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yolo/utils/model_utils.py b/yolo/utils/model_utils.py index 194447cd..a6d4bc65 100644 --- a/yolo/utils/model_utils.py +++ b/yolo/utils/model_utils.py @@ -179,7 +179,7 @@ def __call__( else: prediction = predict - pred_class, _, pred_bbox = predict[:3] + pred_class, _, pred_bbox = prediction[:3] pred_conf = prediction[3] if len(prediction) == 4 else None if rev_tensor is not None: From 114beef0c0b18f6bc73f2d5c6750e505716d5eba Mon Sep 17 00:00:00 2001 From: Ramon Date: Sun, 23 Feb 2025 20:21:33 +0100 Subject: [PATCH 13/15] =?UTF-8?q?=F0=9F=A7=B9=20Correct=20output=5Fnames?= =?UTF-8?q?=20for=20coreml=20exports?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- yolo/utils/export_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/yolo/utils/export_utils.py b/yolo/utils/export_utils.py index 42f2aaa3..906bc568 100644 --- a/yolo/utils/export_utils.py +++ b/yolo/utils/export_utils.py @@ -36,8 +36,6 @@ def __init__(self, cfg: Config, model: YOLO, format: str, model_path: Optional[s "9_bbox_deltas_large", ] - self.output_names: List[str] = ["preds_cls", "preds_anc", "preds_box"] - def export_onnx(self, dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None, model_path: Optional[str] = None): logger.info(f":package: Exporting model to onnx format") import torch @@ -93,6 +91,8 @@ def export_coreml(self): # Convert to Core ML program using the Unified Conversion API. logging.getLogger("coremltools").disabled = True + self.output_names: List[str] = ["preds_cls", "preds_anc", "preds_box"] + model_from_export = ct.convert( exported_program, outputs=[ct.TensorType(name=name) for name in self.output_names], convert_to="mlprogram" ) From 116624c29dac0257be15784233784c8d1c29d8f1 Mon Sep 17 00:00:00 2001 From: Ramon Date: Sun, 23 Feb 2025 21:00:20 +0100 Subject: [PATCH 14/15] =?UTF-8?q?=F0=9F=A7=B9=20Clean=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- yolo/utils/deploy_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/yolo/utils/deploy_utils.py b/yolo/utils/deploy_utils.py index 6d4c507b..d8dbe3f1 100644 --- a/yolo/utils/deploy_utils.py +++ b/yolo/utils/deploy_utils.py @@ -147,7 +147,7 @@ def coreml_forward(self, x: Tensor): models.MLModel.__call__ = coreml_forward - if True or not Path(self.model_path).exists(): + if not Path(self.model_path).exists(): self._create_coreml_model() try: @@ -156,7 +156,7 @@ def coreml_forward(self, x: Tensor): except FileNotFoundError: logger.warning(f"🈳 No found model weight at {self.model_path}") return None - + return model_coreml def _create_tflite_model(self): From e002e3b25a2520b287910590e2df9bec42459806 Mon Sep 17 00:00:00 2001 From: Ramon Date: Sun, 23 Feb 2025 21:06:14 +0100 Subject: [PATCH 15/15] =?UTF-8?q?=F0=9F=A7=B9=20Clean=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- yolo/utils/deploy_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yolo/utils/deploy_utils.py b/yolo/utils/deploy_utils.py index d8dbe3f1..5b28c86f 100644 --- a/yolo/utils/deploy_utils.py +++ b/yolo/utils/deploy_utils.py @@ -156,7 +156,7 @@ def coreml_forward(self, x: Tensor): except FileNotFoundError: logger.warning(f"🈳 No found model weight at {self.model_path}") return None - + return model_coreml def _create_tflite_model(self):