Skip to content

Commit a6e226c

Browse files
committed
wip new lazy to_multiscale() for labels
1 parent e88c24e commit a6e226c

File tree

5 files changed

+199
-52
lines changed

5 files changed

+199
-52
lines changed

src/spatialdata/_io/io_raster.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ome_zarr.writer import write_labels as write_labels_ngff
1414
from ome_zarr.writer import write_multiscale as write_multiscale_ngff
1515
from ome_zarr.writer import write_multiscale_labels as write_multiscale_labels_ngff
16-
from xarray import DataArray, Dataset, DataTree
16+
from xarray import DataArray, DataTree
1717

1818
from spatialdata._io._utils import (
1919
_get_transformations_from_ngff_dict,
@@ -27,6 +27,7 @@
2727
from spatialdata._utils import get_pyramid_levels
2828
from spatialdata.models._utils import get_channel_names
2929
from spatialdata.models.models import ATTRS_KEY
30+
from spatialdata.models.pyramids_utils import dask_arrays_to_datatree
3031
from spatialdata.transformations._utils import (
3132
_get_transformations,
3233
_get_transformations_xarray,
@@ -91,20 +92,8 @@ def _read_multiscale(
9192
channels = [d["label"] for d in omero_metadata["channels"]]
9293
axes = [i["name"] for i in node.metadata["axes"]]
9394
if len(datasets) > 1:
94-
multiscale_image = {}
95-
for i, d in enumerate(datasets):
96-
data = node.load(Multiscales).array(resolution=d)
97-
multiscale_image[f"scale{i}"] = Dataset(
98-
{
99-
"image": DataArray(
100-
data,
101-
name="image",
102-
dims=axes,
103-
coords={"c": channels} if channels is not None else {},
104-
)
105-
}
106-
)
107-
msi = DataTree.from_dict(multiscale_image)
95+
arrays = [node.load(Multiscales).array(resolution=d) for d in datasets]
96+
msi = dask_arrays_to_datatree(arrays, dims=axes, channels=channels)
10897
_set_transformations(msi, transformations)
10998
return compute_coordinates(msi)
11099

src/spatialdata/datasets.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,15 @@
2020
from spatialdata._core.query.relational_query import get_element_instances
2121
from spatialdata._core.spatialdata import SpatialData
2222
from spatialdata._types import ArrayLike
23-
from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, ShapesModel, TableModel
23+
from spatialdata.models import (
24+
Image2DModel,
25+
Image3DModel,
26+
Labels2DModel,
27+
Labels3DModel,
28+
PointsModel,
29+
ShapesModel,
30+
TableModel,
31+
)
2432
from spatialdata.transformations import Identity
2533

2634
__all__ = ["blobs", "raccoon"]
@@ -172,37 +180,47 @@ def _image_blobs(
172180
n_channels: int = 3,
173181
c_coords: str | list[str] | None = None,
174182
multiscale: bool = False,
183+
ndim: int = 2,
175184
) -> DataArray | DataTree:
176185
masks = []
177186
for i in range(n_channels):
178-
mask = self._generate_blobs(length=length, seed=i)
187+
mask = self._generate_blobs(length=length, seed=i, ndim=ndim)
179188
mask = (mask - mask.min()) / np.ptp(mask)
180189
masks.append(mask)
181190

182191
x = np.stack(masks, axis=0)
183-
dims = ["c", "y", "x"]
192+
if ndim == 2:
193+
dims = ["c", "y", "x"]
194+
model = Image2DModel
195+
else:
196+
dims = ["c", "z", "y", "x"]
197+
model = Image3DModel
184198
if not multiscale:
185-
return Image2DModel.parse(x, transformations=transformations, dims=dims, c_coords=c_coords)
186-
return Image2DModel.parse(
187-
x, transformations=transformations, dims=dims, c_coords=c_coords, scale_factors=[2, 2]
188-
)
199+
return model.parse(x, transformations=transformations, dims=dims, c_coords=c_coords)
200+
return model.parse(x, transformations=transformations, dims=dims, c_coords=c_coords, scale_factors=[2, 2])
189201

190202
def _labels_blobs(
191-
self, transformations: dict[str, Any] | None = None, length: int = 512, multiscale: bool = False
203+
self,
204+
transformations: dict[str, Any] | None = None,
205+
length: int = 512,
206+
multiscale: bool = False,
207+
ndim: int = 2,
192208
) -> DataArray | DataTree:
193-
"""Create a 2D labels."""
209+
"""Create labels in 2D or 3D."""
194210
from scipy.ndimage import watershed_ift
195211

196212
# from skimage
197-
mask = self._generate_blobs(length=length)
213+
mask = self._generate_blobs(length=length, ndim=ndim)
198214
threshold = np.percentile(mask, 100 * (1 - 0.3))
199215
inputs = np.logical_not(mask < threshold).astype(np.uint8)
200216
# use watershed from scipy
201-
xm, ym = np.ogrid[0:length:10, 0:length:10]
217+
grid = np.ogrid[tuple(slice(0, length, 10) for _ in range(ndim))]
202218
markers = np.zeros_like(inputs).astype(np.int16)
203-
markers[xm, ym] = np.arange(xm.size * ym.size).reshape((xm.size, ym.size))
219+
grid_shape = tuple(g.size for g in grid)
220+
markers[tuple(grid)] = np.arange(np.prod(grid_shape)).reshape(grid_shape)
204221
out = watershed_ift(inputs, markers)
205-
out[xm, ym] = out[xm - 1, ym - 1] # remove the isolate seeds
222+
shifted = tuple(g - 1 for g in grid)
223+
out[tuple(grid)] = out[tuple(shifted)] # remove the isolated seeds
206224
# reindex by frequency
207225
val, counts = np.unique(out, return_counts=True)
208226
sorted_idx = np.argsort(counts)
@@ -211,20 +229,25 @@ def _labels_blobs(
211229
out[out == val[idx]] = 0
212230
else:
213231
out[out == val[idx]] = i
214-
dims = ["y", "x"]
232+
if ndim == 2:
233+
dims = ["y", "x"]
234+
model = Labels2DModel
235+
else:
236+
dims = ["z", "y", "x"]
237+
model = Labels3DModel
215238
if not multiscale:
216-
return Labels2DModel.parse(out, transformations=transformations, dims=dims)
217-
return Labels2DModel.parse(out, transformations=transformations, dims=dims, scale_factors=[2, 2])
239+
return model.parse(out, transformations=transformations, dims=dims)
240+
return model.parse(out, transformations=transformations, dims=dims, scale_factors=[2, 2])
218241

219-
def _generate_blobs(self, length: int = 512, seed: int | None = None) -> ArrayLike:
242+
def _generate_blobs(self, length: int = 512, seed: int | None = None, ndim: int = 2) -> ArrayLike:
220243
from scipy.ndimage import gaussian_filter
221244

222245
rng = default_rng(42) if seed is None else default_rng(seed)
223246
# from skimage
224-
shape = tuple([length] * 2)
247+
shape = (length,) * ndim
225248
mask = np.zeros(shape)
226-
n_pts = max(int(1.0 / 0.1) ** 2, 1)
227-
points = (length * rng.random((2, n_pts))).astype(int)
249+
n_pts = max(int(1.0 / 0.1) ** ndim, 1)
250+
points = (length * rng.random((ndim, n_pts))).astype(int)
228251
mask[tuple(indices for indices in points)] = 1
229252
mask = gaussian_filter(mask, sigma=0.25 * length * 0.1)
230253
assert isinstance(mask, np.ndarray)

src/spatialdata/models/models.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from dask.array.core import from_array
1515
from dask.dataframe import DataFrame as DaskDataFrame
1616
from geopandas import GeoDataFrame, GeoSeries
17-
from multiscale_spatial_image import to_multiscale
17+
from multiscale_spatial_image import to_multiscale as to_multiscale_msi
1818
from multiscale_spatial_image.to_multiscale.to_multiscale import Methods
1919
from pandas import CategoricalDtype
2020
from shapely._geometry import GeometryType
@@ -38,17 +38,15 @@
3838
_validate_mapping_to_coordinate_system_type,
3939
convert_region_column_to_categorical,
4040
)
41+
from spatialdata.models.pyramids_utils import Chunks_t, ScaleFactors_t
42+
from spatialdata.models.pyramids_utils import to_multiscale as to_multiscale_ozp # ozp -> ome-zarr-py
4143
from spatialdata.transformations._utils import (
4244
_get_transformations,
4345
_set_transformations,
4446
compute_coordinates,
4547
)
4648
from spatialdata.transformations.transformations import Identity
4749

48-
# Types
49-
Chunks_t: TypeAlias = int | tuple[int, ...] | tuple[tuple[int, ...], ...] | Mapping[Any, None | int | tuple[int, ...]]
50-
ScaleFactors_t = Sequence[dict[str, int] | int]
51-
5250
ATTRS_KEY = "spatialdata_attrs"
5351

5452

@@ -225,12 +223,19 @@ def parse(
225223
chunks = {dim: chunks[index] for index, dim in enumerate(data.dims)}
226224
if isinstance(chunks, float):
227225
chunks = {dim: chunks for index, dim in data.dims}
228-
data = to_multiscale(
229-
data,
230-
scale_factors=scale_factors,
231-
method=method,
232-
chunks=chunks,
233-
)
226+
if method is not None:
227+
data = to_multiscale_msi(
228+
data,
229+
scale_factors=scale_factors,
230+
method=method,
231+
chunks=chunks,
232+
)
233+
else:
234+
data = to_multiscale_ozp(
235+
data,
236+
scale_factors=scale_factors,
237+
chunks=chunks,
238+
)
234239
_parse_transformations(data, parsed_transform)
235240
else:
236241
# Chunk single scale images
@@ -375,9 +380,6 @@ def parse( # noqa: D102
375380
) -> DataArray | DataTree:
376381
if kwargs.get("c_coords") is not None:
377382
raise ValueError("`c_coords` is not supported for labels")
378-
if kwargs.get("scale_factors") is not None and kwargs.get("method") is None:
379-
# Override default scaling method to preserve labels
380-
kwargs["method"] = Methods.DASK_IMAGE_NEAREST
381383
return super().parse(*args, **kwargs)
382384

383385

@@ -388,9 +390,6 @@ class Labels3DModel(RasterSchema):
388390
def parse(self, *args: Any, **kwargs: Any) -> DataArray | DataTree: # noqa: D102
389391
if kwargs.get("c_coords") is not None:
390392
raise ValueError("`c_coords` is not supported for labels")
391-
if kwargs.get("scale_factors") is not None and kwargs.get("method") is None:
392-
# Override default scaling method to preserve labels
393-
kwargs["method"] = Methods.DASK_IMAGE_NEAREST
394393
return super().parse(*args, **kwargs)
395394

396395

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from collections.abc import Mapping, Sequence
2+
from typing import Any, TypeAlias
3+
4+
import dask.array as da
5+
from ome_zarr.dask_utils import resize
6+
from xarray import DataArray, Dataset, DataTree
7+
8+
Chunks_t: TypeAlias = int | tuple[int, ...] | tuple[tuple[int, ...], ...] | Mapping[Any, None | int | tuple[int, ...]]
9+
ScaleFactors_t = Sequence[dict[str, int] | int]
10+
11+
12+
def dask_arrays_to_datatree(
13+
arrays: Sequence[da.Array],
14+
dims: Sequence[str],
15+
channels: list[Any] | None = None,
16+
) -> DataTree:
17+
"""Build a multiscale DataTree from a sequence of dask arrays.
18+
19+
Parameters
20+
----------
21+
arrays
22+
Sequence of dask arrays, one per scale level (scale0, scale1, ...).
23+
dims
24+
Dimension names for the arrays (e.g. ``("c", "y", "x")``).
25+
channels
26+
Optional channel coordinate values. If provided, a ``"c"`` coordinate
27+
is added to each scale level.
28+
29+
Returns
30+
-------
31+
DataTree with one child per scale level.
32+
"""
33+
coords = {"c": channels} if channels is not None else {}
34+
d = {}
35+
for i, arr in enumerate(arrays):
36+
d[f"scale{i}"] = Dataset(
37+
{
38+
"image": DataArray(
39+
arr,
40+
name="image",
41+
dims=list(dims),
42+
coords=coords,
43+
)
44+
}
45+
)
46+
return DataTree.from_dict(d)
47+
48+
49+
def to_multiscale(
50+
image: DataArray,
51+
scale_factors: ScaleFactors_t,
52+
chunks: Chunks_t | None = None,
53+
) -> DataTree:
54+
dims = [str(dim) for dim in image.dims]
55+
spatial_dims = [d for d in dims if d != "c"]
56+
order = 1 if "c" in dims else 0
57+
pyramid = [image.data]
58+
for sf in scale_factors:
59+
prev = pyramid[-1]
60+
# Compute per-axis scale factors: int applies to spatial axes only, dict to specific ones.
61+
sf_by_axis = dict.fromkeys(dims, 1)
62+
if isinstance(sf, int):
63+
sf_by_axis.update(dict.fromkeys(spatial_dims, sf))
64+
else:
65+
sf_by_axis.update(sf)
66+
# Clamp: skip axes where the scale factor exceeds the axis size.
67+
for ax, factor in sf_by_axis.items():
68+
ax_size = prev.shape[dims.index(ax)]
69+
if factor > ax_size:
70+
sf_by_axis[ax] = 1
71+
output_shape = tuple(prev.shape[dims.index(ax)] // f for ax, f in sf_by_axis.items())
72+
resized = resize(
73+
image=prev.astype(float),
74+
output_shape=output_shape,
75+
order=order,
76+
mode="reflect",
77+
anti_aliasing=False,
78+
)
79+
pyramid.append(resized.astype(prev.dtype))
80+
if chunks is not None:
81+
if isinstance(chunks, Mapping):
82+
chunks = {dims.index(k) if isinstance(k, str) else k: v for k, v in chunks.items()}
83+
pyramid = [arr.rechunk(chunks) for arr in pyramid]
84+
return dask_arrays_to_datatree(pyramid, dims=dims)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import dask
2+
import numpy as np
3+
import pytest
4+
from multiscale_spatial_image.to_multiscale.to_multiscale import Methods
5+
6+
from spatialdata.datasets import BlobsDataset
7+
from spatialdata.models import Image2DModel, Image3DModel, Labels2DModel, Labels3DModel
8+
9+
CHUNK_SIZE = 32
10+
11+
12+
@pytest.mark.parametrize(
13+
("model", "length", "ndim", "n_channels", "scale_factors", "method"),
14+
[
15+
(Image2DModel, 128, 2, 3, (2, 2), Methods.XARRAY_COARSEN),
16+
(Image3DModel, 32, 3, 3, (2, 2), Methods.XARRAY_COARSEN),
17+
(Labels2DModel, 128, 2, 0, (2, 2), Methods.DASK_IMAGE_NEAREST),
18+
(Labels3DModel, 32, 3, 0, (2, 2), Methods.DASK_IMAGE_NEAREST),
19+
],
20+
)
21+
def test_to_multiscale_via_ome_zarr_scaler(model, length, ndim, n_channels, scale_factors, method):
22+
blob_gen = BlobsDataset()
23+
24+
if n_channels > 0:
25+
# Image: stack multiple blob channels
26+
masks = []
27+
for i in range(n_channels):
28+
mask = blob_gen._generate_blobs(length=length, seed=i, ndim=ndim)
29+
mask = (mask - mask.min()) / np.ptp(mask)
30+
masks.append(mask)
31+
array = np.stack(masks, axis=0)
32+
else:
33+
# Labels: threshold blob pattern to get integer labels
34+
mask = blob_gen._generate_blobs(length=length, ndim=ndim)
35+
threshold = np.percentile(mask, 70)
36+
array = (mask >= threshold).astype(np.int64)
37+
38+
dims = model.dims
39+
dask_data = dask.array.from_array(array).rechunk(CHUNK_SIZE)
40+
41+
# multiscale-spatial-image path (explicit method)
42+
result_msi = model.parse(dask_data, dims=dims, scale_factors=scale_factors, chunks=CHUNK_SIZE, method=method)
43+
44+
# ome-zarr-py scaler path (method=None triggers the ome-zarr-py scaler)
45+
result_ozp = model.parse(dask_data, dims=dims, scale_factors=scale_factors, chunks=CHUNK_SIZE)
46+
47+
# Compare data values at each scale level
48+
for scale_name in result_msi.children:
49+
msi_arr = result_msi[scale_name].ds["image"]
50+
ozp_arr = result_ozp[scale_name].ds["image"]
51+
assert msi_arr.sizes == ozp_arr.sizes
52+
np.testing.assert_allclose(msi_arr.values, ozp_arr.values)

0 commit comments

Comments
 (0)