Skip to content

Commit 38e73ba

Browse files
allow filtering by ids (#627)
* allow filtering by ids * wip match_table_to_sdata() * add match_table_to_sdata(); improve table model, get_annotated_regions() * remove filter() * add default argument for match_sdata_to_table(); fix tests * Apply suggestions from code review Co-authored-by: Wouter-Michiel Vierdag <[email protected]> * apply suggestions from code review --------- Co-authored-by: Luca Marconato <[email protected]> Co-authored-by: LucaMarconato <[email protected]>
1 parent 5e2b1a4 commit 38e73ba

File tree

7 files changed

+271
-16
lines changed

7 files changed

+271
-16
lines changed

docs/api/operations.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Operations on `SpatialData` objects.
1414
.. autofunction:: join_spatialelement_table
1515
.. autofunction:: match_element_to_table
1616
.. autofunction:: match_table_to_element
17+
.. autofunction:: match_sdata_to_table
1718
.. autofunction:: concatenate
1819
.. autofunction:: transform
1920
.. autofunction:: rasterize

src/spatialdata/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
"join_spatialelement_table",
4141
"match_element_to_table",
4242
"match_table_to_element",
43+
"match_sdata_to_table",
4344
"SpatialData",
4445
"get_extent",
4546
"get_centroids",
@@ -72,6 +73,7 @@
7273
get_values,
7374
join_spatialelement_table,
7475
match_element_to_table,
76+
match_sdata_to_table,
7577
match_table_to_element,
7678
)
7779
from spatialdata._core.query.spatial_query import bounding_box_query, polygon_query

