diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..c1e2c64 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,12 @@ +# EditorConfig is awesome: https://EditorConfig.org + +# top-most EditorConfig file +root = true + +[*] +indent_style = space +indent_size = 4 +end_of_line = lf +charset = utf-8 +trim_trailing_whitespace = true +insert_final_newline = true diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..7d40f36 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,23 @@ +name: Lint + +on: + push: + branches: main + pull_request: + branches: main + +jobs: + linting: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Setup Python 3.12 + uses: actions/setup-python@v4 + with: + python-version: "3.12" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pre-commit + - name: Check pre-commit compatibility + run: pre-commit run --all-files --show-diff-on-failure diff --git a/.gitignore b/.gitignore index 0268a52..9d666bc 100644 --- a/.gitignore +++ b/.gitignore @@ -182,9 +182,9 @@ cython_debug/ .abstra/ # Visual Studio Code -# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore -# and can be added to the global gitignore or merged into this file. However, if you prefer, +# and can be added to the global gitignore or merged into this file. However, if you prefer, # you could uncomment the following to ignore the entire vscode folder # .vscode/ @@ -205,4 +205,4 @@ __marimo__/ # Custom .dev .dev/* -*.pyc \ No newline at end of file +*.pyc diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..554acf6 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,54 @@ +fail_fast: false +default_language_version: + python: python3 +default_stages: + - pre-commit + - pre-push +minimum_pre_commit_version: 2.16.0 +repos: + - repo: https://github.com/psf/black + rev: "23.1.0" + hooks: + - id: black + - repo: https://github.com/asottile/blacken-docs + rev: 1.13.0 + hooks: + - id: blacken-docs + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v3.0.0-alpha.4 + hooks: + - id: prettier + # Newer versions of node don't work on systems that have an older version of GLIBC + # (in particular Ubuntu 18.04 and Centos 7) + # EOL of Centos 7 is in 2024-06, we can probably get rid of this then. + # See https://github.com/scverse/cookiecutter-scverse/issues/143 and + # https://github.com/jupyterlab/jupyterlab/issues/12675 + language_version: "17.9.1" + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: v0.0.253 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: detect-private-key + - id: check-ast + - id: end-of-file-fixer + - id: mixed-line-ending + args: [--fix=lf] + - id: trailing-whitespace + - id: check-case-conflict + - repo: local + hooks: + - id: forbid-to-commit + name: Don't commit rej files + entry: | + Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates. + Fix the merge conflicts manually and remove the .rej files. + language: fail + files: '.*\.rej$' diff --git a/README.md b/README.md index 050674a..552a1dd 100644 --- a/README.md +++ b/README.md @@ -2,15 +2,16 @@ ## pip -Before installing **segger**, please install GPU-accelerated versions of PyTorch, RAPIDS, and related packages compatible with your system. *Please ensure all CUDA-enabled packages are compiled for the same CUDA version.* +Before installing **segger**, please install GPU-accelerated versions of PyTorch, RAPIDS, and related packages compatible with your system. _Please ensure all CUDA-enabled packages are compiled for the same CUDA version._ -- **PyTorch & torchvision:** [Installation guide](https://pytorch.org/get-started/locally/) -- **torch_scatter:** [Installation guide](https://github.com/rusty1s/pytorch_scatter#installation) -- **RAPIDS (cuDF, cuML, cuGraph):** [Installation guide](https://docs.rapids.ai/install) -- **CuPy:** [Installation guide](https://docs.cupy.dev/en/stable/install.html) -- **cuSpatial:** [Installation guide](https://docs.rapids.ai/api/cuspatial/stable/user_guide/cuspatial_api_examples/#Installing-cuSpatial) +- **PyTorch & torchvision:** [Installation guide](https://pytorch.org/get-started/locally/) +- **torch_scatter:** [Installation guide](https://github.com/rusty1s/pytorch_scatter#installation) +- **RAPIDS (cuDF, cuML, cuGraph):** [Installation guide](https://docs.rapids.ai/install) +- **CuPy:** [Installation guide](https://docs.cupy.dev/en/stable/install.html) +- **cuSpatial:** [Installation guide](https://docs.rapids.ai/api/cuspatial/stable/user_guide/cuspatial_api_examples/#Installing-cuSpatial) For example, on Linux with CUDA 12.1 and PyTorch 2.5.0: + ```bash # Install PyTorch and torchvision for CUDA 12.1 pip install torch==2.5.0 torchvision==0.20.0 --index-url https://download.pytorch.org/whl/cu121 @@ -24,6 +25,7 @@ pip install --extra-index-url=https://pypi.nvidia.com cuspatial-cu12 cudf-cu12 c # Install CuPy for CUDA 12.x pip install cupy-cuda12x ``` + **December 2025:** To stay up-to-date with new developments, we recommend installing the latest version directly from GitHub: ```bash @@ -35,11 +37,13 @@ pip install -e . # Usage You can run **segger** from the command line with: + ```bash segger segment -i /path/to/your/ist/data/ -o /path/to/save/outputs/ ``` To see all available parameter options: + ```bash segger segment --help -``` \ No newline at end of file +``` diff --git a/pyproject.toml b/pyproject.toml index 097a80e..67537cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,11 @@ dependencies = [ "torch_geometric", ] +[project.optional-dependencies] +dev = [ + "pre-commit", +] + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" @@ -39,4 +44,89 @@ build-backend = "hatchling.build" packages = ["src/segger"] [project.scripts] -segger = "segger.cli.main:app" \ No newline at end of file +segger = "segger.cli.main:app" + +[tool.black] +line-length = 120 +include = '\.pyi?$|\.ipynb?$' +exclude = ''' +( + /( + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist + )/ +) +''' + +[tool.isort] +profile = "black" +use_parentheses = true +known_num = "geopandas,lightning,networkx,numba,numpy,opencv-python,pandas,polars,scipy,scikit-image,scikit-learn,shapely,sklearn,statmodels,torch_geometric" +known_bio = "anndata,scanpy" +sections = "FUTURE,STDLIB,THIRDPARTY,NUM,PLOT,BIO,FIRSTPARTY,LOCALFOLDER" +no_lines_before = "LOCALFOLDER" +balanced_wrapping = true +length_sort = "0" +indent = " " +float_to_top = true +order_by_type = false + +[tool.ruff] +src = ["."] +line-length = 119 +target-version = "py38" +select = [ + "F", # Errors detected by Pyflakes + "E", # Error detected by Pycodestyle + "W", # Warning detected by Pycodestyle + "D", # pydocstyle + "B", # flake8-bugbear + "TID", # flake8-tidy-imports + "C4", # flake8-comprehensions + "BLE", # flake8-blind-except + "UP", # pyupgrade + "RUF100", # Report unused noqa directives +] +ignore = [ + # line too long -> we accept long comment lines; black gets rid of long code lines + "E501", + # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient + "E731", + # allow I, O, l as variable names -> I is the identity matrix + "E741", + # Missing docstring in public package + "D104", + # Missing docstring in public module + "D100", + # Missing docstring in __init__ + "D107", + # Errors from function calls in argument defaults. These are fine when the result is immutable. + "B008", + # __magic__ methods are are often self-explanatory, allow missing docstrings + "D105", + # first line should end with a period [Bug: doesn't work with single-line docstrings] + "D400", + # First line should be in imperative mood; try rephrasing + "D401", + ## Disable one in each pair of mutually incompatible rules + # We don’t want a blank line before a class docstring + "D203", + # We want docstrings to start immediately after the opening triple quote + "D213", + # Missing argument description in the docstring TODO: enable + "D417", + # Unable to detect undefined names + "F403", + # Underfined, or defined from star imports: module + "F405", + # Within an except clause, raise exceptions with `raise ... from err` + "B904", +] diff --git a/src/segger/cli/config.yaml b/src/segger/cli/config.yaml index a415513..b617235 100644 --- a/src/segger/cli/config.yaml +++ b/src/segger/cli/config.yaml @@ -39,4 +39,4 @@ transcripts_loss_weight_end: 1. cells_loss_weight_start: 1. cells_loss_weight_start: 1. segmentation_loss_weight_start: 0. -segmentation_loss_weight_end: 0.1 \ No newline at end of file +segmentation_loss_weight_end: 0.1 diff --git a/src/segger/cli/main.py b/src/segger/cli/main.py index 89905d2..31aeba3 100644 --- a/src/segger/cli/main.py +++ b/src/segger/cli/main.py @@ -1,13 +1,13 @@ -from cyclopts import App, Parameter, Group, validators -from typing import Annotated, Literal from pathlib import Path +from typing import Annotated, Literal -from .registry import ParameterRegistry +from cyclopts import App, Group, Parameter, validators +from .registry import ParameterRegistry # Register defaults and descriptions from files directly # This is to avoid needing to import all requirements before calling CLI -registry = ParameterRegistry(framework='cyclopts') +registry = ParameterRegistry(framework="cyclopts") base_dir = Path(__file__).parent.parent to_register = [ ("data/data_module.py", "ISTDataModule"), @@ -58,241 +58,305 @@ sort_key=7, ) + @app.command def segment( # I/O - input_directory: Annotated[Path, registry.get_parameter( - "input_directory", - alias="-i", - group=group_io, - validator=validators.Path(exists=True, dir_okay=True), - )] = registry.get_default("input_directory"), - - output_directory: Annotated[Path, registry.get_parameter( - "output_directory", - alias="-o", - group=group_io, - validator=validators.Path(exists=True, dir_okay=True), - )] = registry.get_default("output_directory"), - - + input_directory: Annotated[ + Path, + registry.get_parameter( + "input_directory", + alias="-i", + group=group_io, + validator=validators.Path(exists=True, dir_okay=True), + ), + ] = registry.get_default("input_directory"), + output_directory: Annotated[ + Path, + registry.get_parameter( + "output_directory", + alias="-o", + group=group_io, + validator=validators.Path(exists=True, dir_okay=True), + ), + ] = registry.get_default("output_directory"), # Cell Representation - node_representation_dim: Annotated[int, Parameter( - help="Number of dimensions used to represent each node type.", - validator=validators.Number(gt=0), - group=group_nodes, - required=False, - )] = registry.get_default("cells_embedding_size"), - - cells_representation: Annotated[Literal['pca', 'morphology'], registry.get_parameter( - "cells_representation_mode", - group=group_nodes, - )] = registry.get_default("cells_representation_mode"), - - cells_min_counts: Annotated[int, registry.get_parameter( - "cells_min_counts", - validator=validators.Number(gte=0), - group=group_nodes, - )] = registry.get_default("cells_min_counts"), - - cells_clusters_n_neighbors: Annotated[int, registry.get_parameter( - "cells_clusters_n_neighbors", - validator=validators.Number(gt=0), - group=group_nodes, - )] = registry.get_default("cells_clusters_n_neighbors"), - - cells_clusters_resolution: Annotated[float, registry.get_parameter( - "cells_clusters_resolution", - validator=validators.Number(gt=0, lte=5), - group=group_nodes, - )] = registry.get_default("cells_clusters_resolution"), - - + node_representation_dim: Annotated[ + int, + Parameter( + help="Number of dimensions used to represent each node type.", + validator=validators.Number(gt=0), + group=group_nodes, + required=False, + ), + ] = registry.get_default("cells_embedding_size"), + cells_representation: Annotated[ + Literal["pca", "morphology"], + registry.get_parameter( + "cells_representation_mode", + group=group_nodes, + ), + ] = registry.get_default("cells_representation_mode"), + cells_min_counts: Annotated[ + int, + registry.get_parameter( + "cells_min_counts", + validator=validators.Number(gte=0), + group=group_nodes, + ), + ] = registry.get_default("cells_min_counts"), + cells_clusters_n_neighbors: Annotated[ + int, + registry.get_parameter( + "cells_clusters_n_neighbors", + validator=validators.Number(gt=0), + group=group_nodes, + ), + ] = registry.get_default("cells_clusters_n_neighbors"), + cells_clusters_resolution: Annotated[ + float, + registry.get_parameter( + "cells_clusters_resolution", + validator=validators.Number(gt=0, lte=5), + group=group_nodes, + ), + ] = registry.get_default("cells_clusters_resolution"), # Gene Representation - genes_clusters_n_neighbors: Annotated[int, registry.get_parameter( - "genes_clusters_n_neighbors", - validator=validators.Number(gt=0), - group=group_nodes, - )] = registry.get_default("genes_clusters_n_neighbors"), - - genes_clusters_resolution: Annotated[float, registry.get_parameter( - "genes_clusters_resolution", - validator=validators.Number(gt=0, lte=5), - group=group_nodes, - )] = registry.get_default("genes_clusters_resolution"), - - + genes_clusters_n_neighbors: Annotated[ + int, + registry.get_parameter( + "genes_clusters_n_neighbors", + validator=validators.Number(gt=0), + group=group_nodes, + ), + ] = registry.get_default("genes_clusters_n_neighbors"), + genes_clusters_resolution: Annotated[ + float, + registry.get_parameter( + "genes_clusters_resolution", + validator=validators.Number(gt=0, lte=5), + group=group_nodes, + ), + ] = registry.get_default("genes_clusters_resolution"), # Transcript-Transcript Graph - transcripts_max_k: Annotated[int, registry.get_parameter( - "transcripts_graph_max_k", - validator=validators.Number(gt=0), - group=group_transcripts_graph, - )] = registry.get_default("transcripts_graph_max_k"), - - transcripts_max_dist: Annotated[float, registry.get_parameter( - "transcripts_graph_max_dist", - validator=validators.Number(gt=0), - group=group_transcripts_graph, - )] = registry.get_default("transcripts_graph_max_dist"), - - + transcripts_max_k: Annotated[ + int, + registry.get_parameter( + "transcripts_graph_max_k", + validator=validators.Number(gt=0), + group=group_transcripts_graph, + ), + ] = registry.get_default("transcripts_graph_max_k"), + transcripts_max_dist: Annotated[ + float, + registry.get_parameter( + "transcripts_graph_max_dist", + validator=validators.Number(gt=0), + group=group_transcripts_graph, + ), + ] = registry.get_default("transcripts_graph_max_dist"), # Segmentation (Prediction) Graph prediction_mode: Annotated[ Literal["nucleus", "cell", "uniform"], registry.get_parameter( "prediction_graph_mode", group=group_prediction, - ) + ), ] = registry.get_default("prediction_graph_mode"), - - prediction_max_k: Annotated[int | None, registry.get_parameter( - "prediction_graph_max_k", - validator=validators.Number(gt=0), - group=group_prediction, - )] = registry.get_default("prediction_graph_max_k"), - - prediction_expansion_ratio: Annotated[float | None, registry.get_parameter( - "prediction_graph_buffer_ratio", - validator=validators.Number(gt=0), - group=group_prediction, - )] = registry.get_default("prediction_graph_buffer_ratio"), - + prediction_max_k: Annotated[ + int | None, + registry.get_parameter( + "prediction_graph_max_k", + validator=validators.Number(gt=0), + group=group_prediction, + ), + ] = registry.get_default("prediction_graph_max_k"), + prediction_expansion_ratio: Annotated[ + float | None, + registry.get_parameter( + "prediction_graph_buffer_ratio", + validator=validators.Number(gt=0), + group=group_prediction, + ), + ] = registry.get_default("prediction_graph_buffer_ratio"), # Tiling - tiling_margin_training: Annotated[float, registry.get_parameter( - "tiling_margin_training", - validator=validators.Number(gte=0), - group=group_tiling, - )] = registry.get_default("tiling_margin_training"), - - tiling_margin_prediction: Annotated[float, registry.get_parameter( - "tiling_margin_prediction", - validator=validators.Number(gte=0), - group=group_tiling, - )] = registry.get_default("tiling_margin_prediction"), - - max_nodes_per_tile: Annotated[int, registry.get_parameter( - "tiling_nodes_per_tile", - validator=validators.Number(gt=0), - group=group_tiling, - )] = registry.get_default("tiling_nodes_per_tile"), - - max_edges_per_batch: Annotated[int, registry.get_parameter( - "edges_per_batch", - validator=validators.Number(gt=0), - group=group_tiling, - )] = registry.get_default("edges_per_batch"), - + tiling_margin_training: Annotated[ + float, + registry.get_parameter( + "tiling_margin_training", + validator=validators.Number(gte=0), + group=group_tiling, + ), + ] = registry.get_default("tiling_margin_training"), + tiling_margin_prediction: Annotated[ + float, + registry.get_parameter( + "tiling_margin_prediction", + validator=validators.Number(gte=0), + group=group_tiling, + ), + ] = registry.get_default("tiling_margin_prediction"), + max_nodes_per_tile: Annotated[ + int, + registry.get_parameter( + "tiling_nodes_per_tile", + validator=validators.Number(gt=0), + group=group_tiling, + ), + ] = registry.get_default("tiling_nodes_per_tile"), + max_edges_per_batch: Annotated[ + int, + registry.get_parameter( + "edges_per_batch", + validator=validators.Number(gt=0), + group=group_tiling, + ), + ] = registry.get_default("edges_per_batch"), # Model - n_epochs: Annotated[int, Parameter( - validator=validators.Number(gt=0), - group=group_model, - help="Number of training epochs.", - )] = 20, - - n_mid_layers: Annotated[int, registry.get_parameter( - "n_mid_layers", - validator=validators.Number(gte=0), - group=group_model, - )] = registry.get_default("n_mid_layers"), - - n_heads: Annotated[int, registry.get_parameter( - "n_heads", - validator=validators.Number(gt=0), - group=group_model, - )] = registry.get_default("n_heads"), - - hidden_channels: Annotated[int, registry.get_parameter( - "hidden_channels", - validator=validators.Number(gt=0), - group=group_model, - )] = registry.get_default("hidden_channels"), - - out_channels: Annotated[int, registry.get_parameter( - "out_channels", - validator=validators.Number(gt=0), - group=group_model, - )] = registry.get_default("out_channels"), - - learning_rate: Annotated[float, registry.get_parameter( - "learning_rate", - validator=validators.Number(gt=0), - group=group_model, - )] = registry.get_default("learning_rate"), - - use_positional_embeddings: Annotated[bool, registry.get_parameter( - "use_positional_embeddings", - group=group_model, - )] = registry.get_default("use_positional_embeddings"), - - normalize_embeddings: Annotated[bool, registry.get_parameter( - "normalize_embeddings", - group=group_model, - )] = registry.get_default("normalize_embeddings"), - + n_epochs: Annotated[ + int, + Parameter( + validator=validators.Number(gt=0), + group=group_model, + help="Number of training epochs.", + ), + ] = 20, + n_mid_layers: Annotated[ + int, + registry.get_parameter( + "n_mid_layers", + validator=validators.Number(gte=0), + group=group_model, + ), + ] = registry.get_default("n_mid_layers"), + n_heads: Annotated[ + int, + registry.get_parameter( + "n_heads", + validator=validators.Number(gt=0), + group=group_model, + ), + ] = registry.get_default("n_heads"), + hidden_channels: Annotated[ + int, + registry.get_parameter( + "hidden_channels", + validator=validators.Number(gt=0), + group=group_model, + ), + ] = registry.get_default("hidden_channels"), + out_channels: Annotated[ + int, + registry.get_parameter( + "out_channels", + validator=validators.Number(gt=0), + group=group_model, + ), + ] = registry.get_default("out_channels"), + learning_rate: Annotated[ + float, + registry.get_parameter( + "learning_rate", + validator=validators.Number(gt=0), + group=group_model, + ), + ] = registry.get_default("learning_rate"), + use_positional_embeddings: Annotated[ + bool, + registry.get_parameter( + "use_positional_embeddings", + group=group_model, + ), + ] = registry.get_default("use_positional_embeddings"), + normalize_embeddings: Annotated[ + bool, + registry.get_parameter( + "normalize_embeddings", + group=group_model, + ), + ] = registry.get_default("normalize_embeddings"), # Loss segmentation_loss: Annotated[ Literal["triplet", "bce"], registry.get_parameter( "sg_loss_type", group=group_loss, - ) + ), ] = registry.get_default("sg_loss_type"), - - transcripts_margin: Annotated[float, registry.get_parameter( - "tx_margin", - validator=validators.Number(gt=0), - group=group_loss, - )] = registry.get_default("tx_margin"), - - segmentation_margin: Annotated[float, registry.get_parameter( - "sg_margin", - validator=validators.Number(gt=0), - group=group_loss, - )] = registry.get_default("sg_margin"), - - transcripts_loss_weight_start: Annotated[float, registry.get_parameter( - "tx_weight_start", - validator=validators.Number(gte=0), - group=group_loss, - )] = registry.get_default("tx_weight_start"), - - transcripts_loss_weight_end: Annotated[float, registry.get_parameter( - "tx_weight_end", - validator=validators.Number(gte=0), - group=group_loss, - )] = registry.get_default("tx_weight_end"), - - cells_loss_weight_start: Annotated[float, registry.get_parameter( - "bd_weight_start", - validator=validators.Number(gte=0), - group=group_loss, - )] = registry.get_default("bd_weight_start"), - - cells_loss_weight_end: Annotated[float, registry.get_parameter( - "bd_weight_end", - validator=validators.Number(gte=0), - group=group_loss, - )] = registry.get_default("bd_weight_end"), - - segmentation_loss_weight_start: Annotated[float, registry.get_parameter( - "sg_weight_start", - validator=validators.Number(gte=0), - group=group_loss, - )] = registry.get_default("sg_weight_start"), - - segmentation_loss_weight_end: Annotated[float, registry.get_parameter( - "sg_weight_end", - validator=validators.Number(gte=0), - group=group_loss, - )] = registry.get_default("sg_weight_end"), + transcripts_margin: Annotated[ + float, + registry.get_parameter( + "tx_margin", + validator=validators.Number(gt=0), + group=group_loss, + ), + ] = registry.get_default("tx_margin"), + segmentation_margin: Annotated[ + float, + registry.get_parameter( + "sg_margin", + validator=validators.Number(gt=0), + group=group_loss, + ), + ] = registry.get_default("sg_margin"), + transcripts_loss_weight_start: Annotated[ + float, + registry.get_parameter( + "tx_weight_start", + validator=validators.Number(gte=0), + group=group_loss, + ), + ] = registry.get_default("tx_weight_start"), + transcripts_loss_weight_end: Annotated[ + float, + registry.get_parameter( + "tx_weight_end", + validator=validators.Number(gte=0), + group=group_loss, + ), + ] = registry.get_default("tx_weight_end"), + cells_loss_weight_start: Annotated[ + float, + registry.get_parameter( + "bd_weight_start", + validator=validators.Number(gte=0), + group=group_loss, + ), + ] = registry.get_default("bd_weight_start"), + cells_loss_weight_end: Annotated[ + float, + registry.get_parameter( + "bd_weight_end", + validator=validators.Number(gte=0), + group=group_loss, + ), + ] = registry.get_default("bd_weight_end"), + segmentation_loss_weight_start: Annotated[ + float, + registry.get_parameter( + "sg_weight_start", + validator=validators.Number(gte=0), + group=group_loss, + ), + ] = registry.get_default("sg_weight_start"), + segmentation_loss_weight_end: Annotated[ + float, + registry.get_parameter( + "sg_weight_end", + validator=validators.Number(gte=0), + group=group_loss, + ), + ] = registry.get_default("sg_weight_end"), ): """Run cell segmentation on spatial transcriptomics data.""" # Remove SLURM environment autodetect from lightning.pytorch.plugins.environments import SLURMEnvironment + SLURMEnvironment.detect = lambda: False # Setup Lightning Data Module - from ..data import ISTDataModule + from data import ISTDataModule + datamodule = ISTDataModule( input_directory=input_directory, cells_representation_mode=cells_representation, @@ -312,9 +376,10 @@ def segment( tiling_nodes_per_tile=max_nodes_per_tile, edges_per_batch=max_edges_per_batch, ) - + # Setup Lightning Model - from ..models import LitISTEncoder + from models import LitISTEncoder + n_genes = datamodule.ad.shape[1] model = LitISTEncoder( n_genes=n_genes, @@ -338,9 +403,11 @@ def segment( ) # Setup Lightning Trainer - from lightning.pytorch.loggers import CSVLogger - from ..data import ISTSegmentationWriter + from data import ISTSegmentationWriter + from lightning.pytorch import Trainer + from lightning.pytorch.loggers import CSVLogger + logger = CSVLogger(output_directory) writer = ISTSegmentationWriter(output_directory) trainer = Trainer( diff --git a/src/segger/cli/registry.py b/src/segger/cli/registry.py index 139a2ea..f5c1ed5 100644 --- a/src/segger/cli/registry.py +++ b/src/segger/cli/registry.py @@ -1,29 +1,30 @@ -""" -Parameter registry for extracting docstring descriptions and default values +"""Parameter registry for extracting docstring descriptions and default values from class constructors to populate a CLI (works with both Cyclopts and Typer). """ -from typing import Any, Type, Annotated -from dataclasses import dataclass, MISSING -from docstring_parser import parse import ast import inspect +from dataclasses import dataclass, MISSING from pathlib import Path +from typing import Any, Type + +from docstring_parser import parse @dataclass class ParameterInfo: """Container for parameter information.""" + default: Any help: str type_annotation: Any _is_required: bool = False - + @property def is_required(self) -> bool: """Check if this parameter is required (has no default value).""" return self._is_required - + @property def has_default(self) -> bool: """Check if this parameter has a default value.""" @@ -31,17 +32,15 @@ def has_default(self) -> bool: class ParameterRegistry: - """ - Registry for collecting parameter information from multiple classes + """Registry for collecting parameter information from multiple classes and making it available for CLI construction. - + Works with both Cyclopts and Typer frameworks. """ - + def __init__(self, framework: str = "typer"): - """ - Initialize the registry. - + """Initialize the registry. + Parameters ---------- framework : str @@ -50,14 +49,13 @@ def __init__(self, framework: str = "typer"): self._parameters: dict[str, ParameterInfo] = {} self._registration_order: list[str] = [] # Track order of unprefixed names self._framework = framework.lower() - + if self._framework not in ("typer", "cyclopts"): raise ValueError("framework must be either 'typer' or 'cyclopts'") - + def register_from_file(self, file_path: str | Path, class_name: str, prefix: str | None = None) -> None: - """ - Register a class by parsing its source file without importing. - + """Register a class by parsing its source file without importing. + Parameters ---------- file_path : str | Path @@ -66,7 +64,7 @@ def register_from_file(self, file_path: str | Path, class_name: str, prefix: str Name of the class to parse prefix : str, optional Optional prefix for parameter names (defaults to class_name) - + Raises ------ ValueError @@ -77,106 +75,103 @@ def register_from_file(self, file_path: str | Path, class_name: str, prefix: str file_path = Path(file_path) if not file_path.exists(): raise FileNotFoundError(f"File not found: {file_path}") - + # Parse the file - with open(file_path, 'r') as f: + with open(file_path) as f: source = f.read() - + tree = ast.parse(source) - + # Find the class definition class_node = None for node in ast.walk(tree): if isinstance(node, ast.ClassDef) and node.name == class_name: class_node = node break - + if class_node is None: raise ValueError(f"Class '{class_name}' not found in {file_path}") - + # Extract docstring docstring = ast.get_docstring(class_node) or "" - + # Parse parameters from class body (for dataclasses) and __init__ defaults, type_annotations = self._extract_from_ast(class_node) - + # Parse docstring doc_params = {} if docstring: parsed_doc = parse(docstring) - doc_params = {param.arg_name: { - 'description': param.description or "", - 'type_name': param.type_name - } for param in parsed_doc.params} - + doc_params = { + param.arg_name: {"description": param.description or "", "type_name": param.type_name} + for param in parsed_doc.params + } + # Use class name as prefix if not provided if prefix is None: prefix = class_name - + # Process each parameter self._process_parameters(defaults, type_annotations, doc_params, prefix) - + def register_class(self, cls: Type, prefix: str | None = None) -> None: - """ - Register a class by inspecting it directly (requires importing). - + """Register a class by inspecting it directly (requires importing). + Parameters ---------- cls : Type The class to register (e.g., LightningDataModule, LightningModule) prefix : str, optional Optional prefix for parameter names (defaults to class name) - + Raises ------ ValueError If a parameter already exists with conflicting default value or description """ - from dataclasses import fields - # Get docstring information docstring = parse(cls.__doc__ or "") - doc_params = {param.arg_name: { - 'description': param.description or "", - 'type_name': param.type_name - } for param in docstring.params} - + doc_params = { + param.arg_name: {"description": param.description or "", "type_name": param.type_name} + for param in docstring.params + } + # Get default values and type annotations defaults = self._extract_defaults_from_class(cls) type_annotations = self._extract_type_annotations_from_class(cls) - + # Use class name as prefix if not provided if prefix is None: prefix = cls.__name__ - + # Process each parameter self._process_parameters(defaults, type_annotations, doc_params, prefix) - + def _process_parameters(self, defaults: dict, type_annotations: dict, doc_params: dict, prefix: str) -> None: """Process and register parameters from extracted information.""" for param_name, default_value in defaults.items(): is_required = default_value is MISSING - doc_info = doc_params.get(param_name, {'description': '', 'type_name': None}) - + doc_info = doc_params.get(param_name, {"description": "", "type_name": None}) + # Get type annotation (prefer from class annotations, fall back to docstring) type_ann = type_annotations.get(param_name) - if type_ann is None and doc_info['type_name']: + if type_ann is None and doc_info["type_name"]: # Store the string representation from docstring if no annotation found - type_ann = doc_info['type_name'] - + type_ann = doc_info["type_name"] + param_info = ParameterInfo( default=None if is_required else default_value, - help=doc_info['description'], + help=doc_info["description"], type_annotation=type_ann, - _is_required=is_required + _is_required=is_required, ) - + # Register with both prefixed and unprefixed names prefixed_name = f"{prefix}.{param_name}" - + # Store prefixed version self._parameters[prefixed_name] = param_info - + # Check for conflicts on unprefixed name if param_name in self._parameters: self._check_conflicts(param_name, self._parameters[param_name], param_info) @@ -185,11 +180,10 @@ def _process_parameters(self, defaults: dict, type_annotations: dict, doc_params # First time seeing this unprefixed name self._parameters[param_name] = param_info self._registration_order.append(param_name) - + def _extract_from_ast(self, class_node: ast.ClassDef) -> tuple[dict[str, Any], dict[str, str]]: - """ - Extract defaults and type annotations from an AST ClassDef node. - + """Extract defaults and type annotations from an AST ClassDef node. + Returns ------- tuple[dict[str, Any], dict[str, str]] @@ -197,15 +191,15 @@ def _extract_from_ast(self, class_node: ast.ClassDef) -> tuple[dict[str, Any], d """ defaults = {} type_annotations = {} - + # Look for annotated assignments in the class body (dataclass fields) for node in class_node.body: if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name): param_name = node.target.id - + # Get type annotation as string type_annotations[param_name] = ast.unparse(node.annotation) - + # Get default value if present if node.value is not None: try: @@ -216,26 +210,26 @@ def _extract_from_ast(self, class_node: ast.ClassDef) -> tuple[dict[str, Any], d defaults[param_name] = ast.unparse(node.value) else: defaults[param_name] = MISSING - + # Also look for __init__ method - elif isinstance(node, ast.FunctionDef) and node.name == '__init__': + elif isinstance(node, ast.FunctionDef) and node.name == "__init__": # Extract parameters from __init__ signature for arg in node.args.args: - if arg.arg == 'self': + if arg.arg == "self": continue - + param_name = arg.arg - + # Get type annotation if present if arg.annotation is not None: type_annotations[param_name] = ast.unparse(arg.annotation) - + # Get default value if present # Defaults are stored in reverse order at the end of args num_defaults = len(node.args.defaults) num_args = len(node.args.args) - 1 # Exclude self arg_index = node.args.args.index(arg) - 1 # Exclude self from index - + if arg_index >= num_args - num_defaults: # This arg has a default default_index = arg_index - (num_args - num_defaults) @@ -248,22 +242,21 @@ def _extract_from_ast(self, class_node: ast.ClassDef) -> tuple[dict[str, Any], d # No default - only add if not already from dataclass fields if param_name not in defaults: defaults[param_name] = MISSING - + return defaults, type_annotations - + def _extract_defaults_from_class(self, cls: Type) -> dict[str, Any]: - """ - Extract parameter names and their default values from a class. - + """Extract parameter names and their default values from a class. + Returns a dict mapping parameter names to their default values, or MISSING sentinel if no default exists. """ from dataclasses import fields - + defaults = {} - + # Try dataclass fields first - if hasattr(cls, '__dataclass_fields__'): + if hasattr(cls, "__dataclass_fields__"): for field in fields(cls): if field.default is not MISSING: defaults[field.name] = field.default @@ -276,30 +269,29 @@ def _extract_defaults_from_class(self, cls: Type) -> dict[str, Any]: try: sig = inspect.signature(cls.__init__) for param_name, param in sig.parameters.items(): - if param_name == 'self': + if param_name == "self": continue - + if param.default is inspect.Parameter.empty: defaults[param_name] = MISSING else: defaults[param_name] = param.default except (ValueError, TypeError): pass - + return defaults - + def _extract_type_annotations_from_class(self, cls: Type) -> dict[str, Any]: - """ - Extract type annotations from a class. - + """Extract type annotations from a class. + Returns a dict mapping parameter names to their type annotations. """ from dataclasses import fields - + annotations = {} - + # Try dataclass fields first (they have type annotations) - if hasattr(cls, '__dataclass_fields__'): + if hasattr(cls, "__dataclass_fields__"): for field in fields(cls): annotations[field.name] = field.type else: @@ -307,16 +299,16 @@ def _extract_type_annotations_from_class(self, cls: Type) -> dict[str, Any]: try: sig = inspect.signature(cls.__init__) for param_name, param in sig.parameters.items(): - if param_name == 'self': + if param_name == "self": continue - + if param.annotation is not inspect.Parameter.empty: annotations[param_name] = param.annotation except (ValueError, TypeError): pass - + return annotations - + def _check_conflicts(self, param_name: str, existing: ParameterInfo, new: ParameterInfo) -> None: """Check for conflicts between existing and new parameter info.""" # Check default value conflict @@ -325,47 +317,46 @@ def _check_conflicts(self, param_name: str, existing: ParameterInfo, new: Parame f"Parameter '{param_name}' has conflicting requirements: " f"one class requires it, another has a default" ) - + if not existing.is_required and existing.default != new.default: raise ValueError( - f"Parameter '{param_name}' has conflicting default values: " - f"{existing.default} vs {new.default}" + f"Parameter '{param_name}' has conflicting default values: " f"{existing.default} vs {new.default}" ) - + # Check description conflict (only if both are non-empty) - if (existing.help and new.help and existing.help != new.help): + if existing.help and new.help and existing.help != new.help: raise ValueError( - f"Parameter '{param_name}' has conflicting descriptions: " - f"'{existing.help}' vs '{new.help}'" + f"Parameter '{param_name}' has conflicting descriptions: " f"'{existing.help}' vs '{new.help}'" ) - + # Check type annotation conflict (only if both are non-None) - if (existing.type_annotation is not None and - new.type_annotation is not None and - existing.type_annotation != new.type_annotation): + if ( + existing.type_annotation is not None + and new.type_annotation is not None + and existing.type_annotation != new.type_annotation + ): raise ValueError( f"Parameter '{param_name}' has conflicting type annotations: " f"{existing.type_annotation} vs {new.type_annotation}" ) - + def _merge_info(self, param_name: str, new: ParameterInfo) -> None: """Merge new parameter info with existing (prefer non-empty values).""" existing = self._parameters[param_name] - + # Merge descriptions (prefer non-empty) if new.help and not existing.help: self._parameters[param_name].help = new.help - + # Merge type annotations (prefer non-None) if new.type_annotation is not None and existing.type_annotation is None: self._parameters[param_name].type_annotation = new.type_annotation - + def get_parameter(self, param_name: str, **kwargs): - """ - Get an Annotated type for use in Cyclopts function signatures. - + """Get an Annotated type for use in Cyclopts function signatures. + Only available when framework="cyclopts". - + Parameters ---------- param_name : str @@ -375,13 +366,13 @@ def get_parameter(self, param_name: str, **kwargs): **kwargs Additional keyword arguments to pass to cyclopts.Parameter (e.g., validator, group, alias) - + Returns ------- Annotated type An Annotated type with the parsed type and cyclopts.Parameter configured with help text and required status - + Raises ------ ValueError @@ -391,14 +382,13 @@ def get_parameter(self, param_name: str, **kwargs): """ if self._framework != "cyclopts": raise ValueError( - "get_parameter() is only available when framework='cyclopts'. " - "Use get() instead for Typer." + "get_parameter() is only available when framework='cyclopts'. " "Use get() instead for Typer." ) - + from cyclopts import Parameter - + # Check if it's a prefixed name first - if '.' in param_name and param_name in self._parameters: + if "." in param_name and param_name in self._parameters: info = self._parameters[param_name] elif param_name in self._parameters: info = self._parameters[param_name] @@ -407,37 +397,33 @@ def get_parameter(self, param_name: str, **kwargs): f"Parameter '{param_name}' has not been registered. " f"Available parameters: {', '.join(sorted(self._parameters.keys()))}" ) - + if info.type_annotation is None: - raise ValueError( - f"Parameter '{param_name}' has no type annotation. " - "Cannot create Annotated type." - ) - + raise ValueError(f"Parameter '{param_name}' has no type annotation. " "Cannot create Annotated type.") + # Create the Parameter with help and required, plus any user kwargs - param_kwargs = dict(help=info.help, required=info.is_required) + param_kwargs = {"help": info.help, "required": info.is_required} param_kwargs.update(kwargs) param = Parameter(**param_kwargs) - + # Return Annotated type return param - + def get_default(self, param_name: str) -> Any: - """ - Get the default value for a parameter. - + """Get the default value for a parameter. + Only available when framework="cyclopts". - + Parameters ---------- param_name : str The name of the parameter - + Returns ------- Any The default value (None if no default exists) - + Raises ------ ValueError @@ -450,19 +436,18 @@ def get_default(self, param_name: str) -> Any: "get_default() is only available when framework='cyclopts'. " "For Typer, use get() which returns the configured Option with the default." ) - + if param_name not in self._parameters: raise KeyError(f"Parameter '{param_name}' has not been registered") - + return self._parameters[param_name].default - + def get(self, param_name: str, **kwargs): - """ - Get a configured parameter for the CLI framework. - + """Get a configured parameter for the CLI framework. + - For Typer: Returns typer.Option configured with help and default - For Cyclopts: Use get_parameter() and get_default() instead - + Parameters ---------- param_name : str @@ -477,12 +462,12 @@ def get(self, param_name: str, **kwargs): - file_okay: bool - For Path types, allow files - dir_okay: bool - For Path types, allow directories - help: str - Override the help text - + Returns ------- typer.Option (Typer mode) A Typer Option configured with help text and default value - + Raises ------ ValueError @@ -495,11 +480,11 @@ def get(self, param_name: str, **kwargs): "get() is only available when framework='typer'. " "For Cyclopts, use get_parameter() and get_default() instead." ) - + import typer - + # Check if it's a prefixed name first - if '.' in param_name and param_name in self._parameters: + if "." in param_name and param_name in self._parameters: info = self._parameters[param_name] elif param_name in self._parameters: info = self._parameters[param_name] @@ -508,28 +493,27 @@ def get(self, param_name: str, **kwargs): f"Parameter '{param_name}' has not been registered. " f"Available parameters: {', '.join(sorted(self._parameters.keys()))}" ) - + # Use provided default or fall back to registered default - if 'default' not in kwargs: + if "default" not in kwargs: # If required (no default), use ... as Typer's sentinel - kwargs['default'] = ... if info.is_required else info.default - + kwargs["default"] = ... if info.is_required else info.default + # Use provided help or fall back to registered help - if 'help' not in kwargs: - kwargs['help'] = info.help - + if "help" not in kwargs: + kwargs["help"] = info.help + # Create and return the Typer Option return typer.Option(**kwargs) - + def get_info(self, param_name: str) -> ParameterInfo: - """ - Get the raw parameter information for a parameter. - + """Get the raw parameter information for a parameter. + Parameters ---------- param_name : str The name of the parameter (can be "ClassName.param_name" or just "param_name") - + Returns ------- ParameterInfo @@ -539,7 +523,7 @@ def get_info(self, param_name: str) -> ParameterInfo: - type_annotation: The type annotation (None if not found) - is_required: True if parameter has no default value - has_default: True if parameter has a default value - + Raises ------ KeyError @@ -547,13 +531,12 @@ def get_info(self, param_name: str) -> ParameterInfo: """ if param_name not in self._parameters: raise KeyError(f"Parameter '{param_name}' has not been registered") - + return self._parameters[param_name] - + def get_parameter_names(self) -> list[str]: - """ - Get a list of all registered parameter names. - + """Get a list of all registered parameter names. + Returns ------- list[str] diff --git a/src/segger/data/__init__.py b/src/segger/data/__init__.py index dc1da57..728676a 100644 --- a/src/segger/data/__init__.py +++ b/src/segger/data/__init__.py @@ -1,2 +1,4 @@ from .data_module import ISTDataModule -from .writer import ISTSegmentationWriter \ No newline at end of file +from .writer import ISTSegmentationWriter + +__all__ = ["ISTDataModule", "ISTSegmentationWriter"] diff --git a/src/segger/data/data_module.py b/src/segger/data/data_module.py index 2a1b9d1..ca66bb3 100644 --- a/src/segger/data/data_module.py +++ b/src/segger/data/data_module.py @@ -1,42 +1,33 @@ -from torch_geometric.loader import DataLoader -from torch_geometric.transforms import BaseTransform -from torch_geometric.utils import negative_sampling -from lightning.pytorch import LightningDataModule -from torchvision.transforms import Compose +import gc from dataclasses import dataclass -from typing import Literal +from io import get_preprocessor, StandardBoundaryFields, StandardTranscriptFields from pathlib import Path -import polars as pl +from typing import Literal + import torch -import gc -import numpy as np -from .tile_dataset import ( - TileFitDataset, - TilePredictDataset, - DynamicBatchSamplerPatch -) -from ..io import ( - StandardTranscriptFields, - StandardBoundaryFields, - get_preprocessor -) -from .utils import setup_anndata, setup_heterodata -from .tiling import QuadTreeTiling, SquareTiling -from .partition import PartitionSampler +import polars as pl +from lightning.pytorch import LightningDataModule +from torch_geometric.loader import DataLoader +from torch_geometric.transforms import BaseTransform +from torch_geometric.utils import negative_sampling +from .partition import PartitionSampler +from .tile_dataset import DynamicBatchSamplerPatch, TileFitDataset, TilePredictDataset +from .tiling import QuadTreeTiling, SquareTiling +from .utils import setup_anndata, setup_heterodata class NegativeSampling(BaseTransform): - #TODO: Add documentation + # TODO: Add documentation def __init__( self, edge_type: tuple[str], sampling_ratio: float, - pos_index: str = 'edge_index', - neg_index: str = 'neg_edge_index', + pos_index: str = "edge_index", + neg_index: str = "neg_edge_index", ): - #TODO: Add documentation + # TODO: Add documentation super().__init__() self.edge_type = edge_type self.pos_index = pos_index @@ -63,11 +54,11 @@ def forward(self, data): data[self.edge_type][self.neg_index] = neg_idx return data - + @dataclass class ISTDataModule(LightningDataModule): - """PyTorch Lightning DataModule for preparing and loading spatial + """PyTorch Lightning DataModule for preparing and loading spatial transcriptomics data in IST format. This class handles preprocessing, graph construction, tiling, and @@ -125,41 +116,40 @@ class ISTDataModule(LightningDataModule): edges_per_batch : int, default=1_000_000 Maximum number of edges per batch in the DataLoader. """ + input_directory: Path num_workers: int = 8 cells_representation_mode: Literal["pca", "morphology"] = "pca" cells_embedding_size: int | None = 128 cells_min_counts: int = 10 cells_clusters_n_neighbors: int = 10 - cells_clusters_resolution: float = 2. + cells_clusters_resolution: float = 2.0 genes_min_counts: int = 100 genes_clusters_n_neighbors: int = 5 - genes_clusters_resolution: float = 2. + genes_clusters_resolution: float = 2.0 transcripts_graph_max_k: int = 5 - transcripts_graph_max_dist: float = 5. + transcripts_graph_max_dist: float = 5.0 segmentation_graph_mode: Literal["nucleus", "cell"] = "nucleus" - segmentation_graph_negative_edge_rate: float = 1. + segmentation_graph_negative_edge_rate: float = 1.0 prediction_graph_mode: Literal["nucleus", "cell", "uniform"] = "cell" prediction_graph_max_k: int = 3 prediction_graph_buffer_ratio: float = 0.05 tiling_mode: Literal["adaptive", "square"] = "adaptive" # TODO: Remove (benchmarking only) - tiling_margin_training: float = 20. - tiling_margin_prediction: float = 20. + tiling_margin_training: float = 20.0 + tiling_margin_prediction: float = 20.0 tiling_nodes_per_tile: int = 50_000 - tiling_side_length: float = 250. # TODO: Remove (benchmarking only) + tiling_side_length: float = 250.0 # TODO: Remove (benchmarking only) training_fraction: float = 0.75 edges_per_batch: int = 1_000_000 - + def __post_init__(self): - """TODO: Description - """ + """TODO: Description.""" super().__init__() self.save_hyperparameters() self.load() def load(self): - """TODO: Description - """ + """TODO: Description.""" # Load and prepare shared objects tx_fields = StandardTranscriptFields() bd_fields = StandardBoundaryFields() @@ -180,10 +170,7 @@ def load(self): ] boundary_type = bd_fields.cell_value else: - raise ValueError( - f"Unrecognized segmentation graph mode: " - f"'{self.segmentation_graph_mode}'." - ) + raise ValueError(f"Unrecognized segmentation graph mode: " f"'{self.segmentation_graph_mode}'.") tx_mask = pl.col(tx_fields.compartment).is_in(compartments) bd_mask = bd[bd_fields.boundary_type] == boundary_type @@ -205,12 +192,8 @@ def load(self): transcripts=tx, boundaries=bd, adata=self.ad, - segmentation_mask=tx_mask, # This is the original mask, which is correct - cells_embedding_key=( - 'X_pca' - if self.cells_representation_mode == 'pca' - else 'X_morphology' - ), + segmentation_mask=tx_mask, # This is the original mask, which is correct + cells_embedding_key=("X_pca" if self.cells_representation_mode == "pca" else "X_morphology"), transcripts_graph_max_k=self.transcripts_graph_max_k, transcripts_graph_max_dist=self.transcripts_graph_max_dist, prediction_graph_mode=self.prediction_graph_mode, @@ -218,47 +201,42 @@ def load(self): prediction_graph_buffer_ratio=self.prediction_graph_buffer_ratio, ) # Tile graph dataset - node_positions = torch.vstack([ - self.data['tx']['pos'], - self.data['bd']['pos'], - ]) + node_positions = torch.vstack( + [ + self.data["tx"]["pos"], + self.data["bd"]["pos"], + ] + ) if self.tiling_mode == "adaptive": self.tiling = QuadTreeTiling( positions=node_positions, max_tile_size=self.tiling_nodes_per_tile, ) - #TODO: Remove (benchmarking only) + # TODO: Remove (benchmarking only) elif self.tiling_mode == "square": self.tiling = SquareTiling( positions=node_positions, side_length=self.tiling_side_length, ) else: - raise ValueError( - f"Unrecognized tiling strategy: '{self.tiling_mode}'." - ) + raise ValueError(f"Unrecognized tiling strategy: '{self.tiling_mode}'.") # Objects needed by lightning model self.tx_embedding = ( - pl - .from_numpy(self.ad.varm['X_corr']) + pl.from_numpy(self.ad.varm["X_corr"]) .cast(pl.Float32) - .with_columns( - pl.Series(self.ad.var.index).alias(tx_fields.feature)) + .with_columns(pl.Series(self.ad.var.index).alias(tx_fields.feature)) ) - self.tx_similarity = torch.tensor( - self.ad.uns['gene_cluster_similarities']) - self.bd_similarity = torch.tensor( - self.ad.uns['cell_cluster_similarities']) + self.tx_similarity = torch.tensor(self.ad.uns["gene_cluster_similarities"]) + self.bd_similarity = torch.tensor(self.ad.uns["cell_cluster_similarities"]) def setup(self, stage: str): - """TODO: Description - """ + """TODO: Description.""" # Tile dataset (inner margin) for training if stage == "fit": self.fit_dataset = TileFitDataset( data=self.data, tiling=self.tiling, - margin=self.tiling_margin_training, + margin=self.tiling_margin_training, clone=True, # Keep: Tiling removes edges needed in prediction ) # Setup training-validation split @@ -279,8 +257,7 @@ def setup(self, stage: str): return super().setup(stage) def teardown(self, stage): - """TODO: Description - """ + """TODO: Description.""" # Clean up data objects no longer needed if stage == "fit": del self.fit_dataset.data, self.fit_dataset @@ -292,8 +269,7 @@ def teardown(self, stage): self.data = self.data.cpu() def train_dataloader(self): - """TODO: Description - """ + """TODO: Description.""" sampler = PartitionSampler( self.fit_dataset, max_num=self.edges_per_batch, @@ -306,10 +282,9 @@ def train_dataloader(self): batch_sampler=sampler, num_workers=self.num_workers, ) - + def val_dataloader(self): - """TODO: Description - """ + """TODO: Description.""" sampler = PartitionSampler( self.fit_dataset, max_num=self.edges_per_batch, @@ -324,12 +299,11 @@ def val_dataloader(self): ) def predict_dataloader(self): - """TODO: Description - """ + """TODO: Description.""" sampler = DynamicBatchSamplerPatch( self.predict_dataset, max_num=self.edges_per_batch, - mode='edge', + mode="edge", shuffle=False, skip_too_big=False, ) diff --git a/src/segger/data/partition/__init__.py b/src/segger/data/partition/__init__.py index 68d4d1b..78f3d62 100644 --- a/src/segger/data/partition/__init__.py +++ b/src/segger/data/partition/__init__.py @@ -1,2 +1,4 @@ +from .dataset import PartitionDataset from .sampler import PartitionSampler -from .dataset import PartitionDataset \ No newline at end of file + +__all__ = ["PartitionDataset", "PartitionSampler"] diff --git a/src/segger/data/partition/dataset.py b/src/segger/data/partition/dataset.py index 6be8124..7a3a7b1 100644 --- a/src/segger/data/partition/dataset.py +++ b/src/segger/data/partition/dataset.py @@ -1,12 +1,12 @@ -from torch.nested._internal.nested_tensor import NestedTensor -from torch_geometric.data.storage import EdgeStorage -from torch_geometric.transforms import BaseTransform -from torch_geometric.data import Data, HeteroData from dataclasses import dataclass, field -from functools import cached_property -from typing import Any, Literal +from typing import Any + import torch +from torch_geometric.data import Data, HeteroData +from torch_geometric.data.storage import EdgeStorage +from torch_geometric.transforms import BaseTransform + @dataclass class Partition: @@ -30,21 +30,21 @@ class Partition: The permutation that was applied to the original nodes to sort them by partition. """ - node_indptr: torch.Tensor = None - edge_indptr: torch.Tensor = None - node_sizes: torch.Tensor = None - edge_sizes: torch.Tensor = None - node_permutation: torch.Tensor = None + + node_indptr: torch.Tensor = None + edge_indptr: torch.Tensor = None + node_sizes: torch.Tensor = None + edge_sizes: torch.Tensor = None + node_permutation: torch.Tensor = None def _validate_num_partitions(self) -> bool: - """Confirms all node and edge elements have same numbers of partitions. - """ + """Confirms all node and edge elements have same numbers of partitions.""" node_attributes = [self.node_sizes, self.node_indptr] if not any(node_attributes): return True elif not all(node_attributes): return False - + edge_attributes = [self.edge_sizes, self.edge_indptr] if not any(edge_attributes): return True @@ -58,18 +58,13 @@ def _validate_num_partitions(self) -> bool: len(self.node_indptr) - 1, len(self.edge_indptr) - 1, ] - + return len(set(num_partitions)) == 1 - + def __len__(self) -> int: - """ - Returns number of partitions tracked by this partitioning, 0 if empty. - """ + """Returns number of partitions tracked by this partitioning, 0 if empty.""" if not self._validate_num_partitions(): - raise ValueError( - "This `Partition` contains inconsistent numbers of partitions " - "across elements." - ) + raise ValueError("This `Partition` contains inconsistent numbers of partitions " "across elements.") if self.node_sizes: return len(self.node_sizes) if self.edge_sizes: @@ -97,11 +92,12 @@ class HeteroPartition: node_permutation : dict Maps node type to the permutation tensor that was applied to its nodes. """ - node_indptr: dict = field(default_factory=dict) - edge_indptr: dict = field(default_factory=dict) - node_sizes: dict = field(default_factory=dict) - edge_sizes: dict = field(default_factory=dict) - node_permutation: dict = field(default_factory=dict) + + node_indptr: dict = field(default_factory=dict) + edge_indptr: dict = field(default_factory=dict) + node_sizes: dict = field(default_factory=dict) + edge_sizes: dict = field(default_factory=dict) + node_permutation: dict = field(default_factory=dict) def _validate_keys(self) -> bool: """Confirms all node and edge elements have same sets of keys.""" @@ -114,12 +110,10 @@ def _validate_keys(self) -> bool: set(self.edge_sizes), set(self.edge_indptr), ] - return all(s == node_sets[0] for s in node_sets) and \ - all(s == edge_sets[0] for s in edge_sets) - + return all(s == node_sets[0] for s in node_sets) and all(s == edge_sets[0] for s in edge_sets) + def _validate_num_partitions(self) -> bool: - """Confirms all node and edge elements have same numbers of partitions. - """ + """Confirms all node and edge elements have same numbers of partitions.""" node_attributes = [self.node_sizes.values(), self.node_indptr.values()] if not any(node_attributes): return True @@ -139,16 +133,11 @@ def _validate_num_partitions(self) -> bool: ] return len(set(num_partitions)) == 1 - + def __len__(self) -> int: - """ - Returns number of partitions tracked by this partitioning, 0 if empty. - """ + """Returns number of partitions tracked by this partitioning, 0 if empty.""" if not self._validate_num_partitions(): - raise ValueError( - "This `HeteroPartition` contains inconsistent numbers of " - "partitions across elements." - ) + raise ValueError("This `HeteroPartition` contains inconsistent numbers of " "partitions across elements.") if self.node_sizes: return len(next(iter(self.node_sizes.values()))) if self.edge_sizes: @@ -156,7 +145,6 @@ def __len__(self) -> int: return 0 - class PartitionDataset(torch.utils.data.Dataset): """Represents a PyG dataset partitioned into disconnected subgraphs. @@ -183,6 +171,7 @@ class PartitionDataset(torch.utils.data.Dataset): partition : Partition or HeteroPartition An object containing all metadata about the partitions. """ + def __init__( self, data: Data | HeteroData, @@ -202,18 +191,14 @@ def __init__( else: self._validate_dense(data, partition) - self.partition = ( - HeteroPartition() if self._is_hetero else Partition()) + self.partition = HeteroPartition() if self._is_hetero else Partition() self.data = data.clone() if clone else data # Calculate global no. partitions upfront if self._is_hetero: self._num_partitions = -1 for labels in partition.values(): if labels.numel() > 0: - self._num_partitions = max( - self._num_partitions, - labels.max().item() - ) + self._num_partitions = max(self._num_partitions, labels.max().item()) self._num_partitions += 1 else: self._num_partitions = partition.max().item() + 1 @@ -224,31 +209,21 @@ def __init__( self.transform = transform def _validate_sparse(self, data: Data | HeteroData, partition: Any): - """ - Validates that a sparse partition object is consistent with the graph. - """ + """Validates that a sparse partition object is consistent with the graph.""" if self._is_hetero: if not isinstance(partition, HeteroPartition): - raise ValueError( - "For a heterogeneous graph, sparse input must be a " - "`HeteroPartition` object." - ) + raise ValueError("For a heterogeneous graph, sparse input must be a " "`HeteroPartition` object.") if not partition._validate_keys(): raise ValueError( - "Provided `HeteroPartition` contains inconsistent node " - "or edge keys across elements." + "Provided `HeteroPartition` contains inconsistent node " "or edge keys across elements." ) if not partition._validate_num_partitions(): raise ValueError( - "Provided `HeteroPartition` contains inconsistent numbers " - "of partitions across elements." + "Provided `HeteroPartition` contains inconsistent numbers " "of partitions across elements." ) for attr, store in partition.__dict__.items(): if not isinstance(store, dict): - raise TypeError( - f"Attribute '{attr}' in `HeteroPartition` must be a " - f"dict." - ) + raise TypeError(f"Attribute '{attr}' in `HeteroPartition` must be a " f"dict.") is_node_attr = all(n in store for n in data.node_types) is_edge_attr = all(e in store for e in data.edge_types) if not (is_node_attr or is_edge_attr): @@ -264,15 +239,9 @@ def _validate_sparse(self, data: Data | HeteroData, partition: Any): ) else: # homogeneous graph if not isinstance(partition, Partition): - raise ValueError( - "For a homogeneous graph, sparse input must be a " - "`Partition` object." - ) + raise ValueError("For a homogeneous graph, sparse input must be a " "`Partition` object.") if not partition._validate_num_partitions(): - raise ValueError( - "Provided `Partition` contains inconsistent numbers of " - "partitions across elements." - ) + raise ValueError("Provided `Partition` contains inconsistent numbers of " "partitions across elements.") for attr, store in partition.__dict__.items(): if not isinstance(store, torch.Tensor): raise TypeError( @@ -281,9 +250,7 @@ def _validate_sparse(self, data: Data | HeteroData, partition: Any): ) def _validate_dense(self, data: Data | HeteroData, partition: Any): - """ - Validates that a dense partition input is consistent with the graph. - """ + """Validates that a dense partition input is consistent with the graph.""" if self._is_hetero: if not isinstance(partition, dict): raise TypeError( @@ -292,20 +259,15 @@ def _validate_dense(self, data: Data | HeteroData, partition: Any): ) for node_type in data.node_types: if node_type not in partition: - raise KeyError( - f"The `partition` dictionary is missing an entry for " - f"node type: '{node_type}'." - ) + raise KeyError(f"The `partition` dictionary is missing an entry for " f"node type: '{node_type}'.") if not isinstance(partition[node_type], torch.Tensor): raise TypeError( f"The partition for node type '{node_type}' must be a " f"`torch.Tensor`, not {type(partition[node_type])}." ) elif not isinstance(partition, torch.Tensor): - raise TypeError( - "For a homogeneous graph, dense input must be a `torch.Tensor`." - ) - + raise TypeError("For a homogeneous graph, dense input must be a `torch.Tensor`.") + @staticmethod def _index_select( input: torch.Tensor, @@ -342,7 +304,7 @@ def _permute_nodes( labels: torch.Tensor | dict[str, torch.Tensor], ): """Permutes all node attributes and stores partition metadata. - + This method iterates through each node type, calculates the correct node permutation based on the partition labels, applies this permutation to all node-level attributes, and saves the resulting @@ -374,7 +336,7 @@ def _permute_nodes( self.partition.node_permutation, dim=0, ) - + def _permute_node_labels( self, labels: torch.Tensor, @@ -388,12 +350,9 @@ def _permute_node_labels( labels[permutation], minlength=self._num_partitions, ) - indptr = torch.cat(( - torch.tensor([0], device=labels.device), - torch.cumsum(sizes, dim=0) - )) + indptr = torch.cat((torch.tensor([0], device=labels.device), torch.cumsum(sizes, dim=0))) return permutation, indptr, sizes - + def _permute_edges( self, labels: torch.Tensor | dict[str, torch.Tensor], @@ -448,10 +407,12 @@ def _map_edge_index( dst_perm.numel(), device=dst_perm.device, ) - edge_store.edge_index = torch.stack([ - inv_src_perm[edge_store.edge_index[0]], - inv_dst_perm[edge_store.edge_index[1]], - ]) + edge_store.edge_index = torch.stack( + [ + inv_src_perm[edge_store.edge_index[0]], + inv_dst_perm[edge_store.edge_index[1]], + ] + ) def _permute_edge_store( self, @@ -477,7 +438,7 @@ def _permute_edge_store( # Update edge store with permutation, including edge index for attr in edge_store.edge_attrs(): - if attr == 'edge_index': + if attr == "edge_index": edge_store[attr] = edge_store[attr][:, permutation][:, mask] else: edge_store[attr] = self._index_select( @@ -485,26 +446,28 @@ def _permute_edge_store( permutation[mask], dim=0, ) - + # Get partition properties sizes = torch.bincount( src_edge_labels[mask], minlength=self._num_partitions, ) - indptr = torch.cat(( - torch.tensor([0], device=src_edge_labels.device), - torch.cumsum(sizes, dim=0), - )) + indptr = torch.cat( + ( + torch.tensor([0], device=src_edge_labels.device), + torch.cumsum(sizes, dim=0), + ) + ) return indptr, sizes def __len__(self) -> int: """Description.""" return self._num_partitions - + def __getitem__(self, index: int): - """Get the graph partition associated at location `index`. - + """Get the graph partition associated at location `index`. + Initializes an empty Data or HeteroData object and populates with node and edge attributes associated with the indexed graph partition. Other non-node/edge attributes are populated without subsetting. @@ -512,10 +475,7 @@ def __getitem__(self, index: int): if index < 0: index += len(self) if not 0 <= index < len(self): - raise IndexError( - f"Index {index} is out of range for dataset with {len(self)} " - f"partitions." - ) + raise IndexError(f"Index {index} is out of range for dataset with {len(self)} " f"partitions.") if self._is_hetero: part = HeteroData() for node_type, node_store in self.data.node_items(): @@ -536,7 +496,7 @@ def __getitem__(self, index: int): edge_i = self.partition.edge_indptr[edge_type][index] edge_j = self.partition.edge_indptr[edge_type][index + 1] for name, attr in edge_store.items(): - if name == 'edge_index': + if name == "edge_index": edge_index = attr[:, edge_i:edge_j].clone() edge_index[0] -= node_i_src edge_index[1] -= node_i_dst @@ -551,24 +511,21 @@ def __getitem__(self, index: int): else: part = Data() for name, attr in self.data: - node_i, node_j = self.partition.node_indptr[index:index + 2] - edge_i, edge_j = self.partition.edge_indptr[index:index + 2] + node_i, node_j = self.partition.node_indptr[index : index + 2] + edge_i, edge_j = self.partition.edge_indptr[index : index + 2] if self.data.is_node_attr(name): part[name] = self._index_select(attr, slice(node_i, node_j)) elif self.data.is_edge_attr(name): - if name == 'edge_index': + if name == "edge_index": edge_index = attr[:, edge_i:edge_j].clone() part[name] = edge_index - node_i else: - part[name] = self._index_select( - attr, - slice(edge_i, edge_j) - ) + part[name] = self._index_select(attr, slice(edge_i, edge_j)) else: part[name] = attr # Optionally transform part = part if self.transform is None else self.transform(part) - + return part def _add_node_attr( @@ -579,7 +536,7 @@ def _add_node_attr( ): """Adds and permutes a new node attribute to self.data. - The provided attribute tensor must correspond to the nodes in the + The provided attribute tensor must correspond to the nodes in the original graph before partitioning. Parameters @@ -592,26 +549,22 @@ def _add_node_attr( node_type : str, optional The target node type for the attribute. This is required for heterogeneous graphs. - + Raises ------ ValueError - If `node_type` is omitted for a heterogeneous graph, if the `attr` - tensor is too small for the permutation, or if `node_type` is + If `node_type` is omitted for a heterogeneous graph, if the `attr` + tensor is too small for the permutation, or if `node_type` is provided for a homogeneous graph. """ if self._is_hetero: if node_type is None: - raise ValueError( - "A node type must be supplied for HeteroData attributes." - ) - node_perm = self.partition.node_permutation[node_type] + raise ValueError("A node type must be supplied for HeteroData attributes.") + node_perm = self.partition.node_permutation[node_type] node_store = self.data[node_type] else: if node_type is not None: - raise ValueError( - "No node type should be supplied for Data attributes." - ) + raise ValueError("No node type should be supplied for Data attributes.") node_perm = self.partition.node_permutation node_store = self.data if attr.shape[0] > node_perm.max() + 1: diff --git a/src/segger/data/partition/sampler.py b/src/segger/data/partition/sampler.py index dd350e3..d7f9a1c 100644 --- a/src/segger/data/partition/sampler.py +++ b/src/segger/data/partition/sampler.py @@ -1,9 +1,8 @@ -from typing import List, Literal, Iterator, Optional -from torch_geometric.loader import DataLoader -import random -import torch import math +import random +from typing import Iterator, List, Literal, Optional +import torch from .dataset import PartitionDataset @@ -15,8 +14,8 @@ def best_fit_decreasing( ) -> List[List[int]]: """Implements the Best-Fit Decreasing (BFD) bin packing algorithm. - BFD works by first sorting all items from largest to smallest. Then, each - item is placed into the bin where it fits most tightly (i.e., the bin with + BFD works by first sorting all items from largest to smallest. Then, each + item is placed into the bin where it fits most tightly (i.e., the bin with the least remaining capacity that can still hold the item). Parameters @@ -26,32 +25,27 @@ def best_fit_decreasing( bin_capacity : float The capacity of each bin. skip_too_big : bool, optional - If True, items larger than `bin_capacity` or <= 0 are ignored instead + If True, items larger than `bin_capacity` or <= 0 are ignored instead of raising an error. Defaults to False. Returns ------- list of list of int - A list of bins, where each bin is a list of the original indices of the + A list of bins, where each bin is a list of the original indices of the items it contains. Raises ------ ValueError - If any item has a size greater than `bin_capacity` or less than or + If any item has a size greater than `bin_capacity` or less than or equal to 0, and `skip_too_big` is False. """ if skip_too_big: - indexed_items = [ - (val, i) for i, val in enumerate(items) - if 0 < val <= bin_capacity - ] + indexed_items = [(val, i) for i, val in enumerate(items) if 0 < val <= bin_capacity] else: if not all(0 < item <= bin_capacity for item in items): - raise ValueError( - "All items must be > 0 and <= bin_capacity." - ) + raise ValueError("All items must be > 0 and <= bin_capacity.") indexed_items = [(val, i) for i, val in enumerate(items)] # Sort items by size in descending order. @@ -62,7 +56,7 @@ def best_fit_decreasing( for item_val, item_idx in indexed_items: best_bin_idx = -1 - min_remaining_space = float('inf') + min_remaining_space = float("inf") # Find the best bin for the current item. for i, capacity in enumerate(bin_capacities): if capacity >= item_val: @@ -90,11 +84,11 @@ def harmonic_k( ) -> List[List[int]]: """Implements the Harmonic-k online bin packing algorithm. - Classifies each incoming item into a harmonic interval based on its size - and packs it with other items from the same interval. It processes items + Classifies each incoming item into a harmonic interval based on its size + and packs it with other items from the same interval. It processes items in the order they arrive. - The `k` parameter defines `k-1` intervals for items > 1/k, while + The `k` parameter defines `k-1` intervals for items > 1/k, while items <= 1/k are treated as "small" and packed together. Parameters @@ -107,19 +101,19 @@ def harmonic_k( The integer defining the harmonic intervals. Must be >= 2. Defaults to 6. skip_too_big : bool, optional - If True, items larger than `bin_capacity` or <= 0 are ignored instead + If True, items larger than `bin_capacity` or <= 0 are ignored instead of raising an error. Defaults to False. Returns ------- list of list of int - A list of bins, where each bin is a list of the original indices of the + A list of bins, where each bin is a list of the original indices of the items it contains. Raises ------ ValueError - If an invalid item size is found and `skip_too_big` is False, or if `k` + If an invalid item size is found and `skip_too_big` is False, or if `k` is less than 2. """ @@ -127,15 +121,10 @@ def harmonic_k( raise ValueError("Parameter k must be an integer >= 2.") if skip_too_big: - indexed_items = [ - (val, i) for i, val in enumerate(items) - if 0 < val <= bin_capacity - ] + indexed_items = [(val, i) for i, val in enumerate(items) if 0 < val <= bin_capacity] else: if not all(0 < item <= bin_capacity for item in items): - raise ValueError( - "All items must be > 0 and <= bin_capacity." - ) + raise ValueError("All items must be > 0 and <= bin_capacity.") indexed_items = list(enumerate(items)) indexed_items = [(val, i) for i, val in indexed_items] @@ -229,15 +218,10 @@ def first_fit_decreasing_bucketed( rng = rng or random if skip_too_big: - indexed_items = [ - (val, i) for i, val in enumerate(items) - if 0 < val <= bin_capacity - ] + indexed_items = [(val, i) for i, val in enumerate(items) if 0 < val <= bin_capacity] else: if not all(0 < item <= bin_capacity for item in items): - raise ValueError( - "All items must be > 0 and <= bin_capacity." - ) + raise ValueError("All items must be > 0 and <= bin_capacity.") indexed_items = [(val, i) for i, val in enumerate(items)] if not indexed_items: @@ -253,13 +237,8 @@ def first_fit_decreasing_bucketed( rng.shuffle(indexed_items) # Full shuffle else: # Find positions of the (k-1) largest adjacent gaps. - gaps = [ - (indexed_items[i - 1][0] - indexed_items[i][0], i) - for i in range(1, n) - ] - cut_at = { - pos for _, pos in sorted(gaps, reverse=True)[:n_buckets - 1] - } + gaps = [(indexed_items[i - 1][0] - indexed_items[i][0], i) for i in range(1, n)] + cut_at = {pos for _, pos in sorted(gaps, reverse=True)[: n_buckets - 1]} # Shuffle within each bucket. start = 0 @@ -280,7 +259,7 @@ def first_fit_decreasing_bucketed( bin_capacities[i] -= item_val placed_in_bin = True break - + # If no suitable bin was found, open a new one. if not placed_in_bin: bins.append([item_idx]) @@ -292,15 +271,16 @@ def first_fit_decreasing_bucketed( class PartitionSampler(torch.utils.data.Sampler): """A batch sampler that packs data partitions into pre-computed batches. - This sampler groups partitions (e.g., subgraphs) into batches using bin - packing algorithms. Batches are pre-computed to ensure the sampler's length + This sampler groups partitions (e.g., subgraphs) into batches using bin + packing algorithms. Batches are pre-computed to ensure the sampler's length is always accurate. If `shuffle` is True, it uses an online algorithm (Harmonic-k) and regenerates the batches with a new shuffle after iterating through once - (e.g., at the beginning of a new epoch). If `shuffle` is False, it uses an + (e.g., at the beginning of a new epoch). If `shuffle` is False, it uses an offline algorithm (BFD) and computes the batches only once. """ + def __init__( self, dataset: PartitionDataset, @@ -319,7 +299,7 @@ def __init__( max_num : int The maximum number of nodes or edges allowed per batch. mode : {"node", "edge"}, optional - Determines whether to use partition node counts or edge counts as + Determines whether to use partition node counts or edge counts as the weights for packing. Defaults to "edge". subset : list of int, optional A list of partition indices to sample from. If None, the entire @@ -330,7 +310,7 @@ def __init__( batches. If False, an offline algorithm is used to create a fixed, deterministic set of batches. Defaults to False. skip_too_big : bool, optional - If True, partitions larger than `max_num` are ignored. If False, + If True, partitions larger than `max_num` are ignored. If False, the packing algorithm will raise a ValueError. Defaults to False. """ self.dataset = dataset @@ -339,17 +319,14 @@ def __init__( self.subset = subset self.shuffle = shuffle self.skip_too_big = skip_too_big - self.packing_algo = ( - first_fit_decreasing_bucketed if self.shuffle else - best_fit_decreasing - ) + self.packing_algo = first_fit_decreasing_bucketed if self.shuffle else best_fit_decreasing # Get partition sizes in numbers of nodes or edges if mode == "edge": weights = self.dataset.partition.edge_sizes else: weights = self.dataset.partition.node_sizes - + if self.dataset._is_hetero: weights = torch.stack(list(weights.values())).sum(dim=0) self.weights = weights.tolist() @@ -363,8 +340,7 @@ def __init__( def _generate_batches(self) -> None: """Generates and stores a new set of batches for an epoch.""" - indices = self.subset if self.subset is not None else \ - list(range(len(self.weights))) + indices = self.subset if self.subset is not None else list(range(len(self.weights))) if self.shuffle: random.shuffle(indices) @@ -389,8 +365,7 @@ def __iter__(self) -> Iterator[list[int]]: """ if self.stale: self._generate_batches() - for batch in self.batches: - yield batch + yield from self.batches if self.shuffle: self.stale = True diff --git a/src/segger/data/tile_dataset.py b/src/segger/data/tile_dataset.py index 921573f..ebbab86 100644 --- a/src/segger/data/tile_dataset.py +++ b/src/segger/data/tile_dataset.py @@ -1,18 +1,17 @@ -from torch_geometric.loader import DynamicBatchSampler -from torch_geometric.data.storage import NodeStorage -from torch_geometric.data import Data, HeteroData -from torch.utils.data import Dataset -import shapely import torch +from torch.utils.data import Dataset +import shapely +from torch_geometric.data import Data, HeteroData +from torch_geometric.data.storage import NodeStorage +from torch_geometric.loader import DynamicBatchSampler from .partition import PartitionDataset from .tiling import Tiling class TileFitDataset(PartitionDataset): - """ - Partitions a PyG graph based on a geometric tiling of its nodes. + """Partitions a PyG graph based on a geometric tiling of its nodes. This class extends `PartitionDataset` to create partitions by assigning each node to a tile based on its spatial coordinates. It can also add a @@ -34,16 +33,17 @@ class TileFitDataset(PartitionDataset): If True, removes the geometry attribute from the data after partitioning, by default True. """ + def __init__( self, data: Data | HeteroData, tiling: Tiling, margin: float, - geometry_key: str = 'geometry', + geometry_key: str = "geometry", clone: bool = True, drop_geometry: bool = True, ): - """Initializes and tiles the dataset""" + """Initializes and tiles the dataset.""" self.geometry_key = geometry_key self._validate_data(data) @@ -51,7 +51,7 @@ def __init__( self.tiling = tiling self.margin = margin partition = self._get_partition(data) - + # Partition graph by tiling # Note: self.data and self.partition are set inside super.__init__() super().__init__(data=data, partition=partition, clone=clone) @@ -66,9 +66,7 @@ def _validate_geometry( ): """Checks that 'node_store' has a valid geometry attribute.""" if self.geometry_key not in node_store.node_attrs(): - raise AttributeError( - f"{store_name} is missing '{self.geometry_key}' attribute." - ) + raise AttributeError(f"{store_name} is missing '{self.geometry_key}' attribute.") geometry = node_store[self.geometry_key] if not isinstance(geometry, torch.Tensor): raise TypeError( @@ -83,64 +81,48 @@ def _validate_geometry( ) def _validate_data(self, data: Data | HeteroData): - """ - Checks 'data' is a Pytorch Geometric data object, that all node types + """Checks 'data' is a Pytorch Geometric data object, that all node types have valid geometry attributes, and that 'mask' does not already exist as an attribute. """ if isinstance(data, Data): store_name = "The 'data' object" self._validate_geometry(data, store_name) - if 'mask' in data: - raise KeyError( - f"{store_name} must not contain an attribute 'mask'." - ) + if "mask" in data: + raise KeyError(f"{store_name} must not contain an attribute 'mask'.") elif isinstance(data, HeteroData): if not data.node_types: return for node_type in data.node_types: store_name = f"Node type '{node_type}' in the 'data' object" self._validate_geometry(data[node_type], store_name) - if 'mask' in data[node_type]: - raise KeyError( - f"{store_name} must not contain an attribute 'mask'." - ) + if "mask" in data[node_type]: + raise KeyError(f"{store_name} must not contain an attribute 'mask'.") else: - raise TypeError( - f"Input must be a PyG Data or HeteroData object, but got " - f"{type(data).__name__}." - ) + raise TypeError(f"Input must be a PyG Data or HeteroData object, but got " f"{type(data).__name__}.") def _get_partition(self, data: Data | HeteroData) -> torch.Tensor: - """ - Generates partition labels for all nodes using the tiling object. - """ + """Generates partition labels for all nodes using the tiling object.""" if isinstance(data, HeteroData): - partition = dict() + partition = {} for node_type in data.node_types: - partition[node_type] = self.tiling.label( - data[node_type][self.geometry_key] - ) + partition[node_type] = self.tiling.label(data[node_type][self.geometry_key]) return partition else: # isinstance(data, Data) return self.tiling.label(data[self.geometry_key]) - + def _mask_data(self, data: Data | HeteroData) -> Data | HeteroData: - """ - Adds a boolean 'mask' attribute to each node indicating whether it is + """Adds a boolean 'mask' attribute to each node indicating whether it is within a specified margin of a tile's boundary. """ if isinstance(data, HeteroData): for node_type in data.node_types: - data[node_type]['mask'] = self.tiling.mask( + data[node_type]["mask"] = self.tiling.mask( data[node_type][self.geometry_key], self.margin, ) else: # isinstance(data, Data) - data['mask'] = self.tiling.mask( - data[self.geometry_key], - self.margin - ) + data["mask"] = self.tiling.mask(data[self.geometry_key], self.margin) return data def _drop_geometry(self, data: Data | HeteroData) -> Data | HeteroData: @@ -155,11 +137,11 @@ def _drop_geometry(self, data: Data | HeteroData) -> Data | HeteroData: class TilePredictDataset(Dataset): """A dataset for iterating over spatial tiles with overlapping margins. - + This dataset provides subgraphs of a larger graph based on spatial tiling. Each item corresponds to a tile, returning the subgraph of nodes that fall within the tile boundaries plus a specified margin. - + Parameters ---------- data : Data | HeteroData @@ -171,6 +153,7 @@ class TilePredictDataset(Dataset): nodes. Positive values expand tiles outward, negative values shrink them inward. Defaults to 0.0. """ + def __init__( self, data: Data | HeteroData, @@ -187,14 +170,11 @@ def __init__( if self._is_hetero: missing = [] for node_type in self.data.node_types: - if 'pos' not in self.data[node_type].node_attrs(): + if "pos" not in self.data[node_type].node_attrs(): missing.append(node_type) if missing: - raise ValueError( - f"Missing 'pos' attribute for node type: " - f"{', '.join(missing)}" - ) - elif 'pos' not in self.data.node_attrs(): + raise ValueError(f"Missing 'pos' attribute for node type: " f"{', '.join(missing)}") + elif "pos" not in self.data.node_attrs(): raise ValueError("Graph must contain 'pos' attribute.") def __len__(self) -> int: @@ -202,16 +182,14 @@ def __len__(self) -> int: return len(self.tiling.tiles) def __getitem__(self, idx: int) -> Data | HeteroData: - """Get the graph tile associated at location `index`. - + """Get the graph tile associated at location `index`. + Initializes an empty Data or HeteroData object and populates with node and edge attributes associated with the indexed graph partition. Other non-node/edge attributes are populated without subsetting. """ if idx < 0 or idx >= len(self): - raise IndexError( - f"Requested {idx}, but tiling only contains {len(self)} tiles." - ) + raise IndexError(f"Requested {idx}, but tiling only contains {len(self)} tiles.") geometry = self.tiling.tiles[idx] return self._subset(geometry) @@ -222,51 +200,54 @@ def _subset(self, bounds: shapely.Polygon) -> Data | HeteroData: """ inner = bounds.bounds outer = bounds.buffer(self.margin).bounds - + if self._is_hetero: - subset = dict() - p_mask = dict() + subset = {} + p_mask = {} for node_type in self.data.node_types: - pos: torch.Tensor = self.data[node_type]['pos'] + pos: torch.Tensor = self.data[node_type]["pos"] # Row indices of masked elements inside tile w/ margin subset[node_type] = ( - (pos[:, 0] >= outer[0]) & - (pos[:, 0] < outer[2]) & - (pos[:, 1] >= outer[1]) & - (pos[:, 1] < outer[3]) - ).nonzero().flatten() + ( + (pos[:, 0] >= outer[0]) + & (pos[:, 0] < outer[2]) + & (pos[:, 1] >= outer[1]) + & (pos[:, 1] < outer[3]) + ) + .nonzero() + .flatten() + ) p_mask[node_type] = ( - (pos[subset[node_type], 0] >= inner[0]) & - (pos[subset[node_type], 0] <= inner[2]) & - (pos[subset[node_type], 1] >= inner[1]) & - (pos[subset[node_type], 1] <= inner[3]) + (pos[subset[node_type], 0] >= inner[0]) + & (pos[subset[node_type], 0] <= inner[2]) + & (pos[subset[node_type], 1] >= inner[1]) + & (pos[subset[node_type], 1] <= inner[3]) ) sample = self.data.subgraph(subset) - sample.set_value_dict('predict_mask', p_mask) - sample.set_value_dict('global_index', subset) + sample.set_value_dict("predict_mask", p_mask) + sample.set_value_dict("global_index", subset) return sample else: # is homogenous Data - pos: torch.Tensor = self.data['pos'] + pos: torch.Tensor = self.data["pos"] subset = ( - (pos[:, 0] >= outer[0]) & - (pos[:, 0] < outer[2]) & - (pos[:, 1] >= outer[1]) & - (pos[:, 1] < outer[3]) - ).nonzero().flatten() + ((pos[:, 0] >= outer[0]) & (pos[:, 0] < outer[2]) & (pos[:, 1] >= outer[1]) & (pos[:, 1] < outer[3])) + .nonzero() + .flatten() + ) sample = self.data.subgraph(subset) - sample['predict_mask'] = ( - (pos[subset, 0] >= inner[0]) & - (pos[subset, 0] <= inner[2]) & - (pos[subset, 1] >= inner[1]) & - (pos[subset, 1] <= inner[3]) + sample["predict_mask"] = ( + (pos[subset, 0] >= inner[0]) + & (pos[subset, 0] <= inner[2]) + & (pos[subset, 1] >= inner[1]) + & (pos[subset, 1] <= inner[3]) ) - sample['global_index'] = subset + sample["global_index"] = subset return sample class DynamicBatchSamplerPatch(DynamicBatchSampler): - """TODO: Description - """ + """TODO: Description.""" + def __len__(self): return len(self.dataset) # ceiling on dataset length diff --git a/src/segger/data/tiling.py b/src/segger/data/tiling.py index e4e3d09..71c3c83 100644 --- a/src/segger/data/tiling.py +++ b/src/segger/data/tiling.py @@ -1,18 +1,17 @@ -from functools import cached_property from abc import ABC, abstractmethod -from numpy.typing import ArrayLike -from shapely import box -import geopandas as gpd -import numpy as np -import torch +from functools import cached_property + import cudf +import torch +from geometry import * -from ..geometry import * +import geopandas as gpd +import numpy as np +from shapely import box class Tiling(ABC): - """ - An abstract base class for spatial tilings. + """An abstract base class for spatial tilings. Implementing classes must define the `tiles` property, which returns a geopandas GeoSeries. This property should be computed once and cached. @@ -25,8 +24,7 @@ def __init__(self) -> None: @property @abstractmethod def tiles(self) -> gpd.GeoSeries: - """ - A collection of Polygon geometries representing the tiles. + """A collection of Polygon geometries representing the tiles. This is an abstract property that must be implemented by subclasses. It is recommended to use @cached_property in the implementation for @@ -35,12 +33,11 @@ def tiles(self) -> gpd.GeoSeries: ... def _check_tiles(self): - """ - Explicitly ensure `self.tiles` is a collection of Polygon geometries, + """Explicitly ensure `self.tiles` is a collection of Polygon geometries, e.g., not MultiPolygon or Line. """ - assert self.tiles.geom_type.eq('Polygon').all() - + assert self.tiles.geom_type.eq("Polygon").all() + def _query_tiles( self, geometry: torch.Tensor, @@ -88,16 +85,14 @@ def _query_tiles( f"or polygons of shape (N, V, 2), but got {geometry.shape}." ) if margin < 0: - raise ValueError( - f"The margin must be non-negative, but got {margin}." - ) + raise ValueError(f"The margin must be non-negative, but got {margin}.") # Buffer tiles tiles = self.tiles if margin > 0: buffered = tiles.buffer( -margin, - cap_style='square', - join_style='mitre', + cap_style="square", + join_style="mitre", mitre_limit=margin / 2, ) missing = buffered.is_empty.sum() @@ -109,20 +104,20 @@ def _query_tiles( tiles = buffered # Spatial query - predicate = 'intersects' if inclusive else 'contains' - if geometry.dim() == 2: # points + predicate = "intersects" if inclusive else "contains" + if geometry.dim() == 2: # points result = points_in_polygons(geometry, tiles, predicate) - else: # polygons + else: # polygons result = polygons_in_polygons(geometry, tiles, predicate) - result = result.drop_duplicates('index_query') + result = result.drop_duplicates("index_query") # Format to tensor of indices (-1 where no match found) - kwargs = dict(device=geometry.device, dtype=torch.int64) + kwargs = {"device": geometry.device, "dtype": torch.int64} labels = torch.full((len(geometry),), -1, **kwargs) return labels.scatter_( dim=0, - index=torch.tensor(result['index_query'], **kwargs), - src= torch.tensor(result['index_match'], **kwargs), + index=torch.tensor(result["index_query"], **kwargs), + src=torch.tensor(result["index_match"], **kwargs), ) def label( @@ -177,6 +172,7 @@ def mask( labels = self._query_tiles(geometry, inclusive=False, margin=margin) return labels != -1 + class QuadTreeTiling(Tiling): """A tiling system based on a quadtree decomposition of input points. @@ -192,24 +188,24 @@ class QuadTreeTiling(Tiling): max_tile_size : int The maximum number of points allowed in any single quadtree tile. """ + def __init__( self, positions: torch.Tensor, max_tile_size: int, ): # Calculate QuadTree on points and set as tiles - points = points_to_geoseries(positions, backend='cuspatial') + points = points_to_geoseries(positions, backend="cuspatial") _, quadtree = get_quadtree_index( points, max_tile_size, with_bounds=True, ) - self._tiles = quadtree_to_geoseries(quadtree, backend='geopandas') + self._tiles = quadtree_to_geoseries(quadtree, backend="geopandas") @property def tiles(self) -> gpd.GeoSeries: - """ - A collection of Polygon geometries representing the boundaries of the + """A collection of Polygon geometries representing the boundaries of the leaves of the generated QuadTree. """ return self._tiles @@ -217,6 +213,7 @@ def tiles(self) -> gpd.GeoSeries: ### Benchmarking Class ### + class SquareTiling(Tiling): """A tiling system based on a uniform square grid. @@ -233,23 +230,19 @@ class SquareTiling(Tiling): side_length : float The side length of each square tile. Must be positive. """ + def __init__( self, positions: torch.Tensor, side_length: float, ): if side_length <= 0: - raise ValueError( - f"side_length must be positive, but got {side_length}." - ) + raise ValueError(f"side_length must be positive, but got {side_length}.") if positions.dim() != 2 or positions.shape[-1] != 2: - raise ValueError( - f"positions must be a tensor of shape (N, 2), " - f"but got {positions.shape}." - ) + raise ValueError(f"positions must be a tensor of shape (N, 2), " f"but got {positions.shape}.") if len(positions) == 0: raise ValueError("positions cannot be empty.") - + # Store only the spatial extent, not the positions self.min_x = positions[:, 0].min().item() self.max_x = positions[:, 0].max().item() @@ -260,10 +253,9 @@ def __init__( @cached_property def tiles(self) -> gpd.GeoSeries: - """ - A collection of Polygon geometries representing square tiles + """A collection of Polygon geometries representing square tiles covering the spatial extent of the input positions. - + Returns ------- gpd.GeoSeries @@ -272,11 +264,14 @@ def tiles(self) -> gpd.GeoSeries: x, y = np.meshgrid( np.arange(self.min_x, self.max_x, self.side_length), np.arange(self.min_y, self.max_y, self.side_length), - indexing='ij' + indexing="ij", + ) + coords = np.column_stack( + [ + x.ravel(), + y.ravel(), + np.minimum(x.ravel() + self.side_length, self.max_x), + np.minimum(y.ravel() + self.side_length, self.max_y), + ] ) - coords = np.column_stack([ - x.ravel(), y.ravel(), - np.minimum(x.ravel() + self.side_length, self.max_x), - np.minimum(y.ravel() + self.side_length, self.max_y) - ]) return gpd.GeoSeries([box(*c) for c in coords]) diff --git a/src/segger/data/utils/__init__.py b/src/segger/data/utils/__init__.py index 3984a13..6fb0849 100644 --- a/src/segger/data/utils/__init__.py +++ b/src/segger/data/utils/__init__.py @@ -1,3 +1,5 @@ -from .anndata import setup_anndata, anndata_from_transcripts +from .anndata import anndata_from_transcripts, setup_anndata from .heterodata import setup_heterodata -from .neighbors import phenograph_rapids \ No newline at end of file +from .neighbors import phenograph_rapids + +__all__ = ["anndata_from_transcripts", "phenograph_rapids", "setup_anndata", "setup_heterodata"] diff --git a/src/segger/data/utils/anndata.py b/src/segger/data/utils/anndata.py index 93db4c4..d546e02 100644 --- a/src/segger/data/utils/anndata.py +++ b/src/segger/data/utils/anndata.py @@ -1,18 +1,22 @@ +from io.fields import TrainingBoundaryFields, TrainingTranscriptFields + +import cuml +import cupyx +import torch from torch.nn.functional import normalize -from scipy import sparse as sp + import geopandas as gpd -import polars as pl -import pandas as pd -import scanpy as sc import numpy as np +import pandas as pd +import polars as pl import sklearn -import torch -import cupyx -import cuml +from scipy import sparse as sp + +import scanpy as sc -from ...io.fields import TrainingTranscriptFields, TrainingBoundaryFields -from .neighbors import phenograph_rapids from segger.geometry.morphology import get_polygon_props +from .neighbors import phenograph_rapids + def anndata_from_transcripts( tx: pl.DataFrame, @@ -21,18 +25,14 @@ def anndata_from_transcripts( score_column: str | None = None, coordinate_columns: list[str] | None = None, ): - """TODO: Add description. - """ + """TODO: Add description.""" # Remove non-nuclear transcript tx = tx.filter(pl.col(cell_id_column).is_not_null()) # Get sparse counts from transcripts - feature_idx = tx.select( - feature_column).unique().with_row_index() - segment_idx = tx.select( - cell_id_column).unique().with_row_index() + feature_idx = tx.select(feature_column).unique().with_row_index() + segment_idx = tx.select(cell_id_column).unique().with_row_index() groupby = ( - tx - .with_columns( + tx.with_columns( # Map feature to numeric id pl.col(feature_column) .replace_strict( @@ -40,7 +40,7 @@ def anndata_from_transcripts( new=feature_idx["index"], return_dtype=pl.UInt32, ) - .alias('_fid'), + .alias("_fid"), # Map segmentation to numeric id pl.col(cell_id_column) .replace_strict( @@ -48,54 +48,30 @@ def anndata_from_transcripts( new=segment_idx["index"], return_dtype=pl.UInt32, ) - .alias('_sid'), + .alias("_sid"), ) # Create sparse count matrix - .group_by(['_sid', '_fid']) + .group_by(["_sid", "_fid"]) ) # Get correlation matrix ijv = groupby.len().to_numpy().T X = sp.coo_matrix((ijv[2], ijv[:2])).tocsr() - + # To AnnData adata = sc.AnnData( X=X, - obs=pd.DataFrame( - index=( - segment_idx - .get_column(cell_id_column) - .to_numpy() - .astype(str) - ) - ), - var=pd.DataFrame( - index=( - feature_idx - .get_column(feature_column) - .to_numpy() - .astype(str) - ) - ), + obs=pd.DataFrame(index=(segment_idx.get_column(cell_id_column).to_numpy().astype(str))), + var=pd.DataFrame(index=(feature_idx.get_column(feature_column).to_numpy().astype(str))), ) # Optionally: Add transcript scores if score_column is not None: ijv = groupby.agg(pl.col(score_column).mean()).to_numpy().T - adata.layers[f'{score_column}_scores'] = sp.coo_matrix( - (ijv[2], ijv[:2].astype(int))).tocsr() + adata.layers[f"{score_column}_scores"] = sp.coo_matrix((ijv[2], ijv[:2].astype(int))).tocsr() # Optionally: Add coordinates if coordinate_columns is not None: - centroids = ( - tx - .group_by(cell_id_column) - .agg([pl.col(c).mean().alias(c) for c in coordinate_columns]) - ) - coords = ( - centroids - .to_pandas() - .set_index(cell_id_column) - .loc[adata.obs.index, coordinate_columns] - ) + centroids = tx.group_by(cell_id_column).agg([pl.col(c).mean().alias(c) for c in coordinate_columns]) + coords = centroids.to_pandas().set_index(cell_id_column).loc[adata.obs.index, coordinate_columns] adata.obsm["X_spatial"] = coords.values return adata @@ -105,11 +81,10 @@ def get_cluster_cosine_similarity( embedding: torch.Tensor, clusters: torch.Tensor, ) -> torch.Tensor: - """TODO: Add description. - """ + """TODO: Add description.""" # Get label mapping unique, inverse = clusters.unique(sorted=False, return_inverse=True) - + # Empty output tensor k = unique.numel() sums = torch.zeros( @@ -140,8 +115,7 @@ def setup_anndata( genes_clusters_resolution: float, compute_morphology: bool = False, ): - """TODO: Add description. - """ + """TODO: Add description.""" # Standard fields tx_fields = TrainingTranscriptFields() bd_fields = TrainingBoundaryFields() @@ -156,11 +130,9 @@ def setup_anndata( # Map boundary cell IDs to boundary index ad.obs = ( - ad.obs - .join( + ad.obs.join( ( - boundaries - .reset_index(names=bd_fields.index) + boundaries.reset_index(names=bd_fields.index) .set_index(bd_fields.id, verify_integrity=True) .get(bd_fields.index) ), @@ -173,67 +145,67 @@ def setup_anndata( assert ~ad.obs.index.isna().any() # Remove genes with fewer than min counts permanently - ad.var['n_counts'] = ad.X.sum(0).A.flatten() - ad = ad[:, ad.var['n_counts'].ge(genes_min_counts)] + ad.var["n_counts"] = ad.X.sum(0).A.flatten() + ad = ad[:, ad.var["n_counts"].ge(genes_min_counts)] # Explicitly sort indices for reproducibility ad = ad[ad.obs.index.sort_values(), ad.var.index.sort_values()] - + # Add raw counts ad.raw = ad.copy() - ad.layers['counts'] = ad.raw.X.copy() + ad.layers["counts"] = ad.raw.X.copy() # Keep track of filtered cells - ad.obs['n_counts'] = ad.raw.X.sum(1).A.flatten() - ad.obs['filtered'] = ad.obs['n_counts'].ge(cells_min_counts) + ad.obs["n_counts"] = ad.raw.X.sum(1).A.flatten() + ad.obs["filtered"] = ad.obs["n_counts"].ge(cells_min_counts) # Normalize to filtered dataset counts - ad.layers['norm'] = ad.layers['counts'].copy() - target_sum = ad.obs.loc[ad.obs['filtered'], 'n_counts'].median() - sc.pp.normalize_total(ad, target_sum=target_sum, layer='norm') + ad.layers["norm"] = ad.layers["counts"].copy() + target_sum = ad.obs.loc[ad.obs["filtered"], "n_counts"].median() + sc.pp.normalize_total(ad, target_sum=target_sum, layer="norm") # Build gene embedding on filtered dataset - C = np.corrcoef(ad[ad.obs['filtered']].layers['norm'].todense().T) + C = np.corrcoef(ad[ad.obs["filtered"]].layers["norm"].todense().T) C = np.nan_to_num(C, 0, posinf=True, neginf=True) model = sklearn.decomposition.PCA(n_components=cells_embedding_size) - ad.varm['X_corr'] = model.fit_transform(C) + ad.varm["X_corr"] = model.fit_transform(C) # Build PCs on filtered cells and project all cells - counts_sparse_gpu = cupyx.scipy.sparse.csr_matrix(ad.layers['norm']) + counts_sparse_gpu = cupyx.scipy.sparse.csr_matrix(ad.layers["norm"]) model = cuml.PCA(n_components=cells_embedding_size) - model.fit(counts_sparse_gpu[ad.obs['filtered'].values]) - ad.obsm['X_pca'] = model.transform(counts_sparse_gpu).get() + model.fit(counts_sparse_gpu[ad.obs["filtered"].values]) + ad.obsm["X_pca"] = model.transform(counts_sparse_gpu).get() # Compute clusters on filtered cells cell_clusters = phenograph_rapids( - ad[ad.obs['filtered']].obsm['X_pca'], - n_neighbors=cells_clusters_n_neighbors, + ad[ad.obs["filtered"]].obsm["X_pca"], + n_neighbors=cells_clusters_n_neighbors, resolution=cells_clusters_resolution, min_size=100, ) - ad.obs['phenograph_cluster'] = -1 # removed cells have no cluster - ad.obs.loc[ad.obs['filtered'], 'phenograph_cluster'] = cell_clusters - ad.obs['phenograph_cluster'] = pd.Categorical(ad.obs['phenograph_cluster']) + ad.obs["phenograph_cluster"] = -1 # removed cells have no cluster + ad.obs.loc[ad.obs["filtered"], "phenograph_cluster"] = cell_clusters + ad.obs["phenograph_cluster"] = pd.Categorical(ad.obs["phenograph_cluster"]) # Compute pairwise cosine similarities among cell clusters - ad.uns['cell_cluster_similarities'] = get_cluster_cosine_similarity( - embedding=torch.tensor(ad.obsm['X_pca']), - clusters=torch.tensor(ad.obs['phenograph_cluster'].values), + ad.uns["cell_cluster_similarities"] = get_cluster_cosine_similarity( + embedding=torch.tensor(ad.obsm["X_pca"]), + clusters=torch.tensor(ad.obs["phenograph_cluster"].values), ).numpy() # Compute clusters on genes from embedding - ad.var['phenograph_cluster'] = phenograph_rapids( - ad.varm['X_corr'], + ad.var["phenograph_cluster"] = phenograph_rapids( + ad.varm["X_corr"], n_neighbors=genes_clusters_n_neighbors, resolution=genes_clusters_resolution, min_size=-1, ) - ad.var['phenograph_cluster'] = pd.Categorical(ad.var['phenograph_cluster']) + ad.var["phenograph_cluster"] = pd.Categorical(ad.var["phenograph_cluster"]) # Compute pairwise cosine similarities among gene clusters - ad.uns['gene_cluster_similarities'] = get_cluster_cosine_similarity( - embedding=torch.tensor(ad.varm['X_corr']), - clusters=torch.tensor(ad.var['phenograph_cluster'].values), + ad.uns["gene_cluster_similarities"] = get_cluster_cosine_similarity( + embedding=torch.tensor(ad.varm["X_corr"]), + clusters=torch.tensor(ad.var["phenograph_cluster"].values), ).numpy() # Add cell and gene numeric encodings to AnnData ad.obs[tx_fields.cell_encoding] = np.arange(len(ad.obs)).astype(int) @@ -254,5 +226,5 @@ def setup_anndata( for col in morpho_props.columns: ad.obs[col] = morpho_props[col].values # concat all morphology properties into a single embedding - ad.obsm['X_morphology'] = morpho_props.to_numpy(dtype=np.float32) + ad.obsm["X_morphology"] = morpho_props.to_numpy(dtype=np.float32) return ad diff --git a/src/segger/data/utils/heterodata.py b/src/segger/data/utils/heterodata.py index 40f43d9..b5f13c1 100644 --- a/src/segger/data/utils/heterodata.py +++ b/src/segger/data/utils/heterodata.py @@ -1,16 +1,18 @@ -from torch_geometric.data import HeteroData +from io import TrainingBoundaryFields, TrainingTranscriptFields from typing import Literal + +import torch + import geopandas as gpd import polars as pl +from torch_geometric.data import HeteroData + import scanpy as sc -import numpy as np -import torch -from ...io import TrainingBoundaryFields, TrainingTranscriptFields from .neighbors import ( + setup_prediction_graph, setup_segmentation_graph, setup_transcripts_graph, - setup_prediction_graph, ) @@ -24,19 +26,18 @@ def setup_heterodata( prediction_graph_mode: Literal["nucleus", "cell", "uniform"], prediction_graph_max_k: int, prediction_graph_buffer_ratio: float, - cells_embedding_key: str = 'X_pca', - cells_clusters_column: str = 'phenograph_cluster', - cells_encoding_column: str = 'cell_encoding', - genes_embedding_key: str = 'X_corr', - genes_clusters_column: str = 'phenograph_cluster', - genes_encoding_column: str = 'gene_encoding', + cells_embedding_key: str = "X_pca", + cells_clusters_column: str = "phenograph_cluster", + cells_encoding_column: str = "cell_encoding", + genes_embedding_key: str = "X_corr", + genes_clusters_column: str = "phenograph_cluster", + genes_encoding_column: str = "gene_encoding", ) -> HeteroData: - """TODO: Add description. - """ + """TODO: Add description.""" # Standard fields tx_fields = TrainingTranscriptFields() bd_fields = TrainingBoundaryFields() - + # List of columns to potentially drop drop_columns = [ tx_fields.cell_encoding, @@ -45,19 +46,16 @@ def setup_heterodata( tx_fields.gene_cluster, ] # Update transcripts with fields for training - + transcripts = ( transcripts # Reset columns .drop(drop_columns, strict=False) # Add gene embedding and clusters .join( - pl.from_pandas( - adata.var[[genes_encoding_column, genes_clusters_column]], - include_index=True - ), + pl.from_pandas(adata.var[[genes_encoding_column, genes_clusters_column]], include_index=True), left_on=tx_fields.feature, - right_on=adata.var.index.name if adata.var.index.name else 'None', + right_on=adata.var.index.name if adata.var.index.name else "None", ) .rename( { @@ -67,23 +65,17 @@ def setup_heterodata( strict=False, ) # Add cell embedding and clusters - .with_columns( - pl - .when(segmentation_mask) - .then(pl.col(tx_fields.cell_id)) - .alias('join_id_cell') - ) + .with_columns(pl.when(segmentation_mask).then(pl.col(tx_fields.cell_id)).alias("join_id_cell")) .join( pl.from_pandas( - adata.obs[[bd_fields.id, cells_encoding_column, - cells_clusters_column]], + adata.obs[[bd_fields.id, cells_encoding_column, cells_clusters_column]], include_index=True, ), - left_on='join_id_cell', + left_on="join_id_cell", right_on=bd_fields.id, - how='left', + how="left", ) - .drop('join_id_cell') + .drop("join_id_cell") .rename( { cells_clusters_column: tx_fields.cell_cluster, @@ -93,16 +85,17 @@ def setup_heterodata( ) .with_columns(pl.col(tx_fields.cell_cluster).fill_null(-1)) # Recast encodings for efficiency - .cast({ - tx_fields.gene_encoding: pl.UInt16, - tx_fields.cell_encoding: pl.UInt32, - }) + .cast( + { + tx_fields.gene_encoding: pl.UInt16, + tx_fields.cell_encoding: pl.UInt32, + } + ) ) - + # Sort boundaries by AnnData ordering boundaries = ( - boundaries - .reset_index(names=bd_fields.index) + boundaries.reset_index(names=bd_fields.index) .set_index(bd_fields.id) .loc[adata.obs[bd_fields.id]] .reset_index(bd_fields.id) @@ -113,38 +106,34 @@ def setup_heterodata( data = HeteroData() # Transcript nodes - data['tx']['x'] = transcripts[tx_fields.gene_encoding].to_torch() - data['tx']['cluster'] = transcripts[tx_fields.gene_cluster].to_torch() - data['tx']['index'] = transcripts[tx_fields.row_index].to_torch() - data['tx']['geometry'] = transcripts[[tx_fields.x, tx_fields.y]].to_torch() - data['tx']['pos'] = data['tx']['geometry'] + data["tx"]["x"] = transcripts[tx_fields.gene_encoding].to_torch() + data["tx"]["cluster"] = transcripts[tx_fields.gene_cluster].to_torch() + data["tx"]["index"] = transcripts[tx_fields.row_index].to_torch() + data["tx"]["geometry"] = transcripts[[tx_fields.x, tx_fields.y]].to_torch() + data["tx"]["pos"] = data["tx"]["geometry"] # Boundary nodes - data['bd']['x'] = torch.tensor( - adata.obsm[cells_embedding_key]).to(torch.float) - data['bd']['cluster'] = torch.tensor( - adata.obs[cells_clusters_column].values).to(torch.int) - data['bd']['index'] = torch.tensor( - adata.obs[cells_encoding_column].values).to(torch.int) - data['bd']['geometry'] = torch.tensor( - adata.obsm['X_spatial']).to(torch.float) - data['bd']['pos'] = data['bd']['geometry'] + data["bd"]["x"] = torch.tensor(adata.obsm[cells_embedding_key]).to(torch.float) + data["bd"]["cluster"] = torch.tensor(adata.obs[cells_clusters_column].values).to(torch.int) + data["bd"]["index"] = torch.tensor(adata.obs[cells_encoding_column].values).to(torch.int) + data["bd"]["geometry"] = torch.tensor(adata.obsm["X_spatial"]).to(torch.float) + data["bd"]["pos"] = data["bd"]["geometry"] # Transcript neighbors graph - data['tx', 'neighbors', 'tx'].edge_index = setup_transcripts_graph( + data["tx", "neighbors", "tx"].edge_index = setup_transcripts_graph( transcripts, max_k=transcripts_graph_max_k, max_dist=transcripts_graph_max_dist, ) # Reference segmentation graph - data['tx', 'belongs', 'bd'].edge_index = setup_segmentation_graph( + data["tx", "belongs", "bd"].edge_index = setup_segmentation_graph( transcripts, segmentation_mask=segmentation_mask, ) # Transcript-cell graph for prediction - data['tx', 'neighbors', 'bd'].edge_index = setup_prediction_graph( + data["tx", "neighbors", "bd"].edge_index = setup_prediction_graph( transcripts, boundaries, max_k=prediction_graph_max_k, diff --git a/src/segger/data/utils/neighbors.py b/src/segger/data/utils/neighbors.py index ce7ab3e..23668cf 100644 --- a/src/segger/data/utils/neighbors.py +++ b/src/segger/data/utils/neighbors.py @@ -1,18 +1,19 @@ -from numpy.typing import ArrayLike -from scipy.spatial import KDTree +import gc +from io import TrainingBoundaryFields, TrainingTranscriptFields from typing import Any, Literal -import geopandas as gpd -import polars as pl -import numpy as np -import cupy as cp + +import cudf import cugraph -import torch import cuml -import cudf -import gc +import cupy as cp +import torch +from geometry import points_in_polygons -from ...io import TrainingTranscriptFields, TrainingBoundaryFields -from ...geometry import points_in_polygons +import geopandas as gpd +import numpy as np +import polars as pl +from numpy.typing import ArrayLike +from scipy.spatial import KDTree def phenograph_rapids( @@ -21,42 +22,43 @@ def phenograph_rapids( min_size: int = -1, **kwargs, ) -> np.ndarray: - """TODO: Add description. - """ + """TODO: Add description.""" X = cp.array(X) model = cuml.neighbors.NearestNeighbors(n_neighbors=n_neighbors) model.fit(X) _, indices = model.kneighbors(X) n, k = indices.shape - edges = cudf.concat([ - cudf.Series(np.repeat(np.arange(n), k), name='source', dtype="int32"), - cudf.Series(indices.flatten(), name='destination', dtype="int32"), - ], axis=1) + edges = cudf.concat( + [ + cudf.Series(np.repeat(np.arange(n), k), name="source", dtype="int32"), + cudf.Series(indices.flatten(), name="destination", dtype="int32"), + ], + axis=1, + ) G = cugraph.from_cudf_edgelist(edges) - + # Build jaccard-weighted graph in GPU - jaccard_edges = cugraph.jaccard(G, edges[['source', 'destination']]) + jaccard_edges = cugraph.jaccard(G, edges[["source", "destination"]]) G = cugraph.from_cudf_edgelist(jaccard_edges, *jaccard_edges.columns) - + # Cluster jaccard-weighted graph result, _ = cugraph.louvain(G, **kwargs) - + # Sort clusters by size - sizes = result['partition'].value_counts() + sizes = result["partition"].value_counts() sizes.loc[:] = cp.where(sizes > min_size, cp.arange(len(sizes)), -1) - result['partition'] = result['partition'].map(sizes) - + result["partition"] = result["partition"].map(sizes) + # Sort by vertex (e.g. cell) - return result.sort_values('vertex')['partition'].values.get() + return result.sort_values("vertex")["partition"].values.get() def knn_to_edge_index( neighbor_table: torch.Tensor, - padding_value = None, + padding_value=None, ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Convert a dense neighbor table (with padding) into COO edge index. + """Convert a dense neighbor table (with padding) into COO edge index. Parameters ---------- @@ -69,22 +71,20 @@ def knn_to_edge_index( index pointer : (N+1,) long tensor """ with torch.no_grad(): - N, K = neighbor_table.shape + N, K = neighbor_table.shape if padding_value is None: padding_value = N device = neighbor_table.device - valid = neighbor_table != padding_value - flat = valid.view(-1).nonzero(as_tuple=False).squeeze(1) - col = neighbor_table.view(-1)[flat] - row = flat // K + valid = neighbor_table != padding_value + flat = valid.view(-1).nonzero(as_tuple=False).squeeze(1) + col = neighbor_table.view(-1)[flat] + row = flat // K edge_index = torch.stack([row, col]) deg = valid.sum(dim=1) - index_ptr = torch.cat( - (torch.zeros(1, dtype=torch.long, device=device), deg.cumsum(0)) - ) + index_ptr = torch.cat((torch.zeros(1, dtype=torch.long, device=device), deg.cumsum(0))) del valid, flat, col, row, deg torch.cuda.empty_cache() gc.collect() @@ -96,8 +96,7 @@ def edge_index_to_knn( edge_index: torch.Tensor, padding_value: Any = None, ) -> torch.Tensor: - """TODO: Add description. - """ + """TODO: Add description.""" _, lengths = torch.unique_consecutive( edge_index[0], return_counts=True, @@ -105,11 +104,8 @@ def edge_index_to_knn( B = lengths.size(0) L = lengths.max() neighbor_table = edge_index[0].new_full((B, L), -1) - - row = torch.repeat_interleave( - torch.arange(B, device=neighbor_table.device), - lengths - ) + + row = torch.repeat_interleave(torch.arange(B, device=neighbor_table.device), lengths) start = torch.cumsum(lengths, 0) - lengths col = torch.arange(edge_index[0].size(0), device=neighbor_table.device) col -= torch.repeat_interleave(start, lengths) @@ -138,9 +134,9 @@ def kdtree_neighbors( indices = torch.from_numpy(indices) gc.collect() # make sure numpy copy is gone before conversion edge_index, index_pointer = knn_to_edge_index(indices) - del indices # remove big indices tensor + del indices # remove big indices tensor gc.collect() - + return edge_index, index_pointer @@ -149,8 +145,7 @@ def setup_transcripts_graph( max_k: int, max_dist: float, ) -> torch.Tensor: - """TODO: Add description. - """ + """TODO: Add description.""" tx_fields = TrainingTranscriptFields() points = tx[[tx_fields.x, tx_fields.y]].to_numpy() edge_index, _ = kdtree_neighbors( @@ -165,17 +160,9 @@ def setup_segmentation_graph( tx: pl.DataFrame, segmentation_mask: pl.Expr | pl.Series = None, ) -> torch.Tensor: - """TODO: Add description. - """ + """TODO: Add description.""" tx_fields = TrainingTranscriptFields() - return ( - tx - .with_row_index("_tid") - .filter(segmentation_mask) - .select(["_tid", tx_fields.cell_encoding]) - .to_torch() - .T - ) + return tx.with_row_index("_tid").filter(segmentation_mask).select(["_tid", tx_fields.cell_encoding]).to_torch().T def setup_prediction_graph( @@ -183,10 +170,9 @@ def setup_prediction_graph( bd: gpd.GeoDataFrame, max_k: int, buffer_ratio: float, - mode: Literal['nucleus', 'cell', 'uniform'] = 'cell', + mode: Literal["nucleus", "cell", "uniform"] = "cell", ) -> torch.Tensor: - """TODO: Add description. - """ + """TODO: Add description.""" tx_fields = TrainingTranscriptFields() bd_fields = TrainingBoundaryFields() @@ -200,20 +186,18 @@ def setup_prediction_graph( max_k=max_k, ) return edge_index - + # Shape-based graph points = tx[[tx_fields.x, tx_fields.y]].to_numpy() - boundary_type = (bd_fields.cell_value if mode == "cell" - else bd_fields.nucleus_value) + boundary_type = bd_fields.cell_value if mode == "cell" else bd_fields.nucleus_value polygons = bd[bd[bd_fields.boundary_type] == boundary_type].geometry buffer_dists = np.sqrt(polygons.area / np.pi) * buffer_ratio polygons = polygons.buffer(buffer_dists).reset_index(drop=True) result = points_in_polygons( points=points, polygons=polygons, - predicate='contains', + predicate="contains", batches=10, ) - return torch.tensor( - result[['index_query', 'index_match']].values.T).to(torch.int).cpu() + return torch.tensor(result[["index_query", "index_match"]].values.T).to(torch.int).cpu() diff --git a/src/segger/data/writer.py b/src/segger/data/writer.py index 7bd785a..0260cad 100644 --- a/src/segger/data/writer.py +++ b/src/segger/data/writer.py @@ -1,18 +1,20 @@ -from lightning.pytorch.callbacks import BasePredictionWriter -from skimage.filters import threshold_li, threshold_yen -from lightning.pytorch import Trainer, LightningModule -from typing import Sequence, Any +from io import TrainingBoundaryFields, TrainingTranscriptFields from pathlib import Path -import polars as pl +from typing import Any, Sequence + import torch +from skimage.filters import threshold_li, threshold_yen + +import polars as pl +from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch.callbacks import BasePredictionWriter -from ..io import TrainingTranscriptFields, TrainingBoundaryFields from . import ISTDataModule class ISTSegmentationWriter(BasePredictionWriter): - """TODO: Description - + """TODO: Description. + Parameters ---------- output_directory : Path @@ -27,64 +29,43 @@ def write_on_epoch_end( self, trainer: Trainer, pl_module: LightningModule, - predictions: Sequence[list], + predictions: Sequence[list], batch_indices: Sequence[Any], ): - """TODO: Description - """ + """TODO: Description.""" tx_fields = TrainingTranscriptFields() bd_fields = TrainingBoundaryFields() - + # Check datamodule for AnnData input if not isinstance(trainer.datamodule, ISTDataModule): raise TypeError( - f"Expected data module to be `ISTDataModule` but got " - f"{type(self.trainer.datamodule).__name__}." + f"Expected data module to be `ISTDataModule` but got " f"{type(self.trainer.datamodule).__name__}." ) if not hasattr(trainer.datamodule, "ad"): raise ValueError("Data module has no attribute `ad`.") - + # Create segmentation output segmentation = ( - pl - .concat( + pl.concat( [ - pl.from_torch( - torch.hstack([batch[0] for batch in predictions]), - schema=[tx_fields.row_index] - ), + pl.from_torch(torch.hstack([batch[0] for batch in predictions]), schema=[tx_fields.row_index]), pl.from_torch( torch.hstack([batch[1] for batch in predictions]), schema={bd_fields.cell_encoding: pl.Int64}, ), - pl.from_torch( - torch.hstack([batch[2] for batch in predictions]), - schema=["segger_similarity"] - ), + pl.from_torch(torch.hstack([batch[2] for batch in predictions]), schema=["segger_similarity"]), pl.from_torch( torch.hstack([batch[3] for batch in predictions]), schema={tx_fields.feature: pl.Int64}, ), ], - how='horizontal' - ) - .with_columns( - pl - .col(bd_fields.cell_encoding) - .replace(-1, None) - .cast(pl.Int64) + how="horizontal", ) + .with_columns(pl.col(bd_fields.cell_encoding).replace(-1, None).cast(pl.Int64)) .join( ( - pl - .from_pandas(trainer.datamodule.ad.obs[[ - bd_fields.id, - bd_fields.cell_encoding - ]]) - .with_columns( - pl - .col(bd_fields.cell_encoding) - .cast(pl.Int64) + pl.from_pandas(trainer.datamodule.ad.obs[[bd_fields.id, bd_fields.cell_encoding]]).with_columns( + pl.col(bd_fields.cell_encoding).cast(pl.Int64) ) ), on=bd_fields.cell_encoding, @@ -98,11 +79,10 @@ def write_on_epoch_end( ) .unique(tx_fields.row_index, keep="first") ) - + # Per-gene thresholding (iterative to reduce memory usage) feature_counts = ( - segmentation - .filter(pl.col('segger_cell_id').is_not_null()) + segmentation.filter(pl.col("segger_cell_id").is_not_null()) .select(tx_fields.feature) .to_series() .value_counts() @@ -110,32 +90,27 @@ def write_on_epoch_end( thresholds = [] n = 10_000_000 for feature, count in feature_counts.iter_rows(): - similarities = ( - segmentation - .filter( - (pl.col(tx_fields.feature) == feature) & - (pl.col('segger_cell_id').is_not_null()) - ) - .select('segger_similarity') - ) + similarities = segmentation.filter( + (pl.col(tx_fields.feature) == feature) & (pl.col("segger_cell_id").is_not_null()) + ).select("segger_similarity") if count > n: similarities = similarities.sample(n=n, seed=0) similarities = similarities.to_series().to_numpy() threshold_value = min( - threshold_li( similarities), + threshold_li(similarities), threshold_yen(similarities), ) - thresholds.append({ - tx_fields.feature: feature, - 'similarity_threshold': threshold_value, - }) + thresholds.append( + { + tx_fields.feature: feature, + "similarity_threshold": threshold_value, + } + ) thresholds = pl.DataFrame(thresholds) - + # Join and write output to file ( - segmentation - .join(thresholds, on=tx_fields.feature, how='left') + segmentation.join(thresholds, on=tx_fields.feature, how="left") .drop(tx_fields.feature) - .write_parquet( - self.output_directory / 'segger_segmentation.parquet') + .write_parquet(self.output_directory / "segger_segmentation.parquet") ) diff --git a/src/segger/geometry/__init__.py b/src/segger/geometry/__init__.py index 3abe3e5..47cfd69 100644 --- a/src/segger/geometry/__init__.py +++ b/src/segger/geometry/__init__.py @@ -1,4 +1,14 @@ from .conversion import points_to_geoseries, polygons_to_geoseries -from .query import points_in_polygons, polygons_in_polygons +from .morphology import get_polygon_props from .quadtree import get_quadtree_index, quadtree_to_geoseries -from .morphology import get_polygon_props \ No newline at end of file +from .query import points_in_polygons, polygons_in_polygons + +__all__ = [ + "get_polygon_props", + "get_quadtree_index", + "points_in_polygons", + "polygons_in_polygons", + "points_to_geoseries", + "polygons_to_geoseries", + "quadtree_to_geoseries", +] diff --git a/src/segger/geometry/conversion.py b/src/segger/geometry/conversion.py index a149dea..e95e74b 100644 --- a/src/segger/geometry/conversion.py +++ b/src/segger/geometry/conversion.py @@ -1,18 +1,17 @@ from functools import singledispatch from typing import Literal -import geopandas as gpd -import pandas as pd -import numpy as np + +import cudf import cupy as cp import cuspatial -import shapely import torch -import cudf +from cuspatial.utils.column_utils import contains_only_points, contains_only_polygons + +import geopandas as gpd +import numpy as np +import pandas as pd +import shapely -from cuspatial.utils.column_utils import ( - contains_only_polygons, - contains_only_points, -) # --- Coordinates Conversion --- @singledispatch @@ -39,9 +38,8 @@ def points_to_coords(data: any) -> np.ndarray | cp.ndarray: NotImplementedError If no converter is registered for the specific input data type. """ - raise NotImplementedError( - f"No implementation registered for type {type(data).__name__}" - ) + raise NotImplementedError(f"No implementation registered for type {type(data).__name__}") + @points_to_coords.register def _(data: list): @@ -49,53 +47,44 @@ def _(data: list): raise ValueError("Input must be a non-empty list of shapely.Point.") return np.array([p.coords[0] for p in data]) + @points_to_coords.register def _(data: np.ndarray | cp.ndarray): if data.ndim != 2 or data.shape[1] != 2: - raise ValueError( - f"Input array must have shape (N, 2), but got {data.shape}." - ) + raise ValueError(f"Input array must have shape (N, 2), but got {data.shape}.") return data + @points_to_coords.register def _(data: torch.Tensor): if data.dim() != 2 or data.shape[1] != 2: - raise ValueError( - f"Input tensor must have shape (N, 2), but got {data.shape}." - ) + raise ValueError(f"Input tensor must have shape (N, 2), but got {data.shape}.") if data.device == "cpu": return data.numpy() else: # "cuda", zero-copy transfer return cp.array(data.cuda()) + @points_to_coords.register def _(data: gpd.GeoSeries): - if data.geometry.empty or data.geom_type.ne('Point').any(): - raise ValueError( - f"Input must be a non-empty geopandas.GeoSeries of points." - ) + if data.geometry.empty or data.geom_type.ne("Point").any(): + raise ValueError("Input must be a non-empty geopandas.GeoSeries of points.") coords = data.get_coordinates() if coords.shape[1] != 2: - raise ValueError( - f"Input must be points in 2 dimensions, but got {data.shape[1]}." - ) + raise ValueError(f"Input must be points in 2 dimensions, but got {data.shape[1]}.") return coords + @points_to_coords.register def _(data: cuspatial.GeoSeries): if data.empty or not contains_only_points(data): - raise ValueError( - f"Input must be a non-empty cuspatial.GeoSeries of points." - ) + raise ValueError("Input must be a non-empty cuspatial.GeoSeries of points.") return data.points.xy.to_cupy().reshape(-1, 2) # inherently 2D + # --- Points API --- -def points_to_geoseries( - data: any, - backend: Literal['geopandas', 'cuspatial'] -) -> gpd.GeoSeries | cuspatial.GeoSeries: - """ - Converts various point data formats to a specified GeoSeries backend. +def points_to_geoseries(data: any, backend: Literal["geopandas", "cuspatial"]) -> gpd.GeoSeries | cuspatial.GeoSeries: + """Converts various point data formats to a specified GeoSeries backend. This is a generic function that dispatches to a registered implementation based on the type of the input `data`. @@ -114,7 +103,7 @@ def points_to_geoseries( ------- gpd.GeoSeries | cuspatial.GeoSeries The converted GeoSeries object. - + Raises ------ NotImplementedError @@ -124,26 +113,24 @@ def points_to_geoseries( TypeError If the backend is not supported. """ - if backend not in ['geopandas', 'cuspatial']: - raise TypeError( - f"Unsupported backend '{backend}'. Supported backends are " - f"'geopandas' and 'cuspatial'." - ) + if backend not in ["geopandas", "cuspatial"]: + raise TypeError(f"Unsupported backend '{backend}'. Supported backends are " f"'geopandas' and 'cuspatial'.") # Passthrough - if (backend == 'geopandas' and isinstance(data, gpd.GeoSeries)) or \ - (backend == 'cuspatial' and isinstance(data, cuspatial.GeoSeries)): + if (backend == "geopandas" and isinstance(data, gpd.GeoSeries)) or ( + backend == "cuspatial" and isinstance(data, cuspatial.GeoSeries) + ): return data # Collect points coordinates coords = points_to_coords(data) # Convert to backend - if backend == 'geopandas': + if backend == "geopandas": coords = cp.asnumpy(coords) points = gpd.GeoSeries(gpd.points_from_xy(*coords.T)) if isinstance(data, cuspatial.GeoSeries): points.index = pd.Index(data.index.to_numpy()) else: # cuspatial - coords = cp.asarray(coords).ravel().astype('double') + coords = cp.asarray(coords).ravel().astype("double") points = cuspatial.GeoSeries.from_points_xy(coords) if isinstance(data, gpd.GeoSeries): points.index = cudf.Index(data.index) @@ -175,9 +162,8 @@ def polygons_to_parts(data: any) -> tuple[np.ndarray] | tuple[cp.ndarray]: NotImplementedError If no converter is registered for the specific input data type. """ - raise NotImplementedError( - f"No implementation registered for type {type(data).__name__}" - ) + raise NotImplementedError(f"No implementation registered for type {type(data).__name__}") + @polygons_to_parts.register def _(data: list): @@ -188,16 +174,13 @@ def _(data: list): ring_offsets = np.cumsum([0] + [len(c) for c in coords]) return vertices, ring_offsets + @polygons_to_parts.register def _(data: torch.Tensor): if not data.is_nested or data.layout != torch.jagged: - raise ValueError( - "Input tensor must be nested and have 'jagged' layout." - ) + raise ValueError("Input tensor must be nested and have 'jagged' layout.") if data.dim() != 3 or data.shape[-1] != 2: - raise ValueError( - "Input tensor must be of shape (N, j2, 2), but got {data.shape}." - ) + raise ValueError("Input tensor must be of shape (N, j2, 2), but got {data.shape}.") if data.device == "cpu": vertices = data.values().numpy() ring_offsets = data.offsets().numpy() @@ -207,37 +190,32 @@ def _(data: torch.Tensor): ring_offsets = cp.array(data.offsets().cuda()) return vertices, ring_offsets + @polygons_to_parts.register def _(data: gpd.GeoSeries): - if data.geometry.empty or data.geom_type.ne('Polygon').any(): - raise ValueError( - f"Input must be a non-empty geopandas.GeoSeries of polygons." - ) + if data.geometry.empty or data.geom_type.ne("Polygon").any(): + raise ValueError("Input must be a non-empty geopandas.GeoSeries of polygons.") coords = data.get_coordinates() - vertices = coords[['x', 'y']].to_numpy() + vertices = coords[["x", "y"]].to_numpy() _, idx = np.unique(coords.index.to_numpy(), return_index=True) ring_offsets = np.sort(np.append(idx, len(coords))) return vertices, ring_offsets + @polygons_to_parts.register def _(data: cuspatial.GeoSeries): if data.empty or not contains_only_polygons(data): - raise ValueError( - f"Input must be a non-empty cuspatial.GeoSeries of polygons." - ) + raise ValueError("Input must be a non-empty cuspatial.GeoSeries of polygons.") vertices = data.polygons.xy.to_cupy().reshape(-1, 2) ring_offsets = data.polygons.ring_offset return vertices, ring_offsets + # --- Polygons API --- -def polygons_to_geoseries( - data: any, - backend: Literal['geopandas', 'cuspatial'] -) -> gpd.GeoSeries | cuspatial.GeoSeries: - """ - Converts various polygon data formats to a specified GeoSeries backend. - - Polygon geometries must contains exterior rings only; they cannot have +def polygons_to_geoseries(data: any, backend: Literal["geopandas", "cuspatial"]) -> gpd.GeoSeries | cuspatial.GeoSeries: + """Converts various polygon data formats to a specified GeoSeries backend. + + Polygon geometries must contains exterior rings only; they cannot have interior holes, etc. Parameters @@ -245,7 +223,7 @@ def polygons_to_geoseries( data : any The input polygon data. Supported types include: - List of shapely Polygons - - Jagged Torch tensors of shape (N, V, 2) where V is the number of + - Jagged Torch tensors of shape (N, V, 2) where V is the number of vertices per polygon. - GeoPandas/cuSpatial GeoSeries of polygons backend : Literal['geopandas', 'cuspatial'] @@ -255,7 +233,7 @@ def polygons_to_geoseries( ------- gpd.GeoSeries | cuspatial.GeoSeries The converted GeoSeries object. - + Raises ------ NotImplementedError @@ -265,34 +243,34 @@ def polygons_to_geoseries( TypeError If the backend is not supported. """ - if backend not in ['geopandas', 'cuspatial']: - raise TypeError( - f"Unsupported backend '{backend}'. Supported backends are " - f"'geopandas' and 'cuspatial'." - ) + if backend not in ["geopandas", "cuspatial"]: + raise TypeError(f"Unsupported backend '{backend}'. Supported backends are " f"'geopandas' and 'cuspatial'.") # Passthrough - if (backend == 'geopandas' and isinstance(data, gpd.GeoSeries)) or \ - (backend == 'cuspatial' and isinstance(data, cuspatial.GeoSeries)): + if (backend == "geopandas" and isinstance(data, gpd.GeoSeries)) or ( + backend == "cuspatial" and isinstance(data, cuspatial.GeoSeries) + ): return data # Collect points coordinates vertices, ring_offsets = polygons_to_parts(data) # Convert to backend - if backend == 'geopandas': + if backend == "geopandas": vertices = cp.asnumpy(vertices) ring_offsets = cp.asnumpy(ring_offsets) - part_offsets = np.arange(len(ring_offsets), dtype='int32') - polygons = gpd.GeoSeries(shapely.from_ragged_array( - shapely.GeometryType.POLYGON, - vertices, - (ring_offsets, part_offsets), - )) + part_offsets = np.arange(len(ring_offsets), dtype="int32") + polygons = gpd.GeoSeries( + shapely.from_ragged_array( + shapely.GeometryType.POLYGON, + vertices, + (ring_offsets, part_offsets), + ) + ) if isinstance(data, cuspatial.GeoSeries): polygons.index = pd.Index(data.index.to_numpy()) else: # cuspatial - vertices = cp.asarray(vertices).ravel().astype('double') + vertices = cp.asarray(vertices).ravel().astype("double") ring_offsets = cp.asarray(ring_offsets) - part_offsets = cp.arange(len(ring_offsets), dtype='int32') + part_offsets = cp.arange(len(ring_offsets), dtype="int32") polygons = cuspatial.GeoSeries.from_polygons_xy( vertices, ring_offsets, @@ -303,13 +281,14 @@ def polygons_to_geoseries( polygons.index = cudf.Index(data.index) return polygons + def polygons_to_nested_tensor( data: any, device: str | None = None, ) -> torch.Tensor: - """Converts polygon geometries into a nested tensor in jagged layout + """Converts polygon geometries into a nested tensor in jagged layout. - The jagged tensor format is used here for representing polygons with + The jagged tensor format is used here for representing polygons with varying numbers of vertices without requiring padding. Parameters @@ -328,9 +307,9 @@ def polygons_to_nested_tensor( polygon. """ # Convert to universal format (GeoSeries) - backend = 'cuspatial' if device == 'cuda' else 'geopandas' + backend = "cuspatial" if device == "cuda" else "geopandas" polygons = polygons_to_geoseries(data, backend=backend) - + # Build nested tensor from coordinates coords = polygons.geometry.get_coordinates() _, counts = torch.unique( @@ -339,8 +318,4 @@ def polygons_to_nested_tensor( ) indices = torch.cumsum(counts, 0)[:-1] splits = torch.tensor_split(torch.tensor(coords.values), indices) - return torch.nested.nested_tensor( - splits, - layout=torch.jagged, - device=device - ) + return torch.nested.nested_tensor(splits, layout=torch.jagged, device=device) diff --git a/src/segger/geometry/morphology.py b/src/segger/geometry/morphology.py index eeba1b5..31fbb37 100644 --- a/src/segger/geometry/morphology.py +++ b/src/segger/geometry/morphology.py @@ -1,6 +1,7 @@ import geopandas as gpd import pandas as pd + def get_polygon_props( polygons: gpd.GeoSeries, area: bool = True, @@ -8,8 +9,7 @@ def get_polygon_props( elongation: bool = True, circularity: bool = True, ) -> pd.DataFrame: - """ - Computes geometric properties of polygons. + """Computes geometric properties of polygons. Parameters ---------- @@ -40,4 +40,4 @@ def get_polygon_props( if circularity: r = polygons.minimum_bounding_radius() props["circularity"] = polygons.area / r**2 - return props \ No newline at end of file + return props diff --git a/src/segger/geometry/quadtree.py b/src/segger/geometry/quadtree.py index c29b4f8..03c84bd 100644 --- a/src/segger/geometry/quadtree.py +++ b/src/segger/geometry/quadtree.py @@ -1,12 +1,14 @@ -from shapely import from_ragged_array, GeometryType from typing import Literal -import geopandas as gpd -from numba import njit -import pandas as pd -import numpy as np + +import cudf import cupy as cp import cuspatial -import cudf + +import geopandas as gpd +import numpy as np +import pandas as pd +from numba import njit +from shapely import from_ragged_array, GeometryType def get_quadtree_kwargs( @@ -39,20 +41,19 @@ def get_quadtree_kwargs( scale = extent // (1 << max_depth - 1) # Return as dictionary - return dict( - x_min=x_min, - x_max=x_max, - y_min=y_min, - y_max=y_max, - scale=scale, - max_depth=max_depth, - ) + return { + "x_min": x_min, + "x_max": x_max, + "y_min": y_min, + "y_max": y_max, + "scale": scale, + "max_depth": max_depth, + } @njit def keys_to_coordinates(keys): - """ - Decode quadtree keys into 2D integer (x, y) coordinates. + """Decode quadtree keys into 2D integer (x, y) coordinates. Each key encodes the quadrant traversal path using two bits per level: - bit 0: x-direction @@ -97,8 +98,7 @@ def get_quadrant_bounds( y_min: float, y_max: float, ): - """ - Add spatial bounds to each leaf in a cuSpatial quadtree. + """Add spatial bounds to each leaf in a cuSpatial quadtree. This computes the (x_min, x_max, y_min, y_max) of each quadrant using its level and key. Coordinates are clipped to the full extent. @@ -118,21 +118,21 @@ def get_quadrant_bounds( Input DataFrame with added bounding box columns: 'x_min', 'x_max', 'y_min', and 'y_max'. """ - width = x_max - x_min + width = x_max - x_min height = y_max - y_min - levels = quadtree['level'].astype(float) + 1 - coords = cp.array(keys_to_coordinates(quadtree['key'].to_numpy())) + levels = quadtree["level"].astype(float) + 1 + coords = cp.array(keys_to_coordinates(quadtree["key"].to_numpy())) quadrant_max = np.ceil(np.log2(max(width, height))) quadrant_dim = 2 ** (quadrant_max - levels) - - quadtree['x_min'] = x_min + coords[0] * quadrant_dim - quadtree['x_max'] = quadtree['x_min'] + quadrant_dim - quadtree['y_min'] = y_min + coords[1] * quadrant_dim - quadtree['y_max'] = quadtree['y_min'] + quadrant_dim - - quadtree['x_max'] = quadtree['x_max'].clip(x_min, x_max) - quadtree['y_max'] = quadtree['y_max'].clip(y_min, y_max) - + + quadtree["x_min"] = x_min + coords[0] * quadrant_dim + quadtree["x_max"] = quadtree["x_min"] + quadrant_dim + quadtree["y_min"] = y_min + coords[1] * quadrant_dim + quadtree["y_max"] = quadtree["y_min"] + quadrant_dim + + quadtree["x_max"] = quadtree["x_max"].clip(x_min, x_max) + quadtree["y_max"] = quadtree["y_max"].clip(y_min, y_max) + return quadtree @@ -150,7 +150,7 @@ def get_quadtree_index( max_size : int Maximum number of points allowed in a single tile. with_bounds : bool, optional - Whether to return the x, y bounds of each leaf with the quadtree + Whether to return the x, y bounds of each leaf with the quadtree DataFrame. Default is True. Returns @@ -162,12 +162,12 @@ def get_quadtree_index( """ # Get hyperparams for quadtree kwargs = get_quadtree_kwargs(points) - x_min = kwargs['x_min'] - x_max = kwargs['x_max'] - y_min = kwargs['y_min'] - y_max = kwargs['y_max'] - scale = kwargs['scale'] - max_depth = kwargs['max_depth'] + x_min = kwargs["x_min"] + x_max = kwargs["x_max"] + y_min = kwargs["y_min"] + y_max = kwargs["y_max"] + scale = kwargs["scale"] + max_depth = kwargs["max_depth"] # Calculate quadtree on region indices, quadtree = cuspatial.quadtree_on_points( @@ -183,7 +183,7 @@ def get_quadtree_index( # Add bounds of tiles if with_bounds: quadtree = get_quadrant_bounds( - quadtree, + quadtree, x_min=x_min, x_max=x_max, y_min=y_min, @@ -195,10 +195,10 @@ def get_quadtree_index( def quadtree_to_geoseries( quadtree: cudf.DataFrame, - backend: Literal['cuspatial', 'geopandas'], + backend: Literal["cuspatial", "geopandas"], ) -> cuspatial.GeoSeries | gpd.GeoSeries: """Helper function to convert cuspatial Quadtree to leaf geometries. - + Parameters ---------- quadtree : cudf.DataFrame @@ -210,27 +210,27 @@ def quadtree_to_geoseries( The quadtree leaves converted to GeoSeries format. """ # Raise error if bounds not added - bounds_columns = ['x_min', 'y_min', 'x_max', 'y_max'] + bounds_columns = ["x_min", "y_min", "x_max", "y_max"] if not pd.Index(bounds_columns).isin(quadtree.columns).all(): raise IndexError("Quadtree missing boundary column(s).") - + # Convert to GeoSeries - mask = ~quadtree['is_internal_node'] + mask = ~quadtree["is_internal_node"] bounds = quadtree.loc[mask, bounds_columns].values - vertices = bounds[:, [0, 1, 0, 3, 2, 3, 2, 1]].astype('double').flatten() + vertices = bounds[:, [0, 1, 0, 3, 2, 3, 2, 1]].astype("double").flatten() ring_offset = cp.arange(0, bounds.shape[0] * 4 + 1, 4) part_offset = geometry_offset = cp.arange(bounds.shape[0] + 1) - if backend == 'cuspatial': + if backend == "cuspatial": return cuspatial.GeoSeries.from_polygons_xy( vertices, ring_offset, part_offset, geometry_offset, ) - else: # geopandas + else: # geopandas geometry = from_ragged_array( GeometryType.POLYGON, vertices.reshape(-1, 2).get(), (ring_offset.get(), part_offset.get()), ) - return gpd.GeoSeries(geometry) \ No newline at end of file + return gpd.GeoSeries(geometry) diff --git a/src/segger/geometry/query.py b/src/segger/geometry/query.py index 022803c..e6b338b 100644 --- a/src/segger/geometry/query.py +++ b/src/segger/geometry/query.py @@ -1,17 +1,13 @@ from typing import Literal + +import cudf +import cuspatial + import geopandas as gpd import numpy as np -import cuspatial -import cudf -from .conversion import ( - polygons_to_geoseries, - points_to_geoseries, -) -from .quadtree import ( - get_quadtree_index, - get_quadtree_kwargs, -) +from .conversion import points_to_geoseries, polygons_to_geoseries +from .quadtree import get_quadtree_index, get_quadtree_kwargs def _points_in_polygons_contains( @@ -49,26 +45,17 @@ def _points_in_polygons_contains( # Setup inputs for spatial join if max_size is None: max_size = 10000 if len(points) > 5e7 else 1000 # heuristic - point_indices, quadtree = get_quadtree_index( - points, - max_size, - with_bounds=False - ) + point_indices, quadtree = get_quadtree_index(points, max_size, with_bounds=False) kwargs = get_quadtree_kwargs(points) # Perform spatial join in batches batch_idx = np.linspace(0, len(polygons), (batches or 1) + 1, dtype=int) results = [] for start_idx, end_idx in zip(batch_idx, batch_idx[1:]): - # Get polygons for this batch batch_polygons = polygons.iloc[start_idx:end_idx] bboxes = cuspatial.polygon_bounding_boxes(batch_polygons) - poly_quad_pairs = cuspatial.join_quadtree_and_bounding_boxes( - quadtree=quadtree, - bounding_boxes=bboxes, - **kwargs - ) + poly_quad_pairs = cuspatial.join_quadtree_and_bounding_boxes(quadtree=quadtree, bounding_boxes=bboxes, **kwargs) # Run spatial join result = cuspatial.quadtree_point_in_polygon( poly_quad_pairs, @@ -78,24 +65,21 @@ def _points_in_polygons_contains( batch_polygons, ) # Adjust polygon indices back to global indices - result['polygon_index'] += start_idx + result["polygon_index"] += start_idx results.append(result) # Concatenate all batch results result = cudf.concat(results, ignore_index=True) result = result.rename( - {'point_index': 'index_query', 'polygon_index': 'index_match'}, + {"point_index": "index_query", "polygon_index": "index_match"}, axis=1, ) # Remap spatial index order to original point indices - point_indices.name = 'index_query' - result = ( - result - .set_index('index_query') - .join(point_indices) - ) + point_indices.name = "index_query" + result = result.set_index("index_query").join(point_indices) return result + def _points_in_polygons_intersects( points: cuspatial.GeoSeries, polygons: cuspatial.GeoSeries, @@ -133,52 +117,48 @@ def _points_in_polygons_intersects( """ # GPU pass to find all points strictly contained by the polygons contains = _points_in_polygons_contains(points, polygons, batches=batches) - + # Isolate points not found, which are potential boundary cases idx_all = cudf.RangeIndex(len(points)) - idx_missing = idx_all.difference(contains['index_query']) + idx_missing = idx_all.difference(contains["index_query"]) if idx_missing.empty: return contains # Buffer-filter on GPU for a large number of candidates pts_ixn = points.iloc[idx_missing] - ply_ixn = polygons_to_geoseries(polygons, backend='geopandas') + ply_ixn = polygons_to_geoseries(polygons, backend="geopandas") if len(pts_ixn) >= max_unassigned_points: ply_buf = polygons_to_geoseries( ply_ixn.buffer(boundary_buffer), - backend='cuspatial', + backend="cuspatial", ) in_buffer = _points_in_polygons_contains(pts_ixn, ply_buf) - in_buffer = in_buffer['index_query'].drop_duplicates() + in_buffer = in_buffer["index_query"].drop_duplicates() pts_ixn = pts_ixn.iloc[in_buffer] if pts_ixn.empty: return contains # Final CPU Join on the selected candidate set - pts_ixn = points_to_geoseries(pts_ixn, backend='geopandas') - boundary = gpd.sjoin( - gpd.GeoDataFrame(geometry=pts_ixn), - gpd.GeoDataFrame(geometry=ply_ixn), - predicate='intersects' - ) + pts_ixn = points_to_geoseries(pts_ixn, backend="geopandas") + boundary = gpd.sjoin(gpd.GeoDataFrame(geometry=pts_ixn), gpd.GeoDataFrame(geometry=ply_ixn), predicate="intersects") boundary = cudf.DataFrame( - boundary - .rename({'index_right': 'index_match'}, axis=1) - .reset_index(names='index_query') - [['index_query', 'index_match']] + boundary.rename({"index_right": "index_match"}, axis=1).reset_index(names="index_query")[ + ["index_query", "index_match"] + ] ) # Combine results from the initial 'contains' and boundary 'intersects' return cudf.concat([contains, boundary]).reset_index(drop=True) + def points_in_polygons( points: any, polygons: any, - predicate: Literal['contains', 'intersects'] = 'intersects', + predicate: Literal["contains", "intersects"] = "intersects", max_unasigned_points: int = 100_000, boundary_buffer: float = 1e-9, - batches: int | None = None + batches: int | None = None, ) -> cudf.DataFrame: """Finds which points fall inside which polygons using a given predicate. @@ -191,9 +171,9 @@ def points_in_polygons( A collection of polygons to search within. predicate : Literal['contains', 'intersects'], optional The spatial relationship to test for. Defaults to 'intersects'. - - contains: Finds points strictly inside a polygon, excluding its + - contains: Finds points strictly inside a polygon, excluding its boundary. This is a fast, GPU-only operation. - - intersects: Finds points inside a polygon or on its boundary. This + - intersects: Finds points inside a polygon or on its boundary. This uses achybrid GPU/CPU approach. max_unassigned_points : int, optional Used only for the 'intersects' predicate. This is the threshold @@ -215,17 +195,16 @@ def points_in_polygons( mapping each query point to its corresponding matching polygon. """ # Early error catch - if predicate not in ['contains', 'intersects']: + if predicate not in ["contains", "intersects"]: raise TypeError( - f"Unsupported predicate '{predicate}'. Supported predicates are " - f"'contains' and 'intersects'." + f"Unsupported predicate '{predicate}'. Supported predicates are " f"'contains' and 'intersects'." ) # Convert geometries to GeoSeries on GPU - points = points_to_geoseries(points, backend='cuspatial') - polygons = polygons_to_geoseries(polygons, backend='cuspatial') + points = points_to_geoseries(points, backend="cuspatial") + polygons = polygons_to_geoseries(polygons, backend="cuspatial") # Perform spatial join - if predicate == 'contains': + if predicate == "contains": return _points_in_polygons_contains(points, polygons, batches=batches) else: # predicate == 'intersects' return _points_in_polygons_intersects( @@ -236,13 +215,13 @@ def points_in_polygons( batches, ) + def polygons_in_polygons( query_polygons: any, index_polygons: any, - predicate: Literal['contains', 'intersects'] = 'intersects', + predicate: Literal["contains", "intersects"] = "intersects", ): - """ - Finds which query polygons fall inside which index polygons using a given + """Finds which query polygons fall inside which index polygons using a given predicate. Parameters @@ -265,16 +244,13 @@ def polygons_in_polygons( that maps the index of each query polygon to the index of every index polygon it matches based on the predicate. """ - query_polygons = polygons_to_geoseries(query_polygons, backend='geopandas') - index_polygons = polygons_to_geoseries(index_polygons, backend='geopandas') + query_polygons = polygons_to_geoseries(query_polygons, backend="geopandas") + index_polygons = polygons_to_geoseries(index_polygons, backend="geopandas") joined = gpd.sjoin( gpd.GeoDataFrame(geometry=index_polygons), gpd.GeoDataFrame(geometry=query_polygons), predicate=predicate, ) - return ( - joined - .reset_index(names='index_match') - .rename({'index_right': 'index_query'}, axis=1) - [['index_query', 'index_match']] - ) + return joined.reset_index(names="index_match").rename({"index_right": "index_query"}, axis=1)[ + ["index_query", "index_match"] + ] diff --git a/src/segger/io/__init__.py b/src/segger/io/__init__.py index 1f1ad20..355b4a4 100644 --- a/src/segger/io/__init__.py +++ b/src/segger/io/__init__.py @@ -1,7 +1,15 @@ -from .preprocessor import get_preprocessor from .fields import ( StandardBoundaryFields, - TrainingBoundaryFields, StandardTranscriptFields, + TrainingBoundaryFields, TrainingTranscriptFields, -) \ No newline at end of file +) +from .preprocessor import get_preprocessor + +__all__ = [ + "get_preprocessor", + "StandardBoundaryFields", + "StandardTranscriptFields", + "TrainingBoundaryFields", + "TrainingTranscriptFields", +] diff --git a/src/segger/io/cosmx.py b/src/segger/io/cosmx.py index 30a390a..f9586da 100644 --- a/src/segger/io/cosmx.py +++ b/src/segger/io/cosmx.py @@ -1,29 +1,30 @@ -from skimage.transform import AffineTransform -from typing import Literal +import os from pathlib import Path +from typing import Literal + +import tifffile +from skimage.transform import AffineTransform + import geopandas as gpd -import pandas as pd import numpy as np -import tifffile +import pandas as pd import shapely -import os -from .utils import masks_to_contours, contours_to_polygons from .fields import CosMxBoundaryFields +from .utils import contours_to_polygons, masks_to_contours -TOL_FRAC = 1. / 50 # Fraction of area to simplify by +TOL_FRAC = 1.0 / 50 # Fraction of area to simplify by # NOTE: In CosMX, there is a bug in their segmentation where cell masks overlap -# with compartment masks from other cells (e.g. a cell mask A overlaps with +# with compartment masks from other cells (e.g. a cell mask A overlaps with # nuclear mask for cell B). def get_cosmx_polygons( data_dir: os.PathLike, - boundary_type: Literal['cell', 'nucleus'], + boundary_type: Literal["cell", "nucleus"], ) -> gpd.GeoDataFrame: - """ - Extract cell or nuclear polygons from CosMX segmentation outputs. + """Extract cell or nuclear polygons from CosMX segmentation outputs. Parameters ---------- @@ -54,34 +55,31 @@ def get_cosmx_polygons( fov_pos_file = next(data_dir.glob(fields.fov_positions_filename)) # Check file and directory structures - fov_info = pd.read_csv(fov_pos_file, index_col='FOV') + fov_info = pd.read_csv(fov_pos_file, index_col="FOV") # Add 'Slide' column if doesn't exist - if 'Slide' not in fov_info: - fov_info['Slide'] = 1 + if "Slide" not in fov_info: + fov_info["Slide"] = 1 # Check compartment type - if boundary_type == 'cell': + if boundary_type == "cell": valid_codes = [ fields.nucleus_value, fields.membrane_value, fields.cytoplasmic_value, ] - elif boundary_type == 'nucleus': + elif boundary_type == "nucleus": valid_codes = [fields.nucleus_value] else: - msg = ( - f"Invalid compartment '{boundary_type}'. " - f"Choose 'cell' or 'nucleus'." - ) + msg = f"Invalid compartment '{boundary_type}'. " f"Choose 'cell' or 'nucleus'." raise ValueError(msg) # Assemble polygons per FOV polygons = [] for fov, row in fov_info.iterrows(): fov_id = str.zfill(str(fov), 3) - cell_path = cell_labels_dir / f'CellLabels_F{fov_id}.tif' - comp_path = comp_labels_dir / f'CompartmentLabels_F{fov_id}.tif' + cell_path = cell_labels_dir / f"CellLabels_F{fov_id}.tif" + comp_path = comp_labels_dir / f"CompartmentLabels_F{fov_id}.tif" # Get shapely polygons from cell masks cell_labels = tifffile.imread(cell_path) @@ -92,12 +90,12 @@ def get_cosmx_polygons( fov_poly = contours_to_polygons(*contours) # Remove redundant vertices - tol = np.sqrt(fov_poly.area).mean() * TOL_FRAC # scale by avg cell size + tol = np.sqrt(fov_poly.area).mean() * TOL_FRAC # scale by avg cell size fov_poly.geometry = fov_poly.geometry.simplify(tolerance=tol) # FOV coords -> Global coords - tx = row['X_mm'] * 1e3 / fields.mpp - ty = row['Y_mm'] * 1e3 / fields.mpp + tx = row["X_mm"] * 1e3 / fields.mpp + ty = row["Y_mm"] * 1e3 / fields.mpp # Flip y-axis and Translate to global position transform = AffineTransform(scale=[1, -1], translation=[tx, ty]) fov_poly.geometry = shapely.transform(fov_poly.geometry, transform) @@ -105,21 +103,20 @@ def get_cosmx_polygons( prefix = f"c_{row['Slide']}_{fov}_" # match CosMX ID structure fov_poly.index = prefix + fov_poly.index.astype(str) polygons.append(fov_poly) - + polygons = pd.concat(polygons) - tx = fov_info['X_mm'].max() * 1e3 / fields.mpp - ty = fov_info['Y_mm'].max() * 1e3 / fields.mpp - #transform = AffineTransform(translation=[tx, ty]) - #polygons.geometry = shapely.transform(polygons.geometry, transform) - + tx = fov_info["X_mm"].max() * 1e3 / fields.mpp + ty = fov_info["Y_mm"].max() * 1e3 / fields.mpp + # transform = AffineTransform(translation=[tx, ty]) + # polygons.geometry = shapely.transform(polygons.geometry, transform) + return polygons def _preflight_checks( data_dir: Path, ) -> None: - """ - Ensure input directories and FOV info file contain expected files and + """Ensure input directories and FOV info file contain expected files and columns. """ fields = CosMxBoundaryFields() @@ -133,38 +130,25 @@ def _preflight_checks( try: next(data_dir.glob(pat)) except StopIteration: - msg = ( - f"No file or directory with pattern '{pat}' " - f"found in {data_dir}." - ) + msg = f"No file or directory with pattern '{pat}' " f"found in {data_dir}." raise FileNotFoundError(msg) - fov_info = pd.read_csv( - next(data_dir.glob(fields.fov_positions_filename)), - index_col='FOV' - ) - required_cols = {'X_mm', 'Y_mm'} + fov_info = pd.read_csv(next(data_dir.glob(fields.fov_positions_filename)), index_col="FOV") + required_cols = {"X_mm", "Y_mm"} missing_cols = required_cols - set(fov_info.columns) if missing_cols: - raise ValueError( - f"Missing columns in FOV info: {', '.join(missing_cols)}" - ) + raise ValueError(f"Missing columns in FOV info: {', '.join(missing_cols)}") expected_fovs = [str.zfill(str(fov), 3) for fov in fov_info.index] - expected_files = lambda prefix: { - f"{prefix}_F{fov_id}.tif" for fov_id in expected_fovs - } + expected_files = lambda prefix: {f"{prefix}_F{fov_id}.tif" for fov_id in expected_fovs} for dirname, prefix in [ (fields.cell_labels_dirname, "CellLabels"), - (fields.compartment_labels_dirname, "CompartmentLabels") + (fields.compartment_labels_dirname, "CompartmentLabels"), ]: directory = next(data_dir.glob(dirname)) actual = {f.name for f in directory.glob("*.tif")} expected = expected_files(prefix) missing = expected - actual if missing: - raise FileNotFoundError( - f"Missing {len(missing)} {prefix} TIFFs:\n" + - "\n".join(sorted(missing)) - ) + raise FileNotFoundError(f"Missing {len(missing)} {prefix} TIFFs:\n" + "\n".join(sorted(missing))) diff --git a/src/segger/io/fields.py b/src/segger/io/fields.py index 40bd6be..011186e 100644 --- a/src/segger/io/fields.py +++ b/src/segger/io/fields.py @@ -2,77 +2,81 @@ from dataclasses import dataclass + # TODO: Add description @dataclass class XeniumTranscriptFields: - filename: str = 'transcripts.parquet' - x: str = 'x_location' - y: str = 'y_location' - feature: str = 'feature_name' - cell_id: str = 'cell_id' - null_cell_id: str = 'UNASSIGNED' - compartment: str = 'overlaps_nucleus' + filename: str = "transcripts.parquet" + x: str = "x_location" + y: str = "y_location" + feature: str = "feature_name" + cell_id: str = "cell_id" + null_cell_id: str = "UNASSIGNED" + compartment: str = "overlaps_nucleus" nucleus_value: int = 1 - quality: str = 'qv' + quality: str = "qv" filter_substrings = [ - 'NegControlProbe_*', - 'antisense_*', - 'NegControlCodeword*', - 'BLANK_*', - 'DeprecatedCodeword_*', - 'UnassignedCodeword_*', + "NegControlProbe_*", + "antisense_*", + "NegControlCodeword*", + "BLANK_*", + "DeprecatedCodeword_*", + "UnassignedCodeword_*", ] + @dataclass class XeniumBoundaryFields: - cell_filename: str = 'cell_boundaries.parquet' - nucleus_filename: str = 'nucleus_boundaries.parquet' - x: str = 'vertex_x' - y: str = 'vertex_y' - id: str = 'cell_id' + cell_filename: str = "cell_boundaries.parquet" + nucleus_filename: str = "nucleus_boundaries.parquet" + x: str = "vertex_x" + y: str = "vertex_y" + id: str = "cell_id" # TODO: Add description @dataclass class MerscopeTranscriptFields: - filename: str = 'detected_transcripts.csv' - x: str = 'global_x' - y: str = 'global_y' - feature: str = 'gene' - cell_id: str = 'cell_id' + filename: str = "detected_transcripts.csv" + x: str = "global_x" + y: str = "global_y" + feature: str = "gene" + cell_id: str = "cell_id" + @dataclass class MerscopeBoundaryFields: - cell_filename: str = 'cell_boundaries.parquet' - nucleus_filename: str = 'nucleus_boundaries.parquet' - id = 'EntityID' + cell_filename: str = "cell_boundaries.parquet" + nucleus_filename: str = "nucleus_boundaries.parquet" + id = "EntityID" # TODO: Add description @dataclass class CosMxTranscriptFields: - filename: str = '*_tx_file.csv' - x: str = 'x_global_px' - y: str = 'y_global_px' - feature: str = 'target' - cell_id: str = 'cell' - compartment: str = 'CellComp' - nucleus_value: str = 'Nuclear' - membrane_value: str = 'Membrane' - cytoplasmic_value: str = 'Cytoplasm' - extracellular_value: str = 'None' + filename: str = "*_tx_file.csv" + x: str = "x_global_px" + y: str = "y_global_px" + feature: str = "target" + cell_id: str = "cell" + compartment: str = "CellComp" + nucleus_value: str = "Nuclear" + membrane_value: str = "Membrane" + cytoplasmic_value: str = "Cytoplasm" + extracellular_value: str = "None" filter_substrings = [ - 'Negative*', - 'SystemControl*', - 'NegPrb*', + "Negative*", + "SystemControl*", + "NegPrb*", ] + @dataclass class CosMxBoundaryFields: - id: str = 'cell_id' - cell_labels_dirname: str = 'CellLabels' - compartment_labels_dirname: str = 'CompartmentLabels' - fov_positions_filename: str = '*fov_positions_file.csv' + id: str = "cell_id" + cell_labels_dirname: str = "CellLabels" + compartment_labels_dirname: str = "CompartmentLabels" + fov_positions_filename: str = "*fov_positions_file.csv" extracellular_value: int = 0 nucleus_value: int = 1 membrane_value: int = 2 @@ -83,37 +87,39 @@ class CosMxBoundaryFields: # TODO: Add description @dataclass class StandardTranscriptFields: - filename: str = 'transcripts.parquet' - row_index: str = 'row_index' - x: str = 'x' - y: str = 'y' - feature: str = 'feature_name' - cell_id: str = 'cell_id' - compartment: str = 'cell_compartment' + filename: str = "transcripts.parquet" + row_index: str = "row_index" + x: str = "x" + y: str = "y" + feature: str = "feature_name" + cell_id: str = "cell_id" + compartment: str = "cell_compartment" extracellular_value: int = 0 cytoplasmic_value: int = 1 nucleus_value: int = 2 + @dataclass class StandardBoundaryFields: - filename: str = 'boundaries.parquet' - id: str = 'cell_id' - boundary_type: str = 'boundary_type' - cell_value: str = 'cell' - nucleus_value: str = 'nucleus' - contains_nucleus: str = 'contains_nucleus' + filename: str = "boundaries.parquet" + id: str = "cell_id" + boundary_type: str = "boundary_type" + cell_value: str = "cell" + nucleus_value: str = "nucleus" + contains_nucleus: str = "contains_nucleus" # TODO: Add description @dataclass class TrainingTranscriptFields(StandardTranscriptFields): - cell_encoding: str = 'cell_encoding' - gene_encoding: str = 'gene_encoding' - cell_cluster: str = 'cell_cluster' - gene_cluster: str = 'gene_cluster' + cell_encoding: str = "cell_encoding" + gene_encoding: str = "gene_encoding" + cell_cluster: str = "cell_cluster" + gene_cluster: str = "gene_cluster" + @dataclass class TrainingBoundaryFields(StandardBoundaryFields): - index: str = 'entity_index' - cell_encoding: str = 'cell_encoding' - cell_cluster: str = 'cell_cluster' + index: str = "entity_index" + cell_encoding: str = "cell_encoding" + cell_cluster: str = "cell_cluster" diff --git a/src/segger/io/preprocessor.py b/src/segger/io/preprocessor.py index 597a818..7c9d99c 100644 --- a/src/segger/io/preprocessor.py +++ b/src/segger/io/preprocessor.py @@ -1,32 +1,27 @@ -from pandas.errors import DtypeWarning -from functools import cached_property +import logging +import sys +import warnings from abc import ABC, abstractmethod -from anndata import AnnData -from typing import Literal +from functools import cached_property from pathlib import Path + import geopandas as gpd -import polars as pl import pandas as pd -import warnings -import logging -import sys +import polars as pl +from pandas.errors import DtypeWarning + +from anndata import AnnData from .cosmx import get_cosmx_polygons -from .utils import ( - contours_to_polygons, - fix_invalid_geometry, -) from .fields import ( - MerscopeTranscriptFields, - MerscopeBoundaryFields, - StandardTranscriptFields, + CosMxBoundaryFields, + CosMxTranscriptFields, StandardBoundaryFields, - XeniumTranscriptFields, + StandardTranscriptFields, XeniumBoundaryFields, - CosMxTranscriptFields, - CosMxBoundaryFields, + XeniumTranscriptFields, ) - +from .utils import contours_to_polygons, fix_invalid_geometry # Ignore pandas warnings in CosMX transcripts file warnings.filterwarnings("ignore", category=DtypeWarning) @@ -34,10 +29,10 @@ # Register of available ISTPreprocessor subclasses keyed by platform name. PREPROCESSORS = {} + def register_preprocessor(name): - """ - Decorator to register a preprocessor class under a given platform name. - + """Decorator to register a preprocessor class under a given platform name. + Parameters ---------- name : str @@ -48,21 +43,22 @@ def register_preprocessor(name): decorator : Callable Class decorator that adds the class to the PREPROCESSORS registry. """ + def decorator(cls): PREPROCESSORS[name] = cls return cls + return decorator + class ISTPreprocessor(ABC): - """ - Abstract base class for platform-specific preprocessing of spatial + """Abstract base class for platform-specific preprocessing of spatial transcriptomics data. Subclasses must implement methods to construct transcript and boundary GeoDataFrames for the given platform. """ def __init__(self, data_dir: Path): - """ - Parameters + """Parameters ---------- data_dir : Path Path to the raw data directory for the spatial platform. @@ -74,34 +70,23 @@ def __init__(self, data_dir: Path): @staticmethod @abstractmethod def _validate_directory(data_dir: Path): - """ - Check that all required files/directories are present in `data_dir`. - """ + """Check that all required files/directories are present in `data_dir`.""" ... @property @abstractmethod def transcripts(self) -> pl.DataFrame: - """ - Construct, standardize, and return transcripts as a Polars DataFrame. - """ + """Construct, standardize, and return transcripts as a Polars DataFrame.""" ... @property @abstractmethod def boundaries(self) -> gpd.GeoDataFrame: - """ - Construct, standardize, and return cell boundaries. - """ + """Construct, standardize, and return cell boundaries.""" ... - def _get_anndata( - self, - transcripts: gpd.GeoDataFrame, - label: str - ) -> AnnData: - """ - Convert transcript data to an AnnData object using a specified + def _get_anndata(self, transcripts: gpd.GeoDataFrame, label: str) -> AnnData: + """Convert transcript data to an AnnData object using a specified segmentation label column. Parameters @@ -118,14 +103,8 @@ def _get_anndata( """ ... - def save( - self, - out_dir: Path, - verbose: bool = False, - overwrite: bool = False - ): - """ - Generate and save GeoParquet files for transcripts, cell and nucleus + def save(self, out_dir: Path, verbose: bool = False, overwrite: bool = False): + """Generate and save GeoParquet files for transcripts, cell and nucleus boundaries, and an AnnData object from transcript-to-nucleus mappings. Parameters @@ -137,10 +116,10 @@ def save( """ logger = self._setup_logging(verbose) - self.tx_out = out_dir / 'transcripts.parquet' - self.ad_out = out_dir / 'nucleus_boundaries.h5ad' - self.bd_out_cell = out_dir / 'cell_boundaries_geo.parquet' - self.bd_out_nuc = out_dir / 'nucleus_boundaries_geo.parquet' + self.tx_out = out_dir / "transcripts.parquet" + self.ad_out = out_dir / "nucleus_boundaries.h5ad" + self.bd_out_cell = out_dir / "cell_boundaries_geo.parquet" + self.bd_out_nuc = out_dir / "nucleus_boundaries_geo.parquet" logger.info("Loading transcripts") tx = self._get_transcripts() @@ -150,24 +129,16 @@ def save( bd_nuc = gpd.read_parquet(self.bd_out_nuc) else: logger.info("Constructing & saving nuclear boundaries") - bd_nuc = self._get_boundaries('nucleus') - bd_nuc.to_parquet( - self.bd_out_nuc, - write_covering_bbox=True, - geometry_encoding="geoarrow" - ) - + bd_nuc = self._get_boundaries("nucleus") + bd_nuc.to_parquet(self.bd_out_nuc, write_covering_bbox=True, geometry_encoding="geoarrow") + if self.bd_out_cell.exists() and not overwrite: logger.info("Loading cell boundaries (from file)") bd_cell = gpd.read_parquet(self.bd_out_cell) else: logger.info("Constructing & saving cell boundaries") - bd_cell = self._get_boundaries('cell') - bd_cell.to_parquet( - self.bd_out_cell, - write_covering_bbox=True, - geometry_encoding="geoarrow" - ) + bd_cell = self._get_boundaries("cell") + bd_cell.to_parquet(self.bd_out_cell, write_covering_bbox=True, geometry_encoding="geoarrow") logger.info("Assigning to nuclear boundaries") lbl = "nucleus_boundaries_id" @@ -178,7 +149,7 @@ def save( tx = self.assign_transcripts_to_boundaries(tx, bd_cell, lbl) logger.info("Saving transcripts") - tx = pd.DataFrame(tx.drop(columns='geometry')) + tx = pd.DataFrame(tx.drop(columns="geometry")) tx.to_parquet(self.tx_out, index=False) logger.info("Creating AnnData") @@ -188,13 +159,9 @@ def save( ad.write_h5ad(self.ad_out) def assign_transcripts_to_boundaries( - self, - transcripts: gpd.GeoDataFrame, - boundaries: gpd.GeoDataFrame, - boundary_label: str = "boundaries_id" + self, transcripts: gpd.GeoDataFrame, boundaries: gpd.GeoDataFrame, boundary_label: str = "boundaries_id" ) -> gpd.GeoDataFrame: - """ - Assign transcripts to boundaries using spatial join. + """Assign transcripts to boundaries using spatial join. Parameters ---------- @@ -210,28 +177,21 @@ def assign_transcripts_to_boundaries( gpd.GeoDataFrame Transcripts with assigned segmentation labels. """ - joined = gpd.sjoin( - transcripts, - boundaries, - how="left", - predicate="intersects" - ) - + joined = gpd.sjoin(transcripts, boundaries, how="left", predicate="intersects") + return joined.rename(columns={"index_right": boundary_label}) - + def _setup_logging(self, verbose: bool = False) -> logging.Logger: class TimeFilter(logging.Filter): - def filter(self, record): from datetime import datetime + try: last = self.last except AttributeError: last = record.relativeCreated - delta = datetime.fromtimestamp(record.relativeCreated/1e3) - \ - datetime.fromtimestamp(last/1e3) - record.relative = '{0:.2f}'.format( - delta.seconds + delta.microseconds/1e6) + delta = datetime.fromtimestamp(record.relativeCreated / 1e3) - datetime.fromtimestamp(last / 1e3) + record.relative = f"{delta.seconds + delta.microseconds / 1e6:.2f}" self.last = record.relativeCreated return True @@ -243,20 +203,16 @@ def filter(self, record): logger.addHandler(handler) for hndl in logger.handlers: hndl.addFilter(TimeFilter()) - hndl.setFormatter(logging.Formatter( - fmt="%(asctime)s (%(relative)ss) %(message)s" - )) + hndl.setFormatter(logging.Formatter(fmt="%(asctime)s (%(relative)ss) %(message)s")) return logger @register_preprocessor("nanostring_cosmx") class CosMXPreprocessor(ISTPreprocessor): - """ - Preprocessor for NanoString CosMX datasets. - """ + """Preprocessor for NanoString CosMX datasets.""" + @staticmethod def _validate_directory(data_dir: Path): - # Check required files/directories bd_fields = CosMxBoundaryFields() tx_fields = CosMxTranscriptFields() @@ -268,14 +224,13 @@ def _validate_directory(data_dir: Path): ]: num_matches = len(list(data_dir.glob(pat))) if not num_matches == 1: - raise IOError( + raise OSError( f"CosMx sample directory must contain exactly 1 file or " f"directory matching {pat}, but found {num_matches}." ) @cached_property def transcripts(self) -> pl.DataFrame: - # Field names raw = CosMxTranscriptFields() std = StandardTranscriptFields() @@ -285,9 +240,7 @@ def transcripts(self) -> pl.DataFrame: pl.scan_csv(next(self.data_dir.glob(raw.filename))) .with_row_index(name=std.row_index) # Filter data - .filter(pl.col(raw.feature).str.contains( - '|'.join(raw.filter_substrings)).not_() - ) + .filter(pl.col(raw.feature).str.contains("|".join(raw.filter_substrings)).not_()) # Standardize compartment labels .with_columns( pl.col(raw.compartment) @@ -312,11 +265,8 @@ def transcripts(self) -> pl.DataFrame: ) # Map to standard field names .rename({raw.x: std.x, raw.y: std.y, raw.feature: std.feature}) - - # Subset to necessary fields - .select([std.row_index, std.x, std.y, std.feature, std.cell_id, - std.compartment]) - + # Subset to necessary fields + .select([std.row_index, std.x, std.y, std.feature, std.cell_id, std.compartment]) # Add numeric index .with_row_index() .collect() @@ -324,19 +274,16 @@ def transcripts(self) -> pl.DataFrame: @cached_property def boundaries(self) -> gpd.GeoDataFrame: - # Field names - raw = CosMxBoundaryFields() + CosMxBoundaryFields() std = StandardBoundaryFields() # Join boundary datasets - cells = get_cosmx_polygons(self.data_dir, 'cell').reset_index( - drop=False, names=std.id) + cells = get_cosmx_polygons(self.data_dir, "cell").reset_index(drop=False, names=std.id) cells = fix_invalid_geometry(cells) cells[std.boundary_type] = std.cell_value - nuclei = get_cosmx_polygons(self.data_dir, 'nucleus').reset_index( - drop=False, names=std.id) + nuclei = get_cosmx_polygons(self.data_dir, "nucleus").reset_index(drop=False, names=std.id) nuclei = fix_invalid_geometry(nuclei) nuclei[std.boundary_type] = std.nucleus_value @@ -352,29 +299,30 @@ def boundaries(self) -> gpd.GeoDataFrame: .get(std.boundary_type) ) # Convert index to string type (to join on AnnData) - bd.index = bd[std.id] + '_' + bd[std.boundary_type].map({ - std.nucleus_value: '0', - std.cell_value: '1', - }) + bd.index = ( + bd[std.id] + + "_" + + bd[std.boundary_type].map( + { + std.nucleus_value: "0", + std.cell_value: "1", + } + ) + ) return bd - + def _get_anndata(self, transcripts, label): return utils.transcripts_to_anndata( - transcripts=transcripts, - cell_label=label, - gene_label=self._gene, - coordinate_labels=[self._x, self._y] + transcripts=transcripts, cell_label=label, gene_label=self._gene, coordinate_labels=[self._x, self._y] ) @register_preprocessor("10x_xenium") class XeniumPreprocessor(ISTPreprocessor): - """ - Preprocessor for 10x Genomics Xenium datasets. - """ + """Preprocessor for 10x Genomics Xenium datasets.""" + @staticmethod def _validate_directory(data_dir: Path): - # Check required files/directories bd_fields = XeniumBoundaryFields() tx_fields = XeniumTranscriptFields() @@ -385,63 +333,45 @@ def _validate_directory(data_dir: Path): ]: num_matches = len(list(data_dir.glob(pat))) if not num_matches == 1: - raise IOError( + raise OSError( f"Xenium sample directory must contain exactly 1 file or " f"directory matching {pat}, but found {num_matches}." ) @cached_property def transcripts(self) -> pl.DataFrame: - # Field names raw = XeniumTranscriptFields() std = StandardTranscriptFields() return ( # Read in lazily - pl.scan_parquet( - self.data_dir / raw.filename, - parallel='row_groups' - ) + pl.scan_parquet(self.data_dir / raw.filename, parallel="row_groups") # Add numeric index at beginning .with_row_index(name=std.row_index) # Filter data .filter(pl.col(raw.quality) >= 20) - .filter(pl.col(raw.feature).str.contains( - '|'.join(raw.filter_substrings)).not_() - ) + .filter(pl.col(raw.feature).str.contains("|".join(raw.filter_substrings)).not_()) # Standardize compartment labels .with_columns( pl.when(pl.col(raw.compartment) == raw.nucleus_value) .then(std.nucleus_value) - .when( - (pl.col(raw.compartment) != raw.nucleus_value) & - (pl.col(raw.cell_id) != raw.null_cell_id) - ) + .when((pl.col(raw.compartment) != raw.nucleus_value) & (pl.col(raw.cell_id) != raw.null_cell_id)) .then(std.cytoplasmic_value) .otherwise(std.extracellular_value) .alias(std.compartment) ) # Standardize cell IDs - .with_columns( - pl.col(raw.cell_id) - .replace(raw.null_cell_id, None) - .alias(std.cell_id) - ) + .with_columns(pl.col(raw.cell_id).replace(raw.null_cell_id, None).alias(std.cell_id)) # Map to standard field names .rename({raw.x: std.x, raw.y: std.y, raw.feature: std.feature}) - - # Subset to necessary fields - .select([std.row_index, std.x, std.y, std.feature, std.cell_id, - std.compartment]) + # Subset to necessary fields + .select([std.row_index, std.x, std.y, std.feature, std.cell_id, std.compartment]) .collect() ) @staticmethod - def _get_boundaries( - filepath: Path, - boundary_type: str - ) -> gpd.GeoDataFrame: + def _get_boundaries(filepath: Path, boundary_type: str) -> gpd.GeoDataFrame: # TODO: Add documentation # Field names @@ -449,7 +379,7 @@ def _get_boundaries( std = StandardBoundaryFields() # Read in flat vertices and convert to geometries - bd = pl.read_parquet(filepath, parallel='row_groups') + bd = pl.read_parquet(filepath, parallel="row_groups") bd = contours_to_polygons( x=bd[raw.x].to_numpy(), y=bd[raw.y].to_numpy(), @@ -459,7 +389,7 @@ def _get_boundaries( # Standardize cell ids and types bd[std.boundary_type] = boundary_type return bd - + @cached_property def boundaries(self) -> gpd.GeoDataFrame: # TODO: Add documentation @@ -467,18 +397,12 @@ def boundaries(self) -> gpd.GeoDataFrame: std = StandardBoundaryFields() # Join boundary datasets - cells = self._get_boundaries( - self.data_dir / raw.cell_filename, - std.cell_value - ) - nuclei = self._get_boundaries( - self.data_dir / raw.nucleus_filename, - std.nucleus_value - ) + cells = self._get_boundaries(self.data_dir / raw.cell_filename, std.cell_value) + nuclei = self._get_boundaries(self.data_dir / raw.nucleus_filename, std.nucleus_value) # 10X Xenium nucleus segmentation is intersection of geometries idx = cells.index.intersection(nuclei.index) - ixn = cells.loc[idx].intersection(nuclei.loc[idx]) + cells.loc[idx].intersection(nuclei.loc[idx]) # Remove non-overlapping geometries (10X bug) # empty = ixn.is_empty # nuclei.drop(idx[empty], axis=0, inplace=True) @@ -492,24 +416,31 @@ def boundaries(self) -> gpd.GeoDataFrame: cells.loc[idx, std.contains_nucleus] = True # Join geometries - bd = pd.concat([ - cells.reset_index(drop=False, names=std.id), - nuclei.reset_index(drop=False, names=std.id), - ]) + bd = pd.concat( + [ + cells.reset_index(drop=False, names=std.id), + nuclei.reset_index(drop=False, names=std.id), + ] + ) # Convert index to string type (to join on AnnData) - bd.index = bd[std.id] + '_' + bd[std.boundary_type].map({ - std.nucleus_value: '0', - std.cell_value: '1', - }) + bd.index = ( + bd[std.id] + + "_" + + bd[std.boundary_type].map( + { + std.nucleus_value: "0", + std.cell_value: "1", + } + ) + ) return bd @register_preprocessor("vizgen_merscope") class MerscopePreprocessor(ISTPreprocessor): - """ - Preprocessor for Vizgen MERSCOPE datasets. - """ + """Preprocessor for Vizgen MERSCOPE datasets.""" + @staticmethod def _validate_directory(data_dir: Path): raise NotImplementedError() @@ -526,29 +457,18 @@ def _infer_platform(data_dir: Path) -> str: exceptions.append(e) if len(matches) == 0: err_str = ", ".join(map(str, exceptions)) - raise ValueError( - f"Could not infer platform from data directory: {err_str}." - ) + raise ValueError(f"Could not infer platform from data directory: {err_str}.") elif len(matches) > 1: conflicting_platforms = ", ".join(matches) - raise ValueError( - f"Ambiguous data directory: Multiple platforms match: " - f"{conflicting_platforms}." - ) + raise ValueError(f"Ambiguous data directory: Multiple platforms match: " f"{conflicting_platforms}.") return matches[0] -def get_preprocessor( - data_dir: Path, - platform: str | None = None -) -> ISTPreprocessor: +def get_preprocessor(data_dir: Path, platform: str | None = None) -> ISTPreprocessor: data_dir = Path(data_dir) if platform is None: - platform = _infer_platform(data_dir) + platform = _infer_platform(data_dir) if platform not in PREPROCESSORS: - raise ValueError( - f"Unknown platform: '{platform}'. " - f"Available: {list(PREPROCESSORS)}" - ) + raise ValueError(f"Unknown platform: '{platform}'. " f"Available: {list(PREPROCESSORS)}") cls = PREPROCESSORS[platform.lower()] return cls(data_dir) diff --git a/src/segger/io/utils.py b/src/segger/io/utils.py index 1e55b8f..f2d0fa8 100644 --- a/src/segger/io/utils.py +++ b/src/segger/io/utils.py @@ -1,13 +1,14 @@ -from numpy.typing import ArrayLike +import cv2 +import skimage + import geopandas as gpd import numpy as np -import skimage import shapely -import cv2 +from numpy.typing import ArrayLike + def masks_to_contours(masks: ArrayLike) -> np.ndarray: - """ - Convert labeled mask image to contours with cell ID annotations. + """Convert labeled mask image to contours with cell ID annotations. Parameters ---------- @@ -25,19 +26,18 @@ def masks_to_contours(masks: ArrayLike) -> np.ndarray: for p in props: # Get largest contour with label lbl_contours = cv2.findContours( - np.pad(p.image, 0).astype('uint8'), + np.pad(p.image, 0).astype("uint8"), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE, )[0] contour = sorted(lbl_contours, key=lambda c: c.shape[0])[-1] if contour.shape[0] > 2: - contour = np.hstack([ - np.squeeze(contour)[:, ::-1] + p.bbox[:2], # vertices - np.full((contour.shape[0], 1), p.label) # ID - ]) + contour = np.hstack( + [np.squeeze(contour)[:, ::-1] + p.bbox[:2], np.full((contour.shape[0], 1), p.label)] # vertices # ID + ) contours.append(contour) contours = np.concatenate(contours) - + return contours @@ -46,8 +46,7 @@ def contours_to_polygons( y: ArrayLike, ids: ArrayLike, ) -> gpd.GeoDataFrame: - """ - Convert contour vertices into Shapely polygons. + """Convert contour vertices into Shapely polygons. Parameters ---------- @@ -72,7 +71,7 @@ def contours_to_polygons( part_offset = np.arange(len(np.unique(ids)) + 1) polygons = shapely.from_ragged_array( shapely.GeometryType.POLYGON, - coords=np.stack([x, y]).T.copy(order='C'), + coords=np.stack([x, y]).T.copy(order="C"), offsets=(geometry_offset, part_offset), ) @@ -81,8 +80,7 @@ def contours_to_polygons( def resort_coordinates(poly): - """ - Sort a list of (x, y) coordinates in counter-clockwise order. + """Sort a list of (x, y) coordinates in counter-clockwise order. Parameters ---------- @@ -96,19 +94,18 @@ def resort_coordinates(poly): """ coords = np.asarray(poly.exterior.xy).T cx, cy = coords.mean(axis=0) - angles = np.arctan2(coords[:,1] - cy, coords[:,0] - cx) + angles = np.arctan2(coords[:, 1] - cy, coords[:, 0] - cx) sorted_coords = coords[np.argsort(angles)] return shapely.Polygon(sorted_coords) def fix_invalid_geometry(gdf: gpd.GeoDataFrame): - """ - Fix invalid geometries by resorting coordinates. - """ + """Fix invalid geometries by resorting coordinates.""" mask = ~gdf.geometry.is_valid - if not mask.any(): return gdf - + if not mask.any(): + return gdf + fixed = gdf.loc[mask].geometry.apply(resort_coordinates) gdf.loc[mask, gdf.geometry.name] = fixed return gdf diff --git a/src/segger/metrics/segment.py b/src/segger/metrics/segment.py index cba36d6..282857b 100644 --- a/src/segger/metrics/segment.py +++ b/src/segger/metrics/segment.py @@ -1,33 +1,34 @@ -import shapely -import pyarrow -import numpy as np -import cupy as cp +import os +import sys + import cudf +import cupy as cp import cuspatial +import pyarrow + +import numpy as np import pandas as pd -import sys -import os +import shapely + + +class TranscriptColumns: + """_summary_.""" + x = "x_location" + y = "y_location" + id = "codeword_index" + label = "feature_name" + xy = [x, y] -class TranscriptColumns(): - """ - _summary_ - """ - x = 'x_location' - y = 'y_location' - id = 'codeword_index' - label = 'feature_name' - xy = [x,y] -class BoundaryColumns(): - """ - _summary_ - """ - x = 'vertex_x' - y = 'vertex_y' - id = 'label_id' - label = 'cell_id' - xy = [x,y] +class BoundaryColumns: + """_summary_.""" + + x = "vertex_x" + y = "vertex_y" + id = "label_id" + label = "cell_id" + xy = [x, y] def get_xy_bounds( @@ -70,14 +71,17 @@ def read_parquet_region( # Find bounds of full file if not supplied if bounds is None: bounds = get_xy_bounds(filepath, x, y) - + # Load pre-filtered data from Parquet file - filters = [[ - (x, '>', bounds.bounds[0]), - (y, '>', bounds.bounds[1]), - (x, '<', bounds.bounds[2]), - (y, '<', bounds.bounds[3]), - ] + extra_filters] + filters = [ + [ + (x, ">", bounds.bounds[0]), + (y, ">", bounds.bounds[1]), + (x, "<", bounds.bounds[2]), + (y, "<", bounds.bounds[3]), + ] + + extra_filters + ] columns = [x, y] + extra_columns region = dask_cudf.read_parquet( @@ -94,7 +98,7 @@ def get_polygons_from_xy( ): # Directly convert to GeoSeries from cuDF names = BoundaryColumns - vertices = boundaries[names.xy].astype('double') + vertices = boundaries[names.xy].astype("double") ids = boundaries[names.id].values splits = cp.where(ids[:-1] != ids[1:])[0] + 1 geometry_offset = cp.hstack([0, splits, len(ids)]) @@ -115,7 +119,7 @@ def get_points_from_xy( ): # Directly convert to GeoSeries from cuDF names = TranscriptColumns - coords = transcripts[names.xy].astype('double') + coords = transcripts[names.xy].astype("double") points = cuspatial.GeoSeries.from_points_xy(coords.interleave_columns()) del coords gc.collect() @@ -129,29 +133,31 @@ def filter_boundaries( ): # Determine overlaps of boundary polygons names = BoundaryColumns + def in_region(region): in_x = boundaries[names.x].between(region.bounds[0], region.bounds[2]) in_y = boundaries[names.y].between(region.bounds[1], region.bounds[3]) return in_x & in_y + x1, y1, x4, y4 = outset.bounds x2, y2, x3, y3 = inset.bounds - boundaries['top'] = in_region(shapely.box(x1, y1, x4, y2)) - boundaries['left'] = in_region(shapely.box(x1, y1, x2, y4)) - boundaries['right'] = in_region(shapely.box(x3, y1, x4, y4)) - boundaries['bottom'] = in_region(shapely.box(x1, y3, x4, y4)) - boundaries['center'] = in_region(inset) + boundaries["top"] = in_region(shapely.box(x1, y1, x4, y2)) + boundaries["left"] = in_region(shapely.box(x1, y1, x2, y4)) + boundaries["right"] = in_region(shapely.box(x3, y1, x4, y4)) + boundaries["bottom"] = in_region(shapely.box(x1, y3, x4, y4)) + boundaries["center"] = in_region(inset) # Filter boundary polygons # Include overlaps with top and left, not bottom and right gb = boundaries.groupby(names.id, sort=False) - total = gb['center'].transform('size') - in_top = gb['top'].transform('sum') - in_left = gb['left'].transform('sum') - in_right = gb['right'].transform('sum') - in_bottom = gb['bottom'].transform('sum') - in_center = gb['center'].transform('sum') + total = gb["center"].transform("size") + in_top = gb["top"].transform("sum") + in_left = gb["left"].transform("sum") + in_right = gb["right"].transform("sum") + in_bottom = gb["bottom"].transform("sum") + in_center = gb["center"].transform("sum") keep = in_center == total - keep |= ((in_center > 0) & (in_left > 0) & (in_bottom == 0)) - keep |= ((in_center > 0) & (in_top > 0) & (in_right == 0)) + keep |= (in_center > 0) & (in_left > 0) & (in_bottom == 0) + keep |= (in_center > 0) & (in_top > 0) & (in_right == 0) inset_boundaries = boundaries.loc[keep] return inset_boundaries @@ -173,19 +179,25 @@ def get_quadtree_kwargs( ): if user_quadtree_kwargs is None: user_quadtree_kwargs = {} - kwargs = dict(max_depth=10, max_size=10000) + kwargs = {"max_depth": 10, "max_size": 10000} kwargs.update(user_quadtree_kwargs) - if 'scale' not in kwargs: - kwargs['scale'] = max( - bounds.bounds[2] - bounds.bounds[0], - bounds.bounds[3] - bounds.bounds[1], - ) // (1 << kwargs['max_depth']) + 1 - kwargs.update(dict( - x_min=bounds.bounds[0], - y_min=bounds.bounds[1], - x_max=bounds.bounds[2], - y_max=bounds.bounds[3], - )) + if "scale" not in kwargs: + kwargs["scale"] = ( + max( + bounds.bounds[2] - bounds.bounds[0], + bounds.bounds[3] - bounds.bounds[1], + ) + // (1 << kwargs["max_depth"]) + + 1 + ) + kwargs.update( + { + "x_min": bounds.bounds[0], + "y_min": bounds.bounds[1], + "x_max": bounds.bounds[2], + "y_max": bounds.bounds[3], + } + ) return kwargs @@ -196,22 +208,18 @@ def get_expression_matrix( polygons_idx: np.ndarray, bounds: shapely.Polygon, quadtree_kwargs: dict = None, -): +): # Keyword arguments reused below kwargs = get_quadtree_kwargs(bounds, quadtree_kwargs) - + # Build quadtree on points keys_to_pts, quadtree = cuspatial.quadtree_on_points(points, **kwargs) - + # Create bounding box and quadtree lookup - kwargs.pop('max_size') # not used below + kwargs.pop("max_size") # not used below bboxes = cuspatial.polygon_bounding_boxes(polygons) - poly_quad_pairs = cuspatial.join_quadtree_and_bounding_boxes( - quadtree, - bboxes, - **kwargs - ) - + poly_quad_pairs = cuspatial.join_quadtree_and_bounding_boxes(quadtree, bboxes, **kwargs) + # Assign transcripts to cells based on polygon boundaries result = cuspatial.quadtree_point_in_polygon( poly_quad_pairs, @@ -220,17 +228,17 @@ def get_expression_matrix( points, polygons, ) - + # Map from transcript index to gene index codes = cudf.Series(points_idx) - col_ind = result['point_index'].map(keys_to_pts).map(codes) + col_ind = result["point_index"].map(keys_to_pts).map(codes) col_ind = col_ind.to_numpy() - + # Get ordered cell IDs from Xenium _, row_uniques = pd.factorize(polygons_idx) - row_ind = result['polygon_index'].map(cudf.Series(row_uniques)) + row_ind = result["polygon_index"].map(cudf.Series(row_uniques)) row_ind = row_ind.to_numpy() - 1 # originally, 1-index - + # Construct sparse expression matrix X = sp.sparse.csr_array( (np.ones(result.shape[0]), (row_ind, col_ind)), @@ -246,31 +254,26 @@ def get_buffered_counts( buffer_distance: float, overlap: float = 100, quadtree_kwargs: dict = None, - ): # Load transcripts - outset = bounds.buffer(overlap, join_style='mitre') + outset = bounds.buffer(overlap, join_style="mitre") transcripts = read_parquet_region( filepath_transcripts, TranscriptColumns.x, TranscriptColumns.y, bounds=outset, extra_columns=[TranscriptColumns.id], - extra_filters=[('qv', '>', 20)], + extra_filters=[("qv", ">", 20)], ).compute() points = get_points_from_xy(transcripts) - + # Load boundaries boundaries = read_parquet_region( - filepath_boundaries, - BoundaryColumns.x, - BoundaryColumns.y, - bounds=outset, - extra_columns=[BoundaryColumns.id] + filepath_boundaries, BoundaryColumns.x, BoundaryColumns.y, bounds=outset, extra_columns=[BoundaryColumns.id] ).compute() boundaries = filter_boundaries(boundaries, bounds, outset) polygons = get_polygons_from_xy(boundaries) - + if buffer_distance != 0: polygons = buffer_polygons(polygons, buffer_distance) @@ -289,24 +292,24 @@ def get_buffered_counts( def key_to_coordinate(key): # Convert the key to binary and remove the '0b' prefix binary_key = bin(key)[2:] - + # Make sure the binary string length is even by prepending a '0' if necessary if len(binary_key) % 2 != 0: - binary_key = '0' + binary_key + binary_key = "0" + binary_key # Split the binary string into pairs - pairs = [binary_key[i:i+2] for i in range(0, len(binary_key), 2)] - + pairs = [binary_key[i : i + 2] for i in range(0, len(binary_key), 2)] + # Initialize coordinates x, y = 0, 0 - + # Iterate through each pair to calculate the sum of positions for i, pair in enumerate(pairs): power_of_2 = 2 ** (len(pairs) - i - 1) y += int(pair[0]) * power_of_2 x += int(pair[1]) * power_of_2 - - return pd.Series([y, x], index=['y', 'x'], name=key) + + return pd.Series([y, x], index=["y", "x"], name=key) def get_quadrant_bounds( @@ -317,15 +320,15 @@ def get_quadrant_bounds( x_min, y_min, x_max, y_max = bounds.bounds width = x_max - x_min height = y_max - y_min - levels = quadtree['level'] + 1 - coords = quadtree['key'].apply(key_to_coordinate) - quadrant_size_x = width / 2**levels - quadrant_size_y = height / 2**levels - - quadtree['x_min'] = x_min + coords['x'] * quadrant_size_x - quadtree['x_max'] = quadtree['x_min'] + quadrant_size_x - quadtree['y_min'] = y_min + coords['y'] * quadrant_size_y - quadtree['y_max'] = quadtree['y_min'] + quadrant_size_y + levels = quadtree["level"] + 1 + coords = quadtree["key"].apply(key_to_coordinate) + quadrant_size_x = width / 2**levels + quadrant_size_y = height / 2**levels + + quadtree["x_min"] = x_min + coords["x"] * quadrant_size_x + quadtree["x_max"] = quadtree["x_min"] + quadrant_size_x + quadtree["y_min"] = y_min + coords["y"] * quadrant_size_y + quadtree["y_max"] = quadtree["y_min"] + quadrant_size_y return quadtree @@ -346,21 +349,18 @@ def get_transcripts_regions( points = get_points_from_xy(transcripts) del transcripts gc.collect() - + # Build quadtree on points - kwargs = dict(max_depth=10, max_size=max_size) + kwargs = {"max_depth": 10, "max_size": max_size} kwargs = get_quadtree_kwargs(bounds, kwargs) _, quadtree = cuspatial.quadtree_on_points(points, **kwargs) quadtree_df = quadtree.to_pandas() del quadtree gc.collect() - + # Get boundaries of quadtree quadrants quadtree_df = get_quadrant_bounds(quadtree_df, bounds) - regions = quadtree_df.loc[ - ~quadtree_df['is_internal_node'], - ['x_min', 'y_min', 'x_max', 'y_max'] - ] + regions = quadtree_df.loc[~quadtree_df["is_internal_node"], ["x_min", "y_min", "x_max", "y_max"]] return regions @@ -375,13 +375,13 @@ def get_cell_labels( split_row_groups=row_group_chunksize, columns=[id, label], ) - boundaries[label] = boundaries[label].str.replace('\x00', '') + boundaries[label] = boundaries[label].str.replace("\x00", "") cell_labels = boundaries[label].unique().compute() return cell_labels.to_numpy() def get_buffered_counts_distributed( - filepath_transcripts, # + filepath_transcripts, # filepath_boundaries, # Need adapting for Geoparquet filepath_gene_panel, buffer_distance, @@ -401,15 +401,12 @@ def get_buffered_counts_distributed( shapely.box(x_mid, y_min, x_max, y_mid), # Q3 shapely.box(x_mid, y_mid, x_max, y_max), # Q4 ] - + # Build quadtree and get boundaries of each quadrant region - futures = client.map( - lambda q: get_transcripts_regions(filepath_transcripts, bounds=q), - quadrants - ) + futures = client.map(lambda q: get_transcripts_regions(filepath_transcripts, bounds=q), quadrants) regions = pd.concat(client.gather(futures)) gc.collect() - + # Build new counts matrices for each region using buffered (offset) cell # boundaries # Note: transcripts can be doubly-counted @@ -424,22 +421,22 @@ def get_buffered_counts_distributed( ) matrices = client.gather(futures) gc.collect() - + # Combine matrices into one output_shape = tuple(np.array([*map(np.shape, matrices)]).max(0)) - X = sp.sparse.csr_array(output_shape, dtype='uint32') + X = sp.sparse.csr_array(output_shape, dtype="uint32") for matrix in matrices: matrix.resize(output_shape) X += matrix - + # Get gene labels and reorder according to gene panel with open(filepath_gene_panel) as f: gene_panel = json.load(f) - targets = gene_panel['payload']['targets'] - index = [t['codewords'][0] for t in targets] - gene_labels = [t['type']['data']['name'].upper() for t in targets] + targets = gene_panel["payload"]["targets"] + index = [t["codewords"][0] for t in targets] + gene_labels = [t["type"]["data"]["name"].upper() for t in targets] X = X[:, index] - + # Get cell labels # cell_labels = get_cell_labels(filepath_boundaries) @@ -449,4 +446,4 @@ def get_buffered_counts_distributed( obs=pd.DataFrame(index=np.arange(X.shape[0]).astype(str)), var=pd.DataFrame(index=gene_labels), ) - return ad \ No newline at end of file + return ad diff --git a/src/segger/models/__init__.py b/src/segger/models/__init__.py index 3fc2907..e004423 100644 --- a/src/segger/models/__init__.py +++ b/src/segger/models/__init__.py @@ -1 +1,3 @@ -from .lightning_model import LitISTEncoder \ No newline at end of file +from .lightning_model import LitISTEncoder + +__all__ = ["LitISTEncoder"] diff --git a/src/segger/models/_triplet_loss_dev.py b/src/segger/models/_triplet_loss_dev.py index d66d65a..9bc31e6 100644 --- a/src/segger/models/_triplet_loss_dev.py +++ b/src/segger/models/_triplet_loss_dev.py @@ -1,13 +1,12 @@ -from torch.nn import TripletMarginLoss, TripletMarginWithDistanceLoss -from torch.nn.functional import mse_loss, cosine_similarity -from torch_geometric.data import Data from typing import Tuple + import torch +from torch.nn import TripletMarginLoss +from torch.nn.functional import mse_loss -class FastTripletSelector(): - """ - Efficient triplet sampling using a pre-computed node clustering and +class FastTripletSelector: + """Efficient triplet sampling using a pre-computed node clustering and similarity matrix used to weight sampling probabilities of positive and negative examples. """ @@ -25,20 +24,19 @@ def __init__( self._index_built = False @torch.no_grad() - def _build_index( - self, - labels: torch.Tensor - ) -> None: + def _build_index(self, labels: torch.Tensor) -> None: C = self.similarity.size(0) device = labels.device counts = torch.bincount(labels, minlength=C).to(torch.long) - offsets = torch.cat([ - torch.zeros(1, dtype=torch.long, device=device), - counts.cumsum(0), - ])[:-1] - - sorted_idx = torch.argsort(labels) + offsets = torch.cat( + [ + torch.zeros(1, dtype=torch.long, device=device), + counts.cumsum(0), + ] + )[:-1] + + sorted_idx = torch.argsort(labels) sorted_labels = labels[sorted_idx] N = labels.numel() @@ -50,92 +48,81 @@ def _build_index( present = torch.nonzero(counts > 0, as_tuple=False).flatten() diss_pres = self.dissimilarity.to(device)[present][:, present] - pdf_neg = diss_pres / diss_pres.sum(dim=1, keepdim=True) - cdf_neg = torch.cumsum(pdf_neg, dim=1) + pdf_neg = diss_pres / diss_pres.sum(dim=1, keepdim=True) + cdf_neg = torch.cumsum(pdf_neg, dim=1) try: cdf_neg[:, -1] = 1.0 except: - print( - f"No. labels: {N}\n" - f"Present: {present.sum()}\n" - f"PDF Neg.: {pdf_neg}" - ) + print(f"No. labels: {N}\n" f"Present: {present.sum()}\n" f"PDF Neg.: {pdf_neg}") sim_pres = self.similarity.to(device)[present][:, present] - pdf_pos = sim_pres / sim_pres.sum(dim=1, keepdim=True) - cdf_pos = torch.cumsum(pdf_pos, dim=1) + pdf_pos = sim_pres / sim_pres.sum(dim=1, keepdim=True) + cdf_pos = torch.cumsum(pdf_pos, dim=1) cdf_pos[:, -1] = 1.0 present_idx = -torch.ones(C, dtype=torch.long, device=device) present_idx[present] = torch.arange(present.numel(), device=device) - self._counts = counts.to(device) - self._offsets = offsets.to(device) - self._sorted_idx = sorted_idx.to(device) - self._present = present.to(device) - self._cdf_neg = cdf_neg.to(device) - self._cdf_pos = cdf_pos.to(device) - self._present_idx = present_idx.to(device) + self._counts = counts.to(device) + self._offsets = offsets.to(device) + self._sorted_idx = sorted_idx.to(device) + self._present = present.to(device) + self._cdf_neg = cdf_neg.to(device) + self._cdf_pos = cdf_pos.to(device) + self._present_idx = present_idx.to(device) self._index_built = True @torch.no_grad() - def stratified_sample_anchors( - self, - labels: torch.Tensor - ) -> torch.Tensor: - """ - Sample equal numbers of anchors for each label using stratified sampling. + def stratified_sample_anchors(self, labels: torch.Tensor) -> torch.Tensor: + """Sample equal numbers of anchors for each label using stratified sampling. The number of anchors per label equals the average cluster size. - - Returns: + + Returns + ------- anchor_indices: Indices of sampled anchors (stratified across labels) """ if not self._index_built: self._build_index(labels) - + device = labels.device present = self._present counts = self._counts offsets = self._offsets sorted_idx = self._sorted_idx - + # Compute average cluster size (only among present clusters) present_counts = counts[present] avg_size = present_counts.float().mean() samples_per_label = max(1, int(avg_size.round().item())) - + num_present = present.numel() - + # For each present cluster, sample `samples_per_label` anchors # Use random sampling with replacement if cluster is smaller than samples_per_label anchor_indices_list = [] - + for i in range(num_present): cluster_id = present[i] cluster_size = counts[cluster_id].item() cluster_offset = offsets[cluster_id].item() - + if cluster_size >= samples_per_label: # Sample without replacement perm = torch.randperm(cluster_size, device=device)[:samples_per_label] else: # Sample with replacement perm = torch.randint(0, cluster_size, (samples_per_label,), device=device) - + # Convert to global indices cluster_indices = sorted_idx[cluster_offset + perm] anchor_indices_list.append(cluster_indices) - + anchor_indices = torch.cat(anchor_indices_list) return anchor_indices @torch.no_grad() - def sample_triplets( - self, - labels: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - + def sample_triplets(self, labels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: self._build_index(labels) device = labels.device @@ -144,9 +131,7 @@ def sample_triplets( # positive sampling uni_pos = torch.rand(N, device=device) - pos_pres = torch.searchsorted( - self._cdf_pos[pres_idx], uni_pos.unsqueeze(-1) - ).squeeze(-1) + pos_pres = torch.searchsorted(self._cdf_pos[pres_idx], uni_pos.unsqueeze(-1)).squeeze(-1) pos_clust = self._present[pos_pres] pos_sizes = self._counts[pos_clust] uni2 = torch.rand(N, device=device) * pos_sizes.float() @@ -155,21 +140,19 @@ def sample_triplets( # negative sampling uni_neg = torch.rand(N, device=device) - neg_pres = torch.searchsorted( - self._cdf_neg[pres_idx], uni_neg.unsqueeze(-1) - ).squeeze(-1) + neg_pres = torch.searchsorted(self._cdf_neg[pres_idx], uni_neg.unsqueeze(-1)).squeeze(-1) neg_clust = self._present[neg_pres] neg_sizes = self._counts[neg_clust] uni3 = torch.rand(N, device=device) * neg_sizes.float() neg_pos = uni3.floor().to(torch.long) negatives = self._sorted_idx[self._offsets[neg_clust] + neg_pos] - dists = 1. - self.similarity.to(labels.device) + dists = 1.0 - self.similarity.to(labels.device) dists_pos = dists[labels, labels[positives]] dists_neg = dists[labels, labels[negatives]] return ( - positives.detach(), + positives.detach(), negatives.detach(), dists_pos.detach(), dists_neg.detach(), @@ -177,18 +160,16 @@ def sample_triplets( @torch.no_grad() def sample_triplets_stratified( - self, - labels: torch.Tensor, - anchor_indices: torch.Tensor + self, labels: torch.Tensor, anchor_indices: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Sample triplets for the given anchor indices (stratified anchors). - + """Sample triplets for the given anchor indices (stratified anchors). + Args: labels: Cluster labels for all nodes anchor_indices: Indices of anchor nodes (from stratified sampling) - - Returns: + + Returns + ------- positives, negatives, dists_pos, dists_neg """ if not self._index_built: @@ -201,9 +182,7 @@ def sample_triplets_stratified( # positive sampling uni_pos = torch.rand(N, device=device) - pos_pres = torch.searchsorted( - self._cdf_pos[pres_idx], uni_pos.unsqueeze(-1) - ).squeeze(-1) + pos_pres = torch.searchsorted(self._cdf_pos[pres_idx], uni_pos.unsqueeze(-1)).squeeze(-1) pos_clust = self._present[pos_pres] pos_sizes = self._counts[pos_clust] uni2 = torch.rand(N, device=device) * pos_sizes.float() @@ -212,43 +191,37 @@ def sample_triplets_stratified( # negative sampling uni_neg = torch.rand(N, device=device) - neg_pres = torch.searchsorted( - self._cdf_neg[pres_idx], uni_neg.unsqueeze(-1) - ).squeeze(-1) + neg_pres = torch.searchsorted(self._cdf_neg[pres_idx], uni_neg.unsqueeze(-1)).squeeze(-1) neg_clust = self._present[neg_pres] neg_sizes = self._counts[neg_clust] uni3 = torch.rand(N, device=device) * neg_sizes.float() neg_pos = uni3.floor().to(torch.long) negatives = self._sorted_idx[self._offsets[neg_clust] + neg_pos] - dists = 1. - self.similarity.to(device) + dists = 1.0 - self.similarity.to(device) dists_pos = dists[anchor_labels, labels[positives]] dists_neg = dists[anchor_labels, labels[negatives]] return ( - positives.detach(), + positives.detach(), negatives.detach(), dists_pos.detach(), dists_neg.detach(), ) - + class TripletLoss(TripletMarginLoss): - """ - Triplet margin loss on triplets sampled from FastTripletSelector + """Triplet margin loss on triplets sampled from FastTripletSelector with stratified anchor sampling. """ + def __init__( - self, - cluster_similarity: torch.Tensor, - margin: float = 1.0, - stratified: bool = True, - **kwargs + self, cluster_similarity: torch.Tensor, margin: float = 1.0, stratified: bool = True, **kwargs ) -> None: - """ - Initialize TripletLoss with cluster similarity and margin. - + """Initialize TripletLoss with cluster similarity and margin. + Args: + ---- cluster_similarity: Similarity matrix between clusters margin: Triplet margin stratified: If True, use stratified anchor sampling (equal anchors per label) @@ -257,28 +230,21 @@ def __init__( self.selector = FastTripletSelector(cluster_similarity) self.stratified = stratified - def forward( - self, - embeddings: torch.Tensor, - labels: torch.Tensor - ) -> torch.Tensor: - """ - Compute triplet loss on embeddings given cluster labels. + def forward(self, embeddings: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """Compute triplet loss on embeddings given cluster labels. Uses stratified sampling to ensure equal representation of each label. """ if labels.numel() == 0: - return 0. - + return 0.0 + if self.stratified: # Build index and sample stratified anchors self.selector._build_index(labels) anchor_indices = self.selector.stratified_sample_anchors(labels) - + # Sample triplets for stratified anchors - positives, negatives, _, _ = self.selector.sample_triplets_stratified( - labels, anchor_indices - ) - + positives, negatives, _, _ = self.selector.sample_triplets_stratified(labels, anchor_indices) + anchor = embeddings[anchor_indices] positive = embeddings[positives] negative = embeddings[negatives] @@ -288,47 +254,42 @@ def forward( anchor = embeddings positive = embeddings[positives] negative = embeddings[negatives] - + return super().forward(anchor, positive, negative) class MetricLoss: - """ - Metric loss on triplets sampled from FastTripletSelector + """Metric loss on triplets sampled from FastTripletSelector with stratified anchor sampling. """ + def __init__( self, cluster_similarity: torch.Tensor, stratified: bool = True, ) -> None: - """ - Initialize MetricLoss with cluster similarity. - + """Initialize MetricLoss with cluster similarity. + Args: + ---- cluster_similarity: Similarity matrix between clusters stratified: If True, use stratified anchor sampling (equal anchors per label) """ self.selector = FastTripletSelector(cluster_similarity) self.stratified = stratified - def forward( - self, - embeddings: torch.Tensor, - labels: torch.Tensor - ) -> torch.Tensor: - """ - Compute metric loss on embeddings given cluster labels. + def forward(self, embeddings: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """Compute metric loss on embeddings given cluster labels. Uses stratified sampling to ensure equal representation of each label. """ if labels.numel() == 0: - return 0. + return 0.0 if self.stratified: # Build index and sample stratified anchors self.selector._build_index(labels) anchor_indices = self.selector.stratified_sample_anchors(labels) - + # Sample triplets for stratified anchors ( positives, @@ -336,7 +297,7 @@ def forward( dists_pos, dists_neg, ) = self.selector.sample_triplets_stratified(labels, anchor_indices) - + anchor = embeddings[anchor_indices] positive = embeddings[positives] negative = embeddings[negatives] @@ -348,7 +309,7 @@ def forward( dists_pos, dists_neg, ) = self.selector.sample_triplets(labels) - + anchor = embeddings positive = embeddings[positives] negative = embeddings[negatives] @@ -356,7 +317,6 @@ def forward( cos_pos = torch.cosine_similarity(anchor, positive) cos_neg = torch.cosine_similarity(anchor, negative) - return ( - mse_loss(cos_pos, 1 - dists_pos.to(torch.float), reduction="mean") + - mse_loss(cos_neg, 1 - dists_neg.to(torch.float), reduction="mean") + return mse_loss(cos_pos, 1 - dists_pos.to(torch.float), reduction="mean") + mse_loss( + cos_neg, 1 - dists_neg.to(torch.float), reduction="mean" ) diff --git a/src/segger/models/ist_encoder.py b/src/segger/models/ist_encoder.py index 94ad932..4fd3c49 100644 --- a/src/segger/models/ist_encoder.py +++ b/src/segger/models/ist_encoder.py @@ -1,42 +1,36 @@ -from torch_geometric.nn import GATv2Conv, Linear, HeteroDictLinear, HeteroConv -from typing import Dict, Tuple, List, Union, Optional -from torch import Tensor -from torch.nn import ( - Sequential, - ModuleDict, - ModuleList, - Embedding, - Module, - Linear as NNLinear, - SiLU, - functional as F -) -import torch import math +from typing import Dict, Optional, Tuple + +import torch +from torch import Tensor +from torch.nn import Embedding +from torch.nn import functional as F +from torch.nn import Linear as NNLinear +from torch.nn import Module, ModuleDict, ModuleList, Sequential, SiLU + +from torch_geometric.nn import GATv2Conv, HeteroConv, HeteroDictLinear, Linear # --- Test positional encoding --- + def sinusoidal_embedding(x, dim, max_period=1000): half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half - ).to(device=x.device) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=x.device + ) args = x[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding + class Positional2dEmbedder(Module): - """ - Embeds scalar timesteps into vector representations. - """ - def __init__( - self, - hidden_size:int, - frequency_embedding_size:int=256): + """Embeds scalar timesteps into vector representations.""" + + def __init__(self, hidden_size: int, frequency_embedding_size: int = 256): super().__init__() - self.dim = hidden_size//2 + self.dim = hidden_size // 2 self.mlp = Sequential( NNLinear(frequency_embedding_size, self.dim, bias=True), SiLU(), @@ -45,25 +39,25 @@ def __init__( self.frequency_embedding_size = frequency_embedding_size @staticmethod - def embed(x:torch.Tensor, dim:int, max_period:int=10000): + def embed(x: torch.Tensor, dim: int, max_period: int = 10000): shape = x.shape embedding_flat = sinusoidal_embedding(x.flatten(), dim, max_period=max_period) - embedding = embedding_flat.reshape(shape+(dim,)) + embedding = embedding_flat.reshape(shape + (dim,)) return embedding def forward( - self, - pos: torch.Tensor, - batch: Optional[torch.Tensor] = None, + self, + pos: torch.Tensor, + batch: Optional[torch.Tensor] = None, ) -> torch.Tensor: if batch is None: pos = pos - pos.min(dim=0).values pos = pos / pos.max(dim=0).values else: # normalize per batch - mins = torch.zeros((batch.max()+1, 2), device=pos.device) - maxs = torch.zeros((batch.max()+1, 2), device=pos.device) - for b in range(batch.max()+1): + mins = torch.zeros((batch.max() + 1, 2), device=pos.device) + maxs = torch.zeros((batch.max() + 1, 2), device=pos.device) + for b in range(batch.max() + 1): mask = batch == b if mask.any(): mins[b] = pos[mask].min(dim=0).values @@ -75,10 +69,10 @@ def forward( pos_emb = pos_emb.flatten(-2) # ... x 2*dim return pos_emb + # --- Test positional encoding --- class SkipGAT(Module): - """ - Graph Attention module that encapsulates a HeteroConv layer with two GATv2 + """Graph Attention module that encapsulates a HeteroConv layer with two GATv2 convolutions for different edge types. The attention weights from the last forward pass are stored internally and can be accessed via the `attention_weights` property. @@ -105,21 +99,21 @@ def __init__( # Build a HeteroConv that internally uses GATv2Conv for each edge type. self.conv = HeteroConv( convs={ - ('tx', 'neighbors', 'tx'): GATv2Conv( + ("tx", "neighbors", "tx"): GATv2Conv( in_channels=in_channels, out_channels=out_channels, heads=n_heads, add_self_loops=add_self_loops_tx, dropout=0.2, ), - ('tx', 'belongs', 'bd'): GATv2Conv( + ("tx", "belongs", "bd"): GATv2Conv( in_channels=in_channels, out_channels=out_channels, heads=n_heads, add_self_loops=False, dropout=0.2, ), - ('bd', 'contains', 'tx'): GATv2Conv( + ("bd", "contains", "tx"): GATv2Conv( in_channels=in_channels, out_channels=out_channels, heads=n_heads, @@ -127,22 +121,21 @@ def __init__( dropout=0.2, ), }, - aggr='sum' + aggr="sum", ) # This will store the attention weights from the last forward pass. self._attn_weights: Dict[Tuple[str, str, str], Tensor] = {} # Register a forward hook to capture attention weights internally. - edge_type = 'tx', 'neighbors', 'tx' + edge_type = "tx", "neighbors", "tx" self.conv.convs[edge_type].register_forward_hook( self._make_hook(edge_type), with_kwargs=True, ) def _make_hook(self, edge_type: Tuple[str, str, str]): - """ - Internal hook function that captures attention weights from the + """Internal hook function that captures attention weights from the forward pass of each GATv2Conv submodule. Parameters @@ -150,17 +143,14 @@ def _make_hook(self, edge_type: Tuple[str, str, str]): edge_type : tuple of str The edge type associated with this GATv2Conv. """ + def _store_attn_weights(module, inputs, kwargs, outputs) -> None: self._attn_weights[edge_type] = outputs[1][1] + return _store_attn_weights - def forward( - self, - x_dict: Dict[str, Tensor], - edge_index_dict: Dict[str, Tensor] - ) -> Dict[str, Tensor]: - """ - Forward pass for SkipGAT. Always calls HeteroConv with + def forward(self, x_dict: Dict[str, Tensor], edge_index_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: + """Forward pass for SkipGAT. Always calls HeteroConv with `return_attention_weights=True`, but never returns them from this method. Attention weights are stored internally via the hook. @@ -180,16 +170,13 @@ def forward( x_dict = self.conv( x_dict, edge_index_dict, - return_attention_weights_dict = { - edge: False for edge in self.conv.convs - }, + return_attention_weights_dict={edge: False for edge in self.conv.convs}, ) return x_dict @property def attention_weights(self) -> Dict[Tuple[str, str, str], Tensor]: - """ - The attention weights from the most recent forward pass. + """The attention weights from the most recent forward pass. Raises ------ @@ -209,9 +196,7 @@ def attention_weights(self) -> Dict[Tuple[str, str, str], Tensor]: class ISTEncoder(torch.nn.Module): - """ - TODO: Description. - """ + """TODO: Description.""" def __init__( self, @@ -224,8 +209,7 @@ def __init__( normalize_embeddings: bool = True, use_positional_embeddings: bool = True, ): - """ - Initialize the Segger model. + """Initialize the Segger model. Parameters ---------- @@ -249,13 +233,13 @@ def __init__( self.use_positional_embeddings = use_positional_embeddings # Store hyperparameters for PyTorch Lightning self.hparams = locals() - for k in ['self', '__class__']: + for k in ["self", "__class__"]: self.hparams.pop(k) # First layer: ? -> in self.lin_first = ModuleDict( { - 'tx': Embedding(n_genes, in_channels), - 'bd': Linear(-1, in_channels), + "tx": Embedding(n_genes, in_channels), + "bd": Linear(-1, in_channels), } ) # Positional encoding: in @@ -263,24 +247,14 @@ def __init__( self.conv_layers = ModuleList() # First convolution: in -> hidden x heads - self.conv_layers.append( - SkipGAT((-1, -1), hidden_channels, n_heads) - ) + self.conv_layers.append(SkipGAT((-1, -1), hidden_channels, n_heads)) # Middle convolutions: hidden x heads -> hidden x heads for _ in range(n_mid_layers): - self.conv_layers.append( - SkipGAT((-1, -1), hidden_channels, n_heads) - ) + self.conv_layers.append(SkipGAT((-1, -1), hidden_channels, n_heads)) # Last convolution: hidden x heads -> out x heads - self.conv_layers.append( - SkipGAT((-1, -1), out_channels, n_heads) - ) + self.conv_layers.append(SkipGAT((-1, -1), out_channels, n_heads)) # Last layer: out x heads -> out - self.lin_last = HeteroDictLinear( - -1, - out_channels, - types=("tx", "bd") - ) + self.lin_last = HeteroDictLinear(-1, out_channels, types=("tx", "bd")) def forward( self, @@ -289,8 +263,7 @@ def forward( pos_dict: dict[str, Tensor], batch_dict: dict[str, Tensor], ) -> dict[str, Tensor]: - """ - Forward pass for the Segger model. + """Forward pass for the Segger model. Parameters ---------- @@ -308,10 +281,7 @@ def forward( x_dict = {k: self.lin_first[k](x) for k, x in x_dict.items()} if self.use_positional_embeddings: - x_dict = { - k: torch.cat((x, self.pos_emb(pos_dict[k], batch_dict[k])), -1) - for k, x in x_dict.items() - } + x_dict = {k: torch.cat((x, self.pos_emb(pos_dict[k], batch_dict[k])), -1) for k, x in x_dict.items()} x_dict = {k: F.gelu(x) for k, x in x_dict.items()} diff --git a/src/segger/models/lightning_model.py b/src/segger/models/lightning_model.py index de212ca..a131e11 100644 --- a/src/segger/models/lightning_model.py +++ b/src/segger/models/lightning_model.py @@ -1,20 +1,17 @@ -from torch.nn import Embedding, BCEWithLogitsLoss, TripletMarginLoss -from torch_geometric.data import Batch -from lightning import LightningModule -from torch_scatter import scatter_max -from torch.nn import functional as F -from typing import Any -import polars as pl -import pandas as pd -import numpy as np -import torch import math -import os +from io.fields import StandardTranscriptFields + +import torch +from data.data_module import ISTDataModule +from torch.nn import BCEWithLogitsLoss, Embedding, TripletMarginLoss +from torch_scatter import scatter_max + +from lightning import LightningModule +from torch_geometric.data import Batch -from .triplet_loss import TripletLoss, MetricLoss -from ..io.fields import StandardTranscriptFields -from ..data.data_module import ISTDataModule from .ist_encoder import ISTEncoder +from .triplet_loss import MetricLoss, TripletLoss + class LitISTEncoder(LightningModule): """TODO: Description. @@ -24,6 +21,7 @@ class LitISTEncoder(LightningModule): output_directory : Path Description. """ + def __init__( self, n_genes: int, @@ -33,14 +31,14 @@ def __init__( n_mid_layers: int = 2, n_heads: int = 2, learning_rate: float = 1e-3, - sg_loss_type: str = 'triplet', + sg_loss_type: str = "triplet", tx_margin: float = 0.3, sg_margin: float = 0.4, - tx_weight_start: float = 1., - tx_weight_end: float = 1., - bd_weight_start: float = 1., - bd_weight_end: float = 1., - sg_weight_start: float = 0., + tx_weight_start: float = 1.0, + tx_weight_end: float = 1.0, + bd_weight_start: float = 1.0, + bd_weight_end: float = 1.0, + sg_weight_start: float = 0.0, sg_weight_end: float = 0.5, update_gene_embedding: bool = True, use_positional_embeddings: bool = True, @@ -54,7 +52,7 @@ def __init__( Description. """ super().__init__() - + self.save_hyperparameters() self.model = ISTEncoder( @@ -71,36 +69,36 @@ def __init__( self._sg_loss_type = sg_loss_type self._tx_margin = tx_margin self._sg_margin = sg_margin - self._w_start = torch.tensor([ - tx_weight_start, - bd_weight_start, - sg_weight_start, - ]) - self._w_end = torch.tensor([ - tx_weight_end, - bd_weight_end, - sg_weight_end, - ]) + self._w_start = torch.tensor( + [ + tx_weight_start, + bd_weight_start, + sg_weight_start, + ] + ) + self._w_end = torch.tensor( + [ + tx_weight_end, + bd_weight_end, + sg_weight_end, + ] + ) self._freeze_gene_embedding = not update_gene_embedding def setup(self, stage): # LitISTEncoder needs supp. data from ISTDataModule to train if not isinstance(self.trainer.datamodule, ISTDataModule): raise TypeError( - f"Expected data module to be `ISTDataModule` but got " - f"{type(self.trainer.datamodule).__name__}." + f"Expected data module to be `ISTDataModule` but got " f"{type(self.trainer.datamodule).__name__}." ) # Only set gene embeddings if exist in data module if hasattr(self.trainer.datamodule, "gene_embedding"): tx_fields = StandardTranscriptFields() embedding_weights = ( - self.trainer.datamodule.gene_embedding - .drop(tx_fields.feature) - .to_torch() - .to(torch.float) + self.trainer.datamodule.gene_embedding.drop(tx_fields.feature).to_torch().to(torch.float) ) - self.model.lin_first['tx'] = Embedding.from_pretrained( + self.model.lin_first["tx"] = Embedding.from_pretrained( embedding_weights, freeze=self._freeze_gene_embedding, ) @@ -113,9 +111,9 @@ def setup(self, stage): self.loss_bd = MetricLoss( self.trainer.datamodule.bd_similarity, ) - if self._sg_loss_type == 'triplet': + if self._sg_loss_type == "triplet": self.loss_sg = TripletMarginLoss(margin=self._sg_margin) - elif self._sg_loss_type == 'bce': + elif self._sg_loss_type == "bce": self.loss_sg = BCEWithLogitsLoss() else: raise ValueError( @@ -145,47 +143,44 @@ def _scheduled_weights( alpha = 0.5 * (1.0 + math.cos(math.pi * t)) w = w_end + (w_start - w_end) * alpha if normalize: - w /= (w.sum() + 1e-8) + w /= w.sum() + 1e-8 return w.to(self.device) - + def get_losses(self, batch: Batch) -> tuple[torch.Tensor]: """Get all training losses and combine.""" embeddings = self.forward(batch) - tx_mask = batch['tx']['mask'] - bd_mask = batch['bd']['mask'] & (batch['bd']['cluster'] >= 0) + tx_mask = batch["tx"]["mask"] + bd_mask = batch["bd"]["mask"] & (batch["bd"]["cluster"] >= 0) # Both triplet losses loss_tx = self.loss_tx.forward( - embeddings['tx'][tx_mask], - batch['tx']['cluster'][tx_mask], + embeddings["tx"][tx_mask], + batch["tx"]["cluster"][tx_mask], ) loss_bd = self.loss_bd.forward( - embeddings['bd'][bd_mask], - batch['bd']['cluster'][bd_mask], + embeddings["bd"][bd_mask], + batch["bd"]["cluster"][bd_mask], ) - + # Segmentation loss - src_pos, dst_pos = batch['tx', 'belongs', 'bd'].edge_index - num_bd = embeddings['bd'].size(0) + src_pos, dst_pos = batch["tx", "belongs", "bd"].edge_index + num_bd = embeddings["bd"].size(0) N = src_pos.size(0) # Handle edge case where there are too few boundaries for sampling if num_bd <= 1: - loss_sg = torch.tensor(0.0, device=embeddings['bd'].device, - requires_grad=True) + loss_sg = torch.tensor(0.0, device=embeddings["bd"].device, requires_grad=True) else: # Generate negative destination nodes - dst_neg = ( - dst_pos + torch.randint(1, num_bd, (N,), device=dst_pos.device) - ) % num_bd + dst_neg = (dst_pos + torch.randint(1, num_bd, (N,), device=dst_pos.device)) % num_bd - if self._sg_loss_type == 'triplet': - anchor = embeddings['tx'][src_pos] - positive = embeddings['bd'][dst_pos] - negative = embeddings['bd'][dst_neg] + if self._sg_loss_type == "triplet": + anchor = embeddings["tx"][src_pos] + positive = embeddings["bd"][dst_pos] + negative = embeddings["bd"][dst_neg] loss_sg = self.loss_sg(anchor, positive, negative) - + # BCE loss else: src = torch.cat([src_pos, src_pos]) @@ -194,15 +189,12 @@ def get_losses(self, batch: Batch) -> tuple[torch.Tensor]: uniq_src, inv_src = torch.unique(src, return_inverse=True) uniq_dst, inv_dst = torch.unique(dst, return_inverse=True) - src_vecs = embeddings['tx'].index_select(0, uniq_src) - dst_vecs = embeddings['bd'].index_select(0, uniq_dst) + src_vecs = embeddings["tx"].index_select(0, uniq_src) + dst_vecs = embeddings["bd"].index_select(0, uniq_dst) logits = (src_vecs[inv_src] * dst_vecs[inv_dst]).sum(dim=-1) - labels = torch.cat([ - torch.ones(N, device=logits.device), - torch.zeros(N, device=logits.device) - ]) + labels = torch.cat([torch.ones(N, device=logits.device), torch.zeros(N, device=logits.device)]) loss_sg = self.loss_sg(logits, labels) @@ -259,7 +251,7 @@ def validation_step(self, batch: Batch, batch_idx: int) -> torch.Tensor: batch_size=batch.num_graphs, ) return loss - + def predict_step( self, batch: Batch, @@ -267,32 +259,31 @@ def predict_step( min_similarity: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Prediction pass for the batch of data.""" - # Compute embeddings on full dataset embeddings = self.forward(batch) - + # Compute all top assignments - src, dst = batch['tx', 'neighbors', 'bd'].edge_index + src, dst = batch["tx", "neighbors", "bd"].edge_index sim = torch.cosine_similarity( - embeddings['tx'][src], - embeddings['bd'][dst], + embeddings["tx"][src], + embeddings["bd"][dst], ) max_sim, max_idx = scatter_max( sim, src, - dim_size=batch['tx'].num_nodes, + dim_size=batch["tx"].num_nodes, ) # Filter by similarity valid = max_idx < dst.shape[0] if min_similarity is not None: valid &= max_sim >= min_similarity - src_idx = batch['tx']['index'] - dst_idx = batch['bd']['index'].to(torch.long) + src_idx = batch["tx"]["index"] + dst_idx = batch["bd"]["index"].to(torch.long) seg_idx = torch.full_like(max_idx, -1) seg_idx[valid] = dst_idx[dst[max_idx[valid]]] - gen_idx = batch['tx']['x'] - mask = batch['tx']['predict_mask'] + gen_idx = batch["tx"]["x"] + mask = batch["tx"]["predict_mask"] return src_idx[mask], seg_idx[mask], max_sim[mask], gen_idx[mask] diff --git a/src/segger/models/triplet_loss.py b/src/segger/models/triplet_loss.py index 9451852..4cbb046 100644 --- a/src/segger/models/triplet_loss.py +++ b/src/segger/models/triplet_loss.py @@ -1,13 +1,12 @@ -from torch.nn import TripletMarginLoss, TripletMarginWithDistanceLoss -from torch.nn.functional import mse_loss, cosine_similarity -from torch_geometric.data import Data from typing import Tuple + import torch +from torch.nn import TripletMarginLoss +from torch.nn.functional import mse_loss -class FastTripletSelector(): - """ - Efficient triplet sampling using a pre-computed node clustering and +class FastTripletSelector: + """Efficient triplet sampling using a pre-computed node clustering and similarity matrix used to weight sampling probabilities of positive and negative examples. """ @@ -25,20 +24,19 @@ def __init__( self._index_built = False @torch.no_grad() - def _build_index( - self, - labels: torch.Tensor - ) -> None: + def _build_index(self, labels: torch.Tensor) -> None: C = self.similarity.size(0) device = labels.device counts = torch.bincount(labels, minlength=C).to(torch.long) - offsets = torch.cat([ - torch.zeros(1, dtype=torch.long, device=device), - counts.cumsum(0), - ])[:-1] - - sorted_idx = torch.argsort(labels) + offsets = torch.cat( + [ + torch.zeros(1, dtype=torch.long, device=device), + counts.cumsum(0), + ] + )[:-1] + + sorted_idx = torch.argsort(labels) sorted_labels = labels[sorted_idx] N = labels.numel() @@ -50,41 +48,33 @@ def _build_index( present = torch.nonzero(counts > 0, as_tuple=False).flatten() diss_pres = self.dissimilarity.to(device)[present][:, present] - pdf_neg = diss_pres / diss_pres.sum(dim=1, keepdim=True) - cdf_neg = torch.cumsum(pdf_neg, dim=1) + pdf_neg = diss_pres / diss_pres.sum(dim=1, keepdim=True) + cdf_neg = torch.cumsum(pdf_neg, dim=1) try: cdf_neg[:, -1] = 1.0 except: - print( - f"No. labels: {N}\n" - f"Present: {present.sum()}\n" - f"PDF Neg.: {pdf_neg}" - ) + print(f"No. labels: {N}\n" f"Present: {present.sum()}\n" f"PDF Neg.: {pdf_neg}") sim_pres = self.similarity.to(device)[present][:, present] - pdf_pos = sim_pres / sim_pres.sum(dim=1, keepdim=True) - cdf_pos = torch.cumsum(pdf_pos, dim=1) + pdf_pos = sim_pres / sim_pres.sum(dim=1, keepdim=True) + cdf_pos = torch.cumsum(pdf_pos, dim=1) cdf_pos[:, -1] = 1.0 present_idx = -torch.ones(C, dtype=torch.long, device=device) present_idx[present] = torch.arange(present.numel(), device=device) - self._counts = counts.to(device) - self._offsets = offsets.to(device) - self._sorted_idx = sorted_idx.to(device) - self._present = present.to(device) - self._cdf_neg = cdf_neg.to(device) - self._cdf_pos = cdf_pos.to(device) - self._present_idx = present_idx.to(device) + self._counts = counts.to(device) + self._offsets = offsets.to(device) + self._sorted_idx = sorted_idx.to(device) + self._present = present.to(device) + self._cdf_neg = cdf_neg.to(device) + self._cdf_pos = cdf_pos.to(device) + self._present_idx = present_idx.to(device) self._index_built = True @torch.no_grad() - def sample_triplets( - self, - labels: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - + def sample_triplets(self, labels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: self._build_index(labels) device = labels.device @@ -93,9 +83,7 @@ def sample_triplets( # positive sampling uni_pos = torch.rand(N, device=device) - pos_pres = torch.searchsorted( - self._cdf_pos[pres_idx], uni_pos.unsqueeze(-1) - ).squeeze(-1) + pos_pres = torch.searchsorted(self._cdf_pos[pres_idx], uni_pos.unsqueeze(-1)).squeeze(-1) pos_clust = self._present[pos_pres] pos_sizes = self._counts[pos_clust] uni2 = torch.rand(N, device=device) * pos_sizes.float() @@ -104,85 +92,60 @@ def sample_triplets( # negative sampling uni_neg = torch.rand(N, device=device) - neg_pres = torch.searchsorted( - self._cdf_neg[pres_idx], uni_neg.unsqueeze(-1) - ).squeeze(-1) + neg_pres = torch.searchsorted(self._cdf_neg[pres_idx], uni_neg.unsqueeze(-1)).squeeze(-1) neg_clust = self._present[neg_pres] neg_sizes = self._counts[neg_clust] uni3 = torch.rand(N, device=device) * neg_sizes.float() neg_pos = uni3.floor().to(torch.long) negatives = self._sorted_idx[self._offsets[neg_clust] + neg_pos] - dists = 1. - self.similarity.to(labels.device) + dists = 1.0 - self.similarity.to(labels.device) dists_pos = dists[labels, labels[positives]] dists_neg = dists[labels, labels[negatives]] return ( - positives.detach(), + positives.detach(), negatives.detach(), dists_pos.detach(), dists_neg.detach(), ) - + class TripletLoss(TripletMarginLoss): - """ - Triplet margin loss on triplets sampled from FastTripletSelector. - """ - def __init__( - self, - cluster_similarity: torch.Tensor, - margin: float = 1.0, - **kwargs - ) -> None: - """ - Initialize TripletLoss with cluster similarity and margin. - """ + """Triplet margin loss on triplets sampled from FastTripletSelector.""" + + def __init__(self, cluster_similarity: torch.Tensor, margin: float = 1.0, **kwargs) -> None: + """Initialize TripletLoss with cluster similarity and margin.""" super().__init__(margin=margin, **kwargs) self.selector = FastTripletSelector(cluster_similarity) - def forward( - self, - embeddings: torch.Tensor, - labels: torch.Tensor - ) -> torch.Tensor: - """ - Compute triplet loss on embeddings given cluster labels. - """ + def forward(self, embeddings: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """Compute triplet loss on embeddings given cluster labels.""" if labels.numel() == 0: - return 0. - + return 0.0 + positives, negatives, _, _ = self.selector.sample_triplets(labels) anchor = embeddings positive = embeddings[positives] negative = embeddings[negatives] - + return super().forward(anchor, positive, negative) class MetricLoss: - """ - Metric loss on triplets sampled from FastTripletSelector. - """ + """Metric loss on triplets sampled from FastTripletSelector.""" + def __init__( self, cluster_similarity: torch.Tensor, ) -> None: - """ - Initialize TripletLoss with cluster similarity and margin. - """ + """Initialize TripletLoss with cluster similarity and margin.""" self.selector = FastTripletSelector(cluster_similarity) - def forward( - self, - embeddings: torch.Tensor, - labels: torch.Tensor - ) -> torch.Tensor: - """ - Compute triplet loss on embeddings given cluster labels. - """ + def forward(self, embeddings: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """Compute triplet loss on embeddings given cluster labels.""" if labels.numel() == 0: - return 0. + return 0.0 ( positives, @@ -198,7 +161,6 @@ def forward( cos_pos = torch.cosine_similarity(anchor, positive) cos_neg = torch.cosine_similarity(anchor, negative) - return ( - mse_loss(cos_pos, 1 - dists_pos.to(torch.float), reduction="mean") + - mse_loss(cos_neg, 1 - dists_neg.to(torch.float), reduction="mean") + return mse_loss(cos_pos, 1 - dists_pos.to(torch.float), reduction="mean") + mse_loss( + cos_neg, 1 - dists_neg.to(torch.float), reduction="mean" ) diff --git a/src/segger/validation/contamination.py b/src/segger/validation/contamination.py index ed10548..7c5a0ed 100644 --- a/src/segger/validation/contamination.py +++ b/src/segger/validation/contamination.py @@ -1,21 +1,21 @@ from __future__ import annotations -from typing import Dict, List - -import cupy as cp -import scanpy as sc import cuml +import cupy as cp + import numpy as np import pandas as pd import polars as pl import scipy.sparse as sp -from anndata import AnnData from scipy import sparse +import scanpy as sc +from anndata import AnnData + def map_with_default( - keys: List[str] | np.ndarray, - mapping: Dict[str, int], + keys: list[str] | np.ndarray, + mapping: dict[str, int], default: int = -1, dtype: np.dtype = np.int32, ) -> np.ndarray: @@ -37,6 +37,7 @@ def map_with_default( out[i] = mapping.get(str(k), default) return out + def get_neighbor_frequencies( ad: AnnData, k: int, @@ -99,6 +100,7 @@ def get_neighbor_frequencies( ad.obsm[key_added] = df return df + def calculate_contamination( adata: AnnData, reference: pl.DataFrame, @@ -157,10 +159,8 @@ def calculate_contamination( X_layer = adata.layers[counts_layer] X = X_layer.tocoo() if isinstance(X_layer, sp.spmatrix) else X_layer.to_coo() rows, cols, vals = X.row, X.col, X.data - - host_ct_idx_all = map_with_default( - adata.obs[cell_type_key].astype(str), ct_map, -1 - ) + + host_ct_idx_all = map_with_default(adata.obs[cell_type_key].astype(str), ct_map, -1) host_ct_idx = host_ct_idx_all[rows] gene_idx_all = map_with_default(adata.var_names, gn_map, -1) gene_idx = gene_idx_all[cols] @@ -192,31 +192,20 @@ def calculate_contamination( q_back[missing_gene] = 0 shape = adata.layers[counts_layer].shape - adata.layers["q_self"] = sparse.coo_matrix( - (q_self, (rows, cols)), shape=shape - ).tocsr() - adata.layers["q_neighbor"] = sparse.coo_matrix( - (q_neigh, (rows, cols)), shape=shape - ).tocsr() - adata.layers["q_background"] = sparse.coo_matrix( - (q_back, (rows, cols)), shape=shape - ).tocsr() + adata.layers["q_self"] = sparse.coo_matrix((q_self, (rows, cols)), shape=shape).tocsr() + adata.layers["q_neighbor"] = sparse.coo_matrix((q_neigh, (rows, cols)), shape=shape).tocsr() + adata.layers["q_background"] = sparse.coo_matrix((q_back, (rows, cols)), shape=shape).tocsr() # percent contamination per cell contam_mask = q_self < contam_cutoff contam_mask[missing_gene] = False contam_vals = np.where(contam_mask, vals, 0.0) - adata.layers["contamination"] = sparse.coo_matrix( - (contam_vals, (rows, cols)), shape=shape - ).tocsr() + adata.layers["contamination"] = sparse.coo_matrix((contam_vals, (rows, cols)), shape=shape).tocsr() - contam_counts = np.bincount( - rows[contam_mask], weights=vals[contam_mask], minlength=adata.n_obs - ) + contam_counts = np.bincount(rows[contam_mask], weights=vals[contam_mask], minlength=adata.n_obs) total_counts = np.bincount(rows, weights=vals, minlength=adata.n_obs) - adata.obs["percent_contamination"] = ( - 100.0 * contam_counts / np.maximum(total_counts, 1) - ) + adata.obs["percent_contamination"] = 100.0 * contam_counts / np.maximum(total_counts, 1) + def contamination_flow( ad: AnnData, @@ -284,15 +273,15 @@ def contamination_flow( flow[d] = sums / np.maximum(cell_counts, 1) flow = pd.DataFrame(flow, index=donor_types, columns=host_types) - flow.index.name = 'source' - flow.columns.name = 'host' + flow.index.name = "source" + flow.columns.name = "host" return flow def group_reference( reference: pl.DataFrame, - grouping: Dict[str, str], + grouping: dict[str, str], *, cell_type_name_col: str = "cell_type_name", gene_name_col: str = "gene_name", @@ -327,11 +316,7 @@ def group_reference( .alias(cell_type_name_col) ) - ref = ref.with_columns( - ( - pl.col(mean_expr_col) * pl.col(n_pos_cells_col) - ).alias("weighted_expr") - ) + ref = ref.with_columns((pl.col(mean_expr_col) * pl.col(n_pos_cells_col)).alias("weighted_expr")) agg = ( ref.group_by([cell_type_name_col, gene_name_col]) @@ -341,68 +326,59 @@ def group_reference( pl.sum("weighted_expr").alias("expr_sum"), ) .with_columns( - ( - pl.col("expr_sum") / pl.col(n_pos_cells_col) - ).fill_null(0).alias(mean_expr_col), - ( - pl.col(n_pos_cells_col) / pl.col(n_cells_col) - ).fill_null(0).alias(percent_col), + (pl.col("expr_sum") / pl.col(n_pos_cells_col)).fill_null(0).alias(mean_expr_col), + (pl.col(n_pos_cells_col) / pl.col(n_cells_col)).fill_null(0).alias(percent_col), ) .drop("expr_sum") ) return agg + def expression_summary_from_anndata( - ad: sc.AnnData, - cell_type_col: str, - raw_layer: str, - min_counts: int = 2 + ad: sc.AnnData, cell_type_col: str, raw_layer: str, min_counts: int = 2 ) -> pl.DataFrame: # TODO: Add documentation - + # Normalize as in CellxGene - ad.layers['_cxg_norm'] = ad.layers[raw_layer].copy() - sc.pp.normalize_total(ad, target_sum=1e4, layer='_cxg_norm') - sc.pp.log1p(ad, layer='_cxg_norm') + ad.layers["_cxg_norm"] = ad.layers[raw_layer].copy() + sc.pp.normalize_total(ad, target_sum=1e4, layer="_cxg_norm") + sc.pp.log1p(ad, layer="_cxg_norm") # Filter as in CellxGene mask = ad.layers[raw_layer] >= min_counts - ad.layers['_cxg_norm'] = ad.layers['_cxg_norm'].multiply(mask) - ad.layers['_cxg_norm'].eliminate_zeros() + ad.layers["_cxg_norm"] = ad.layers["_cxg_norm"].multiply(mask) + ad.layers["_cxg_norm"].eliminate_zeros() # Summary data from CellxGene expression summary aggs = { - 'n': 'count_nonzero', # 1) Non-zero counts per cell type - 'me': 'sum', # 2) Mean expression in positive cells + "n": "count_nonzero", # 1) Non-zero counts per cell type + "me": "sum", # 2) Mean expression in positive cells } - stats = dict() + stats = {} for name, func in aggs.items(): stats[name] = pl.from_pandas( - sc.get.aggregate(ad, by=cell_type_col, func=func, layer='_cxg_norm') + sc.get.aggregate(ad, by=cell_type_col, func=func, layer="_cxg_norm") .to_df(layer=func) - .melt(value_name=name, ignore_index=False, var_name='gene_name') - .reset_index(names='cell_type_name') + .melt(value_name=name, ignore_index=False, var_name="gene_name") + .reset_index(names="cell_type_name") ) # 3) Number of cells per cell type n_ct = pl.from_pandas( ad.obs.value_counts(cell_type_col) .reset_index() - .rename( - {cell_type_col: 'cell_type_name', 'count': 'n_cells_cell_type'}, - axis=1 - ) - ).with_columns(pl.col('cell_type_name').cast(pl.String)) + .rename({cell_type_col: "cell_type_name", "count": "n_cells_cell_type"}, axis=1) + ).with_columns(pl.col("cell_type_name").cast(pl.String)) # Join into summary dataframe summary = ( - stats['n'] - .join(stats['me'], on=['cell_type_name', 'gene_name']) - .join(n_ct, on='cell_type_name') - .filter(pl.col('n') > 0) - .with_columns(pl.col('me') / pl.col('n')) - .with_columns(pc=pl.col('n') / pl.col('n_cells_cell_type')) - .with_columns(pl.col('n').cast(pl.Int64)) + stats["n"] + .join(stats["me"], on=["cell_type_name", "gene_name"]) + .join(n_ct, on="cell_type_name") + .filter(pl.col("n") > 0) + .with_columns(pl.col("me") / pl.col("n")) + .with_columns(pc=pl.col("n") / pl.col("n_cells_cell_type")) + .with_columns(pl.col("n").cast(pl.Int64)) ) - return summary \ No newline at end of file + return summary