@@ -143,6 +143,7 @@ def _(
143143
144144
145145# TODO: replace function use throughout repo by `join_sdata_spatialelement_table`
146+ # TODO: benchmark against join operations before removing
146147def _filter_table_by_elements (
147148 table : AnnData | None , elements_dict : dict [str , dict [str , Any ]], match_rows : bool = False
148149) -> AnnData | None :
@@ -312,6 +313,8 @@ def _right_exclusive_join_spatialelement_table(
312313 element_dict : dict [str , dict [str , Any ]], table : AnnData , match_rows : Literal ["left" , "no" , "right" ]
313314) -> tuple [dict [str , Any ], AnnData | None ]:
314315 regions , region_column_name , instance_key = get_table_keys (table )
316+ if isinstance (regions , str ):
317+ regions = [regions ]
315318 groups_df = table .obs .groupby (by = region_column_name , observed = False )
316319 mask = []
317320 for element_type , name_element in element_dict .items ():
@@ -350,6 +353,8 @@ def _right_join_spatialelement_table(
350353 if match_rows == "left" :
351354 warnings .warn ("Matching rows 'left' is not supported for 'right' join." , UserWarning , stacklevel = 2 )
352355 regions , region_column_name , instance_key = get_table_keys (table )
356+ if isinstance (regions , str ):
357+ regions = [regions ]
353358 groups_df = table .obs .groupby (by = region_column_name , observed = False )
354359 for element_type , name_element in element_dict .items ():
355360 for name , element in name_element .items ():
@@ -380,6 +385,8 @@ def _inner_join_spatialelement_table(
380385 element_dict : dict [str , dict [str , Any ]], table : AnnData , match_rows : Literal ["left" , "no" , "right" ]
381386) -> tuple [dict [str , Any ], AnnData ]:
382387 regions , region_column_name , instance_key = get_table_keys (table )
388+ if isinstance (regions , str ):
389+ regions = [regions ]
383390 obs = table .obs .reset_index ()
384391 groups_df = obs .groupby (by = region_column_name , observed = False )
385392 joined_indices = None
@@ -424,6 +431,8 @@ def _left_exclusive_join_spatialelement_table(
424431 element_dict : dict [str , dict [str , Any ]], table : AnnData , match_rows : Literal ["left" , "no" , "right" ]
425432) -> tuple [dict [str , Any ], AnnData | None ]:
426433 regions , region_column_name , instance_key = get_table_keys (table )
434+ if isinstance (regions , str ):
435+ regions = [regions ]
427436 groups_df = table .obs .groupby (by = region_column_name , observed = False )
428437 for element_type , name_element in element_dict .items ():
429438 for name , element in name_element .items ():
@@ -457,6 +466,8 @@ def _left_join_spatialelement_table(
457466 if match_rows == "right" :
458467 warnings .warn ("Matching rows 'right' is not supported for 'left' join." , UserWarning , stacklevel = 2 )
459468 regions , region_column_name , instance_key = get_table_keys (table )
469+ if isinstance (regions , str ):
470+ regions = [regions ]
460471 obs = table .obs .reset_index ()
461472 groups_df = obs .groupby (by = region_column_name , observed = False )
462473 joined_indices = None
0 commit comments