src/spatialdata/_core/query/relational_query.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,45 @@ def match_element_to_table(
784784
return element_dict, table
785785

786786

787+
def match_sdata_to_table(
788+
sdata: SpatialData,
789+
table_name: str,
790+
table: AnnData | None = None,
791+
how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right",
792+
) -> SpatialData:
793+
"""
794+
Filter the elements of a SpatialData object to match only the rows present in the table.
795+
796+
Parameters
797+
----------
798+
sdata
799+
SpatialData object containing all the elements and tables.
800+
table
801+
The table to join with the spatial elements. Has precedence over `table_name`.
802+
table_name
803+
The name of the table to join with the SpatialData object if `table` is not provided. If table is provided,
804+
`table_name` is used to name the table in the returned `SpatialData` object.
805+
how
806+
The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right".
807+
808+
"""
809+
if table is None:
810+
table = sdata[table_name]
811+
_, region_key, instance_key = get_table_keys(table)
812+
annotated_regions = SpatialData.get_annotated_regions(table)
813+
filtered_elements, filtered_table = join_spatialelement_table(
814+
sdata, spatial_element_names=annotated_regions, table=table, how=how
815+
)
816+
filtered_table = TableModel.parse(
817+
filtered_table,
818+
region=annotated_regions,
819+
region_key=region_key,
820+
instance_key=instance_key,
821+
overwrite_metadata=True,
822+
)
823+
return SpatialData.init_from_elements(filtered_elements | {table_name: filtered_table})
824+
825+
787826
@dataclass
788827
class _ValueOrigin:
789828
origin: str

src/spatialdata/_core/spatialdata.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def from_elements_dict(
262262
return SpatialData.init_from_elements(elements=elements_dict, attrs=attrs)
263263

264264
@staticmethod
265-
def get_annotated_regions(table: AnnData) -> str | list[str]:
265+
def get_annotated_regions(table: AnnData) -> list[str]:
266266
"""
267267
Get the regions annotated by a table.
268268
@@ -275,8 +275,9 @@ def get_annotated_regions(table: AnnData) -> str | list[str]:
275275
-------
276276
The annotated regions.
277277
"""
278-
regions, _, _ = get_table_keys(table)
279-
return regions
278+
from spatialdata.models.models import _get_region_metadata_from_region_key_column
279+
280+
return _get_region_metadata_from_region_key_column(table)
280281

281282
@staticmethod
282283
def get_region_key_column(table: AnnData) -> pd.Series:

src/spatialdata/models/models.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,6 +1064,7 @@ def parse(
10641064
region: str | list[str] | None = None,
10651065
region_key: str | None = None,
10661066
instance_key: str | None = None,
1067+
overwrite_metadata: bool = False,
10671068
) -> AnnData:
10681069
"""
10691070
Parse the :class:`anndata.AnnData` to be compatible with the model.
@@ -1078,6 +1079,8 @@ def parse(
10781079
Key in `adata.obs` that specifies the region.
10791080
instance_key
10801081
Key in `adata.obs` that specifies the instance.
1082+
overwrite_metadata
1083+
If `True`, the `region`, `region_key` and `instance_key` metadata will be overwritten.
10811084
10821085
Returns
10831086
-------
@@ -1087,31 +1090,38 @@ def parse(
10871090
# either all live in adata.uns or all be passed in as argument
10881091
n_args = sum([region is not None, region_key is not None, instance_key is not None])
10891092
if n_args == 0:
1090-
return adata
1091-
if n_args > 0:
1092-
if cls.ATTRS_KEY in adata.uns:
1093-
raise ValueError(
1094-
f"`{cls.REGION_KEY}`, `{cls.REGION_KEY_KEY}` and / or `{cls.INSTANCE_KEY}` is/has been passed as"
1095-
f"as argument(s). However, `adata.uns[{cls.ATTRS_KEY!r}]` has already been set."
1096-
)
1097-
elif cls.ATTRS_KEY in adata.uns:
1093+
if cls.ATTRS_KEY not in adata.uns:
1094+
# table not annotating any element
1095+
return adata
10981096
attr = adata.uns[cls.ATTRS_KEY]
10991097
region = attr[cls.REGION_KEY]
11001098
region_key = attr[cls.REGION_KEY_KEY]
11011099
instance_key = attr[cls.INSTANCE_KEY]
1100+
elif n_args > 0 and not overwrite_metadata and cls.ATTRS_KEY in adata.uns:
1101+
raise ValueError(
1102+
f"`{cls.REGION_KEY}`, `{cls.REGION_KEY_KEY}` and / or `{cls.INSTANCE_KEY}` is/has been passed as"
1103+
f" argument(s). However, `adata.uns[{cls.ATTRS_KEY!r}]` has already been set."
1104+
)
1105+
1106+
if cls.ATTRS_KEY not in adata.uns:
1107+
adata.uns[cls.ATTRS_KEY] = {}
11021108

1109+
if region is None:
1110+
raise ValueError(f"`{cls.REGION_KEY}` must be provided.")
11031111
if region_key is None:
11041112
raise ValueError(f"`{cls.REGION_KEY_KEY}` must be provided.")
1113+
if instance_key is None:
1114+
raise ValueError("`instance_key` must be provided.")
1115+
11051116
if isinstance(region, np.ndarray):
11061117
region = region.tolist()
1107-
if region is None:
1108-
raise ValueError(f"`{cls.REGION_KEY}` must be provided.")
11091118
region_: list[str] = region if isinstance(region, list) else [region]
11101119
if not adata.obs[region_key].isin(region_).all():
11111120
raise ValueError(f"`adata.obs[{region_key}]` values do not match with `{cls.REGION_KEY}` values.")
11121121

1113-
if instance_key is None:
1114-
raise ValueError("`instance_key` must be provided.")
1122+
adata.uns[cls.ATTRS_KEY][cls.REGION_KEY] = region
1123+
adata.uns[cls.ATTRS_KEY][cls.REGION_KEY_KEY] = region_key
1124+
adata.uns[cls.ATTRS_KEY][cls.INSTANCE_KEY] = instance_key
11151125

11161126
# note! this is an expensive check and therefore we skip it during validation
11171127
# https://github.com/scverse/spatialdata/issues/715
@@ -1214,3 +1224,20 @@ def get_table_keys(table: AnnData) -> tuple[str | list[str], str, str]:
12141224
raise ValueError(
12151225
"No spatialdata_attrs key found in table.uns, therefore, no table keys found. Please parse the table."
12161226
)
1227+
1228+
1229+
def _get_region_metadata_from_region_key_column(table: AnnData) -> list[str]:
1230+
_, region_key, instance_key = get_table_keys(table)
1231+
region_key_column = table.obs[region_key]
1232+
if not isinstance(region_key_column.dtype, CategoricalDtype):
1233+
warnings.warn(
1234+
f"The region key column `{region_key}` is not of type `pd.Categorical`. Consider casting it to "
1235+
f"improve performance.",
1236+
UserWarning,
1237+
stacklevel=2,
1238+
)
1239+
annotated_regions = region_key_column.unique().tolist()
1240+
else:
1241+
annotated_regions = table.obs[region_key].cat.remove_unused_categories().cat.categories.unique().tolist()
1242+
assert isinstance(annotated_regions, list)
1243+
return annotated_regions
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import pytest
2+
3+
from spatialdata import SpatialData, concatenate, match_sdata_to_table
4+
from spatialdata.datasets import blobs_annotating_element
5+
6+
7+
def _make_test_data() -> SpatialData:
8+
sdata1 = blobs_annotating_element("blobs_polygons")
9+
sdata2 = blobs_annotating_element("blobs_polygons")
10+
sdata = concatenate({"sdata1": sdata1, "sdata2": sdata2}, concatenate_tables=True)
11+
sdata["table"].obs["value"] = list(range(sdata["table"].obs.shape[0]))
12+
return sdata
13+
14+
15+
# constructing the example data; let's use a global variable as we can reuse the same object on most tests
16+
# without having to recreate it
17+
sdata = _make_test_data()
18+
19+
20+
def test_match_sdata_to_table_filter_specific_instances():
21+
"""
22+
Filter to keep only specific instances. Note that it works even when the table annotates multiple elements.
23+
"""
24+
matched = match_sdata_to_table(
25+
sdata,
26+
table=sdata["table"][sdata["table"].obs.instance_id.isin([1, 2])],
27+
table_name="table",
28+
)
29+
assert len(matched["table"]) == 4
30+
assert "blobs_polygons-sdata1" in matched
31+
assert "blobs_polygons-sdata2" in matched
32+
33+
34+
def test_match_sdata_to_table_filter_specific_instances_element():
35+
"""
36+
Filter to keep only specific instances, in a specific element.
37+
"""
38+
matched = match_sdata_to_table(
39+
sdata,
40+
table=sdata["table"][
41+
sdata["table"].obs.instance_id.isin([1, 2]) & (sdata["table"].obs.region == "blobs_polygons-sdata1")
42+
],
43+
table_name="table",
44+
)
45+
assert len(matched["table"]) == 2
46+
assert "blobs_polygons-sdata1" in matched
47+
assert "blobs_polygons-sdata2" not in matched
48+
49+
50+
def test_match_sdata_to_table_filter_by_threshold():
51+
"""
52+
Filter by a threshold on a value column, in a specific element.
53+
"""
54+
matched = match_sdata_to_table(
55+
sdata,
56+
table=sdata["table"][sdata["table"].obs.query('value < 5 and region == "blobs_polygons-sdata1"').index],
57+
table_name="table",
58+
)
59+
assert len(matched["table"]) == 5
60+
assert "blobs_polygons-sdata1" in matched
61+
assert "blobs_polygons-sdata2" not in matched
62+
63+
64+
def test_match_sdata_to_table_subset_certain_obs():
65+
"""
66+
Subset to certain obs (we could also subset to certain var or layer).
67+
"""
68+
matched = match_sdata_to_table(
69+
sdata,
70+
table=sdata["table"][[0, 1, 2, 3]],
71+
table_name="table",
72+
)
73+
assert len(matched["table"]) == 4
74+
assert "blobs_polygons-sdata1" in matched
75+
assert "blobs_polygons-sdata2" not in matched
76+
77+
78+
def test_match_sdata_to_table_shapes_and_points():
79+
"""
80+
The function works both for shapes (examples above) and points.
81+
Changes the target of the table to labels.
82+
"""
83+
sdata = _make_test_data()
84+
sdata["table"].obs["region"] = sdata["table"].obs["region"].apply(lambda x: x.replace("polygons", "points"))
85+
sdata["table"].obs["region"] = sdata["table"].obs["region"].astype("category")
86+
sdata.set_table_annotates_spatialelement(
87+
table_name="table",
88+
region=["blobs_points-sdata1", "blobs_points-sdata2"],
89+
region_key="region",
90+
instance_key="instance_id",
91+
)
92+
93+
matched = match_sdata_to_table(
94+
sdata,
95+
table=sdata["table"],
96+
table_name="table",
97+
)
98+
99+
assert len(matched["table"]) == 10
100+
assert "blobs_points-sdata1" in matched
101+
assert "blobs_points-sdata2" in matched
102+
assert "blobs_polygons-sdata1" not in matched
103+
104+
105+
def test_match_sdata_to_table_match_labels_error():
106+
"""
107+
match_sdata_to_table() uses the join operations; so when trying to match labels, the error will be raised by the
108+
join.
109+
"""
110+
sdata = _make_test_data()
111+
sdata["table"].obs["region"] = sdata["table"].obs["region"].apply(lambda x: x.replace("polygons", "labels"))
112+
sdata["table"].obs["region"] = sdata["table"].obs["region"].astype("category")
113+
sdata.set_table_annotates_spatialelement(
114+
table_name="table",
115+
region=["blobs_labels-sdata1", "blobs_labels-sdata2"],
116+
region_key="region",
117+
instance_key="instance_id",
118+
)
119+
120+
with pytest.warns(
121+
UserWarning,
122+
match="Element type `labels` not supported for 'right' join. Skipping ",
123+
):
124+
matched = match_sdata_to_table(
125+
sdata,
126+
table=sdata["table"],
127+
table_name="table",
128+
)
129+
130+
assert len(matched["table"]) == 10
131+
assert "blobs_labels-sdata1" in matched
132+
assert "blobs_labels-sdata2" in matched
133+
assert "blobs_points-sdata1" not in matched
134+
135+
136+
def test_match_sdata_to_table_no_table_argument():
137+
"""
138+
If no table argument is passed, the table_name argument will be used to match the table.
139+
"""
140+
matched = match_sdata_to_table(sdata=sdata, table_name="table")
141+
142+
assert len(matched["table"]) == 10
143+
assert "blobs_polygons-sdata1" in matched
144+
assert "blobs_polygons-sdata2" in matched

tests/models/test_models.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from spatialdata._core.validation import ValidationError
3030
from spatialdata._types import ArrayLike
3131
from spatialdata.config import LARGE_CHUNK_THRESHOLD_BYTES
32+
from spatialdata.models import get_table_keys
3233
from spatialdata.models._utils import (
3334
force_2d,
3435
points_dask_dataframe_to_geopandas,
@@ -377,6 +378,46 @@ def test_table_model(
377378
assert TableModel.REGION_KEY_KEY in table.uns[TableModel.ATTRS_KEY]
378379
assert table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] == region
379380

381+
# error when trying to parse a table by specifying region, region_key, instance_key, but these keys are
382+
# already set
383+
with pytest.raises(ValueError, match=" has already been set"):
384+
_ = TableModel.parse(adata, region=region, region_key=region_key, instance_key="A")
385+
386+
# error when region is missing
387+
with pytest.raises(ValueError, match="`region` must be provided"):
388+
_ = TableModel.parse(adata, region_key=region_key, instance_key="A", overwrite_metadata=True)
389+
390+
# error when region_key is missing
391+
with pytest.raises(ValueError, match="`region_key` must be provided"):
392+
_ = TableModel.parse(adata, region=region, instance_key="A", overwrite_metadata=True)
393+
394+
# error when instance_key is missing
395+
with pytest.raises(ValueError, match="`instance_key` must be provided"):
396+
_ = TableModel.parse(adata, region=region, region_key=region_key, overwrite_metadata=True)
397+
398+
# we try to overwrite, but the values in the `region_key` column do not match the expected `region` values
399+
with pytest.raises(ValueError, match="values do not match with `region` values"):
400+
_ = TableModel.parse(adata, region="element", region_key="B", instance_key="C", overwrite_metadata=True)
401+
402+
# we correctly overwrite; here we check that the metadata is updated
403+
region_, region_key_, instance_key_ = get_table_keys(table)
404+
assert region_ == region
405+
assert region_key_ == region_key
406+
assert instance_key_ == "A"
407+
408+
# let's fix the region_key column
409+
table.obs["B"] = ["element"] * len(table)
410+
_ = TableModel.parse(adata, region="element", region_key="B", instance_key="C", overwrite_metadata=True)
411+
412+
region_, region_key_, instance_key_ = get_table_keys(table)
413+
assert region_ == "element"
414+
assert region_key_ == "B"
415+
assert instance_key_ == "C"
416+
417+
# we can parse a table when no metadata is present (i.e. the table does not annotate any element)
418+
del table.uns[TableModel.ATTRS_KEY]
419+
_ = TableModel.parse(table)
420+
380421
@pytest.mark.parametrize(
381422
"name",
382423
[
@@ -423,7 +464,7 @@ def test_table_instance_key_values_not_unique(self, model: TableModel, region: s
423464
ValueError,
424465
match=re.escape("Instance key column for region(s) `sample_1, sample_2`"),
425466
):
426-
model.parse(adata, region=region, region_key=region_key, instance_key="A")
467+
model.parse(adata, region=region, region_key=region_key, instance_key="A", overwrite_metadata=True)
427468

428469
@pytest.mark.parametrize(
429470
"key",

0 commit comments

Comments
 (0)