Skip to content

Visualize Ultralytics Yolo models #900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/computer_vision/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
.datachain
output
*.pt
30 changes: 21 additions & 9 deletions examples/computer_vision/ultralytics-bbox.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,37 @@
import os

os.environ["YOLO_VERBOSE"] = "false"


from io import BytesIO

from numpy import asarray
from PIL import Image
from ultralytics import YOLO

from datachain import C, DataChain, File
from datachain import DataChain, File
from datachain.model.ultralytics import YoloBBoxes
from datachain.toolkit.ultralytics import visualize_yolo

OUTPUT_DIR = "output/bbox"


def process_bboxes(yolo: YOLO, file: File) -> YoloBBoxes:
results = yolo(Image.open(BytesIO(file.read())))
return YoloBBoxes.from_results(results)
# read image
img = Image.open(BytesIO(file.read()))

# detect objects using YOLO model
results = yolo(img, verbose=False)
# convert results to YoloBBoxes signal
signal = YoloBBoxes.from_results(results)

# visualize results
img2 = visualize_yolo(asarray(img), signal)
img2.save(f"{OUTPUT_DIR}/{file.get_file_stem()}.jpg")

return signal


os.makedirs(OUTPUT_DIR, exist_ok=True)

(
DataChain.from_storage("gs://datachain-demo/openimages-v6-test-jsonpairs/")
.filter(C("file.path").glob("*.jpg"))
DataChain.from_storage("gs://datachain-demo/coco2017/images")
.limit(20)
.setup(yolo=lambda: YOLO("yolo11n.pt"))
.map(boxes=process_bboxes)
Expand Down
30 changes: 21 additions & 9 deletions examples/computer_vision/ultralytics-pose.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,37 @@
import os

os.environ["YOLO_VERBOSE"] = "false"


from io import BytesIO

from numpy import asarray
from PIL import Image
from ultralytics import YOLO

from datachain import C, DataChain, File
from datachain import DataChain, File
from datachain.model.ultralytics import YoloPoses
from datachain.toolkit.ultralytics import visualize_yolo

OUTPUT_DIR = "output/pose"


def process_poses(yolo: YOLO, file: File) -> YoloPoses:
results = yolo(Image.open(BytesIO(file.read())))
return YoloPoses.from_results(results)
# read image
img = Image.open(BytesIO(file.read()))

# detect objects using YOLO model
results = yolo(img, verbose=False)
# convert results to YoloPoses signal
signal = YoloPoses.from_results(results)

# visualize results
img2 = visualize_yolo(asarray(img), signal)
img2.save(f"{OUTPUT_DIR}/{file.get_file_stem()}.jpg")

return signal


os.makedirs(OUTPUT_DIR, exist_ok=True)

(
DataChain.from_storage("gs://datachain-demo/openimages-v6-test-jsonpairs/")
.filter(C("file.path").glob("*.jpg"))
DataChain.from_storage("gs://datachain-demo/coco2017/images")
.limit(20)
.setup(yolo=lambda: YOLO("yolo11n-pose.pt"))
.map(poses=process_poses)
Expand Down
30 changes: 21 additions & 9 deletions examples/computer_vision/ultralytics-segment.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,37 @@
import os

os.environ["YOLO_VERBOSE"] = "false"


from io import BytesIO

from numpy import asarray
from PIL import Image
from ultralytics import YOLO

from datachain import C, DataChain, File
from datachain import DataChain, File
from datachain.model.ultralytics import YoloSegments
from datachain.toolkit.ultralytics import visualize_yolo

OUTPUT_DIR = "output/segment"


def process_segments(yolo: YOLO, file: File) -> YoloSegments:
results = yolo(Image.open(BytesIO(file.read())))
return YoloSegments.from_results(results)
# read image
img = Image.open(BytesIO(file.read()))

# detect objects using YOLO model
results = yolo(img, verbose=False)
# convert results to YoloSegments signal
signal = YoloSegments.from_results(results)

# visualize results
img2 = visualize_yolo(asarray(img), signal)
img2.save(f"{OUTPUT_DIR}/{file.get_file_stem()}.jpg")

return signal


os.makedirs(OUTPUT_DIR, exist_ok=True)

(
DataChain.from_storage("gs://datachain-demo/openimages-v6-test-jsonpairs/")
.filter(C("file.path").glob("*.jpg"))
DataChain.from_storage("gs://datachain-demo/coco2017/images")
.limit(20)
.setup(yolo=lambda: YOLO("yolo11n-seg.pt"))
.map(segments=process_segments)
Expand Down
159 changes: 159 additions & 0 deletions src/datachain/toolkit/ultralytics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from typing import Union

Check warning on line 1 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L1

Added line #L1 was not covered by tests

import numpy as np
import torch
from PIL import Image
from ultralytics.data.utils import polygon2mask
from ultralytics.engine.results import Results

Check warning on line 7 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L3-L7

Added lines #L3 - L7 were not covered by tests

from datachain.model.ultralytics.bbox import YoloBBox, YoloBBoxes
from datachain.model.ultralytics.pose import YoloPose, YoloPoses
from datachain.model.ultralytics.segment import YoloSegment, YoloSegments

Check warning on line 11 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L9-L11

Added lines #L9 - L11 were not covered by tests

YoloSignal = Union[YoloBBox, YoloBBoxes, YoloPose, YoloPoses, YoloSegment, YoloSegments]

Check warning on line 13 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L13

Added line #L13 was not covered by tests


def _signal_to_results(img: np.ndarray, signal: YoloSignal) -> Results:

Check warning on line 16 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L16

Added line #L16 was not covered by tests
"""Convert a YOLO signal to Ultralytics Results."""
# Convert RGB to BGR
if img.ndim == 3 and img.shape[2] == 3:
bgr_array = img[:, :, ::-1]

Check warning on line 20 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L20

Added line #L20 was not covered by tests
else:
# If the image is not RGB (e.g., grayscale or RGBA), use as is
bgr_array = img

Check warning on line 23 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L23

Added line #L23 was not covered by tests

names = {}
boxes_list = []
keypoints_list = []
masks_list = []

Check warning on line 28 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L25-L28

Added lines #L25 - L28 were not covered by tests

# Get the boxes, keypoints, and masks from the signal
if isinstance(signal, YoloBBox):
names[signal.cls] = signal.name
boxes_list.append(

Check warning on line 33 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L32-L33

Added lines #L32 - L33 were not covered by tests
torch.tensor([[*signal.box.coords, signal.confidence, signal.cls]])
)
elif isinstance(signal, YoloBBoxes):
for i, _ in enumerate(signal.cls):
names[signal.cls[i]] = signal.name[i]
boxes_list.append(

Check warning on line 39 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L38-L39

Added lines #L38 - L39 were not covered by tests
torch.tensor(
[[*signal.box[i].coords, signal.confidence[i], signal.cls[i]]]
)
)
elif isinstance(signal, YoloPose):
names[signal.cls] = signal.name
boxes_list.append(

Check warning on line 46 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L45-L46

Added lines #L45 - L46 were not covered by tests
torch.tensor([[*signal.box.coords, signal.confidence, signal.cls]])
)
keypoints_list.append(

Check warning on line 49 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L49

Added line #L49 was not covered by tests
torch.tensor([list(zip(signal.pose.x, signal.pose.y, signal.pose.visible))])
)
elif isinstance(signal, YoloPoses):
for i, _ in enumerate(signal.cls):
names[signal.cls[i]] = signal.name[i]
boxes_list.append(

Check warning on line 55 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L54-L55

Added lines #L54 - L55 were not covered by tests
torch.tensor(
[[*signal.box[i].coords, signal.confidence[i], signal.cls[i]]]
)
)
keypoints_list.append(

Check warning on line 60 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L60

Added line #L60 was not covered by tests
torch.tensor(
[
list(
zip(
signal.pose[i].x,
signal.pose[i].y,
signal.pose[i].visible,
)
)
]
)
)
elif isinstance(signal, YoloSegment):
names[signal.cls] = signal.name
boxes_list.append(

Check warning on line 75 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L74-L75

Added lines #L74 - L75 were not covered by tests
torch.tensor([[*signal.box.coords, signal.confidence, signal.cls]])
)
masks_list.append(

Check warning on line 78 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L78

Added line #L78 was not covered by tests
torch.tensor(
polygon2mask(
img.shape[:2],
[np.asarray(list(zip(signal.segment.x, signal.segment.y)))],
)
)
)
elif isinstance(signal, YoloSegments):
for i, _ in enumerate(signal.cls):
names[signal.cls[i]] = signal.name[i]
boxes_list.append(

Check warning on line 89 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L88-L89

Added lines #L88 - L89 were not covered by tests
torch.tensor(
[[*signal.box[i].coords, signal.confidence[i], signal.cls[i]]]
)
)
masks_list.append(

Check warning on line 94 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L94

Added line #L94 was not covered by tests
torch.tensor(
polygon2mask(
img.shape[:2],
[
np.asarray(
list(zip(signal.segment[i].x, signal.segment[i].y))
)
],
)
)
)

boxes = torch.cat(boxes_list, dim=0) if len(boxes_list) > 0 else None
keypoints = torch.cat(keypoints_list, dim=0) if len(keypoints_list) > 0 else None
masks = torch.stack(masks_list) if len(masks_list) > 0 else None

Check warning on line 109 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L107-L109

Added lines #L107 - L109 were not covered by tests

return Results(

Check warning on line 111 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L111

Added line #L111 was not covered by tests
bgr_array,
path="",
names=names,
boxes=boxes,
keypoints=keypoints,
masks=masks,
)


def visualize_yolo(

Check warning on line 121 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L121

Added line #L121 was not covered by tests
img: np.ndarray,
signal: YoloSignal,
scale: float = 1.0,
line_width: int = 1,
font_size: int = 20,
kpt_radius: int = 3,
) -> Image.Image:
"""
Visualize signals detected by YOLO.

Args:
image (ndarray): The image to visualize as a NumPy array.
signal: The signal detected by YOLO. Possible signals are YoloBBox, YoloBBoxes,
YoloPose, YoloPoses, YoloSegment, and YoloSegments.
scale (float): The scale factor for the image. Default is 1.0.
line_width (int): The line width for drawing boxes and lines. Default is 1.
font_size (int): The font size for text. Default is 20.
kpt_radius (int): The radius for drawing keypoints. Default is 3.

Returns:
PIL.Image.Image: The image with the detected signals visualized.
"""
results = _signal_to_results(img, signal)

Check warning on line 144 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L144

Added line #L144 was not covered by tests

im_bgr = results.plot(

Check warning on line 146 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L146

Added line #L146 was not covered by tests
line_width=line_width,
font_size=font_size,
kpt_radius=kpt_radius,
)

im_rgb = Image.fromarray(im_bgr[..., ::-1])

Check warning on line 152 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L152

Added line #L152 was not covered by tests

if scale != 1.0:
orig_height, orig_width = results.orig_shape
new_size = (int(orig_width * scale), int(orig_height * scale))
im_rgb = im_rgb.resize(new_size, Image.Resampling.LANCZOS)

Check warning on line 157 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L155-L157

Added lines #L155 - L157 were not covered by tests

return im_rgb

Check warning on line 159 in src/datachain/toolkit/ultralytics.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/toolkit/ultralytics.py#L159

Added line #L159 was not covered by tests
Loading