Skip to content

Commit cf0445b

Browse files
ArneDefauwLucaMarconatopre-commit-ci[bot]
authored
Visium hd rasterize bins labels (#811)
* rasterize bins labels * rasterize bins labels * fix mypy * minor fixes in docstrings, mypy, exceptions, todos * fix tests/pre-commit (merge with main) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix test * fix docs * add tests * cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * moved _get_uint_dtype() to models._utils; fix docs --------- Co-authored-by: LucaMarconato <[email protected]> Co-authored-by: Luca Marconato <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 7bce868 commit cf0445b

File tree

7 files changed

+329
-67
lines changed

7 files changed

+329
-67
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning][].
1313
### Major
1414

1515
- Added attributes at the SpatialData object level (`.attrs`)
16+
- `rasterize_bins()` can now produce a labels element #811 @ArneDefauw
1617

1718
## [0.2.6] - 2024-11-26
1819

docs/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Operations on `SpatialData` objects.
3737
transform
3838
rasterize
3939
rasterize_bins
40+
rasterize_bins_link_table_to_labels
4041
to_circles
4142
to_polygons
4243
aggregate

src/spatialdata/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"concatenate",
2828
"rasterize",
2929
"rasterize_bins",
30+
"rasterize_bins_link_table_to_labels",
3031
"to_circles",
3132
"to_polygons",
3233
"transform",
@@ -61,7 +62,7 @@
6162
from spatialdata._core.operations.aggregate import aggregate
6263
from spatialdata._core.operations.map import map_raster, relabel_sequential
6364
from spatialdata._core.operations.rasterize import rasterize
64-
from spatialdata._core.operations.rasterize_bins import rasterize_bins
65+
from spatialdata._core.operations.rasterize_bins import rasterize_bins, rasterize_bins_link_table_to_labels
6566
from spatialdata._core.operations.transform import transform
6667
from spatialdata._core.operations.vectorize import to_circles, to_polygons
6768
from spatialdata._core.query._utils import get_bounding_box_corners

src/spatialdata/_core/operations/rasterize_bins.py

Lines changed: 140 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import dask.array as da
66
import numpy as np
77
import pandas as pd
8+
from anndata import AnnData
89
from dask.dataframe import DataFrame as DaskDataFrame
910
from geopandas import GeoDataFrame
1011
from numpy.random import default_rng
@@ -14,12 +15,16 @@
1415
from xarray import DataArray
1516

1617
from spatialdata._core.query.relational_query import get_values
18+
from spatialdata._logging import logger
1719
from spatialdata._types import ArrayLike
18-
from spatialdata.models import Image2DModel, get_table_keys
20+
from spatialdata.models import Image2DModel, Labels2DModel, get_table_keys
21+
from spatialdata.models._utils import _get_uint_dtype
1922
from spatialdata.transformations import Affine, Sequence, get_transformation
2023

2124
RNG = default_rng(0)
2225

26+
__all__ = ["rasterize_bins", "rasterize_bins_link_table_to_labels"]
27+
2328

2429
if TYPE_CHECKING:
2530
from spatialdata import SpatialData
@@ -32,6 +37,7 @@ def rasterize_bins(
3237
col_key: str,
3338
row_key: str,
3439
value_key: str | list[str] | None = None,
40+
return_region_as_labels: bool = False,
3541
) -> DataArray:
3642
"""
3743
Rasterizes grid-like binned shapes/points annotated by a table (e.g. Visium HD data).
@@ -51,6 +57,14 @@ def rasterize_bins(
5157
value_key
5258
The key(s) (obs columns/var names) in the table that will be used to rasterize the bins.
5359
If `None`, all the var names will be used, and the returned object will be lazily constructed.
60+
Ignored if `return_region_as_labels` is `True`.
61+
return_regions_as_labels
62+
If `False` this function returns a `xarray.DataArray` of shape `(c, y, x)` with dimension
63+
of `c` equal to the number of key(s) specified in `value_key`, or the number of var names
64+
in `table_name` if `value_key` is `None`. If `True`, will return labels of shape `(y, x)`,
65+
where each bin of the `bins` element will be represented as a pixel. The table by default will not be set to
66+
annotate the new rasterized labels; this can be achieved using the helper function
67+
`spatialdata.rasterize_bins_link_table_to_labels()`.
5468
5569
Returns
5670
-------
@@ -73,24 +87,93 @@ def rasterize_bins(
7387
"""
7488
element = sdata[bins]
7589
table = sdata.tables[table_name]
76-
if not isinstance(element, GeoDataFrame | DaskDataFrame):
77-
raise ValueError("The bins should be a GeoDataFrame or a DaskDataFrame.")
90+
if not isinstance(element, GeoDataFrame | DaskDataFrame | DataArray):
91+
raise ValueError("The bins should be a GeoDataFrame, a DaskDataFrame or a DataArray.")
92+
if isinstance(element, DataArray):
93+
if "c" in element.dims:
94+
raise ValueError(
95+
"If bins is a DataArray, it should hold labels; found a image element instead, with"
96+
f" 'c': {element.dims}."
97+
)
98+
if not np.issubdtype(element.dtype, np.integer):
99+
raise ValueError(f"If bins is a DataArray, it should hold integers. Found dtype {element.dtype}.")
78100

