Skip to content

Commit c2136b3

Browse files
Fix bug join operation with non-annotating table (#864)
fix bug join non-annotating table wrong region metadata
1 parent e3ab814 commit c2136b3

File tree

2 files changed

+208
-46
lines changed

2 files changed

+208
-46
lines changed

src/spatialdata/_core/query/relational_query.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def _(
143143

144144

145145
# TODO: replace function use throughout repo by `join_sdata_spatialelement_table`
146+
# TODO: benchmark against join operations before removing
146147
def _filter_table_by_elements(
147148
table: AnnData | None, elements_dict: dict[str, dict[str, Any]], match_rows: bool = False
148149
) -> AnnData | None:
@@ -312,6 +313,8 @@ def _right_exclusive_join_spatialelement_table(
312313
element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"]
313314
) -> tuple[dict[str, Any], AnnData | None]:
314315
regions, region_column_name, instance_key = get_table_keys(table)
316+
if isinstance(regions, str):
317+
regions = [regions]
315318
groups_df = table.obs.groupby(by=region_column_name, observed=False)
316319
mask = []
317320
for element_type, name_element in element_dict.items():
@@ -350,6 +353,8 @@ def _right_join_spatialelement_table(
350353
if match_rows == "left":
351354
warnings.warn("Matching rows 'left' is not supported for 'right' join.", UserWarning, stacklevel=2)
352355
regions, region_column_name, instance_key = get_table_keys(table)
356+
if isinstance(regions, str):
357+
regions = [regions]
353358
groups_df = table.obs.groupby(by=region_column_name, observed=False)
354359
for element_type, name_element in element_dict.items():
355360
for name, element in name_element.items():
@@ -380,6 +385,8 @@ def _inner_join_spatialelement_table(
380385
element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"]
381386
) -> tuple[dict[str, Any], AnnData]:
382387
regions, region_column_name, instance_key = get_table_keys(table)
388+
if isinstance(regions, str):
389+
regions = [regions]
383390
obs = table.obs.reset_index()
384391
groups_df = obs.groupby(by=region_column_name, observed=False)
385392
joined_indices = None
@@ -424,6 +431,8 @@ def _left_exclusive_join_spatialelement_table(
424431
element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"]
425432
) -> tuple[dict[str, Any], AnnData | None]:
426433
regions, region_column_name, instance_key = get_table_keys(table)
434+
if isinstance(regions, str):
435+
regions = [regions]
427436
groups_df = table.obs.groupby(by=region_column_name, observed=False)
428437
for element_type, name_element in element_dict.items():
429438
for name, element in name_element.items():
@@ -457,6 +466,8 @@ def _left_join_spatialelement_table(
457466
if match_rows == "right":
458467
warnings.warn("Matching rows 'right' is not supported for 'left' join.", UserWarning, stacklevel=2)
459468
regions, region_column_name, instance_key = get_table_keys(table)
469+
if isinstance(regions, str):
470+
regions = [regions]
460471
obs = table.obs.reset_index()
461472
groups_df = obs.groupby(by=region_column_name, observed=False)
462473
joined_indices = None

0 commit comments

Comments
 (0)