Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizations for deepface.recognition.find, Optimization and New Iterator Functionality in image_utils #1420

Merged
merged 11 commits into from
Jan 7, 2025
Merged
36 changes: 26 additions & 10 deletions deepface/commons/image_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# built-in dependencies
import os
import io
from typing import List, Union, Tuple
from typing import Generator, List, Union, Tuple
import hashlib
import base64
from pathlib import Path
Expand All @@ -14,6 +14,10 @@
from werkzeug.datastructures import FileStorage


IMAGE_EXTS = {".jpg", ".jpeg", ".png"}
PIL_EXTS = {"jpeg", "png"}


def list_images(path: str) -> List[str]:
"""
List images in a given path
Expand All @@ -25,17 +29,29 @@ def list_images(path: str) -> List[str]:
images = []
for r, _, f in os.walk(path):
for file in f:
exact_path = os.path.join(r, file)

ext_lower = os.path.splitext(exact_path)[-1].lower()
if os.path.splitext(file)[1].lower() in IMAGE_EXTS:
exact_path = os.path.join(r, file)
with Image.open(exact_path) as img: # lazy
if img.format.lower() in PIL_EXTS:
images.append(exact_path)
return images

if ext_lower not in {".jpg", ".jpeg", ".png"}:
continue

with Image.open(exact_path) as img: # lazy
if img.format.lower() in {"jpeg", "png"}:
images.append(exact_path)
return images
def yield_images(path: str) -> Generator[str, None, None]:
"""
Yield images in a given path
Args:
path (str): path's location
Yields:
image (str): image path
"""
for r, _, f in os.walk(path):
for file in f:
if os.path.splitext(file)[1].lower() in IMAGE_EXTS:
exact_path = os.path.join(r, file)
with Image.open(exact_path) as img: # lazy
if img.format.lower() in PIL_EXTS:
yield exact_path


def find_image_hash(file_path: str) -> str:
Expand Down
20 changes: 11 additions & 9 deletions deepface/modules/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,15 @@ def find(
representations = []

# required columns for representations
df_cols = [
df_cols = {
"identity",
"hash",
"embedding",
"target_x",
"target_y",
"target_w",
"target_h",
]
}

# Ensure the proper pickle file exists
if not os.path.exists(datastore_path):
Expand All @@ -157,18 +157,15 @@ def find(

# check each item of representations list has required keys
for i, current_representation in enumerate(representations):
missing_keys = set(df_cols) - set(current_representation.keys())
missing_keys = df_cols - set(current_representation.keys())
if len(missing_keys) > 0:
raise ValueError(
f"{i}-th item does not have some required keys - {missing_keys}."
f"Consider to delete {datastore_path}"
)

# embedded images
pickled_images = [representation["identity"] for representation in representations]

# Get the list of images on storage
storage_images = image_utils.list_images(path=db_path)
storage_images = set(image_utils.yield_images(path=db_path))
serengil marked this conversation as resolved.
Show resolved Hide resolved

if len(storage_images) == 0 and refresh_database is True:
raise ValueError(f"No item found in {db_path}")
Expand All @@ -186,8 +183,13 @@ def find(

# Enforce data consistency amongst on disk images and pickle file
if refresh_database:
new_images = set(storage_images) - set(pickled_images) # images added to storage
old_images = set(pickled_images) - set(storage_images) # images removed from storage
# embedded images
pickled_images = {
representation["identity"] for representation in representations
}

new_images = storage_images - pickled_images # images added to storage
old_images = pickled_images - storage_images # images removed from storage

# detect replaced images
for current_representation in representations:
Expand Down
17 changes: 14 additions & 3 deletions tests/test_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,23 @@ def test_filetype_for_find():


def test_filetype_for_find_bulk_embeddings():
imgs = image_utils.list_images("dataset")
# List
list_imgs = image_utils.list_images("dataset")

assert len(imgs) > 0
assert len(list_imgs) > 0

# img47 is webp even though its extension is jpg
assert "dataset/img47.jpg" not in imgs
assert "dataset/img47.jpg" not in list_imgs

# Generator
gen_imgs = list(image_utils.yield_images("dataset"))

assert len(gen_imgs) > 0

# img47 is webp even though its extension is jpg
assert "dataset/img47.jpg" not in gen_imgs

assert gen_imgs == list_imgs
serengil marked this conversation as resolved.
Show resolved Hide resolved


def test_find_without_refresh_database():
Expand Down
Loading