55import dask .array as da
66import numpy as np
77import pandas as pd
8+ from anndata import AnnData
89from dask .dataframe import DataFrame as DaskDataFrame
910from geopandas import GeoDataFrame
1011from numpy .random import default_rng
1415from xarray import DataArray
1516
1617from spatialdata ._core .query .relational_query import get_values
18+ from spatialdata ._logging import logger
1719from spatialdata ._types import ArrayLike
18- from spatialdata .models import Image2DModel , get_table_keys
20+ from spatialdata .models import Image2DModel , Labels2DModel , get_table_keys
21+ from spatialdata .models ._utils import _get_uint_dtype
1922from spatialdata .transformations import Affine , Sequence , get_transformation
2023
2124RNG = default_rng (0 )
2225
26+ __all__ = ["rasterize_bins" , "rasterize_bins_link_table_to_labels" ]
27+
2328
2429if TYPE_CHECKING :
2530 from spatialdata import SpatialData
@@ -32,6 +37,7 @@ def rasterize_bins(
3237 col_key : str ,
3338 row_key : str ,
3439 value_key : str | list [str ] | None = None ,
40+ return_region_as_labels : bool = False ,
3541) -> DataArray :
3642 """
3743 Rasterizes grid-like binned shapes/points annotated by a table (e.g. Visium HD data).
@@ -51,6 +57,14 @@ def rasterize_bins(
5157 value_key
5258 The key(s) (obs columns/var names) in the table that will be used to rasterize the bins.
5359 If `None`, all the var names will be used, and the returned object will be lazily constructed.
60+ Ignored if `return_region_as_labels` is `True`.
61+ return_regions_as_labels
62+ If `False` this function returns a `xarray.DataArray` of shape `(c, y, x)` with dimension
63+ of `c` equal to the number of key(s) specified in `value_key`, or the number of var names
64+ in `table_name` if `value_key` is `None`. If `True`, will return labels of shape `(y, x)`,
65+ where each bin of the `bins` element will be represented as a pixel. The table by default will not be set to
66+ annotate the new rasterized labels; this can be achieved using the helper function
67+ `spatialdata.rasterize_bins_link_table_to_labels()`.
5468
5569 Returns
5670 -------
@@ -73,24 +87,93 @@ def rasterize_bins(
7387 """
7488 element = sdata [bins ]
7589 table = sdata .tables [table_name ]
76- if not isinstance (element , GeoDataFrame | DaskDataFrame ):
77- raise ValueError ("The bins should be a GeoDataFrame or a DaskDataFrame." )
90+ if not isinstance (element , GeoDataFrame | DaskDataFrame | DataArray ):
91+ raise ValueError ("The bins should be a GeoDataFrame, a DaskDataFrame or a DataArray." )
92+ if isinstance (element , DataArray ):
93+ if "c" in element .dims :
94+ raise ValueError (
95+ "If bins is a DataArray, it should hold labels; found a image element instead, with"
96+ f" 'c': { element .dims } ."
97+ )
98+ if not np .issubdtype (element .dtype , np .integer ):
99+ raise ValueError (f"If bins is a DataArray, it should hold integers. Found dtype { element .dtype } ." )
78100
79101 _ , region_key , instance_key = get_table_keys (table )
80102 if not table .obs [region_key ].dtype == "category" :
81103 raise ValueError (f"Please convert `table.obs['{ region_key } ']` to a category series to improve performances" )
82104 unique_regions = table .obs [region_key ].cat .categories
83- if len (unique_regions ) > 1 or unique_regions [0 ] != bins :
105+ if len (unique_regions ) > 1 :
106+ raise ValueError (
107+ f"Found multiple regions annotated by the table: { ', ' .join (list (unique_regions ))} , "
108+ "currently only tables annotating a single region are supported. Please open a feature request if you are "
109+ "interested in the general case."
110+ )
111+ if unique_regions [0 ] != bins :
112+ raise ValueError ("The table should be associated with the specified bins." )
113+
114+ if isinstance (element , DataArray ) and return_region_as_labels :
84115 raise ValueError (
85- "The table should be associated with the specified bins . "
86- f"Found multiple regions annotated by the table: { ', ' . join ( list ( unique_regions )) } ."
116+ f"bins is already a labels layer that annotates the table ' { table_name } ' . "
117+ "Consider setting 'return_region_as_labels' to 'False' to create a lazy spatial image ."
87118 )
88119
89120 min_row , min_col = table .obs [row_key ].min (), table .obs [col_key ].min ()
90121 n_rows , n_cols = table .obs [row_key ].max () - min_row + 1 , table .obs [col_key ].max () - min_col + 1
91122 y = (table .obs [row_key ] - min_row ).values
92123 x = (table .obs [col_key ] - min_col ).values
93124
125+ if isinstance (element , DataArray ):
126+ transformations = get_transformation (element , get_all = True )
127+ assert isinstance (transformations , dict )
128+ else :
129+ # get the transformation
130+ if table .n_obs < 6 :
131+ raise ValueError ("At least 6 bins are needed to estimate the transformation." )
132+
133+ random_indices = RNG .choice (table .n_obs , min (20 , table .n_obs ), replace = True )
134+ location_ids = table .obs [instance_key ].iloc [random_indices ].values
135+ sub_df , sub_table = element .loc [location_ids ], table [random_indices ]
136+
137+ src = np .stack ([sub_table .obs [col_key ] - min_col , sub_table .obs [row_key ] - min_row ], axis = 1 )
138+ if isinstance (sub_df , GeoDataFrame ):
139+ if isinstance (sub_df .iloc [0 ].geometry , Point ):
140+ sub_x = sub_df .geometry .x .values
141+ sub_y = sub_df .geometry .y .values
142+ else :
143+ assert isinstance (sub_df .iloc [0 ].geometry , Polygon | MultiPolygon )
144+ sub_x = sub_df .centroid .x
145+ sub_y = sub_df .centroid .y
146+ else :
147+ assert isinstance (sub_df , DaskDataFrame )
148+ sub_x = sub_df .x .compute ().values
149+ sub_y = sub_df .y .compute ().values
150+ dst = np .stack ([sub_x , sub_y ], axis = 1 )
151+
152+ to_bins = Sequence (
153+ [
154+ Affine (
155+ estimate_transform (ttype = "affine" , src = src , dst = dst ).params ,
156+ input_axes = ("x" , "y" ),
157+ output_axes = ("x" , "y" ),
158+ )
159+ ]
160+ )
161+ bins_transformations = get_transformation (element , get_all = True )
162+
163+ assert isinstance (bins_transformations , dict )
164+
165+ transformations = {cs : to_bins .compose_with (t ) for cs , t in bins_transformations .items ()}
166+
167+ if return_region_as_labels :
168+ new_instance_key = _get_relabeled_column_name (instance_key )
169+ table .obs [new_instance_key ] = _relabel_labels (table = table , instance_key = instance_key )
170+ dtype = table .obs [new_instance_key ].dtype
171+ labels_element = np .zeros ((n_rows , n_cols ), dtype = dtype )
172+ # make labels layer that can visualy represent the cells
173+ labels_element [y , x ] = table .obs [new_instance_key ].values .T
174+
175+ return Labels2DModel .parse (data = labels_element , dims = ("y" , "x" ), transformations = transformations )
176+
94177 keys = ([value_key ] if isinstance (value_key , str ) else value_key ) if value_key is not None else table .var_names
95178
96179 if (value_key is None or any (key in table .var_names for key in keys )) and not isinstance (
@@ -115,7 +198,6 @@ def rasterize_bins(
115198 shape = (n_rows , n_cols )
116199
117200 def channel_rasterization (block_id : tuple [int , int , int ] | None ) -> ArrayLike :
118-
119201 image : ArrayLike = np .zeros ((1 , * shape ), dtype = dtype )
120202
121203 if block_id is None :
@@ -148,42 +230,59 @@ def channel_rasterization(block_id: tuple[int, int, int] | None) -> ArrayLike:
148230 else :
149231 image [i , y , x ] = table .X [:, key_index ]
150232
151- # get the transformation
152- if table .n_obs < 6 :
153- raise ValueError ("At least 6 bins are needed to estimate the transformation." )
233+ return Image2DModel .parse (
234+ data = image ,
235+ dims = ("c" , "y" , "x" ),
236+ transformations = transformations ,
237+ c_coords = keys ,
238+ )
154239
155- random_indices = RNG .choice (table .n_obs , min (20 , table .n_obs ), replace = True )
156- location_ids = table .obs [instance_key ].iloc [random_indices ].values
157- sub_df , sub_table = element .loc [location_ids ], table [random_indices ]
158240
159- src = np .stack ([sub_table .obs [col_key ] - min_col , sub_table .obs [row_key ] - min_row ], axis = 1 )
160- if isinstance (sub_df , GeoDataFrame ):
161- if isinstance (sub_df .iloc [0 ].geometry , Point ):
162- sub_x = sub_df .geometry .x .values
163- sub_y = sub_df .geometry .y .values
164- else :
165- assert isinstance (sub_df .iloc [0 ].geometry , Polygon | MultiPolygon )
166- sub_x = sub_df .centroid .x
167- sub_y = sub_df .centroid .y
168- else :
169- assert isinstance (sub_df , DaskDataFrame )
170- sub_x = sub_df .x .compute ().values
171- sub_y = sub_df .y .compute ().values
172- dst = np .stack ([sub_x , sub_y ], axis = 1 )
173-
174- to_bins = Sequence (
175- [
176- Affine (
177- estimate_transform (ttype = "affine" , src = src , dst = dst ).params ,
178- input_axes = ("x" , "y" ),
179- output_axes = ("x" , "y" ),
180- )
181- ]
182- )
183- bins_transformations = get_transformation (element , get_all = True )
241+ def _get_relabeled_column_name (column_name : str ) -> str :
242+ return f"relabeled_{ column_name } "
184243
185- assert isinstance (bins_transformations , dict )
186244
187- transformations = {cs : to_bins .compose_with (t ) for cs , t in bins_transformations .items ()}
245+ def _relabel_labels (table : AnnData , instance_key : str ) -> pd .Series :
246+ labels_values_count = len (table .obs [instance_key ].unique ())
188247
189- return Image2DModel .parse (image , transformations = transformations , c_coords = keys , dims = ("c" , "y" , "x" ))
248+ is_not_numeric = not np .issubdtype (table .obs [instance_key ].dtype , np .number )
249+ zero_in_instance_key = 0 in table .obs [instance_key ].values
250+ has_gaps = not is_not_numeric and labels_values_count != table .obs [instance_key ].max () + int (zero_in_instance_key )
251+
252+ relabeling_is_needed = is_not_numeric or zero_in_instance_key or has_gaps
253+ if relabeling_is_needed :
254+ logger .info (
255+ f"The instance_key column in 'table.obs' ('table.obs[{ instance_key } ]') will be relabeled to ensure"
256+ " a numeric data type, with a continuous range and without including the value 0 (which is reserved "
257+ "for the background). The new labels will be stored in a new column named "
258+ f"{ _get_relabeled_column_name (instance_key )!r} ."
259+ )
260+
261+ relabeled_instance_key_column = table .obs [instance_key ].astype ("category" ).cat .codes + int (zero_in_instance_key )
262+ # uses only allowed dtypes that passes our model validations, in particuar no uint8
263+ dtype = _get_uint_dtype (value = relabeled_instance_key_column .max ())
264+ return relabeled_instance_key_column .astype (dtype )
265+
266+
267+ def rasterize_bins_link_table_to_labels (sdata : SpatialData , table_name : str , rasterized_labels_name : str ) -> None :
268+ """
269+ Change the annotation target of the table to the rasterized labels.
270+
271+ This function should be called after having rasterized the bins (calling `rasterize_bins()` with
272+ `return_regions_as_labels=True`) and after having added the rasterized labels to the spatial data object.
273+
274+ Parameters
275+ ----------
276+ sdata
277+ The spatial data object containing the rasterized labels.
278+ table_name
279+ The name of the table to be annotated.
280+ rasterized_labels_name
281+ The name of the rasterized labels in the spatial data object.
282+ """
283+ _ , region_key , instance_key = get_table_keys (sdata [table_name ])
284+ sdata [table_name ].obs [region_key ] = rasterized_labels_name
285+ relabled_instance_key = _get_relabeled_column_name (instance_key )
286+ sdata .set_table_annotates_spatialelement (
287+ table_name = table_name , region = rasterized_labels_name , region_key = region_key , instance_key = relabled_instance_key
288+ )
0 commit comments