Skip to content

Commit 4679e13

Browse files
committed
add tests
1 parent 068e3d5 commit 4679e13

File tree

5 files changed

+135
-37
lines changed

5 files changed

+135
-37
lines changed

docs/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Operations on `SpatialData` objects.
3737
transform
3838
rasterize
3939
rasterize_bins
40+
rasterize_bins_link_table_to_labels
4041
to_circles
4142
to_polygons
4243
aggregate

src/spatialdata/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"concatenate",
2828
"rasterize",
2929
"rasterize_bins",
30+
"rasterize_bins_link_table_to_labels",
3031
"to_circles",
3132
"to_polygons",
3233
"transform",
@@ -61,7 +62,7 @@
6162
from spatialdata._core.operations.aggregate import aggregate
6263
from spatialdata._core.operations.map import map_raster, relabel_sequential
6364
from spatialdata._core.operations.rasterize import rasterize
64-
from spatialdata._core.operations.rasterize_bins import rasterize_bins
65+
from spatialdata._core.operations.rasterize_bins import rasterize_bins, rasterize_bins_link_table_to_labels
6566
from spatialdata._core.operations.transform import transform
6667
from spatialdata._core.operations.vectorize import to_circles, to_polygons
6768
from spatialdata._core.query._utils import get_bounding_box_corners

src/spatialdata/_core/operations/rasterize_bins.py

Lines changed: 70 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import dask.array as da
66
import numpy as np
77
import pandas as pd
8+
from anndata import AnnData
89
from dask.dataframe import DataFrame as DaskDataFrame
910
from geopandas import GeoDataFrame
1011
from numpy.random import default_rng
@@ -21,6 +22,8 @@
2122

2223
RNG = default_rng(0)
2324

25+
__all__ = ["rasterize_bins", "rasterize_bins_link_table_to_labels"]
26+
2427

