Skip to content

Chunkwise image loader #279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
129 changes: 129 additions & 0 deletions src/spatialdata_io/readers/_utils/_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from collections.abc import Callable
from typing import Any

import dask.array as da
import numpy as np
from dask import delayed
from numpy.typing import NDArray


def _compute_chunk_sizes_positions(size: int, chunk: int, min_coord: int) -> tuple[NDArray[np.int_], NDArray[np.int_]]:
"""Calculate chunk sizes and positions for a given dimension and chunk size"""
# All chunks have the same size except for the last one
positions = np.arange(min_coord, min_coord + size, chunk)
lengths = np.full_like(positions, chunk, dtype=int)

if positions[-1] + chunk > size + min_coord:
lengths[-1] = size + min_coord - positions[-1]
Comment on lines +13 to +17
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
positions = np.arange(min_coord, min_coord + size, chunk)
lengths = np.full_like(positions, chunk, dtype=int)
if positions[-1] + chunk > size + min_coord:
lengths[-1] = size + min_coord - positions[-1]
positions = np.arange(min_coord, min_coord + size, chunk)
lengths = np.minimum(chunk, min_coord + size - positions)

Think this is the equivalent two liner:) but just a bit nitpicky


return positions, lengths


def _compute_chunks(
dimensions: tuple[int, int],
chunk_size: tuple[int, int],
min_coordinates: tuple[int, int] = (0, 0),
) -> NDArray[np.int_]:
"""Create all chunk specs for a given image and chunk size.

Creates specifications (x, y, width, height) with (x, y) being the upper left corner
of chunks of size chunk_size. Chunks at the edges correspond to the remainder of
chunk size and dimensions

Parameters
----------
dimensions : tuple[int, int]
Comment on lines +22 to +35
Copy link
Collaborator

@melonora melonora Mar 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _compute_chunks(
dimensions: tuple[int, int],
chunk_size: tuple[int, int],
min_coordinates: tuple[int, int] = (0, 0),
) -> NDArray[np.int_]:
"""Create all chunk specs for a given image and chunk size.
Creates specifications (x, y, width, height) with (x, y) being the upper left corner
of chunks of size chunk_size. Chunks at the edges correspond to the remainder of
chunk size and dimensions
Parameters
----------
dimensions : tuple[int, int]
def _compute_chunks(
shape: tuple[int, int],
chunk_size: tuple[int, int],
min_coordinates: tuple[int, int] = (0, 0),
) -> NDArray[np.int_]:
"""Create all chunk specs for a given image and chunk size.
Creates specifications (x, y, width, height) with (x, y) being the upper left corner
of chunks of size chunk_size. Chunks at the edges correspond to the remainder of
chunk size and dimensions
Parameters
----------
shape : tuple[int, int]

Just to stick to standard numpy / array api conventions:) Dimensions could be interpreted as TCZYX.

Size of the image in (width, height).
chunk_size : tuple[int, int]
Size of individual tiles in (width, height).
min_coordinates : tuple[int, int], optional
Minimum coordinates (x, y) in the image, defaults to (0, 0).

Returns
-------
np.ndarray
Array of shape (n_tiles_x, n_tiles_y, 4). Each entry defines a tile
as (x, y, width, height).
"""
x_positions, widths = _compute_chunk_sizes_positions(dimensions[1], chunk_size[1], min_coord=min_coordinates[1])
y_positions, heights = _compute_chunk_sizes_positions(dimensions[0], chunk_size[0], min_coord=min_coordinates[0])

# Generate the tiles
tiles = np.array(
[
[[x, y, w, h] for x, w in zip(x_positions, widths, strict=True)]
for y, h in zip(y_positions, heights, strict=True)
],
dtype=int,
)
return tiles


def _read_chunks(
func: Callable[..., NDArray[np.int_]],
slide: Any,
coords: NDArray[np.int_],
n_channel: int,
dtype: np.number,
**func_kwargs: Any,
) -> list[list[da.array]]:
"""Abstract method to tile a large microscopy image.

Parameters
----------
func
Function to retrieve a rectangular tile from the slide image. Must take the
arguments:

- slide Full slide image
- x0: x (col) coordinate of upper left corner of chunk
- y0: y (row) coordinate of upper left corner of chunk
- width: Width of chunk
- height: Height of chunk

and should return the chunk as numpy array of shape (c, y, x)
slide
Slide image in lazyly loaded format compatible with func
coords
Coordinates of the upper left corner of the image in format (n_row_x, n_row_y, 4)
where the last dimension defines the rectangular tile in format (x, y, width, height).
n_row_x represents the number of chunks in x dimension and n_row_y the number of chunks
in y dimension.
n_channel
Number of channels in array
dtype
Data type of image
func_kwargs
Additional keyword arguments passed to func

Returns
-------
list[list[da.array]]
List (length: n_row_x) of lists (length: n_row_y) of chunks.
Represents all chunks of the full image.
"""
func_kwargs = func_kwargs if func_kwargs else {}

