Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

image preprocessing functions #185

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ coverage.xml
.hypothesis/
.pytest_cache/
cover/
outputs/

# Translations
*.mo
Expand Down
9 changes: 9 additions & 0 deletions attributions.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

Images:
taggui/images/people_landscape.webp "Group of teens at the beach" by Vladimir Pustovit https://flickr.com/photos/pustovit/14689258425 -- License: CC BY 2.0
taggui/images/people_portrait.webp "Crowd" by oatsy40 https://flickr.com/photos/oatsy40/14315923786 -- License: CC BY 2.0
taggui/images/diagram_portrait_adversarial.png "EPK komplexes Beispiel" by Florian Lindner https://commons.wikimedia.org/wiki/File:EPK_komplexes_Beispiel.png -- CC SA 2.5

Sounds:
taggui/tada.ogg "success 2" by Leszek_Szary -- https://freesound.org/s/171670/ -- License: CC0

Binary file added images/diagram_portrait_adversarial.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/people_landscape.webp
Binary file not shown.
Binary file added images/people_landscape_adversarial.webp
Binary file not shown.
Binary file added images/people_portrait.webp
Binary file not shown.
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ pillow==10.3.0
pyparsing==3.1.2
PySide6==6.7.1
transformers==4.41.2
playsound3==2.2.1
opencv-contrib-python==4.10.0.82
numpy==1.26.4

# PyTorch
torch==2.2.2; platform_system != "Windows"
Expand All @@ -24,7 +27,6 @@ xformers==0.0.25.post1

# InternLM-XComposer2
auto-gptq==0.7.1; platform_system == "Linux" or platform_system == "Windows"
numpy==1.26.4

# WD Tagger
huggingface-hub==0.23.3
Expand Down
8 changes: 5 additions & 3 deletions taggui/auto_captioning/captioning_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from models.image_list_model import ImageListModel
from utils.enums import CaptionDevice, CaptionModelType, CaptionPosition
from utils.image import Image
from utils.settings import get_tag_separator
from utils.settings import get_settings, get_tag_separator


