Skip to content

Commit

Permalink
✨ Background Segmentation
Browse files Browse the repository at this point in the history
Based on CVZone, ripped out what we needed
Reduces dependencies.
  • Loading branch information
Nixxen committed Oct 22, 2022
1 parent 4e2dd05 commit 944d38d
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 26 deletions.
39 changes: 25 additions & 14 deletions Detector/Detector.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import cv2
import mediapipe as mp
import time
import cv2 # type: ignore
import mediapipe as mp # type: ignore


class Detector:
def __init__(self, mode=False,
upBody=False,
smooth=True,
detectionCon=0.5,
trackCon=0.5):
def __init__(
self,
mode=False,
upBody=False,
smoothBody=False,
smooth=True,
detectionCon=0.5,
trackCon=0.5,
):

self.mode = mode
self.upBody = upBody
self.smoothBody = smoothBody
self.detectionCon = detectionCon
self.trackCon = trackCon
self.smooth = smooth
Expand All @@ -21,11 +26,14 @@ def __init__(self, mode=False,
self.mp_drawing = mp.solutions.drawing_utils
# Importing the pose estimation models
self.mp_pose = mp.solutions.pose
self.pose = self.mp_pose.Pose(static_image_mode=self.mode,
enable_segmentation=self.upBody,
smooth_landmarks=self.smooth,
min_detection_confidence=self.detectionCon,
min_tracking_confidence=self.trackCon)
self.pose = self.mp_pose.Pose(
static_image_mode=self.mode,
enable_segmentation=self.upBody,
smooth_segmentation=self.smoothBody,
smooth_landmarks=self.smooth,
min_detection_confidence=self.detectionCon,
min_tracking_confidence=self.trackCon,
)

def make_detections(self, frame):
# Recolor the frame (opencv gives the image in BGR format. while mediapipe uses images in RGB format)
Expand Down Expand Up @@ -65,6 +73,9 @@ def draw_pose_pose_landmark(self, frame, results):
color=(245, 66, 230), thickness=2, circle_radius=2
),
)

def mask_point(self, frame, pointID, lmList):
if len(lmList) != 0:
cv2.circle(frame, (lmList[pointID][1], lmList[pointID][2]), 40, (255, 0, 0), 4)
cv2.circle(
frame, (lmList[pointID][1], lmList[pointID][2]), 40, (255, 0, 0), 4
)
42 changes: 42 additions & 0 deletions SelfieSegmentation/selfie_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
Selfisegmentation from CVZone
By: Computer Vision Zone
Website: https://www.computervision.zone/
With slight modifications for our project
"""
import cv2 # type: ignore
import mediapipe as mp # type: ignore
import numpy as np


class SelfieSegmentation:
def __init__(self, model=1):
"""
:param model: model type 0 or 1. 0 is general 1 is landscape(faster)
"""
self.model = model
self.mpDraw = mp.solutions.drawing_utils
self.mpSelfieSegmentation = mp.solutions.selfie_segmentation
self.selfieSegmentation = self.mpSelfieSegmentation.SelfieSegmentation(
self.model
)

def removeBG(self, img, imgBg=(255, 255, 255), threshold=0.1):
"""
:param img: image to remove background from
:param imgBg: BackGround Image
:param threshold: higher = more cut, lower = less cut
:return:
"""
imgRGB = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
results = self.selfieSegmentation.process(imgRGB)
condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > threshold
if isinstance(imgBg, tuple):
_imgBg = np.zeros(img.shape, dtype=np.uint8)
_imgBg[:] = imgBg
imgOut = np.where(condition, img, _imgBg)
else:
imgOut = np.where(condition, img, imgBg)
return imgOut
50 changes: 50 additions & 0 deletions Utility/fps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""
FPS Module
By: Computer Vision Zone
Website: https://www.computervision.zone/
Slightly modified for our project
"""

import time

import cv2 # type: ignore


class FPS:
"""
Helps in finding Frames Per Second and display on an OpenCV Image
"""

def __init__(self):
self.pTime = time.time()

