55import dask .array as da
66import numpy as np
77import pandas as pd
8+ from anndata import AnnData
89from dask .dataframe import DataFrame as DaskDataFrame
910from geopandas import GeoDataFrame
1011from numpy .random import default_rng
2122
2223RNG = default_rng (0 )
2324
25+ __all__ = ["rasterize_bins" , "rasterize_bins_link_table_to_labels" ]
26+
2427
2528if TYPE_CHECKING :
2629 from spatialdata import SpatialData
@@ -58,7 +61,8 @@ def rasterize_bins(
5861 If `False` this function returns a `xarray.DataArray` of shape `(c, y, x)` with dimension
5962 of `c` equal to the number of key(s) specified in `value_key`, or the number of var names
6063 in `table_name` if `value_key` is `None`. If `True`, will return labels of shape `(y, x)`,
61- where each bin of the `bins` element will be represented as a pixel.
64+ where each bin of the `bins` element will be represented as a pixel. The table by default will not be set to
65+ annotate the new rasterized labels; see the Notes section for how to do this.
6266
6367 Returns
6468 -------
@@ -78,6 +82,10 @@ def rasterize_bins(
7882
7983 If `spatialdata-plot` is used to visualized the returned image, the parameter `scale='full'` needs to be passed to
8084 `.render_shapes()`, to disable an automatic rasterization that would confict with the rasterization performed here.
85+
86+ When `return_region_as_labels` is `True`, the function will return a labels layer that is not annotated by default
87+ by the table. To change the annotation target of the table you can call the helper function
88+ `spatialdata.rasterize_bins_link_table_to_labels()`.
8189 """
8290 element = sdata [bins ]
8391 table = sdata .tables [table_name ]
@@ -159,22 +167,12 @@ def rasterize_bins(
159167 transformations = {cs : to_bins .compose_with (t ) for cs , t in bins_transformations .items ()}
160168
161169 if return_region_as_labels :
162- dtype = _get_uint_dtype (table .obs [instance_key ].max ())
163- _min_value = table .obs [instance_key ].min ()
164- # TODO: add a new column instead of modyfing the table inplace
165- # TODO: do not modify the index of the elements
166- if _min_value == 0 :
167- logger .info (
168- f"The minimum value of the instance key column ('table.obs[{ instance_key } ]') has been"
169- " detected to be 0. Since the label 0 is reserved for the background, "
170- "both the instance key column in 'table.obs' "
171- f"and the index of the annotating element '{ bins } ' is incremented by 1."
172- )
173- table .obs [instance_key ] += 1
174- element .index += 1
170+ new_instance_key = _get_relabeled_column_name (instance_key )
171+ table .obs [new_instance_key ] = _relabel_labels (table = table , instance_key = instance_key )
172+ dtype = table .obs [new_instance_key ].dtype
175173 labels_element = np .zeros ((n_rows , n_cols ), dtype = dtype )
176174 # make labels layer that can visualy represent the cells
177- labels_element [y , x ] = table .obs [instance_key ].values .T
175+ labels_element [y , x ] = table .obs [new_instance_key ].values .T
178176
179177 return Labels2DModel .parse (data = labels_element , dims = ("y" , "x" ), transformations = transformations )
180178
@@ -242,17 +240,69 @@ def channel_rasterization(block_id: tuple[int, int, int] | None) -> ArrayLike:
242240 )
243241
244242
245- def _get_uint_dtype (value : int ) -> str :
243+ def _get_uint_dtype (labels_values_count : int ) -> str :
246244 max_uint64 = np .iinfo (np .uint64 ).max
247245 max_uint32 = np .iinfo (np .uint32 ).max
248246 max_uint16 = np .iinfo (np .uint16 ).max
249247
250- if max_uint16 >= value :
248+ if max_uint16 >= labels_values_count :
251249 dtype = "uint16"
252- elif max_uint32 >= value :
250+ elif max_uint32 >= labels_values_count :
253251 dtype = "uint32"
254- elif max_uint64 >= value :
252+ elif max_uint64 >= labels_values_count :
255253 dtype = "uint64"
256254 else :
257- raise ValueError (f"Maximum cell number is { value } . Values higher than { max_uint64 } are not supported." )
255+ raise ValueError (
256+ f"Maximum cell number is { labels_values_count } . Values higher than { max_uint64 } are not supported."
257+ )
258258 return dtype
259+
260+
261+ def _get_relabeled_column_name (column_name : str ) -> str :
262+ return f"relabeled_{ column_name } "
263+
264+
265+ def _relabel_labels (table : AnnData , instance_key : str ) -> pd .Series :
266+ labels_values_count = len (table .obs [instance_key ].unique ())
267+
268+ is_not_numeric = not np .issubdtype (table .obs [instance_key ].dtype , np .number )
269+ zero_in_instance_key = 0 in table .obs [instance_key ].values
270+ has_gaps = not is_not_numeric and labels_values_count != table .obs [instance_key ].max () + int (zero_in_instance_key )
271+
272+ relabeling_is_needed = is_not_numeric or zero_in_instance_key or has_gaps
273+ if relabeling_is_needed :
274+ logger .info (
275+ f"The instance_key column in 'table.obs' ('table.obs[{ instance_key } ]') will be relabeled to ensure"
276+ " a numeric data type, with a continuous range and without including the value 0 (which is reserved "
277+ "for the background). The new labels will be stored in a new column named "
278+ f"{ _get_relabeled_column_name (instance_key )!r} ."
279+ )
280+
281+ relabeled_instance_key_column = table .obs [instance_key ].astype ("category" ).cat .codes + int (zero_in_instance_key )
282+ # uses only allowed dtypes that passes our model validations, in particualr no uint8
283+ dtype = _get_uint_dtype (labels_values_count = relabeled_instance_key_column .max ())
284+ return relabeled_instance_key_column .astype (dtype )
285+
286+
287+ def rasterize_bins_link_table_to_labels (sdata : SpatialData , table_name : str , rasterized_labels_name : str ) -> None :
288+ """
289+ Annotates the table with the rasterized labels.
290+
291+ This function should be called after rasterizing the bins (with `return_regions_as_labels` is `True`) and adding
292+ the rasterized labels in the spatial data object.
293+
294+ Parameters
295+ ----------
296+ sdata
297+ The spatial data object containing the rasterized labels.
298+ table_name
299+ The name of the table to be annotated.
300+ rasterized_labels_name
301+ The name of the rasterized labels in the spatial data object.
302+ """
303+ _ , region_key , instance_key = get_table_keys (sdata [table_name ])
304+ sdata [table_name ].obs [region_key ] = rasterized_labels_name
305+ relabled_instance_key = _get_relabeled_column_name (instance_key )
306+ sdata .set_table_annotates_spatialelement (
307+ table_name = table_name , region = rasterized_labels_name , region_key = region_key , instance_key = relabled_instance_key
308+ )
0 commit comments