2528
if TYPE_CHECKING:
2629
from spatialdata import SpatialData
@@ -58,7 +61,8 @@ def rasterize_bins(
5861
If `False` this function returns a `xarray.DataArray` of shape `(c, y, x)` with dimension
5962
of `c` equal to the number of key(s) specified in `value_key`, or the number of var names
6063
in `table_name` if `value_key` is `None`. If `True`, will return labels of shape `(y, x)`,
61-
where each bin of the `bins` element will be represented as a pixel.
64+
where each bin of the `bins` element will be represented as a pixel. The table by default will not be set to
65+
annotate the new rasterized labels; see the Notes section for how to do this.
6266
6367
Returns
6468
-------
@@ -78,6 +82,10 @@ def rasterize_bins(
7882
7983
If `spatialdata-plot` is used to visualized the returned image, the parameter `scale='full'` needs to be passed to
8084
`.render_shapes()`, to disable an automatic rasterization that would confict with the rasterization performed here.
85+
86+
When `return_region_as_labels` is `True`, the function will return a labels layer that is not annotated by default
87+
by the table. To change the annotation target of the table you can call the helper function
88+
`spatialdata.rasterize_bins_link_table_to_labels()`.
8189
"""
8290
element = sdata[bins]
8391
table = sdata.tables[table_name]
@@ -159,22 +167,12 @@ def rasterize_bins(
159167
transformations = {cs: to_bins.compose_with(t) for cs, t in bins_transformations.items()}
160168

161169
if return_region_as_labels:
162-
dtype = _get_uint_dtype(table.obs[instance_key].max())
163-
_min_value = table.obs[instance_key].min()
164-
# TODO: add a new column instead of modyfing the table inplace
165-
# TODO: do not modify the index of the elements
166-
if _min_value == 0:
167-
logger.info(
168-
f"The minimum value of the instance key column ('table.obs[{instance_key}]') has been"
169-
" detected to be 0. Since the label 0 is reserved for the background, "
170-
"both the instance key column in 'table.obs' "
171-
f"and the index of the annotating element '{bins}' is incremented by 1."
172-
)
173-
table.obs[instance_key] += 1
174-
element.index += 1
170+
new_instance_key = _get_relabeled_column_name(instance_key)
171+
table.obs[new_instance_key] = _relabel_labels(table=table, instance_key=instance_key)
172+
dtype = table.obs[new_instance_key].dtype
175173
labels_element = np.zeros((n_rows, n_cols), dtype=dtype)
176174
# make labels layer that can visualy represent the cells
177-
labels_element[y, x] = table.obs[instance_key].values.T
175+
labels_element[y, x] = table.obs[new_instance_key].values.T
178176

179177
return Labels2DModel.parse(data=labels_element, dims=("y", "x"), transformations=transformations)
180178

@@ -242,17 +240,69 @@ def channel_rasterization(block_id: tuple[int, int, int] | None) -> ArrayLike:
242240
)
243241

244242

245-
def _get_uint_dtype(value: int) -> str:
243+
def _get_uint_dtype(labels_values_count: int) -> str:
246244
max_uint64 = np.iinfo(np.uint64).max
247245
max_uint32 = np.iinfo(np.uint32).max
248246
max_uint16 = np.iinfo(np.uint16).max
249247

250-
if max_uint16 >= value:
248+
if max_uint16 >= labels_values_count:
251249
dtype = "uint16"
252-
elif max_uint32 >= value:
250+
elif max_uint32 >= labels_values_count:
253251
dtype = "uint32"
254-
elif max_uint64 >= value:
252+
elif max_uint64 >= labels_values_count:
255253
dtype = "uint64"
256254
else:
257-
raise ValueError(f"Maximum cell number is {value}. Values higher than {max_uint64} are not supported.")
255+
raise ValueError(
256+
f"Maximum cell number is {labels_values_count}. Values higher than {max_uint64} are not supported."
257+
)
258258
return dtype
259+
260+
261+
def _get_relabeled_column_name(column_name: str) -> str:
262+
return f"relabeled_{column_name}"
263+
264+
265+
def _relabel_labels(table: AnnData, instance_key: str) -> pd.Series:
266+
labels_values_count = len(table.obs[instance_key].unique())
267+
268+
is_not_numeric = not np.issubdtype(table.obs[instance_key].dtype, np.number)
269+
zero_in_instance_key = 0 in table.obs[instance_key].values
270+
has_gaps = not is_not_numeric and labels_values_count != table.obs[instance_key].max() + int(zero_in_instance_key)
271+
272+
relabeling_is_needed = is_not_numeric or zero_in_instance_key or has_gaps
273+
if relabeling_is_needed:
274+
logger.info(
275+
f"The instance_key column in 'table.obs' ('table.obs[{instance_key}]') will be relabeled to ensure"
276+
" a numeric data type, with a continuous range and without including the value 0 (which is reserved "
277+
"for the background). The new labels will be stored in a new column named "
278+
f"{_get_relabeled_column_name(instance_key)!r}."
279+
)
280+
281+
relabeled_instance_key_column = table.obs[instance_key].astype("category").cat.codes + int(zero_in_instance_key)
282+
# uses only allowed dtypes that passes our model validations, in particualr no uint8
283+
dtype = _get_uint_dtype(labels_values_count=relabeled_instance_key_column.max())
284+
return relabeled_instance_key_column.astype(dtype)
285+
286+
287+
def rasterize_bins_link_table_to_labels(sdata: SpatialData, table_name: str, rasterized_labels_name: str) -> None:
288+
"""
289+
Annotates the table with the rasterized labels.
290+
291+
This function should be called after rasterizing the bins (with `return_regions_as_labels` is `True`) and adding
292+
the rasterized labels in the spatial data object.
293+
294+
Parameters
295+
----------
296+
sdata
297+
The spatial data object containing the rasterized labels.
298+
table_name
299+
The name of the table to be annotated.
300+
rasterized_labels_name
301+
The name of the rasterized labels in the spatial data object.
302+
"""
303+
_, region_key, instance_key = get_table_keys(sdata[table_name])
304+
sdata[table_name].obs[region_key] = rasterized_labels_name
305+
relabled_instance_key = _get_relabeled_column_name(instance_key)
306+
sdata.set_table_annotates_spatialelement(
307+
table_name=table_name, region=rasterized_labels_name, region_key=region_key, instance_key=relabled_instance_key
308+
)

src/spatialdata/transformations/_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import warnings
34
from functools import singledispatch
45
from typing import TYPE_CHECKING, Any, Optional, Union
56

@@ -8,7 +9,6 @@
89
from geopandas import GeoDataFrame
910
from xarray import DataArray, Dataset, DataTree
1011

11-
from spatialdata._logging import logger
1212
from spatialdata._types import ArrayLike
1313

1414
if TYPE_CHECKING:
@@ -253,10 +253,12 @@ def scale_radii(radii: ArrayLike, affine: Affine, axes: tuple[str, ...]) -> Arra
253253
modules = np.absolute(eigenvalues)
254254
if not np.allclose(modules, modules[0]):
255255
scale_factor = np.mean(modules)
256-
logger.warning(
256+
warnings.warn(
257257
"The vector part of the transformation matrix is not isotropic, the radius will be scaled by the average "
258258
f"of the modules of eigenvalues of the affine transformation matrix.\nmatrix={matrix}\n"
259-
f"eigenvalues={eigenvalues}\nscale_factor={scale_factor}"
259+
f"eigenvalues={eigenvalues}\nscale_factor={scale_factor}",
260+
UserWarning,
261+
stacklevel=2,
260262
)
261263
else:
262264
scale_factor = modules[0]

tests/core/operations/test_rasterize_bins.py

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import logging
34
import re
45

56
import numpy as np
@@ -12,8 +13,13 @@
1213
from shapely.geometry import Polygon
1314

1415
from spatialdata._core.data_extent import are_extents_equal, get_extent
15-
from spatialdata._core.operations.rasterize_bins import rasterize_bins
16+
from spatialdata._core.operations.rasterize_bins import (
17+
_relabel_labels,
18+
rasterize_bins,
19+
rasterize_bins_link_table_to_labels,
20+
)
1621
from spatialdata._core.spatialdata import SpatialData
22+
from spatialdata._logging import logger
1723
from spatialdata._types import ArrayLike
1824
from spatialdata.models.models import Image2DModel, Labels2DModel, PointsModel, ShapesModel, TableModel
1925
from spatialdata.transformations.transformations import Scale
@@ -80,18 +86,21 @@ def test_rasterize_bins(geometry: str, value_key: str | list[str] | None, return
8086
if return_region_as_labels:
8187
labels_name = "labels"
8288
sdata[labels_name] = rasterized
83-
adata = sdata["table"]
84-
adata.obs["region"] = labels_name
85-
adata.obs["region"] = adata.obs["region"].astype("category")
86-
del adata.uns[TableModel.ATTRS_KEY]
87-
adata = TableModel.parse(
88-
adata,
89-
region=labels_name,
90-
region_key="region",
91-
instance_key="instance_id",
92-
)
93-
del sdata["table"]
94-
sdata["table"] = adata
89+
90+
rasterize_bins_link_table_to_labels(sdata=sdata, table_name="table", rasterized_labels_name=labels_name)
91+
# adata = sdata["table"]
92+
# adata.obs["region"] = labels_name
93+
# adata.obs["region"] = adata.obs["region"].astype("category")
94+
# del adata.uns[TableModel.ATTRS_KEY]
95+
# adata = TableModel.parse(
96+
# adata,
97+
# region=labels_name,
98+
# region_key="region",
99+
# instance_key="instance_id",
100+
# )
101+
# del sdata["table"]
102+
# sdata["table"] = adata
103+
95104
# this fails because table already annotated by labels layer
96105
with pytest.raises(
97106
ValueError,
@@ -265,3 +274,38 @@ def _get_sdata(n: int):
265274
row_key="row_index",
266275
value_key="instance_id",
267276
)
277+
278+
279+
def test_relabel_labels(caplog):
280+
obs = DataFrame(
281+
data={
282+
"instance_key0": np.arange(1, 11),
283+
"instance_key1": np.arange(10),
284+
"instance_key2": [1, 2] + list(range(4, 12)),
285+
"instance_key3": [str(i) for i in range(1, 11)],
286+
}
287+
)
288+
adata = AnnData(X=RNG.normal(size=(10, 2)), obs=obs)
289+
_relabel_labels(table=adata, instance_key="instance_key0")
290+
# check logger info message
291+
expected_log_message = (
292+
"will be relabeled to ensure a numeric data type, with a continuous range and without including the value 0 ("
293+
"which is reserved for the background). The new labels will be stored in a new column named"
294+
)
295+
logger.propagate = True
296+
with caplog.at_level(logging.INFO):
297+
_relabel_labels(table=adata, instance_key="instance_key1")
298+
assert expected_log_message in caplog.text
299+
300+
with caplog.at_level(logging.INFO):
301+
_relabel_labels(table=adata, instance_key="instance_key2")
302+
assert expected_log_message in caplog.text
303+
304+
with caplog.at_level(logging.INFO):
305+
_relabel_labels(table=adata, instance_key="instance_key3")
306+
assert expected_log_message in caplog.text
307+
logger.propagate = False
308+
309+
310+
if __name__ == "__main__":
311+
test_relabel_labels()

0 commit comments

Comments
 (0)