# Collect each delayed chunk as item in list of list
# Inner list becomes dim=-1 (cols/x)
# Outer list becomes dim=-2 (rows/y)
# see dask.array.block
chunks = [
[
da.from_delayed(
delayed(func)(
slide,
x0=coords[y, x, 0],
y0=coords[y, x, 1],
width=coords[y, x, 2],
height=coords[y, x, 3],
**func_kwargs,
),
dtype=dtype,
shape=(n_channel, *coords[y, x, [3, 2]]),
)
for x in range(coords.shape[1])
]
for y in range(coords.shape[0])
]
return chunks
71 changes: 65 additions & 6 deletions src/spatialdata_io/readers/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,34 @@
import warnings
from collections.abc import Sequence
from pathlib import Path
from typing import Protocol, TypeVar

import dask.array as da
import numpy as np
from dask_image.imread import imread
from geopandas import GeoDataFrame
from numpy.typing import NDArray
from spatialdata._docs import docstring_parameter
from spatialdata.models import Image2DModel, ShapesModel
from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM
from spatialdata.transformations import Identity
from tifffile import memmap as tiffmmemap
from xarray import DataArray

from ._utils._image import _compute_chunks, _read_chunks

VALID_IMAGE_TYPES = [".tif", ".tiff", ".png", ".jpg", ".jpeg"]
VALID_SHAPE_TYPES = [".geojson"]
DEFAULT_CHUNKSIZE = (1000, 1000)

__all__ = ["generic", "geojson", "image", "VALID_IMAGE_TYPES", "VALID_SHAPE_TYPES"]

T = TypeVar("T", bound=np.generic) # Restrict to NumPy scalar types


class DaskArray(Protocol[T]):
dtype: np.dtype[T]