def replace_template_variable(match: re.Match, image: Image) -> str:
Expand Down Expand Up @@ -286,6 +286,7 @@ def get_prompt(self, model_type: CaptionModelType,
def get_model_inputs(self, image: Image, prompt: str | None,
model_type: CaptionModelType, device: torch.device,
model, processor) -> BatchFeature | dict | np.ndarray:
settings = get_settings()
# Load the image.
pil_image = PilImage.open(image.path)
# Rotate the image according to the orientation tag.
Expand Down Expand Up @@ -318,9 +319,10 @@ def get_model_inputs(self, image: Image, prompt: str | None,
model_type, model, processor, text, pil_image, beam_count,
device, dtype_argument)
elif model_type == CaptionModelType.COGVLM2:
image_preprocessing_method = settings.value('image_preprocessing_method')
model_inputs = get_cogvlm2_inputs(model, processor, text,
pil_image, device,
dtype_argument, beam_count)
pil_image, image_preprocessing_method, device,
dtype_argument, beam_count)
elif model_type in (CaptionModelType.MOONDREAM1,
CaptionModelType.MOONDREAM2):
model_inputs = get_moondream_inputs(
Expand Down
7 changes: 6 additions & 1 deletion taggui/auto_captioning/cogvlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torchvision import transforms

from utils.enums import CaptionDevice
from utils.image import select_preprocess_img_by_str

LANGUAGE_TOKEN_TYPE_ID = 0
VISION_TOKEN_TYPE_ID = 1
Expand Down Expand Up @@ -34,6 +35,7 @@ def get_cogvlm2_error_message(model_id: str, device: CaptionDevice,


def get_cogvlm2_inputs(model, processor, text: str, pil_image: PilImage,
image_preprocessing_method: str,
device: torch.device, dtype_argument: dict,
beam_count: int) -> dict:
image_size = model.config.vision_config['image_size']
Expand All @@ -49,6 +51,9 @@ def get_cogvlm2_inputs(model, processor, text: str, pil_image: PilImage,
input_ids += text_ids
token_type_ids += [LANGUAGE_TOKEN_TYPE_ID] * len(text_ids)
attention_mask = [1] * len(input_ids)

preprocessed_img = select_preprocess_img_by_str(pil_image, image_size, method=image_preprocessing_method)

transform = transforms.Compose([
transforms.Resize(
(image_size, image_size),
Expand All @@ -58,7 +63,7 @@ def get_cogvlm2_inputs(model, processor, text: str, pil_image: PilImage,
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711))
])
image = transform(pil_image)
image = transform(preprocessed_img)
inputs = {
'input_ids': torch.tensor(input_ids).unsqueeze(0).to(device),
'token_type_ids': torch.tensor(token_type_ids).unsqueeze(0).to(device),
Expand Down
15 changes: 12 additions & 3 deletions taggui/dialogs/settings_dialog.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from PySide6.QtCore import Qt, Slot
from PySide6.QtWidgets import (QDialog, QFileDialog, QGridLayout, QLabel,
QLineEdit, QPushButton, QVBoxLayout)
QLineEdit, QPushButton, QVBoxLayout, QComboBox)

from utils.settings import DEFAULT_SETTINGS, get_settings
from utils.settings_widgets import (SettingsBigCheckBox, SettingsLineEdit,
Expand Down Expand Up @@ -31,6 +31,8 @@ def __init__(self, parent):
5, 0, Qt.AlignmentFlag.AlignRight)
grid_layout.addWidget(QLabel('Auto-captioning models directory'), 6, 0,
Qt.AlignmentFlag.AlignRight)
grid_layout.addWidget(QLabel('Image preprocessing function'), 8, 0,
Qt.AlignmentFlag.AlignRight)

font_size_spin_box = SettingsSpinBox(
key='font_size', default=DEFAULT_SETTINGS['font_size'],
Expand Down Expand Up @@ -77,7 +79,11 @@ def __init__(self, parent):
default=DEFAULT_SETTINGS['image_list_file_formats'])
file_types_line_edit.setMinimumWidth(400)
file_types_line_edit.textChanged.connect(self.show_restart_warning)

image_preprocessing_method_combo_box = QComboBox()
image_preprocessing_method_combo_box.addItems(["stretch-and-squish", "scale-and-centercrop", "black", "gray", "white"])
image_preprocessing_method_combo_box.setCurrentText(self.settings.value('image_preprocessing_method'))
image_preprocessing_method_combo_box.currentTextChanged.connect(
lambda text: self.settings.setValue("image_preprocessing_method", text))
grid_layout.addWidget(font_size_spin_box, 0, 1,
Qt.AlignmentFlag.AlignLeft)
grid_layout.addWidget(file_types_line_edit, 1, 1,
Expand All @@ -94,6 +100,9 @@ def __init__(self, parent):
Qt.AlignmentFlag.AlignLeft)
grid_layout.addWidget(models_directory_button, 7, 1,
Qt.AlignmentFlag.AlignLeft)
grid_layout.addWidget(image_preprocessing_method_combo_box, 8, 1,
Qt.AlignmentFlag.AlignLeft)

layout.addLayout(grid_layout)

# Prevent the grid layout from moving to the center when the warning
Expand Down Expand Up @@ -140,4 +149,4 @@ def set_models_directory_path(self):
'models',
dir=initial_directory_path)
if models_directory_path:
self.models_directory_line_edit.setText(models_directory_path)
self.models_directory_line_edit.setText(models_directory_path)
3 changes: 3 additions & 0 deletions taggui/run_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from tests import test_image

test_image.test_prepares()
Empty file added taggui/tests/__init__.py
Empty file.
20 changes: 20 additions & 0 deletions taggui/tests/test_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os

from PIL import Image
from PIL.Image import Resampling

from utils.image import select_preprocess_img_by_str

def test_prepares():
target_size = 1344
resampling = Resampling.LANCZOS
out_dir = "outputs/"

os.makedirs(out_dir, exist_ok=True)

for path in ["images/people_landscape.webp", "images/people_portrait.webp"]:
basename, ext = os.path.splitext(os.path.basename(path))
img = Image.open(path)
for method in ["stretch-and-squish", "scale-and-centercrop", "white", "gray", "black", "noise", "replicate", "reflect", "unknown"]:
ret = select_preprocess_img_by_str(img, target_size, resampling, method)
ret.save(f"{out_dir}/{basename}_{method}.webp", format='WebP', lossless=True, quality=0)
100 changes: 100 additions & 0 deletions taggui/utils/image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import random

from dataclasses import dataclass, field
from pathlib import Path

from PIL import Image as PilImage, ImageColor, ImageOps
from PIL.Image import Resampling
import cv2 as opencv
import numpy as np

from PySide6.QtGui import QIcon


Expand All @@ -10,3 +17,96 @@ class Image:
dimensions: tuple[int, int] | None
tags: list[str] = field(default_factory=list)
thumbnail: QIcon | None = None

def select_preprocess_img_by_str(pil_image: PilImage, target_size: int, resample=Resampling.LANCZOS, method="stretch-and-squish") -> PilImage:
color = None
try:
ImageColor.getrgb(method)
color = method
method = "color"
except ValueError:
pass

match method:
case "stretch-and-squish": ret = preprocess_img_stretch_and_squish(pil_image, target_size, resample)
case "scale-and-centercrop": ret = preprocess_img_scale_and_centercrop(pil_image, target_size, resample)
case "noise" | "replicate" | "reflect": ret = preprocess_img_scale_and_fill(pil_image, target_size, resample, method)
case "color": ret = preprocess_img_scale_and_fill(pil_image, target_size, resample, color)
case _: ret = pil_image

return ret

# https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters
def preprocess_img_stretch_and_squish(pil_image: PilImage, target_size: int, resample=Resampling.LANCZOS) -> PilImage:
"""Preprocesses an image for the model by simply stretching and squishing it to the target size. Does not retain shapes (see https://github.com/THUDM/CogVLM2/discussions/83)"""
ret = pil_image.resize((target_size, target_size), resample=resample)
return ret

def preprocess_img_scale_and_centercrop(pil_image: PilImage, target_size: int, resample=Resampling.LANCZOS) -> PilImage:
"""Preprocesses an image for the model by scaling the short side to target size and then center cropping a square. May crop important content especially in very rectangular images (this method was used in Stable Diffusion 1 see https://arxiv.org/abs/2112.10752)"""
width, height = pil_image.size
if width < height:
new_width = target_size
new_height = int(target_size * height / width)
else:
new_height = target_size
new_width = int(target_size * width / height)

# Resize the image with the calculated dimensions
ret = pil_image.resize((new_width, new_height), resample=resample)

# Center crop a square from the resized image (make sure that there are no off-by-one errors)
left = (new_width - target_size) / 2
top = (new_height - target_size) / 2
right = left + target_size
bottom = top + target_size
ret = ret.crop((left, top, right, bottom))
return ret

def preprocess_img_scale_and_fill(pil_image: PilImage, target_size: int, resample=Resampling.LANCZOS, method: str = "black") -> PilImage:
"""
Preprocesses an image for the model by scaling the long side to target size and filling borders of the short side with content according to method (color, noise, replicate, reflect) until it is square. Introduces new content that wasn't there before which might be caught up by the model ("This image showcases a portrait of a person. On the left and right side are black borders.")
- method: can be on of "noise", "replicate", "reflect" or a color value ("gray", "#000000", "rgb(100%,100%,100%)" etc.) which can be interpreted by Pillow (see https://pillow.readthedocs.io/en/stable/reference/ImageColor.html and https://developer.mozilla.org/en-US/docs/Web/CSS/named-color)
"""
color = None
try:
color = ImageColor.getrgb(method)
method = "color"
except ValueError:
pass

width, height = pil_image.size
if width > height:
new_width = target_size
new_height = int((new_width / width) * height)
else:
new_height = target_size
new_width = int((new_height / height) * width)

pastee = pil_image.resize((new_width, new_height), resample=resample)

if method == "color": # fill borders with color
canvas = PilImage.new("RGB", (target_size, target_size), color)
offset = ((target_size - new_width) // 2, (target_size - new_height) // 2)
canvas.paste(pastee, offset)
ret = canvas
elif method == "noise": # fill borders with RGB noise
canvas = PilImage.new("RGB", (target_size, target_size))
for x in range(target_size):
for y in range(target_size):
canvas.putpixel((x, y), (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)))
canvas.paste(pastee, ((target_size - new_width) // 2, (target_size - new_height) // 2))
ret = canvas
elif method in ("replicate", "reflect"): # fill borders with color value of the edge
left_padding = int((target_size - new_width) / 2)
top_padding = int((target_size - new_height) / 2)
right_padding = target_size - new_width - left_padding
bottom_padding = target_size - new_height - top_padding
opencv_pastee = np.array(pastee)
borderType = { "replicate": opencv.BORDER_REPLICATE, "reflect": opencv.BORDER_REFLECT }[method]
opencv_ret = opencv.copyMakeBorder(opencv_pastee, top_padding, bottom_padding, left_padding, right_padding, borderType=borderType)
ret = PilImage.fromarray(opencv_ret)
else:
raise ValueError(f"Invalid method='{method}'")

return ret
3 changes: 2 additions & 1 deletion taggui/utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
# Common image formats that are supported in PySide6.
'image_list_file_formats': 'bmp, gif, jpg, jpeg, png, tif, tiff, webp',
'image_list_image_width': 200,
'image_preprocessing_method': 'stretch-and-squish',
'tag_separator': ',',
'insert_space_after_tag_separator': True,
'autocomplete_tags': True,
'models_directory_path': ''
'models_directory_path': '',
}


Expand Down