79101
_, region_key, instance_key = get_table_keys(table)
80102
if not table.obs[region_key].dtype == "category":
81103
raise ValueError(f"Please convert `table.obs['{region_key}']` to a category series to improve performances")
82104
unique_regions = table.obs[region_key].cat.categories
83-
if len(unique_regions) > 1 or unique_regions[0] != bins:
105+
if len(unique_regions) > 1:
106+
raise ValueError(
107+
f"Found multiple regions annotated by the table: {', '.join(list(unique_regions))}, "
108+
"currently only tables annotating a single region are supported. Please open a feature request if you are "
109+
"interested in the general case."
110+
)
111+
if unique_regions[0] != bins:
112+
raise ValueError("The table should be associated with the specified bins.")
113+
114+
if isinstance(element, DataArray) and return_region_as_labels:
84115
raise ValueError(
85-
"The table should be associated with the specified bins. "
86-
f"Found multiple regions annotated by the table: {', '.join(list(unique_regions))}."
116+
f"bins is already a labels layer that annotates the table '{table_name}'. "
117+
"Consider setting 'return_region_as_labels' to 'False' to create a lazy spatial image."
87118
)
88119

89120
min_row, min_col = table.obs[row_key].min(), table.obs[col_key].min()
90121
n_rows, n_cols = table.obs[row_key].max() - min_row + 1, table.obs[col_key].max() - min_col + 1
91122
y = (table.obs[row_key] - min_row).values
92123
x = (table.obs[col_key] - min_col).values
93124

125+
if isinstance(element, DataArray):
126+
transformations = get_transformation(element, get_all=True)
127+
assert isinstance(transformations, dict)
128+
else:
129+
# get the transformation
130+
if table.n_obs < 6:
131+
raise ValueError("At least 6 bins are needed to estimate the transformation.")
132+
133+
random_indices = RNG.choice(table.n_obs, min(20, table.n_obs), replace=True)
134+
location_ids = table.obs[instance_key].iloc[random_indices].values
135+
sub_df, sub_table = element.loc[location_ids], table[random_indices]
136+
137+
src = np.stack([sub_table.obs[col_key] - min_col, sub_table.obs[row_key] - min_row], axis=1)
138+
if isinstance(sub_df, GeoDataFrame):
139+
if isinstance(sub_df.iloc[0].geometry, Point):
140+
sub_x = sub_df.geometry.x.values
141+
sub_y = sub_df.geometry.y.values
142+
else:
143+
assert isinstance(sub_df.iloc[0].geometry, Polygon | MultiPolygon)
144+
sub_x = sub_df.centroid.x
145+
sub_y = sub_df.centroid.y
146+
else:
147+
assert isinstance(sub_df, DaskDataFrame)
148+
sub_x = sub_df.x.compute().values
149+
sub_y = sub_df.y.compute().values
150+
dst = np.stack([sub_x, sub_y], axis=1)
151+
152+
to_bins = Sequence(
153+
[
154+
Affine(
155+
estimate_transform(ttype="affine", src=src, dst=dst).params,
156+
input_axes=("x", "y"),
157+
output_axes=("x", "y"),
158+
)
159+
]
160+
)
161+
bins_transformations = get_transformation(element, get_all=True)
162+
163+
assert isinstance(bins_transformations, dict)
164+
165+
transformations = {cs: to_bins.compose_with(t) for cs, t in bins_transformations.items()}
166+
167+
if return_region_as_labels:
168+
new_instance_key = _get_relabeled_column_name(instance_key)
169+
table.obs[new_instance_key] = _relabel_labels(table=table, instance_key=instance_key)
170+
dtype = table.obs[new_instance_key].dtype
171+
labels_element = np.zeros((n_rows, n_cols), dtype=dtype)
172+
# make labels layer that can visualy represent the cells
173+
labels_element[y, x] = table.obs[new_instance_key].values.T
174+
175+
return Labels2DModel.parse(data=labels_element, dims=("y", "x"), transformations=transformations)
176+
94177
keys = ([value_key] if isinstance(value_key, str) else value_key) if value_key is not None else table.var_names
95178

