Skip to content

Commit f8fb7ab

Browse files
committed
Add support for probing SVG images
1 parent 7671d87 commit f8fb7ab

File tree

2 files changed

+42
-11
lines changed

2 files changed

+42
-11
lines changed

src/zimscraperlib/image/probing.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import colorthief
1313
import PIL.Image
1414

15+
from zimscraperlib.filesystem import get_content_mimetype, get_file_mimetype
16+
1517

1618
def get_colors(
1719
src: pathlib.Path, *, use_palette: bool | None = True
@@ -59,8 +61,23 @@ def format_for(
5961
) -> str | None:
6062
"""Pillow format of a given filename, either Pillow-detected or from suffix"""
6163
if not from_suffix:
62-
with PIL.Image.open(src) as img:
63-
return img.format
64+
try:
65+
with PIL.Image.open(src) as img:
66+
return img.format
67+
except PIL.UnidentifiedImageError:
68+
# Fallback based on mimetype for SVG which are not supported by PIL
69+
if (
70+
isinstance(src, pathlib.Path)
71+
and get_file_mimetype(src) == "image/svg+xml"
72+
):
73+
return "SVG"
74+
elif (
75+
isinstance(src, io.BytesIO)
76+
and get_content_mimetype(src.getvalue()) == "image/svg+xml"
77+
):
78+
return "SVG"
79+
else: # pragma: no cover
80+
raise
6481

6582
if not isinstance(src, pathlib.Path):
6683
raise ValueError(
@@ -70,8 +87,11 @@ def format_for(
7087
from PIL.Image import EXTENSION as PIL_FMT_EXTENSION
7188
from PIL.Image import init as init_pil
7289

73-
init_pil()
74-
return PIL_FMT_EXTENSION[src.suffix] if src.suffix in PIL_FMT_EXTENSION else None
90+
init_pil() # populate the PIL_FMT_EXTENSION dictionary
91+
92+
known_extensions = {".svg": "SVG"}
93+
known_extensions.update(PIL_FMT_EXTENSION)
94+
return known_extensions[src.suffix] if src.suffix in known_extensions else None
7595

7696

7797
def is_valid_image(

tests/image/test_image.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,15 @@ def get_src_dst(
6868
jpg_image: pathlib.Path | None = None,
6969
gif_image: pathlib.Path | None = None,
7070
webp_image: pathlib.Path | None = None,
71+
svg_image: pathlib.Path | None = None,
7172
) -> tuple[pathlib.Path, pathlib.Path]:
72-
options = {"png": png_image, "jpg": jpg_image, "webp": webp_image, "gif": gif_image}
73+
options = {
74+
"png": png_image,
75+
"jpg": jpg_image,
76+
"webp": webp_image,
77+
"gif": gif_image,
78+
"svg": svg_image,
79+
}
7380
if fmt not in options:
7481
raise LookupError(f"Unsupported fmt passed: {fmt}")
7582
src = options[fmt]
@@ -616,10 +623,10 @@ def test_ensure_matches(webp_image):
616623

617624
@pytest.mark.parametrize(
618625
"fmt,expected",
619-
[("png", "PNG"), ("jpg", "JPEG"), ("gif", "GIF"), ("webp", "WEBP")],
626+
[("png", "PNG"), ("jpg", "JPEG"), ("gif", "GIF"), ("webp", "WEBP"), ("svg", "SVG")],
620627
)
621628
def test_format_for_real_images_suffix(
622-
png_image, jpg_image, gif_image, webp_image, tmp_path, fmt, expected
629+
png_image, jpg_image, gif_image, webp_image, svg_image, tmp_path, fmt, expected
623630
):
624631
src, _ = get_src_dst(
625632
tmp_path,
@@ -628,16 +635,17 @@ def test_format_for_real_images_suffix(
628635
jpg_image=jpg_image,
629636
gif_image=gif_image,
630637
webp_image=webp_image,
638+
svg_image=svg_image,
631639
)
632640
assert format_for(src) == expected
633641

634642

635643
@pytest.mark.parametrize(
636644
"fmt,expected",
637-
[("png", "PNG"), ("jpg", "JPEG"), ("gif", "GIF"), ("webp", "WEBP")],
645+
[("png", "PNG"), ("jpg", "JPEG"), ("gif", "GIF"), ("webp", "WEBP"), ("svg", "SVG")],
638646
)
639647
def test_format_for_real_images_content_path(
640-
png_image, jpg_image, gif_image, webp_image, tmp_path, fmt, expected
648+
png_image, jpg_image, gif_image, webp_image, svg_image, tmp_path, fmt, expected
641649
):
642650
src, _ = get_src_dst(
643651
tmp_path,
@@ -646,16 +654,17 @@ def test_format_for_real_images_content_path(
646654
jpg_image=jpg_image,
647655
gif_image=gif_image,
648656
webp_image=webp_image,
657+
svg_image=svg_image,
649658
)
650659
assert format_for(src, from_suffix=False) == expected
651660

652661

653662
@pytest.mark.parametrize(
654663
"fmt,expected",
655-
[("png", "PNG"), ("jpg", "JPEG"), ("gif", "GIF"), ("webp", "WEBP")],
664+
[("png", "PNG"), ("jpg", "JPEG"), ("gif", "GIF"), ("webp", "WEBP"), ("svg", "SVG")],
656665
)
657666
def test_format_for_real_images_content_bytes(
658-
png_image, jpg_image, gif_image, webp_image, tmp_path, fmt, expected
667+
png_image, jpg_image, gif_image, webp_image, svg_image, tmp_path, fmt, expected
659668
):
660669
src, _ = get_src_dst(
661670
tmp_path,
@@ -664,6 +673,7 @@ def test_format_for_real_images_content_bytes(
664673
jpg_image=jpg_image,
665674
gif_image=gif_image,
666675
webp_image=webp_image,
676+
svg_image=svg_image,
667677
)
668678
assert format_for(io.BytesIO(src.read_bytes()), from_suffix=False) == expected
669679

@@ -675,6 +685,7 @@ def test_format_for_real_images_content_bytes(
675685
("image.jpg", "JPEG"),
676686
("image.gif", "GIF"),
677687
("image.webp", "WEBP"),
688+
("image.svg", "SVG"),
678689
("image.raster", None),
679690
],
680691
)

0 commit comments

Comments
 (0)