From 11cfb4cbf1ec400de1cf7c2eaeb8a0dfa6b32ee7 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 16 Jun 2023 20:00:26 -0400 Subject: [PATCH 1/3] Changes for yolo annotation demo --- src/deepsparse/pipeline.py | 27 ++++++++++++++------------- src/deepsparse/utils/annotate.py | 21 +++++++++++++++++---- src/deepsparse/yolo/utils/utils.py | 4 ++-- 3 files changed, 33 insertions(+), 19 deletions(-) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 88d5414992..fb6bd5ddf7 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -205,6 +205,7 @@ def __init__( self.engine = self._initialize_engine() self._batch_size = self._batch_size or 1 + self._timer = Timer() self.log( identifier=f"{SystemGroups.INFERENCE_DETAILS}/num_cores_total", @@ -218,12 +219,12 @@ def __call__(self, *args, **kwargs) -> BaseModel: "invalid kwarg engine_inputs. engine inputs determined " f"by {self.__class__.__qualname__}.parse_inputs" ) - timer = Timer() + self._timer = Timer() - timer.start(InferencePhases.TOTAL_INFERENCE) + self._timer.start(InferencePhases.TOTAL_INFERENCE) # ------ PREPROCESSING ------ - timer.start(InferencePhases.PRE_PROCESS) + self._timer.start(InferencePhases.PRE_PROCESS) # parse inputs into input_schema pipeline_inputs = self.parse_inputs(*args, **kwargs) @@ -244,7 +245,7 @@ def __call__(self, *args, **kwargs) -> BaseModel: engine_inputs, postprocess_kwargs = engine_inputs else: postprocess_kwargs = {} - timer.stop(InferencePhases.PRE_PROCESS) + self._timer.stop(InferencePhases.PRE_PROCESS) self.log( identifier="engine_inputs", @@ -253,13 +254,13 @@ def __call__(self, *args, **kwargs) -> BaseModel: ) self.log( identifier=f"{SystemGroups.PREDICTION_LATENCY}/{InferencePhases.PRE_PROCESS}_seconds", # noqa E501 - value=timer.time_delta(InferencePhases.PRE_PROCESS), + value=self._timer.time_delta(InferencePhases.PRE_PROCESS), category=MetricCategories.SYSTEM, ) # ------ INFERENCE ------ # split inputs into batches of size `self._batch_size` - timer.start(InferencePhases.ENGINE_FORWARD) + self._timer.start(InferencePhases.ENGINE_FORWARD) batches = self.split_engine_inputs(engine_inputs, self._batch_size) # submit split batches to engine threadpool @@ -267,7 +268,7 @@ def __call__(self, *args, **kwargs) -> BaseModel: # join together the batches of size `self._batch_size` engine_outputs = self.join_engine_outputs(batch_outputs) - timer.stop(InferencePhases.ENGINE_FORWARD) + self._timer.stop(InferencePhases.ENGINE_FORWARD) self.log( identifier=f"{SystemGroups.INFERENCE_DETAILS}/input_batch_size_total", @@ -286,12 +287,12 @@ def __call__(self, *args, **kwargs) -> BaseModel: ) self.log( identifier=f"{SystemGroups.PREDICTION_LATENCY}/{InferencePhases.ENGINE_FORWARD}_seconds", # noqa E501 - value=timer.time_delta(InferencePhases.ENGINE_FORWARD), + value=self._timer.time_delta(InferencePhases.ENGINE_FORWARD), category=MetricCategories.SYSTEM, ) # ------ POSTPROCESSING ------ - timer.start(InferencePhases.POST_PROCESS) + self._timer.start(InferencePhases.POST_PROCESS) pipeline_outputs = self.process_engine_outputs( engine_outputs, **postprocess_kwargs ) @@ -300,8 +301,8 @@ def __call__(self, *args, **kwargs) -> BaseModel: f"Outputs of {self.__class__} must be instances of " f"{self.output_schema} found output of type {type(pipeline_outputs)}" ) - timer.stop(InferencePhases.POST_PROCESS) - timer.stop(InferencePhases.TOTAL_INFERENCE) + self._timer.stop(InferencePhases.POST_PROCESS) + self._timer.stop(InferencePhases.TOTAL_INFERENCE) self.log( identifier="pipeline_outputs", @@ -310,12 +311,12 @@ def __call__(self, *args, **kwargs) -> BaseModel: ) self.log( identifier=f"{SystemGroups.PREDICTION_LATENCY}/{InferencePhases.POST_PROCESS}_seconds", # noqa E501 - value=timer.time_delta(InferencePhases.POST_PROCESS), + value=self._timer.time_delta(InferencePhases.POST_PROCESS), category=MetricCategories.SYSTEM, ) self.log( identifier=f"{SystemGroups.PREDICTION_LATENCY}/{InferencePhases.TOTAL_INFERENCE}_seconds", # noqa E501 - value=timer.time_delta(InferencePhases.TOTAL_INFERENCE), + value=self._timer.time_delta(InferencePhases.TOTAL_INFERENCE), category=MetricCategories.SYSTEM, ) diff --git a/src/deepsparse/utils/annotate.py b/src/deepsparse/utils/annotate.py index 36164a7537..ca8d18cb3b 100644 --- a/src/deepsparse/utils/annotate.py +++ b/src/deepsparse/utils/annotate.py @@ -20,12 +20,14 @@ import os import shutil import time +from collections import deque from copy import copy from pathlib import Path from typing import Any, Callable, Iterable, Iterator, List, Optional, Tuple, Union import numpy +from deepsparse.timing import InferencePhases from sparsezoo.utils import create_dirs @@ -41,6 +43,19 @@ __all__ = ["get_image_loader_and_saver", "get_annotations_save_dir", "annotate"] +class AverageFPS: + def __init__(self, num_samples=20): + self.frame_times = deque(maxlen=num_samples) + + def measure(self, duration): + self.frame_times.append(duration) + + def calculate(self): + if len(self.frame_times) > 1: + return numpy.average(self.frame_times) + else: + return 0.0 +afps = AverageFPS() def get_image_loader_and_saver( path: str, @@ -364,13 +379,11 @@ def annotate( if isinstance(original_image, str): original_image = cv2.imread(image) - if target_fps is None and calc_fps: - start = time.perf_counter() - pipeline_output = pipeline(images=[image]) if target_fps is None and calc_fps: - target_fps = 1 / (time.perf_counter() - start) + afps.measure(1 / pipeline._timer.time_delta(InferencePhases.ENGINE_FORWARD)) + target_fps = afps.calculate() result = annotation_func( image=original_image, diff --git a/src/deepsparse/yolo/utils/utils.py b/src/deepsparse/yolo/utils/utils.py index e778fabe17..ff71fca2de 100644 --- a/src/deepsparse/yolo/utils/utils.py +++ b/src/deepsparse/yolo/utils/utils.py @@ -410,12 +410,12 @@ def modify_yolo_onnx_input_shape( f"at {model_path} with the new input shape" ) save_onnx(model, model_path) - return model_path + return model_path, None else: _LOGGER.info( "Saving the ONNX model with the " "new input shape to a temporary file" ) - return save_onnx_to_temp_files(model, with_external_data=not inplace) + return save_onnx_to_temp_files(model, with_external_data=not inplace), None def get_tensor_dim_shape(tensor: onnx.TensorProto, dim: int) -> int: From a121d9934e2c6be1e6c8a9a1644a3bf76f3ba3d6 Mon Sep 17 00:00:00 2001 From: mgoin Date: Sun, 18 Jun 2023 19:57:36 -0400 Subject: [PATCH 2/3] More edits --- src/deepsparse/pipeline.py | 35 +++++++++++------- src/deepsparse/utils/annotate.py | 23 +++++++----- src/deepsparse/yolo/annotate.py | 5 ++- src/deepsparse/yolo/utils/utils.py | 57 ++++++++++++++++++++++++++---- src/deepsparse/yolov8/annotate.py | 5 ++- 5 files changed, 94 insertions(+), 31 deletions(-) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index fb6bd5ddf7..6de288c414 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -219,12 +219,12 @@ def __call__(self, *args, **kwargs) -> BaseModel: "invalid kwarg engine_inputs. engine inputs determined " f"by {self.__class__.__qualname__}.parse_inputs" ) - self._timer = Timer() + timer = Timer() - self._timer.start(InferencePhases.TOTAL_INFERENCE) + timer.start(InferencePhases.TOTAL_INFERENCE) # ------ PREPROCESSING ------ - self._timer.start(InferencePhases.PRE_PROCESS) + timer.start(InferencePhases.PRE_PROCESS) # parse inputs into input_schema pipeline_inputs = self.parse_inputs(*args, **kwargs) @@ -245,7 +245,7 @@ def __call__(self, *args, **kwargs) -> BaseModel: engine_inputs, postprocess_kwargs = engine_inputs else: postprocess_kwargs = {} - self._timer.stop(InferencePhases.PRE_PROCESS) + timer.stop(InferencePhases.PRE_PROCESS) self.log( identifier="engine_inputs", @@ -254,13 +254,13 @@ def __call__(self, *args, **kwargs) -> BaseModel: ) self.log( identifier=f"{SystemGroups.PREDICTION_LATENCY}/{InferencePhases.PRE_PROCESS}_seconds", # noqa E501 - value=self._timer.time_delta(InferencePhases.PRE_PROCESS), + value=timer.time_delta(InferencePhases.PRE_PROCESS), category=MetricCategories.SYSTEM, ) # ------ INFERENCE ------ # split inputs into batches of size `self._batch_size` - self._timer.start(InferencePhases.ENGINE_FORWARD) + timer.start(InferencePhases.ENGINE_FORWARD) batches = self.split_engine_inputs(engine_inputs, self._batch_size) # submit split batches to engine threadpool @@ -268,7 +268,7 @@ def __call__(self, *args, **kwargs) -> BaseModel: # join together the batches of size `self._batch_size` engine_outputs = self.join_engine_outputs(batch_outputs) - self._timer.stop(InferencePhases.ENGINE_FORWARD) + timer.stop(InferencePhases.ENGINE_FORWARD) self.log( identifier=f"{SystemGroups.INFERENCE_DETAILS}/input_batch_size_total", @@ -287,12 +287,12 @@ def __call__(self, *args, **kwargs) -> BaseModel: ) self.log( identifier=f"{SystemGroups.PREDICTION_LATENCY}/{InferencePhases.ENGINE_FORWARD}_seconds", # noqa E501 - value=self._timer.time_delta(InferencePhases.ENGINE_FORWARD), + value=timer.time_delta(InferencePhases.ENGINE_FORWARD), category=MetricCategories.SYSTEM, ) # ------ POSTPROCESSING ------ - self._timer.start(InferencePhases.POST_PROCESS) + timer.start(InferencePhases.POST_PROCESS) pipeline_outputs = self.process_engine_outputs( engine_outputs, **postprocess_kwargs ) @@ -301,8 +301,8 @@ def __call__(self, *args, **kwargs) -> BaseModel: f"Outputs of {self.__class__} must be instances of " f"{self.output_schema} found output of type {type(pipeline_outputs)}" ) - self._timer.stop(InferencePhases.POST_PROCESS) - self._timer.stop(InferencePhases.TOTAL_INFERENCE) + timer.stop(InferencePhases.POST_PROCESS) + timer.stop(InferencePhases.TOTAL_INFERENCE) self.log( identifier="pipeline_outputs", @@ -311,15 +311,17 @@ def __call__(self, *args, **kwargs) -> BaseModel: ) self.log( identifier=f"{SystemGroups.PREDICTION_LATENCY}/{InferencePhases.POST_PROCESS}_seconds", # noqa E501 - value=self._timer.time_delta(InferencePhases.POST_PROCESS), + value=timer.time_delta(InferencePhases.POST_PROCESS), category=MetricCategories.SYSTEM, ) self.log( identifier=f"{SystemGroups.PREDICTION_LATENCY}/{InferencePhases.TOTAL_INFERENCE}_seconds", # noqa E501 - value=self._timer.time_delta(InferencePhases.TOTAL_INFERENCE), + value=timer.time_delta(InferencePhases.TOTAL_INFERENCE), category=MetricCategories.SYSTEM, ) + self._timer = timer + return pipeline_outputs @staticmethod @@ -705,6 +707,13 @@ def engine_type(self) -> str: """ return self._engine_type + @property + def timer(self) -> Timer: + """ + :return: reference to timer used for latest inference + """ + return self._timer + def to_config(self) -> "PipelineConfig": """ :return: PipelineConfig that can be used to reload this object diff --git a/src/deepsparse/utils/annotate.py b/src/deepsparse/utils/annotate.py index ca8d18cb3b..9642e43c76 100644 --- a/src/deepsparse/utils/annotate.py +++ b/src/deepsparse/utils/annotate.py @@ -76,6 +76,19 @@ def get_image_loader_and_saver( image_batch, video, or web-cam based on path given, and a boolean value that is True is the returned objects load videos """ + # webcam + if path.isnumeric(): + loader = WebcamLoader(int(path), image_shape) + saver = ( + VideoSaver(save_dir, 30, loader.original_frame_size, None) + if not no_save + else None + ) + return loader, saver, True + + if no_save: + print("no_save ignored since not using webcam") + # video if path.endswith(".mp4"): loader = VideoLoader(path, image_shape) @@ -86,15 +99,7 @@ def get_image_loader_and_saver( target_fps, ) return loader, saver, True - # webcam - if path.isnumeric(): - loader = WebcamLoader(int(path), image_shape) - saver = ( - VideoSaver(save_dir, 30, loader.original_frame_size, None) - if not no_save - else None - ) - return loader, saver, True + # image file(s) return ImageLoader(path, image_shape), ImageSaver(save_dir), False diff --git a/src/deepsparse/yolo/annotate.py b/src/deepsparse/yolo/annotate.py index ffecbf6b51..3752873e1e 100644 --- a/src/deepsparse/yolo/annotate.py +++ b/src/deepsparse/yolo/annotate.py @@ -219,8 +219,11 @@ def main( ) if is_webcam: + cv2.namedWindow("annotated", cv2.WINDOW_NORMAL) cv2.imshow("annotated", annotated_image) - cv2.waitKey(1) + ch = cv2.waitKey(1) + if ch == 27 or ch == ord("q") or ch == ord("Q"): + break # save if saver: diff --git a/src/deepsparse/yolo/utils/utils.py b/src/deepsparse/yolo/utils/utils.py index ff71fca2de..e31ec21789 100644 --- a/src/deepsparse/yolo/utils/utils.py +++ b/src/deepsparse/yolo/utils/utils.py @@ -461,9 +461,11 @@ def annotate_image( img_res = numpy.copy(image) + num_ppl = 0 for idx in range(len(boxes)): label = labels[idx] if scores[idx] > score_threshold: + num_ppl += 1 if label == "person" else 0 annotation_text = f"{label}: {scores[idx]:.0%}" # bounding box points @@ -509,13 +511,22 @@ def annotate_image( ) if images_per_sec is not None: - img_res = _plot_fps( - img_res=img_res, - images_per_sec=images_per_sec, - x=20, - y=30, - font_scale=0.9, - thickness=2, + # img_res = _plot_fps( + # img_res=img_res, + # images_per_sec=images_per_sec, + # x=20, + # y=30, + # font_scale=0.9, + # thickness=2, + # ) + img_res = _draw_text( + img_res, + f"FPS: {images_per_sec:0.1f} | People Count: {num_ppl} | YOLOv8 on DeepSparse", + pos=(10, 10), + font_scale=0.7, + text_color=(204, 85, 17), + text_color_bg=(255, 255, 255), + font_thickness=2, ) return img_res @@ -557,3 +568,35 @@ def _plot_fps( cv2.LINE_AA, ) return img_res + + +def _draw_text( + img: numpy.ndarray, + text: str, + font=cv2.FONT_HERSHEY_SIMPLEX, + pos=(0, 0), + font_scale=1, + font_thickness=2, + text_color=(0, 255, 0), + text_color_bg=(0, 0, 0), +): + + offset = (5, 5) + x, y = pos + text_size, _ = cv2.getTextSize(text, font, font_scale, font_thickness) + text_w, text_h = text_size + rec_start = tuple(x - y for x, y in zip(pos, offset)) + rec_end = tuple(x + y for x, y in zip((x + text_w, y + text_h), offset)) + cv2.rectangle(img, rec_start, rec_end, text_color_bg, -1) + cv2.putText( + img, + text, + (x, int(y + text_h + font_scale - 1)), + font, + font_scale, + text_color, + font_thickness, + cv2.LINE_AA, + ) + + return text_size \ No newline at end of file diff --git a/src/deepsparse/yolov8/annotate.py b/src/deepsparse/yolov8/annotate.py index 3140311d1e..12a518d936 100644 --- a/src/deepsparse/yolov8/annotate.py +++ b/src/deepsparse/yolov8/annotate.py @@ -230,8 +230,11 @@ def main( ) if is_webcam: + cv2.namedWindow("annotations", cv2.WINDOW_NORMAL) cv2.imshow("annotated", annotated_image) - cv2.waitKey(1) + ch = cv2.waitKey(1) + if ch == 27 or ch == ord("q") or ch == ord("Q"): + break # save if saver: From 5925cd472a2324fab64fb871e0c5366b07c09909 Mon Sep 17 00:00:00 2001 From: mgoin Date: Sun, 18 Jun 2023 20:01:48 -0400 Subject: [PATCH 3/3] Format --- src/deepsparse/utils/annotate.py | 4 ++++ src/deepsparse/yolo/utils/utils.py | 2 +- src/deepsparse/yolov8/annotate.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/deepsparse/utils/annotate.py b/src/deepsparse/utils/annotate.py index 9642e43c76..ef560064ed 100644 --- a/src/deepsparse/utils/annotate.py +++ b/src/deepsparse/utils/annotate.py @@ -43,6 +43,7 @@ __all__ = ["get_image_loader_and_saver", "get_annotations_save_dir", "annotate"] + class AverageFPS: def __init__(self, num_samples=20): self.frame_times = deque(maxlen=num_samples) @@ -55,8 +56,11 @@ def calculate(self): return numpy.average(self.frame_times) else: return 0.0 + + afps = AverageFPS() + def get_image_loader_and_saver( path: str, save_dir: str, diff --git a/src/deepsparse/yolo/utils/utils.py b/src/deepsparse/yolo/utils/utils.py index e31ec21789..7261d71e11 100644 --- a/src/deepsparse/yolo/utils/utils.py +++ b/src/deepsparse/yolo/utils/utils.py @@ -599,4 +599,4 @@ def _draw_text( cv2.LINE_AA, ) - return text_size \ No newline at end of file + return text_size diff --git a/src/deepsparse/yolov8/annotate.py b/src/deepsparse/yolov8/annotate.py index 12a518d936..deace00f3b 100644 --- a/src/deepsparse/yolov8/annotate.py +++ b/src/deepsparse/yolov8/annotate.py @@ -230,7 +230,7 @@ def main( ) if is_webcam: - cv2.namedWindow("annotations", cv2.WINDOW_NORMAL) + cv2.namedWindow("annotated", cv2.WINDOW_NORMAL) cv2.imshow("annotated", annotated_image) ch = cv2.waitKey(1) if ch == 27 or ch == ord("q") or ch == ord("Q"):