@docstring_parameter(
valid_image_types=", ".join(VALID_IMAGE_TYPES),
Expand Down Expand Up @@ -68,11 +81,57 @@ def geojson(input: Path, coordinate_system: str) -> GeoDataFrame:
return ShapesModel.parse(input, transformations={coordinate_system: Identity()})


def _tiff_to_chunks(input: Path, axes_dim_mapping: dict[str, int]) -> list[list[DaskArray[np.int_]]]:
"""Chunkwise reader for tiff files.

Parameters
----------
input
Path to image
axes_dim_mapping
Mapping between dimension name (x, y, c) and index

Returns
-------
list[list[DaskArray]]
"""
# Lazy file reader
slide = tiffmmemap(input)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the tiff.memmap might not always work for example with compression or tiling so I would add a try, except clause.


# Transpose to cyx order
slide = np.transpose(slide, (axes_dim_mapping["c"], axes_dim_mapping["y"], axes_dim_mapping["x"]))

# Get dimensions in (x, y)
slide_dimensions = slide.shape[2], slide.shape[1]

# Get number of channels (c)
n_channel = slide.shape[0]

# Compute chunk coords
chunk_coords = _compute_chunks(slide_dimensions, chunk_size=DEFAULT_CHUNKSIZE, min_coordinates=(0, 0))

# Define reader func
def _reader_func(slide: NDArray[np.int_], x0: int, y0: int, width: int, height: int) -> NDArray[np.int_]:
return np.array(slide[:, y0 : y0 + height, x0 : x0 + width])

return _read_chunks(_reader_func, slide, coords=chunk_coords, n_channel=n_channel, dtype=slide.dtype)


def image(input: Path, data_axes: Sequence[str], coordinate_system: str) -> DataArray:
"""Reads an image file and returns a parsed Image2D spatial element"""
# this function is just a draft, the more general one will be available when
# https://github.com/scverse/spatialdata-io/pull/234 is merged
image = imread(input)
if len(image.shape) == len(data_axes) + 1 and image.shape[0] == 1:
image = np.squeeze(image, axis=0)
"""Reads an image file and returns a parsed Image2DModel"""
# Map passed data axes to position of dimension
axes_dim_mapping = {axes: ndim for ndim, axes in enumerate(data_axes)}

if input.suffix in [".tiff", ".tif"]:
chunks = _tiff_to_chunks(input, axes_dim_mapping=axes_dim_mapping)
image = da.block(chunks, allow_unknown_chunksizes=True)

elif input.suffix in [".png", ".jpg", ".jpeg"]:
image = imread(input)
if len(image.shape) == len(data_axes) + 1 and image.shape[0] == 1:
image = np.squeeze(image, axis=0)

else:
raise NotImplementedError(f"File format {input.suffix} not implemented")

return Image2DModel.parse(image, dims=data_axes, transformations={coordinate_system: Identity()})
65 changes: 65 additions & 0 deletions tests/readers/test_utils_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np
import pytest
from numpy.typing import NDArray

from spatialdata_io.readers._utils._image import (
_compute_chunk_sizes_positions,
_compute_chunks,
)


@pytest.mark.parametrize(
("size", "chunk", "min_coordinate", "positions", "lengths"),
[
(300, 100, 0, np.array([0, 100, 200]), np.array([100, 100, 100])),
(300, 200, 0, np.array([0, 200]), np.array([200, 100])),
(300, 100, -100, np.array([-100, 0, 100]), np.array([100, 100, 100])),
(300, 200, -100, np.array([-100, 100]), np.array([200, 100])),
],
)
def test_compute_chunk_sizes_positions(
size: int,
chunk: int,
min_coordinate: int,
positions: NDArray[np.number],
lengths: NDArray[np.number],
) -> None:
computed_positions, computed_lengths = _compute_chunk_sizes_positions(size, chunk, min_coordinate)
assert (positions == computed_positions).all()
assert (lengths == computed_lengths).all()


@pytest.mark.parametrize(
("dimensions", "chunk_size", "min_coordinates", "result"),
[
# Regular grid 2x2
(
(2, 2),
(1, 1),
(0, 0),
np.array([[[0, 0, 1, 1], [1, 0, 1, 1]], [[0, 1, 1, 1], [1, 1, 1, 1]]]),
),
# Different tile sizes
(
(3, 3),
(2, 2),
(0, 0),
np.array([[[0, 0, 2, 2], [2, 0, 1, 2]], [[0, 2, 2, 1], [2, 2, 1, 1]]]),
),
(
(2, 2),
(1, 1),
(-1, 0),
np.array([[[0, -1, 1, 1], [1, -1, 1, 1]], [[0, 0, 1, 1], [1, 0, 1, 1]]]),
),
],
)
def test_compute_chunks(
dimensions: tuple[int, int],
chunk_size: tuple[int, int],
min_coordinates: tuple[int, int],
result: NDArray[np.number],
) -> None:
tiles = _compute_chunks(dimensions, chunk_size, min_coordinates)

assert (tiles == result).all()
29 changes: 29 additions & 0 deletions tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
from PIL import Image
from spatialdata import SpatialData
from spatialdata.datasets import blobs
from tifffile import imread as tiffread
from tifffile import imwrite as tiffwrite

from spatialdata_io.__main__ import read_generic_wrapper
from spatialdata_io.converters.generic_to_zarr import generic_to_zarr
from spatialdata_io.readers.generic import image


@contextmanager
Expand All @@ -33,6 +36,32 @@ def save_temp_files() -> Generator[tuple[Path, Path, Path], None, None]:
yield jpg_path, geojson_path, Path(tmpdir)


@pytest.fixture(scope="module", params=[("c", "y", "x"), ("x", "y", "c")])
def save_tiff_files(
request: pytest.FixtureRequest,
) -> Generator[tuple[Path, tuple[str], Path], None, None]:
with tempfile.TemporaryDirectory() as tmpdir:
axes = request.param
sdata = blobs()
# save the image as tiff
x = sdata["blobs_image"].transpose(*axes).data.compute()
img = np.clip(x * 255, 0, 255).astype(np.uint8)

tiff_path = Path(tmpdir) / "blobs_image.tiff"
tiffwrite(tiff_path, img)

yield tiff_path, axes, Path(tmpdir)


def test_read_tiff(save_tiff_files: tuple[Path, tuple[str], Path]) -> None:
tiff_path, axes, _ = save_tiff_files
img = image(tiff_path, data_axes=axes, coordinate_system="global")

reference = tiffread(tiff_path)

assert (img.compute() == reference).all()


@pytest.mark.parametrize("cli", [True, False])
@pytest.mark.parametrize("element_name", [None, "test_element"])
def test_read_generic_image(runner: CliRunner, cli: bool, element_name: str | None) -> None:
Expand Down
Loading