96179
if (value_key is None or any(key in table.var_names for key in keys)) and not isinstance(
@@ -115,7 +198,6 @@ def rasterize_bins(
115198
shape = (n_rows, n_cols)
116199

117200
def channel_rasterization(block_id: tuple[int, int, int] | None) -> ArrayLike:
118-
119201
image: ArrayLike = np.zeros((1, *shape), dtype=dtype)
120202

121203
if block_id is None:
@@ -148,42 +230,59 @@ def channel_rasterization(block_id: tuple[int, int, int] | None) -> ArrayLike:
148230
else:
149231
image[i, y, x] = table.X[:, key_index]
150232

151-
# get the transformation
152-
if table.n_obs < 6:
153-
raise ValueError("At least 6 bins are needed to estimate the transformation.")
233+
return Image2DModel.parse(
234+
data=image,
235+
dims=("c", "y", "x"),
236+
transformations=transformations,
237+
c_coords=keys,
238+
)
154239

155-
random_indices = RNG.choice(table.n_obs, min(20, table.n_obs), replace=True)
156-
location_ids = table.obs[instance_key].iloc[random_indices].values
157-
sub_df, sub_table = element.loc[location_ids], table[random_indices]
158240

159-
src = np.stack([sub_table.obs[col_key] - min_col, sub_table.obs[row_key] - min_row], axis=1)
160-
if isinstance(sub_df, GeoDataFrame):
161-
if isinstance(sub_df.iloc[0].geometry, Point):
162-
sub_x = sub_df.geometry.x.values
163-
sub_y = sub_df.geometry.y.values
164-
else:
165-
assert isinstance(sub_df.iloc[0].geometry, Polygon | MultiPolygon)
166-
sub_x = sub_df.centroid.x
167-
sub_y = sub_df.centroid.y
168-
else:
169-
assert isinstance(sub_df, DaskDataFrame)
170-
sub_x = sub_df.x.compute().values
171-
sub_y = sub_df.y.compute().values
172-
dst = np.stack([sub_x, sub_y], axis=1)
173-
174-
to_bins = Sequence(
175-
[
176-
Affine(
177-
estimate_transform(ttype="affine", src=src, dst=dst).params,
178-
input_axes=("x", "y"),
179-
output_axes=("x", "y"),
180-
)
181-
]
182-
)
183-
bins_transformations = get_transformation(element, get_all=True)
241+
def _get_relabeled_column_name(column_name: str) -> str:
242+
return f"relabeled_{column_name}"
184243

185-
assert isinstance(bins_transformations, dict)
186244

187-
transformations = {cs: to_bins.compose_with(t) for cs, t in bins_transformations.items()}
245+
def _relabel_labels(table: AnnData, instance_key: str) -> pd.Series:
246+
labels_values_count = len(table.obs[instance_key].unique())
188247

189-
return Image2DModel.parse(image, transformations=transformations, c_coords=keys, dims=("c", "y", "x"))
248+
is_not_numeric = not np.issubdtype(table.obs[instance_key].dtype, np.number)
249+
zero_in_instance_key = 0 in table.obs[instance_key].values
250+
has_gaps = not is_not_numeric and labels_values_count != table.obs[instance_key].max() + int(zero_in_instance_key)
251+
252+
relabeling_is_needed = is_not_numeric or zero_in_instance_key or has_gaps
253+
if relabeling_is_needed:
254+
logger.info(
255+
f"The instance_key column in 'table.obs' ('table.obs[{instance_key}]') will be relabeled to ensure"
256+
" a numeric data type, with a continuous range and without including the value 0 (which is reserved "
257+
"for the background). The new labels will be stored in a new column named "
258+
f"{_get_relabeled_column_name(instance_key)!r}."
259+
)
260+
261+
relabeled_instance_key_column = table.obs[instance_key].astype("category").cat.codes + int(zero_in_instance_key)
262+
# uses only allowed dtypes that passes our model validations, in particuar no uint8
263+
dtype = _get_uint_dtype(value=relabeled_instance_key_column.max())
264+
return relabeled_instance_key_column.astype(dtype)
265+
266+
267+
def rasterize_bins_link_table_to_labels(sdata: SpatialData, table_name: str, rasterized_labels_name: str) -> None:
268+
"""
269+
Change the annotation target of the table to the rasterized labels.
270+
271+
This function should be called after having rasterized the bins (calling `rasterize_bins()` with
272+
`return_regions_as_labels=True`) and after having added the rasterized labels to the spatial data object.
273+
274+
Parameters
275+
----------
276+
sdata
277+
The spatial data object containing the rasterized labels.
278+
table_name
279+
The name of the table to be annotated.
280+
rasterized_labels_name
281+
The name of the rasterized labels in the spatial data object.
282+
"""
283+
_, region_key, instance_key = get_table_keys(sdata[table_name])
284+
sdata[table_name].obs[region_key] = rasterized_labels_name
285+
relabled_instance_key = _get_relabeled_column_name(instance_key)
286+
sdata.set_table_annotates_spatialelement(
287+
table_name=table_name, region=rasterized_labels_name, region_key=region_key, instance_key=relabled_instance_key
288+
)

src/spatialdata/models/_utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import warnings
24
from functools import singledispatch
35
from typing import TYPE_CHECKING, Any, TypeAlias
@@ -367,7 +369,7 @@ def force_2d(gdf: GeoDataFrame) -> None:
367369
gdf.geometry = new_shapes
368370

369371

370-
def get_raster_model_from_data_dims(dims: tuple[str, ...]) -> type["RasterSchema"]:
372+
def get_raster_model_from_data_dims(dims: tuple[str, ...]) -> type[RasterSchema]:
371373
"""
372374
Get the raster model from the dimensions of the data.
373375
@@ -435,3 +437,19 @@ def set_channel_names(element: DataArray | DataTree, channel_names: str | list[s
435437
raise TypeError("Element model does not support setting channel names, no `c` dimension found.")
436438

437439
return element
440+
441+
442+
def _get_uint_dtype(value: int) -> str:
443+
max_uint64 = np.iinfo(np.uint64).max
444+
max_uint32 = np.iinfo(np.uint32).max
445+
max_uint16 = np.iinfo(np.uint16).max
446+
447+
if max_uint16 >= value:
448+
dtype = "uint16"
449+
elif max_uint32 >= value:
450+
dtype = "uint32"
451+
elif max_uint64 >= value:
452+
dtype = "uint64"
453+
else:
454+
raise ValueError(f"Maximum cell number is {value}. Values higher than {max_uint64} are not supported.")
455+
return dtype

src/spatialdata/transformations/_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import warnings
34
from functools import singledispatch
45
from typing import TYPE_CHECKING, Any, Optional, Union
56

@@ -8,7 +9,6 @@
89
from geopandas import GeoDataFrame
910
from xarray import DataArray, Dataset, DataTree
1011

11-
from spatialdata._logging import logger
1212
from spatialdata._types import ArrayLike
1313

1414
if TYPE_CHECKING:
@@ -253,10 +253,12 @@ def scale_radii(radii: ArrayLike, affine: Affine, axes: tuple[str, ...]) -> Arra
253253
modules = np.absolute(eigenvalues)
254254
if not np.allclose(modules, modules[0]):
255255
scale_factor = np.mean(modules)
256-
logger.warning(
256+
warnings.warn(
257257
"The vector part of the transformation matrix is not isotropic, the radius will be scaled by the average "
258258
f"of the modules of eigenvalues of the affine transformation matrix.\nmatrix={matrix}\n"
259-
f"eigenvalues={eigenvalues}\nscale_factor={scale_factor}"
259+
f"eigenvalues={eigenvalues}\nscale_factor={scale_factor}",
260+
UserWarning,
261+
stacklevel=2,
260262
)
261263
else:
262264
scale_factor = modules[0]

0 commit comments

Comments
 (0)