diff --git a/pyproject.toml b/pyproject.toml index 097a80e..4fb31ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "scikit-learn", "tifffile", "torch_geometric", + "json", ] [build-system] diff --git a/src/segger/io/fields.py b/src/segger/io/fields.py index 40bd6be..8dc30a9 100644 --- a/src/segger/io/fields.py +++ b/src/segger/io/fields.py @@ -23,6 +23,26 @@ class XeniumTranscriptFields: 'UnassignedCodeword_*', ] +@dataclass +class XeniumTranscriptFieldsV1: + filename: str = 'transcripts.parquet' + x: str = 'x_location' + y: str = 'y_location' + feature: str = 'feature_name' + cell_id: str = 'cell_id' + null_cell_id: str = "-1" + compartment: str = 'overlaps_nucleus' + nucleus_value: int = 1 + quality: str = 'qv' + filter_substrings = [ + 'NegControlProbe_*', + 'antisense_*', + 'NegControlCodeword*', + 'BLANK_*', + 'DeprecatedCodeword_*', + 'UnassignedCodeword_*', + ] + @dataclass class XeniumBoundaryFields: cell_filename: str = 'cell_boundaries.parquet' diff --git a/src/segger/io/preprocessor.py b/src/segger/io/preprocessor.py index 597a818..8320230 100644 --- a/src/segger/io/preprocessor.py +++ b/src/segger/io/preprocessor.py @@ -7,6 +7,7 @@ import geopandas as gpd import polars as pl import pandas as pd +import json import warnings import logging import sys @@ -21,7 +22,8 @@ MerscopeBoundaryFields, StandardTranscriptFields, StandardBoundaryFields, - XeniumTranscriptFields, + XeniumTranscriptFields, + XeniumTranscriptFieldsV1, XeniumBoundaryFields, CosMxTranscriptFields, CosMxBoundaryFields, @@ -372,16 +374,48 @@ class XeniumPreprocessor(ISTPreprocessor): """ Preprocessor for 10x Genomics Xenium datasets. """ + + tx_fields = XeniumTranscriptFields() + bd_fields = XeniumBoundaryFields() + sw_version = lambda version: version[0] > 1 + @staticmethod - def _validate_directory(data_dir: Path): + def _get_analysis_sw_version(data_dir: Path) -> str: + """ + Get 10x xenium analysis software version. Example experiment.xenium file: + { + ..., + "analysis_sw_version": "xenium-3.3.1.1" + } + Return: + version : list of ints representing major, minor, and patch version numbers (e.g. [3, 3, 1, 1]) + """ + # get version + path_meta = data_dir / "experiment.xenium" + with open(path_meta) as f: + meta = json.load(f) + # version can be xenium-x.y.z or Xenium-x.y.z, ... + version = meta["analysis_sw_version"].split("-")[-1].split(".") + version = [int(v) for v in version] + return version + + @classmethod + def _validate_directory(cls, data_dir: Path): + + # Apply xenium software version 2 or higher (when cell id "Unassigned" was introduced. Previously -1) + version = XeniumPreprocessor._get_analysis_sw_version(data_dir) + if not cls.sw_version(version): + raise IOError( + f"Xenium analysis software version must be 2.0.0 or higher, " + f"but found version {'.'.join(version)}." + ) + # Check required files/directories - bd_fields = XeniumBoundaryFields() - tx_fields = XeniumTranscriptFields() for pat in [ - tx_fields.filename, - bd_fields.cell_filename, - bd_fields.nucleus_filename, + cls.tx_fields.filename, + cls.bd_fields.cell_filename, + cls.bd_fields.nucleus_filename, ]: num_matches = len(list(data_dir.glob(pat))) if not num_matches == 1: @@ -394,7 +428,7 @@ def _validate_directory(data_dir: Path): def transcripts(self) -> pl.DataFrame: # Field names - raw = XeniumTranscriptFields() + raw = self.tx_fields std = StandardTranscriptFields() return ( @@ -405,6 +439,11 @@ def transcripts(self) -> pl.DataFrame: ) # Add numeric index at beginning .with_row_index(name=std.row_index) + # Cast binary columns to string (Some Xenium parquet stores these as binary) + .with_columns( + pl.col(raw.feature).cast(pl.Utf8), + pl.col(raw.cell_id).cast(pl.Utf8), + ) # Filter data .filter(pl.col(raw.quality) >= 20) .filter(pl.col(raw.feature).str.contains( @@ -437,15 +476,16 @@ def transcripts(self) -> pl.DataFrame: .collect() ) - @staticmethod + @classmethod def _get_boundaries( + cls, filepath: Path, boundary_type: str ) -> gpd.GeoDataFrame: # TODO: Add documentation # Field names - raw = XeniumBoundaryFields() + raw = cls.bd_fields std = StandardBoundaryFields() # Read in flat vertices and convert to geometries @@ -463,7 +503,7 @@ def _get_boundaries( @cached_property def boundaries(self) -> gpd.GeoDataFrame: # TODO: Add documentation - raw = XeniumBoundaryFields() + raw = self.bd_fields std = StandardBoundaryFields() # Join boundary datasets @@ -496,14 +536,24 @@ def boundaries(self) -> gpd.GeoDataFrame: cells.reset_index(drop=False, names=std.id), nuclei.reset_index(drop=False, names=std.id), ]) - # Convert index to string type (to join on AnnData) - bd.index = bd[std.id] + '_' + bd[std.boundary_type].map({ + # cell_id is string in later 10x versions, but int in earlier versions. + bd.index = bd[std.id].astype(str) + '_' + bd[std.boundary_type].map({ std.nucleus_value: '0', std.cell_value: '1', }) return bd +@register_preprocessor("10x_xenium_v1") +class XeniumPreprocessorV1(XeniumPreprocessor): + """ + Preprocessor for 10x Genomics Xenium datasets analyzed with software version 1.x. + """ + + tx_fields = XeniumTranscriptFieldsV1() + bd_fields = XeniumBoundaryFields() + sw_version = lambda version: version[0] == 1 + @register_preprocessor("vizgen_merscope") class MerscopePreprocessor(ISTPreprocessor):