diff --git a/setup.cfg b/setup.cfg index 111dcc3..02f0ef6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -56,7 +56,6 @@ install_requires = pillow>=11.0 requests - [options.packages.find] where = src exclude = @@ -66,6 +65,9 @@ exclude = # Add here additional requirements for extra features, to install with: # `pip install SpatialExperiment[PDF]` like: # PDF = ReportLab; RXP +extra = + spatialdata + anndata # Add here test requirements (semicolon/line-separated) testing = diff --git a/src/spatialexperiment/SpatialExperiment.py b/src/spatialexperiment/SpatialExperiment.py index facd646..a49cab6 100644 --- a/src/spatialexperiment/SpatialExperiment.py +++ b/src/spatialexperiment/SpatialExperiment.py @@ -861,6 +861,109 @@ def mirror_img(self, sample_id=None, image_id=None, axis=("h", "v")): def to_spatial_experiment(): raise NotImplementedError() + ################################ + ####>> SpatialData interop <<### + ################################ + + @classmethod + def from_spatialdata(cls, input: "spatialdata.SpatialData", points_key: str = "") -> "SpatialExperiment": + """Create a ``SpatialExperiment`` from :py:class:`~spatialdata.SpatialData`. + + When building ``SpatialExperiment``'s `img_data`, if the image is stored as a :py:class:`~xarray.DataArray`, the corresponding key will be used as the `sample_id`, `DataArray.name` will be used for the `image_id`, and `DataArray.attrs['scale_factor']` for the `scale_factor`. + + For when the images are stored as a :py:class:`~xarray.DataTree`, see :py:func:`~spatialdata._sdatautils.build_img_data` for details. + + **NOTE**: This is a lossy conversion. The resulting ``SpatialExperiment`` only preserves a subset of the data from the incoming `SpatialData` object. + + Args: + input: + Input data. + points_key: + The key corresponding to the DataFrame that should be used for constructing spatial coordinates. Defaults to the first entry. + + Returns: + A ``SpatialExperiment`` object. + """ + from spatialdata import SpatialData + from xarray import DataArray, DataTree + from ._sdatautils import build_img_data + + if not isinstance(input, SpatialData): + raise TypeError("Input must be a `SpatialData` object.") + + # validate that the incoming SpatialData can be converted to a SpatialExperiment + points = input.points + if points_key: + points_elem = points[points_key] + else: + points_elem = next(iter(points.values())) + + adata = input.table + if adata.shape[0] != len(points_elem): + raise ValueError("Table and Points must have the same number of observations.") + + sce = super().from_anndata(adata) + + # build spatial coordinates + coords_2d = {'x', 'y'} + coords_3d = {'x', 'y', 'z'} + + points_cols = set(points_elem.columns) + if coords_3d.issubset(points_cols): + coords_cols = list(coords_3d) + elif coords_2d.issubset(points_cols): + coords_cols = list(coords_2d) + else: + coords_cols = [] + + if coords_cols: + spatial_coords = points_elem[coords_cols].compute() + spatial_coords = BiocFrame.from_pandas(spatial_coords) + + # build image data + images = input.images + img_data = BiocFrame( + { + "sample_id": [], + "image_id": [], + "data": [], + "scale_factor": [] + } + ) + for name, image in images.items(): + if isinstance(image, DataArray): + curr_img = construct_spatial_image_class(np.array(image)) + curr_scale_factor = [image.attrs["scale_factor"]] if "scale_factor" in image.attrs else [np.nan] + curr_img_data = BiocFrame({ + "sample_id": [name], + "image_id": [image.name], + "data": [curr_img], + "scale_factor": curr_scale_factor + }) + elif isinstance(image, DataTree): + curr_img_data = build_img_data(image, name) + else: + raise TypeError(f"Cannot build image data from {type(image)}") + + img_data = img_data.combine_rows(curr_img_data) + + return cls( + assays=sce.assays, + row_ranges=sce.row_ranges, + row_data=sce.row_data, + column_data=sce.col_data, + row_names=sce.row_names, + column_names=sce.column_names, + metadata=sce.metadata, + reduced_dims=sce.reduced_dims, + main_experiment_name=sce.main_experiment_name, + alternative_experiments=sce.alternative_experiments, + row_pairs=sce.row_pairs, + column_pairs=sce.column_pairs, + spatial_coords=spatial_coords, + img_data=img_data + ) + ################################ #######>> combine ops <<######## ################################ diff --git a/src/spatialexperiment/_sdatautils.py b/src/spatialexperiment/_sdatautils.py new file mode 100644 index 0000000..36b84c0 --- /dev/null +++ b/src/spatialexperiment/_sdatautils.py @@ -0,0 +1,75 @@ +import numpy as np +from biocframe import BiocFrame +from xarray import DataArray, DataTree, Variable + +from .SpatialImage import construct_spatial_image_class + + +def process_dataset_images(dt: DataTree, root_name: str) -> BiocFrame: + """Processes image-related attributes from a :py:class:`~xarray.DataTree` object and compiles them into a :py:class:`~biocframe.BiocFrame`. The resulting BiocFrame adheres to the standards required for a ``SpatialExperiment``'s `img_data`. + + Args: + dt: A DataTree object containing datasets with image data. + root_name: An identifier for the highest ancestor of the DataTree to which this subtree belongs. + + Returns: + A BiocFrame that conforms to the standards of a ``SpatialExperiment``'s `img_data`. + """ + img_data = BiocFrame( + { + "sample_id": [], + "image_id": [], + "data": [], + "scale_factor": [] + } + ) + for var_name, obj in dt.dataset.items(): + if isinstance(obj, (DataArray, Variable)): + var = obj + else: + dims, data, *optional = obj + attrs = optional[0] if optional else None + var = Variable(dims=dims, data=data, attrs=attrs) + + scale_factor = var.attrs.get("scale_factor", np.nan) + spi = construct_spatial_image_class(np.array(var)) + img_row = BiocFrame( + { + "sample_id": [f"{root_name}::{dt.name}"], + "image_id": [var_name], + "data": [spi], + "scale_factor": [scale_factor] + } + ) + img_data = img_data.combine_rows(img_row) + + return img_data + + +def build_img_data(dt: DataTree, root_name: str): + """Recursively compiles image data from a :py:class:`~xarray.DataTree` into a :py:class:`~biocframe.BiocFrame.BiocFrame` structure. + + This function traverses a `DataTree`, extracting image-related attributes from each dataset and compiling them into a `BiocFrame`. It processes the parent dataset and recursively handles dataset(s) from child nodes. The resulting `BiocFrame` adheres to the standards required for a ``SpatialExperiment``'s `img_data`. + + The following conditions are assumed: + - `DataTree.name` will be used as the `sample_id`. + - The keys of `dt.dataset.data_vars` will be used as the `image_id`'s of each image. + - The `scale_factor` is extracted from the attributes of the objects in `dt.dataset.data_vars`. + + Args: + dt: A DataTree object containing datasets with image data. + root_name: An identifier for the highest ancestor of the DataTree to which this subtree belongs. + + Returns: + A BiocFrame containing compiled image data for the entire DataTree. + """ + if len(dt.children) == 0: + return process_dataset_images(dt, root_name) + + parent_img_data = process_dataset_images(dt, root_name) + + for key, child in dt.children.items(): + child_img_data = build_img_data(child, root_name) + parent_img_data = parent_img_data.combine_rows(child_img_data) + + return parent_img_data diff --git a/tests/conftest.py b/tests/conftest.py index 0123424..46ce1b0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,11 @@ import pytest +from random import random import numpy as np from biocframe import BiocFrame +import anndata as ad +import spatialdata as sd from spatialexperiment import SpatialExperiment, construct_spatial_image_class -from random import random +from spatialdata.models import Image2DModel, PointsModel @pytest.fixture @@ -70,3 +73,40 @@ def spe(): ) return spe_instance + + +@pytest.fixture +def sdata(): + img = np.random.randint(0, 256, size=(50, 50, 3), dtype=np.uint8) + img = Image2DModel.parse(data=img) + img.name = "image01" + img.attrs['scale_factor'] = 1 + + num_cols = 25 + x_coords = np.random.uniform(low=0.0, high=100.0, size=num_cols) + y_coords = np.random.uniform(low=0.0, high=100.0, size=num_cols) + stacked_coords = np.column_stack((x_coords, y_coords)) + points = PointsModel.parse(stacked_coords) + + n_vars = 10 + X = np.random.random((num_cols, n_vars)) + adata = ad.AnnData(X=X) + + sdata = sd.SpatialData( + images={"sample01": img}, + points={"coords": points}, + tables=adata + ) + + return sdata + + +@pytest.fixture +def sdata_tree(): + img_1 = np.random.randint(0, 256, size=(50, 50, 3), dtype=np.uint8) + img_1 = Image2DModel.parse(data=img_1) + img_1.name = "image01" + + img_2 = np.random.randint(0, 256, size=(50, 50, 3), dtype=np.uint8) + img_2 = Image2DModel.parse(data=img_2) + img_2.name = "image02" diff --git a/tests/test_sdata_interop.py b/tests/test_sdata_interop.py new file mode 100644 index 0000000..d34755c --- /dev/null +++ b/tests/test_sdata_interop.py @@ -0,0 +1,24 @@ +import pytest +from spatialexperiment import SpatialExperiment + +def test_from_sdata(sdata): + spe = SpatialExperiment.from_spatialdata(sdata) + + assert isinstance(spe, SpatialExperiment) + + table = sdata['table'] + assert spe.shape == (table.shape[1], table.shape[0]) + + sdata_points = next(iter(sdata.points.values())) + assert spe.spatial_coords.shape == (len(sdata_points), sdata_points.shape[1]) + assert sorted(spe.spatial_coords.columns.as_list()) == sorted(['x','y']) + + assert spe.img_data.shape == (1, 4) + assert spe.img_data["sample_id"] == ["sample01"] + assert spe.img_data["image_id"] == ["image01"] + assert spe.img_data["scale_factor"] == [1] + + +def test_invalid_input(): + with pytest.raises(TypeError): + SpatialExperiment.from_spatialdata("Not a SpatialData object!") \ No newline at end of file diff --git a/tox.ini b/tox.ini index 69f8159..32d9620 100644 --- a/tox.ini +++ b/tox.ini @@ -17,6 +17,7 @@ passenv = SETUPTOOLS_* extras = testing + extra commands = pytest {posargs}