diff --git a/Detector/Detector.py b/Detector/Detector.py index 51daaa1..8cc4919 100644 --- a/Detector/Detector.py +++ b/Detector/Detector.py @@ -1,16 +1,21 @@ -import cv2 -import mediapipe as mp -import time +import cv2 # type: ignore +import mediapipe as mp # type: ignore + class Detector: - def __init__(self, mode=False, - upBody=False, - smooth=True, - detectionCon=0.5, - trackCon=0.5): + def __init__( + self, + mode=False, + upBody=False, + smoothBody=False, + smooth=True, + detectionCon=0.5, + trackCon=0.5, + ): self.mode = mode self.upBody = upBody + self.smoothBody = smoothBody self.detectionCon = detectionCon self.trackCon = trackCon self.smooth = smooth @@ -21,11 +26,14 @@ def __init__(self, mode=False, self.mp_drawing = mp.solutions.drawing_utils # Importing the pose estimation models self.mp_pose = mp.solutions.pose - self.pose = self.mp_pose.Pose(static_image_mode=self.mode, - enable_segmentation=self.upBody, - smooth_landmarks=self.smooth, - min_detection_confidence=self.detectionCon, - min_tracking_confidence=self.trackCon) + self.pose = self.mp_pose.Pose( + static_image_mode=self.mode, + enable_segmentation=self.upBody, + smooth_segmentation=self.smoothBody, + smooth_landmarks=self.smooth, + min_detection_confidence=self.detectionCon, + min_tracking_confidence=self.trackCon, + ) def make_detections(self, frame): # Recolor the frame (opencv gives the image in BGR format. while mediapipe uses images in RGB format) @@ -65,6 +73,9 @@ def draw_pose_pose_landmark(self, frame, results): color=(245, 66, 230), thickness=2, circle_radius=2 ), ) + def mask_point(self, frame, pointID, lmList): if len(lmList) != 0: - cv2.circle(frame, (lmList[pointID][1], lmList[pointID][2]), 40, (255, 0, 0), 4) \ No newline at end of file + cv2.circle( + frame, (lmList[pointID][1], lmList[pointID][2]), 40, (255, 0, 0), 4 + ) diff --git a/SelfieSegmentation/selfie_segmentation.py b/SelfieSegmentation/selfie_segmentation.py new file mode 100644 index 0000000..6450a5c --- /dev/null +++ b/SelfieSegmentation/selfie_segmentation.py @@ -0,0 +1,42 @@ +""" +Selfisegmentation from CVZone +By: Computer Vision Zone +Website: https://www.computervision.zone/ + +With slight modifications for our project +""" +import cv2 # type: ignore +import mediapipe as mp # type: ignore +import numpy as np + + +class SelfieSegmentation: + def __init__(self, model=1): + """ + :param model: model type 0 or 1. 0 is general 1 is landscape(faster) + """ + self.model = model + self.mpDraw = mp.solutions.drawing_utils + self.mpSelfieSegmentation = mp.solutions.selfie_segmentation + self.selfieSegmentation = self.mpSelfieSegmentation.SelfieSegmentation( + self.model + ) + + def removeBG(self, img, imgBg=(255, 255, 255), threshold=0.1): + """ + + :param img: image to remove background from + :param imgBg: BackGround Image + :param threshold: higher = more cut, lower = less cut + :return: + """ + imgRGB = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + results = self.selfieSegmentation.process(imgRGB) + condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > threshold + if isinstance(imgBg, tuple): + _imgBg = np.zeros(img.shape, dtype=np.uint8) + _imgBg[:] = imgBg + imgOut = np.where(condition, img, _imgBg) + else: + imgOut = np.where(condition, img, imgBg) + return imgOut diff --git a/Utility/fps.py b/Utility/fps.py new file mode 100644 index 0000000..3c747dc --- /dev/null +++ b/Utility/fps.py @@ -0,0 +1,50 @@ +""" +FPS Module +By: Computer Vision Zone +Website: https://www.computervision.zone/ + +Slightly modified for our project +""" + +import time + +import cv2 # type: ignore + + +class FPS: + """ + Helps in finding Frames Per Second and display on an OpenCV Image + """ + + def __init__(self): + self.pTime = time.time() + + def update(self, img=None, pos=(20, 50), color=(255, 0, 0), scale=3, thickness=3): + """ + Update the frame rate + :param img: Image to display on, can be left blank if only fps value required + :param pos: Position on the FPS on the image + :param color: Color of the FPS Value displayed + :param scale: Scale of the FPS Value displayed + :param thickness: Thickness of the FPS Value displayed + :return: + """ + cTime = time.time() + try: + fps = 1 / (cTime - self.pTime) + self.pTime = cTime + if img is None: + return fps + else: + cv2.putText( + img, + f"FPS: {int(fps)}", + pos, + cv2.FONT_HERSHEY_PLAIN, + scale, + color, + thickness, + ) + return fps, img + except ZeroDivisionError: + return 0 diff --git a/Utility/utility.py b/Utility/utility.py new file mode 100644 index 0000000..2bdf6ef --- /dev/null +++ b/Utility/utility.py @@ -0,0 +1,22 @@ +""" +General collection of utility functions +""" +import cv2 # type: ignore + + +def whiteness_offset(img) -> float: + """Uses the amount of white vs other noise in the image to decide a threshold for background removal + + Args: + img (np.nparray): The image to calculate the threshold for + + Returns: + float: The threshold to use for background removal + """ + # Convert the image to grayscale + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + # Calculate the average pixel value + avg = gray.mean() + # Calculate the threshold + thresh = avg / 255 + return thresh diff --git a/main.py b/main.py index e26895a..a24ad65 100644 --- a/main.py +++ b/main.py @@ -1,9 +1,20 @@ +""" +Prof of concept for a workout assistant, utilizing Mediapipe, OpenCV, and +our own custom code (some based on CVZone). +Detects repetitions for pushups. +""" + import cv2 # type: ignore import mediapipe as mp # type: ignore import numpy as np -from functions.calculate_angle_between_points import calculate_angle_between_points -from Detector.Detector import Detector +from Detector.Detector import Detector # type: ignore +from functions.calculate_angle_between_points import ( # type: ignore + calculate_angle_between_points, +) +from SelfieSegmentation.selfie_segmentation import SelfieSegmentation # type: ignore +from Utility.fps import FPS # type: ignore +from Utility.utility import whiteness_offset # type: ignore # Gives us all the drawing utilities. Going to be used to visualize the poses mp_drawing = mp.solutions.drawing_utils @@ -13,27 +24,40 @@ if __name__ == "__main__": # instance of the detector class - detector = Detector() + detector = Detector(upBody=True, smoothBody=True) + # Initialize the SelfieSegmentationModule + segmenter = SelfieSegmentation() + + # Initialize the FPS reader for displaying on the final image + fps_injector = FPS() # counter for reps - counter = 0 + counter: int = 0 # determine we are now on the up or down of the curl exercise - stage = None + stage: None | str = None - # Video Feed - # setting up the video capture device. The number represents the camera (can change from device to device) + # Video Feed setting up the video capture device. The number represents the + # camera (can change from device to device) cap = cv2.VideoCapture(0) - # Accesses a pose detection model with detection and tracking confidence of 50% + # Accesses a pose detection model with detection and tracking confidence of + # 50% with mp_pose.Pose( min_detection_confidence=0.5, min_tracking_confidence=0.5 ) as my_pose: while cap.isOpened(): - # Stores what ever we get from the capture (ret is return variable (nothing here) and frame is the image) + # Stores what ever we get from the capture (ret is return variable + # (nothing here) and frame is the image) ret, my_frame = cap.read() - my_image, my_results = detector.make_detections(my_frame) + threshold = whiteness_offset(my_frame) + bg_image = cv2.GaussianBlur(my_frame, (55, 55), 0) + clean_img = segmenter.removeBG( + my_frame, imgBg=bg_image, threshold=threshold + ) + + my_image, my_results = detector.make_detections(clean_img) # Extract landmarks try: @@ -77,7 +101,8 @@ counter += 1 print(counter) - except: + except TypeError: + # If there is no pose detected (NoneType error), pass pass # Visualize the curl counter in a box @@ -129,7 +154,7 @@ 1, cv2.LINE_AA, ) - lmList = detector.get_interest_points(frame = my_image, results=my_results) + lmList = detector.get_interest_points(frame=my_image, results=my_results) print(lmList) detector.mask_point(frame=my_image, lmList=lmList, pointID=13) @@ -137,6 +162,9 @@ # Draws the pose landmarks and the connections between them to the image detector.draw_pose_pose_landmark(frame=my_image, results=my_results) + # Inject the FPS onto the frame + fps_injector.update(my_image, (20, 200)) + # Shows the image with the landmarks on them (after the processing) cv2.imshow("Mediapipe Feed", my_image) # Breaks the loop if you hit q