def update(self, img=None, pos=(20, 50), color=(255, 0, 0), scale=3, thickness=3):
"""
Update the frame rate
:param img: Image to display on, can be left blank if only fps value required
:param pos: Position on the FPS on the image
:param color: Color of the FPS Value displayed
:param scale: Scale of the FPS Value displayed
:param thickness: Thickness of the FPS Value displayed
:return:
"""
cTime = time.time()
try:
fps = 1 / (cTime - self.pTime)
self.pTime = cTime
if img is None:
return fps
else:
cv2.putText(
img,
f"FPS: {int(fps)}",
pos,
cv2.FONT_HERSHEY_PLAIN,
scale,
color,
thickness,
)
return fps, img
except ZeroDivisionError:
return 0
22 changes: 22 additions & 0 deletions Utility/utility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
General collection of utility functions
"""
import cv2 # type: ignore


def whiteness_offset(img) -> float:
"""Uses the amount of white vs other noise in the image to decide a threshold for background removal
Args:
img (np.nparray): The image to calculate the threshold for
Returns:
float: The threshold to use for background removal
"""
# Convert the image to grayscale
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# Calculate the average pixel value
avg = gray.mean()
# Calculate the threshold
thresh = avg / 255
return thresh
52 changes: 40 additions & 12 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
"""
Prof of concept for a workout assistant, utilizing Mediapipe, OpenCV, and
our own custom code (some based on CVZone).
Detects repetitions for pushups.
"""

import cv2 # type: ignore
import mediapipe as mp # type: ignore
import numpy as np

from functions.calculate_angle_between_points import calculate_angle_between_points
from Detector.Detector import Detector
from Detector.Detector import Detector # type: ignore
from functions.calculate_angle_between_points import ( # type: ignore
calculate_angle_between_points,
)
from SelfieSegmentation.selfie_segmentation import SelfieSegmentation # type: ignore
from Utility.fps import FPS # type: ignore
from Utility.utility import whiteness_offset # type: ignore

# Gives us all the drawing utilities. Going to be used to visualize the poses
mp_drawing = mp.solutions.drawing_utils
Expand All @@ -13,27 +24,40 @@

if __name__ == "__main__":
# instance of the detector class
detector = Detector()
detector = Detector(upBody=True, smoothBody=True)
# Initialize the SelfieSegmentationModule
segmenter = SelfieSegmentation()

# Initialize the FPS reader for displaying on the final image
fps_injector = FPS()

# counter for reps
counter = 0
counter: int = 0
# determine we are now on the up or down of the curl exercise
stage = None
stage: None | str = None

# Video Feed
# setting up the video capture device. The number represents the camera (can change from device to device)
# Video Feed setting up the video capture device. The number represents the
# camera (can change from device to device)
cap = cv2.VideoCapture(0)

# Accesses a pose detection model with detection and tracking confidence of 50%
# Accesses a pose detection model with detection and tracking confidence of
# 50%
with mp_pose.Pose(
min_detection_confidence=0.5, min_tracking_confidence=0.5
) as my_pose:

while cap.isOpened():
# Stores what ever we get from the capture (ret is return variable (nothing here) and frame is the image)
# Stores what ever we get from the capture (ret is return variable
# (nothing here) and frame is the image)
ret, my_frame = cap.read()

my_image, my_results = detector.make_detections(my_frame)
threshold = whiteness_offset(my_frame)
bg_image = cv2.GaussianBlur(my_frame, (55, 55), 0)
clean_img = segmenter.removeBG(
my_frame, imgBg=bg_image, threshold=threshold
)

my_image, my_results = detector.make_detections(clean_img)

# Extract landmarks
try:
Expand Down Expand Up @@ -77,7 +101,8 @@
counter += 1
print(counter)

except:
except TypeError:
# If there is no pose detected (NoneType error), pass
pass

# Visualize the curl counter in a box
Expand Down Expand Up @@ -129,14 +154,17 @@
1,
cv2.LINE_AA,
)
lmList = detector.get_interest_points(frame = my_image, results=my_results)
lmList = detector.get_interest_points(frame=my_image, results=my_results)
print(lmList)

detector.mask_point(frame=my_image, lmList=lmList, pointID=13)

# Draws the pose landmarks and the connections between them to the image
detector.draw_pose_pose_landmark(frame=my_image, results=my_results)

# Inject the FPS onto the frame
fps_injector.update(my_image, (20, 200))

# Shows the image with the landmarks on them (after the processing)
cv2.imshow("Mediapipe Feed", my_image)
# Breaks the loop if you hit q
Expand Down

0 comments on commit 944d38d

Please sign in to comment.