diff --git a/src/segger/export/spatialdata_writer.py b/src/segger/export/spatialdata_writer.py index 75db1c4..9bada98 100644 --- a/src/segger/export/spatialdata_writer.py +++ b/src/segger/export/spatialdata_writer.py @@ -94,7 +94,7 @@ def __init__( shapes_key: str = "cells", fragment_shapes_key: str = "fragments", include_table: bool = True, - table_key: str = "cells_table", # no duplicate names allowed + table_key: str = "cells_table", fragment_table_key: str = "fragments_table", table_region_key: str = "cell_id", ): @@ -250,7 +250,6 @@ def _create_spatialdata( """Create SpatialData object from transcripts and boundaries.""" import spatialdata from spatialdata.models import PointsModel, ShapesModel, TableModel - import spatialdata.models._accessor # for points parsing on pre-release (https://github.com/scverse/spatialdata/issues/1093) import dask.dataframe as dd identity = self._identity_transform() @@ -291,6 +290,7 @@ def _create_spatialdata( tx_pd[col] = tx_pd[col].astype(float) # Create Dask DataFrame for points + tx_pd[feature_column] = tx_pd[feature_column].astype("category") tx_dask = dd.from_pandas(tx_pd) # Points element @@ -336,17 +336,17 @@ def _parse_shapes(shapes): if self.include_boundaries and self.boundary_method != "skip": if self.boundary_method == "input": - for bd_type in ["cell", "nucleus"]: # these are segger hard-coded + bd_types = {"cell": "cells", "nucleus": "nuclei"} + for k, v in bd_types.items(): shapes = self._get_input_boundaries( cell_tx_pd, cell_id_column, boundaries, - bd_type) + k) shapes = _ensure_cell_id(shapes) parsed = _parse_shapes(shapes) if parsed is not None: - shapes_elements[f"{bd_type}_boundaries"] = parsed - # this naming convention is very Xenium-based (ideally one would maintain the input one which is currently lost) + shapes_elements[v] = parsed else: shape_specs = [(self.shapes_key, cell_tx_pd)] if has_fragments and fragment_tx_pd is not None: @@ -391,6 +391,12 @@ def _parse_shapes(shapes): z_column=z_column, ) + for name, table in tables_elements.items(): + if 'spatialdata_attrs' not in table.uns.keys(): + warnings.warn( + f"Table {name} does not contain the `uns['spatialdata_attrs']` field as no shapes element is associated." + ) + # Create SpatialData (prefer modern constructor methods, keep fallback on single elemnts) sdata = self._build_spatialdata( spatialdata=spatialdata, @@ -449,7 +455,7 @@ def _build_table_element( obs_index_as_str=True, ) if region is None: - return table + return TableModel.validate(table) instance_key = self.table_region_key table.obs["region"] = region @@ -462,7 +468,8 @@ def _build_table_element( region_key="region", instance_key=instance_key or "instance_id", ) - except Exception: + except Exception as e: + warnings.warn(f"TableModel.parse failed: {e}") return table def _write_spatialdata_zarr(self, sdata, output_path: Path, overwrite: bool) -> None: @@ -535,6 +542,7 @@ def _get_generated_boundaries( elif self.boundary_method == "delaunay": from segger.export.boundary import generate_boundaries + warnings.filterwarnings('ignore', 'GeoSeries.notna', UserWarning) boundaries_gdf = generate_boundaries( assigned, diff --git a/src/segger/io/filtering.py b/src/segger/io/filtering.py index 6d8fbca..abb796d 100644 --- a/src/segger/io/filtering.py +++ b/src/segger/io/filtering.py @@ -57,13 +57,13 @@ def platform_feature_filter_patterns(platform: str | None) -> list[str]: return list(MerscopeTranscriptFields.filter_substrings) return [] - def glob_patterns_to_regex(patterns: Sequence[str]) -> str: - """Convert glob-like patterns (``*``) to a regex union.""" - return "|".join( - f"^{re.escape(pattern).replace(r'\\*', '.*')}$" - for pattern in patterns - ) + """Convert glob-like patterns (`*`) to a regex union.""" + regexes = [] + for pattern in patterns: + regex_pattern = re.escape(pattern).replace(r"\*", ".*") + regexes.append(f"^{regex_pattern}$") + return "|".join(regexes) def apply_feature_filters( diff --git a/src/segger/io/spatialdata_loader.py b/src/segger/io/spatialdata_loader.py index 2daeeb5..889dc40 100644 --- a/src/segger/io/spatialdata_loader.py +++ b/src/segger/io/spatialdata_loader.py @@ -314,7 +314,7 @@ def transcripts( cell_id_col = self._detect_column( columns, - ["cell_id", "cell", "segger_cell_id", "segmentation_cell_id", "instance_id"], + ["cell_id", "cell", "segger_cell_id", "segmentation_cell_id", "instance_id", "cell_ID"], optional=True, )