Skip to content

Commit 7d7dd5d

Browse files
committed
fix pre-commit and tests
1 parent 19b4d5f commit 7d7dd5d

File tree

3 files changed

+48
-24
lines changed

3 files changed

+48
-24
lines changed

src/spatialdata/datasets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def _image_blobs(
189189
masks.append(mask)
190190

191191
x = np.stack(masks, axis=0)
192+
model: type[Image2DModel] | type[Image3DModel]
192193
if ndim == 2:
193194
dims = ["c", "y", "x"]
194195
model = Image2DModel
@@ -229,6 +230,7 @@ def _labels_blobs(
229230
out[out == val[idx]] = 0
230231
else:
231232
out[out == val[idx]] = i
233+
model: type[Labels2DModel] | type[Labels3DModel]
232234
if ndim == 2:
233235
dims = ["y", "x"]
234236
model = Labels2DModel

src/spatialdata/models/pyramids_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,36 @@ def to_multiscale(
5151
scale_factors: ScaleFactors_t,
5252
chunks: Chunks_t | None = None,
5353
) -> DataTree:
54+
"""Build a multiscale pyramid DataTree from a single-scale image.
55+
56+
Iteratively downscales the image by the given scale factors using
57+
interpolation (order 1 for images with a channel dimension, order 0
58+
for labels) and assembles all levels into a DataTree.
59+
60+
Makes uses of internal ome-zarr-py APIs for dask downscaling.
61+
62+
TODO: ome-zarr-py will support 3D downscaling once https://github.com/ome/ome-zarr-py/pull/516 is merged, and this
63+
function could make use of it. Also the PR will introduce new downscaling methods such as "nearest". Nevertheless,
64+
this function supports different scaling factors per axis, which is not supported by ome-zarr-py yet.
65+
66+
Parameters
67+
----------
68+
image
69+
Input image/labels as an xarray DataArray (e.g. with dims ``("c", "y", "x")``
70+
or ``("y", "x")``). Supports both 2D/3D images and 2D/3D labels.
71+
scale_factors
72+
Sequence of per-level scale factors. Each element is either an int
73+
(applied to all spatial axes) or a dict mapping dimension names to
74+
per-axis factors (e.g. ``{"y": 2, "x": 2}``).
75+
chunks
76+
Optional chunk specification passed to :meth:`dask.array.Array.rechunk`
77+
after building the pyramid.
78+
79+
Returns
80+
-------
81+
DataTree
82+
Multiscale DataTree with children ``scale0``, ``scale1``, etc.
83+
"""
5484
dims = [str(dim) for dim in image.dims]
5585
spatial_dims = [d for d in dims if d != "c"]
5686
order = 1 if "c" in dims else 0

tests/models/test_pyramids_utils.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
@pytest.mark.parametrize(
1313
("model", "length", "ndim", "n_channels", "scale_factors", "method"),
1414
[
15-
# (Image2DModel, 128, 2, 3, (2, 2), Methods.XARRAY_COARSEN),
16-
# (Image3DModel, 32, 3, 3, (2, 2), Methods.XARRAY_COARSEN),
15+
(Image2DModel, 128, 2, 3, (2, 2), Methods.XARRAY_COARSEN),
16+
(Image3DModel, 32, 3, 3, (2, 2), Methods.XARRAY_COARSEN),
1717
(Labels2DModel, 128, 2, 0, (2, 2), Methods.DASK_IMAGE_NEAREST),
1818
(Labels3DModel, 32, 3, 0, (2, 2), Methods.DASK_IMAGE_NEAREST),
1919
],
@@ -38,35 +38,27 @@ def test_to_multiscale_via_ome_zarr_scaler(model, length, ndim, n_channels, scal
3838
dims = model.dims
3939
dask_data = dask.array.from_array(array).rechunk(CHUNK_SIZE)
4040

41-
# # multiscale-spatial-image path (explicit method)
41+
# multiscale-spatial-image path (explicit method)
4242
result_msi = model.parse(dask_data, dims=dims, scale_factors=scale_factors, chunks=CHUNK_SIZE, method=method)
4343

4444
# ome-zarr-py scaler path (method=None triggers the ome-zarr-py scaler)
4545
result_ozp = model.parse(dask_data, dims=dims, scale_factors=scale_factors, chunks=CHUNK_SIZE)
4646

47-
# ##
48-
# from napari_spatialdata import Interactive
49-
# from spatialdata import SpatialData
50-
#
51-
# sdata = SpatialData.init_from_elements({'msi': result_msi, 'ozp': result_ozp})
52-
# Interactive(sdata)
53-
54-
##
55-
5647
# Compare data values at each scale level
57-
import matplotlib.pyplot as plt
58-
_, axes = plt.subplots(len(result_msi.children), 2, figsize=(8, 4 * len(result_msi.children)))
5948
for i, scale_name in enumerate(result_msi.children):
6049
msi_arr = result_msi[scale_name].ds["image"]
6150
ozp_arr = result_ozp[scale_name].ds["image"]
6251
assert msi_arr.sizes == ozp_arr.sizes
63-
64-
if msi_arr.ndim == 3:
65-
msi_arr = msi_arr[0]
66-
ozp_arr = ozp_arr[0]
67-
axes[i, 0].imshow(msi_arr.values)
68-
axes[i, 1].imshow(ozp_arr.values)
69-
pass
70-
# np.testing.assert_allclose(msi_arr.values, ozp_arr.values)
71-
plt.tight_layout()
72-
plt.show()
52+
if model in [Image2DModel, Image3DModel]:
53+
# exact comparison for images
54+
np.testing.assert_allclose(msi_arr.values, ozp_arr.values)
55+
else:
56+
if i == 0:
57+
# no downscaling is performed, so they must be equal
58+
np.testing.assert_array_equal(msi_arr.values, ozp_arr.values)
59+
else:
60+
# we expect differences: ngff-zarr uses "nearest", ozp uses "resize"
61+
# TODO: when https://github.com/ome/ome-zarr-py/pull/516 is merged we can use nearest for labels and
62+
# expect a much stricter adherence
63+
fraction_non_equal = np.sum(msi_arr.values != ozp_arr.values) / np.prod(msi_arr.values.shape)
64+
assert fraction_non_equal < 0.5

0 commit comments

Comments
 (0)