From 0142c4ad704efaec0063c4275aa0940a053d56c2 Mon Sep 17 00:00:00 2001 From: "Lumberbot (aka Jack)" <39504233+meeseeksmachine@users.noreply.github.com> Date: Thu, 27 Feb 2025 08:32:17 -0800 Subject: [PATCH] Backport PR #3164 on branch 1.3.x (feat: SemiSupervised Training Mixin class) (#3222) Backport PR #3164: feat: SemiSupervised Training Mixin class Co-authored-by: Ori Kronfeld --- .github/workflows/test_linux_autotune.yml | 70 ++++ CHANGELOG.md | 10 +- docs/tutorials/index.md | 2 +- docs/tutorials/index_hub.md | 1 - docs/tutorials/index_tuning.md | 8 - docs/tutorials/index_use_cases.md | 9 + .../use_case/downstream_analysis_tasks.md | 1 + pyproject.toml | 6 +- src/scvi/model/_scanvi.py | 212 +---------- src/scvi/model/base/__init__.py | 3 +- src/scvi/model/base/_training_mixin.py | 357 +++++++++++++++++- tests/autotune/test_tune.py | 4 + tests/conftest.py | 20 +- tests/model/test_scanvi.py | 75 ++++ 14 files changed, 558 insertions(+), 220 deletions(-) create mode 100644 .github/workflows/test_linux_autotune.yml delete mode 100644 docs/tutorials/index_tuning.md create mode 100644 docs/tutorials/index_use_cases.md diff --git a/.github/workflows/test_linux_autotune.yml b/.github/workflows/test_linux_autotune.yml new file mode 100644 index 0000000000..83b8095bd4 --- /dev/null +++ b/.github/workflows/test_linux_autotune.yml @@ -0,0 +1,70 @@ +name: test (Autotune) + +on: + push: + branches: [main, "[0-9]+.[0-9]+.x"] #this is new + pull_request: + branches: [main, "[0-9]+.[0-9]+.x"] + types: [labeled, synchronize, opened] + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + test: + # if PR has label "autotune" or "all tests" or if scheduled or manually triggered or on push + if: >- + ( + contains(github.event.pull_request.labels.*.name, 'autotune') || + contains(github.event.pull_request.labels.*.name, 'all tests') || + contains(github.event_name, 'schedule') || + contains(github.event_name, 'workflow_dispatch') + ) + + runs-on: [self-hosted, Linux, X64, CUDA] + + defaults: + run: + shell: bash -e {0} # -e to fail on error + + container: + image: ghcr.io/scverse/scvi-tools:py3.12-cu12-base + options: --user root --gpus all --pull always + + name: integration + + env: + OS: ${{ matrix.os }} + PYTHON: ${{ matrix.python }} + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + cache: "pip" + cache-dependency-path: "**/pyproject.toml" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip wheel uv + python -m uv pip install --system "scvi-tools[tests] @ ." + python -m pip install "jax[cuda]==0.4.35" + python -m pip install nvidia-nccl-cu12 + + - name: Run pytest + env: + MPLBACKEND: agg + PLATFORM: ${{ matrix.os }} + DISPLAY: :42 + COLUMNS: 120 + run: | + coverage run -m pytest -v --color=yes --autotune-tests --accelerator cuda --devices auto + coverage report + + - uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e52ef4e48..2fde65420c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,7 @@ to [Semantic Versioning]. Full commit history is available in the ## Version 1.3 -### 1.3.0 (2025-02-XX) +### 1.3.0 (2025-02-28) #### Added @@ -18,15 +18,15 @@ to [Semantic Versioning]. Full commit history is available in the - Add an exception callback to {class}`scvi.train._callbacks.SaveCheckpoint` in order to save optimal model during training, in case of failure because of Nan's in gradients. {pr}`3159`. - Add {meth}`~scvi.model.SCVI.get_normalized_expression` for models: {class}`~scvi.model.PEAKVI`, - {class}`~scvi.external.PoissonVI`, {class}`~scvi.model.CondSCVI`, {class}`~scvi.model.AutoZI`, - {class}`~scvi.external.CellAssign` and {class}`~scvi.external.GimVI`. {pr}`3121`. + {class}`~scvi.external.POISSONVI`, {class}`~scvi.model.CondSCVI`, {class}`~scvi.model.AUTOZI`, + {class}`~scvi.external.CellAssign` and {class}`~scvi.external.GIMVI`. {pr}`3121`. - Add {class}`scvi.external.RESOLVI` for bias correction in single-cell resolved spatial transcriptomics {pr}`3144`. +- Add semisupervised training mixin class {class}`scvi.model.base.SemisupervisedTrainingMixin` {pr}`3164`. - Add scib-metrics support for {class}`scvi.autotune.AutotuneExperiment` and {class}`scvi.train._callbacks.ScibCallback` for autotune for scib metrics {pr}`3168`. - Add Support of dask arrays in AnnTorchDataset. {pr}`3193`. -- Add a [use cases](%22https://docs.scvi-tools.org/en/latest/user_guide/index.html#common-use-cases%22) - section in the docs, {pr}`3200`. +- Add a {doc}`/user_guide/use_case` section in the docs, {pr}`3200`. - Add {class}`scvi.external.SysVI` for cycle consistency loss and VampPrior {pr}`3195`. #### Fixed diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md index 4aa0539323..261b2bb374 100644 --- a/docs/tutorials/index.md +++ b/docs/tutorials/index.md @@ -19,7 +19,7 @@ index_scbs index_multimodal index_spatial index_hub -index_tuning +index_use_cases index_dev ``` diff --git a/docs/tutorials/index_hub.md b/docs/tutorials/index_hub.md index 0febb7b381..beef4a1eda 100644 --- a/docs/tutorials/index_hub.md +++ b/docs/tutorials/index_hub.md @@ -6,5 +6,4 @@ notebooks/hub/cellxgene_census_model notebooks/hub/scvi_hub_intro_and_download notebooks/hub/scvi_hub_upload_and_large_files -notebooks/hub/minification ``` diff --git a/docs/tutorials/index_tuning.md b/docs/tutorials/index_tuning.md deleted file mode 100644 index b5f9f1a809..0000000000 --- a/docs/tutorials/index_tuning.md +++ /dev/null @@ -1,8 +0,0 @@ -# Hyperparameter tuning - -```{toctree} -:maxdepth: 1 - -notebooks/tuning/autotune_scvi -notebooks/tuning/autotune_new_model -``` diff --git a/docs/tutorials/index_use_cases.md b/docs/tutorials/index_use_cases.md new file mode 100644 index 0000000000..1c9b728987 --- /dev/null +++ b/docs/tutorials/index_use_cases.md @@ -0,0 +1,9 @@ +# Common Modelling Use Cases + +```{toctree} +:maxdepth: 1 + +notebooks/use_cases/autotune_scvi +notebooks/use_cases/minification +notebooks/use_cases/interpretability +``` diff --git a/docs/user_guide/use_case/downstream_analysis_tasks.md b/docs/user_guide/use_case/downstream_analysis_tasks.md index 509ca948c3..c9839eaf6b 100644 --- a/docs/user_guide/use_case/downstream_analysis_tasks.md +++ b/docs/user_guide/use_case/downstream_analysis_tasks.md @@ -22,6 +22,7 @@ You can compare the expression of genes between clusters to identify which genes differential_expression = scvi.model.SCVI().differential_expression() ``` Log-fold Change (LFC) and p-values are typically used to assess which genes have significant expression differences between groups. +Refer to [SCVI-Hub]("https://huggingface.co/scvi-tools") for use cases of DE. 3. Cell Type Identification Mapping to Known Labels: After training a model with SCVI, you can use the latent space to assign cells to known or predicted cell types. You can compare how well SCVI clusters cells by their latent representations and match them to known biological annotations. If you have labeled data (e.g., cell types), you can assess how well the model’s clusters correspond to these labels. diff --git a/pyproject.toml b/pyproject.toml index 04fefbc37f..566a6c5291 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,6 @@ dependencies = [ "torchmetrics>=0.11.0", "tqdm", "xarray>=2023.2.0", - "dask", ] [project.optional-dependencies] @@ -98,10 +97,12 @@ scanpy = ["scanpy>=1.10", "scikit-misc"] file_sharing = ["pooch"] # for parallelization engine parallel = ["dask[array]>=2023.5.1,<2024.8.0"] +# for supervised models interpretability +interpretability = ["captum","shap"] optional = [ - "scvi-tools[autotune,aws,hub,file_sharing,regseq,scanpy,parallel]" + "scvi-tools[autotune,aws,hub,file_sharing,regseq,scanpy,parallel,interpretability]" ] tutorials = [ "cell2location", @@ -135,6 +136,7 @@ markers = [ "optional: mark optional tests, usually take more time", "private: mark tests that uses private keys, like HF", "multigpu: mark tests that are used to check multi GPU performance", + "autotune: mark tests that are used to check ray autotune capabilities", ] [tool.ruff] diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index 5a43075754..4e0af95102 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -5,17 +5,13 @@ from copy import deepcopy from typing import TYPE_CHECKING -import numpy as np -import pandas as pd -import torch - from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager from scvi.data._constants import ( _SETUP_ARGS_KEY, ADATA_MINIFY_TYPE, ) -from scvi.data._utils import _get_adata_minify_type, _is_minified, get_anndata_attribute +from scvi.data._utils import _get_adata_minify_type, _is_minified from scvi.data.fields import ( CategoricalJointObsField, CategoricalObsField, @@ -24,18 +20,20 @@ NumericalJointObsField, NumericalObsField, ) -from scvi.dataloaders import SemiSupervisedDataSplitter -from scvi.model._utils import _init_library_size, get_max_epochs_heuristic, use_distributed_sampler +from scvi.model._utils import _init_library_size from scvi.module import SCANVAE -from scvi.train import SemiSupervisedTrainingPlan, TrainRunner -from scvi.train._callbacks import SubSampleLabels +from scvi.train import SemiSupervisedTrainingPlan from scvi.utils import setup_anndata_dsp -from scvi.utils._docstrings import devices_dsp -from .base import ArchesMixin, BaseMinifiedModeModelClass, RNASeqMixin, VAEMixin +from .base import ( + ArchesMixin, + BaseMinifiedModeModelClass, + RNASeqMixin, + SemisupervisedTrainingMixin, + VAEMixin, +) if TYPE_CHECKING: - from collections.abc import Sequence from typing import Literal from anndata import AnnData @@ -45,7 +43,9 @@ logger = logging.getLogger(__name__) -class SCANVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMinifiedModeModelClass): +class SCANVI( + RNASeqMixin, SemisupervisedTrainingMixin, VAEMixin, ArchesMixin, BaseMinifiedModeModelClass +): """Single-cell annotation using variational inference :cite:p:`Xu21`. Inspired from M1 + M2 model, as described in (https://arxiv.org/pdf/1406.5298.pdf). @@ -250,192 +250,6 @@ def from_scvi_model( return scanvi_model - def _set_indices_and_labels(self): - """Set indices for labeled and unlabeled cells.""" - labels_state_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) - self.original_label_key = labels_state_registry.original_key - self.unlabeled_category_ = labels_state_registry.unlabeled_category - - labels = get_anndata_attribute( - self.adata, - self.adata_manager.data_registry.labels.attr_name, - self.original_label_key, - ).ravel() - self._label_mapping = labels_state_registry.categorical_mapping - - # set unlabeled and labeled indices - self._unlabeled_indices = np.argwhere(labels == self.unlabeled_category_).ravel() - self._labeled_indices = np.argwhere(labels != self.unlabeled_category_).ravel() - self._code_to_label = dict(enumerate(self._label_mapping)) - - def predict( - self, - adata: AnnData | None = None, - indices: Sequence[int] | None = None, - soft: bool = False, - batch_size: int | None = None, - use_posterior_mean: bool = True, - ) -> np.ndarray | pd.DataFrame: - """Return cell label predictions. - - Parameters - ---------- - adata - AnnData object that has been registered via :meth:`~scvi.model.SCANVI.setup_anndata`. - indices - Return probabilities for each class label. - soft - If True, returns per class probabilities - batch_size - Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. - use_posterior_mean - If ``True``, uses the mean of the posterior distribution to predict celltype - labels. Otherwise, uses a sample from the posterior distribution - this - means that the predictions will be stochastic. - """ - adata = self._validate_anndata(adata) - - if indices is None: - indices = np.arange(adata.n_obs) - - scdl = self._make_data_loader( - adata=adata, - indices=indices, - batch_size=batch_size, - ) - y_pred = [] - for _, tensors in enumerate(scdl): - x = tensors[REGISTRY_KEYS.X_KEY] - batch = tensors[REGISTRY_KEYS.BATCH_KEY] - - cont_key = REGISTRY_KEYS.CONT_COVS_KEY - cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None - - cat_key = REGISTRY_KEYS.CAT_COVS_KEY - cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None - - pred = self.module.classify( - x, - batch_index=batch, - cat_covs=cat_covs, - cont_covs=cont_covs, - use_posterior_mean=use_posterior_mean, - ) - if self.module.classifier.logits: - pred = torch.nn.functional.softmax(pred, dim=-1) - if not soft: - pred = pred.argmax(dim=1) - y_pred.append(pred.detach().cpu()) - - y_pred = torch.cat(y_pred).numpy() - if not soft: - predictions = [] - for p in y_pred: - predictions.append(self._code_to_label[p]) - - return np.array(predictions) - else: - n_labels = len(pred[0]) - pred = pd.DataFrame( - y_pred, - columns=self._label_mapping[:n_labels], - index=adata.obs_names[indices], - ) - return pred - - @devices_dsp.dedent - def train( - self, - max_epochs: int | None = None, - n_samples_per_label: float | None = None, - check_val_every_n_epoch: int | None = None, - train_size: float | None = None, - validation_size: float | None = None, - shuffle_set_split: bool = True, - batch_size: int = 128, - accelerator: str = "auto", - devices: int | list[int] | str = "auto", - datasplitter_kwargs: dict | None = None, - plan_kwargs: dict | None = None, - **trainer_kwargs, - ): - """Train the model. - - Parameters - ---------- - max_epochs - Number of passes through the dataset for semisupervised training. - n_samples_per_label - Number of subsamples for each label class to sample per epoch. By default, there - is no label subsampling. - check_val_every_n_epoch - Frequency with which metrics are computed on the data for validation set for both - the unsupervised and semisupervised trainers. If you'd like a different frequency for - the semisupervised trainer, set check_val_every_n_epoch in semisupervised_train_kwargs. - train_size - Size of training set in the range [0.0, 1.0]. - validation_size - Size of the test set. If `None`, defaults to 1 - `train_size`. If - `train_size + validation_size < 1`, the remaining cells belong to a test set. - shuffle_set_split - Whether to shuffle indices before splitting. If `False`, the val, train, and test set - are split in the sequential order of the data according to `validation_size` and - `train_size` percentages. - batch_size - Minibatch size to use during training. - %(param_accelerator)s - %(param_devices)s - datasplitter_kwargs - Additional keyword arguments passed into - :class:`~scvi.dataloaders.SemiSupervisedDataSplitter`. - plan_kwargs - Keyword args for :class:`~scvi.train.SemiSupervisedTrainingPlan`. Keyword arguments - passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. - **trainer_kwargs - Other keyword args for :class:`~scvi.train.Trainer`. - """ - if max_epochs is None: - max_epochs = get_max_epochs_heuristic(self.adata.n_obs) - - if self.was_pretrained: - max_epochs = int(np.min([10, np.max([2, round(max_epochs / 3.0)])])) - - logger.info(f"Training for {max_epochs} epochs.") - - plan_kwargs = {} if plan_kwargs is None else plan_kwargs - datasplitter_kwargs = datasplitter_kwargs or {} - - # if we have labeled cells, we want to subsample labels each epoch - sampler_callback = [SubSampleLabels()] if len(self._labeled_indices) != 0 else [] - - data_splitter = SemiSupervisedDataSplitter( - adata_manager=self.adata_manager, - train_size=train_size, - validation_size=validation_size, - shuffle_set_split=shuffle_set_split, - n_samples_per_label=n_samples_per_label, - distributed_sampler=use_distributed_sampler(trainer_kwargs.get("strategy", None)), - batch_size=batch_size, - **datasplitter_kwargs, - ) - training_plan = self._training_plan_cls(self.module, self.n_labels, **plan_kwargs) - if "callbacks" in trainer_kwargs.keys(): - trainer_kwargs["callbacks"] + [sampler_callback] - else: - trainer_kwargs["callbacks"] = sampler_callback - - runner = TrainRunner( - self, - training_plan=training_plan, - data_splitter=data_splitter, - max_epochs=max_epochs, - accelerator=accelerator, - devices=devices, - check_val_every_n_epoch=check_val_every_n_epoch, - **trainer_kwargs, - ) - return runner() - @classmethod @setup_anndata_dsp.dedent def setup_anndata( diff --git a/src/scvi/model/base/__init__.py b/src/scvi/model/base/__init__.py index 4b38494caf..97f10675d6 100644 --- a/src/scvi/model/base/__init__.py +++ b/src/scvi/model/base/__init__.py @@ -14,7 +14,7 @@ PyroSviTrainMixin, ) from ._rnamixin import RNASeqMixin -from ._training_mixin import UnsupervisedTrainingMixin +from ._training_mixin import SemisupervisedTrainingMixin, UnsupervisedTrainingMixin from ._vaemixin import VAEMixin __all__ = [ @@ -32,4 +32,5 @@ "BaseMinifiedModeModelClass", "BaseMudataMinifiedModeModelClass", "EmbeddingMixin", + "SemisupervisedTrainingMixin", ] diff --git a/src/scvi/model/base/_training_mixin.py b/src/scvi/model/base/_training_mixin.py index ebace98445..e803473c34 100644 --- a/src/scvi/model/base/_training_mixin.py +++ b/src/scvi/model/base/_training_mixin.py @@ -1,15 +1,31 @@ from __future__ import annotations +import importlib +import logging from typing import TYPE_CHECKING -from scvi.dataloaders import DataSplitter +import anndata +import numpy as np +import pandas as pd +import torch + +from scvi import REGISTRY_KEYS +from scvi.data._utils import get_anndata_attribute +from scvi.dataloaders import DataSplitter, SemiSupervisedDataSplitter from scvi.model._utils import get_max_epochs_heuristic, use_distributed_sampler -from scvi.train import TrainingPlan, TrainRunner +from scvi.train import SemiSupervisedTrainingPlan, TrainingPlan, TrainRunner +from scvi.train._callbacks import SubSampleLabels from scvi.utils._docstrings import devices_dsp if TYPE_CHECKING: + from collections.abc import Sequence + from lightning import LightningDataModule + from scvi._types import AnnOrMuData + +logger = logging.getLogger(__name__) + class UnsupervisedTrainingMixin: """General purpose unsupervised train method.""" @@ -143,3 +159,340 @@ def train( **trainer_kwargs, ) return runner() + + +class SemisupervisedTrainingMixin: + _training_plan_cls = SemiSupervisedTrainingPlan + + def _set_indices_and_labels(self): + """Set indices for labeled and unlabeled cells.""" + labels_state_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) + self.original_label_key = labels_state_registry.original_key + self.unlabeled_category_ = labels_state_registry.unlabeled_category + + labels = get_anndata_attribute( + self.adata, + self.adata_manager.data_registry.labels.attr_name, + self.original_label_key, + mod_key=getattr(self.adata_manager.data_registry.labels, "mod_key", None), + ).ravel() + self._label_mapping = labels_state_registry.categorical_mapping + + # set unlabeled and labeled indices + self._unlabeled_indices = np.argwhere(labels == self.unlabeled_category_).ravel() + self._labeled_indices = np.argwhere(labels != self.unlabeled_category_).ravel() + self._code_to_label = dict(enumerate(self._label_mapping)) + + def predict( + self, + adata: AnnOrMuData | None = None, + indices: Sequence[int] | None = None, + soft: bool = False, + batch_size: int | None = None, + use_posterior_mean: bool = True, + ig_interpretability: bool = False, + ig_args: dict | None = None, + ) -> (np.ndarray | pd.DataFrame, None | np.ndarray): + """Return cell label predictions. + + Parameters + ---------- + adata + AnnData or MuData object that has been registered via corresponding setup + method in model class. + indices + Return probabilities for each class label. + soft + If True, returns per class probabilities + batch_size + Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + use_posterior_mean + If ``True``, uses the mean of the posterior distribution to predict celltype + labels. Otherwise, uses a sample from the posterior distribution - this + means that the predictions will be stochastic. + ig_interpretability + If True, run the integrated circuits interpretability per sample and returns a score + matrix, in which for each sample we score each gene for its contribution to the + sample prediction + ig_args + Keyword args for IntegratedGradients + """ + adata = self._validate_anndata(adata) + + if indices is None: + indices = np.arange(adata.n_obs) + + attributions = None + if ig_interpretability: + missing_modules = [] + try: + importlib.import_module("captum") + except ImportError: + missing_modules.append("captum") + if len(missing_modules) > 0: + raise ModuleNotFoundError("Please install captum to use this functionality.") + from captum.attr import IntegratedGradients + + ig = IntegratedGradients(self.module.classify) + attributions = [] + + # in case of no indices to predict return empty values + if len(indices) == 0: + pred = [] + if ig_interpretability: + return pred, attributions + else: + return pred + + scdl = self._make_data_loader( + adata=adata, + indices=indices, + batch_size=batch_size, + ) + + y_pred = [] + for _, tensors in enumerate(scdl): + inference_inputs = self.module._get_inference_input(tensors) # (n_obs, n_vars) + data_inputs = { + key: inference_inputs[key] + for key in inference_inputs.keys() + if key not in ["batch_index", "cont_covs", "cat_covs"] + } + + batch = tensors[REGISTRY_KEYS.BATCH_KEY] + + cont_key = REGISTRY_KEYS.CONT_COVS_KEY + cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None + + cat_key = REGISTRY_KEYS.CAT_COVS_KEY + cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None + + pred = self.module.classify( + **data_inputs, + batch_index=batch, + cat_covs=cat_covs, + cont_covs=cont_covs, + use_posterior_mean=use_posterior_mean, + ) + if self.module.classifier.logits: + pred = torch.nn.functional.softmax(pred, dim=-1) + if not soft: + pred = pred.argmax(dim=1) + y_pred.append(pred.detach().cpu()) + + if ig_interpretability: + # we need the hard prediction if was not done yet + hard_pred = pred.argmax(dim=1) if soft else pred + ig_args = ig_args or {} + attribution = ig.attribute( + tuple(data_inputs.values()), target=hard_pred, **ig_args + ) + attributions.append(attribution[0]) + + if ig_interpretability: + if attributions is not None and len(attributions) > 0: + attributions = torch.cat(attributions, dim=0).detach().numpy() + attributions = self.get_ranked_genes(adata, attributions) + + if len(y_pred) > 0: + y_pred = torch.cat(y_pred).numpy() + if not soft: + predictions = [self._code_to_label[p] for p in y_pred] + if ig_interpretability: + return np.array(predictions), attributions + else: + return np.array(predictions) + else: + n_labels = len(pred[0]) + pred = pd.DataFrame( + y_pred, + columns=self._label_mapping[:n_labels], + index=adata.obs_names[indices], + ) + if ig_interpretability: + return pred, attributions + else: + return pred + + @devices_dsp.dedent + def train( + self, + max_epochs: int | None = None, + n_samples_per_label: float | None = None, + check_val_every_n_epoch: int | None = None, + train_size: float = 0.9, + validation_size: float | None = None, + shuffle_set_split: bool = True, + batch_size: int = 128, + accelerator: str = "auto", + devices: int | list[int] | str = "auto", + datasplitter_kwargs: dict | None = None, + plan_kwargs: dict | None = None, + **trainer_kwargs, + ): + """Train the model. + + Parameters + ---------- + max_epochs + Number of passes through the dataset for semisupervised training. + n_samples_per_label + Number of subsamples for each label class to sample per epoch. By default, there + is no label subsampling. + check_val_every_n_epoch + Frequency with which metrics are computed on the data for validation set for both + the unsupervised and semisupervised trainers. If you'd like a different frequency for + the semisupervised trainer, set check_val_every_n_epoch in semisupervised_train_kwargs. + train_size + Size of training set in the range [0.0, 1.0]. + validation_size + Size of the test set. If `None`, defaults to 1 - `train_size`. If + `train_size + validation_size < 1`, the remaining cells belong to a test set. + shuffle_set_split + Whether to shuffle indices before splitting. If `False`, the val, train, + and test set are split in the sequential order of the data according to + `validation_size` and `train_size` percentages. + batch_size + Minibatch size to use during training. + %(param_accelerator)s + %(param_devices)s + datasplitter_kwargs + Additional keyword arguments passed into + :class:`~scvi.dataloaders.SemiSupervisedDataSplitter`. + plan_kwargs + Keyword args for :class:`~scvi.train.SemiSupervisedTrainingPlan`. Keyword + arguments passed to `train()` will overwrite values present in `plan_kwargs`, + when appropriate. + **trainer_kwargs + Other keyword args for :class:`~scvi.train.Trainer`. + """ + if max_epochs is None: + max_epochs = get_max_epochs_heuristic(self.adata.n_obs) + + if self.was_pretrained: + max_epochs = int(np.min([10, np.max([2, round(max_epochs / 3.0)])])) + + logger.info(f"Training for {max_epochs} epochs.") + + plan_kwargs = {} if plan_kwargs is None else plan_kwargs + datasplitter_kwargs = datasplitter_kwargs or {} + + # if we have labeled cells, we want to subsample labels each epoch + sampler_callback = [SubSampleLabels()] if len(self._labeled_indices) != 0 else [] + + data_splitter = SemiSupervisedDataSplitter( + adata_manager=self.adata_manager, + train_size=train_size, + validation_size=validation_size, + shuffle_set_split=shuffle_set_split, + n_samples_per_label=n_samples_per_label, + distributed_sampler=use_distributed_sampler(trainer_kwargs.get("strategy", None)), + batch_size=batch_size, + **datasplitter_kwargs, + ) + training_plan = self._training_plan_cls(self.module, self.n_labels, **plan_kwargs) + + if "callbacks" in trainer_kwargs.keys(): + trainer_kwargs["callbacks"] + [sampler_callback] + else: + trainer_kwargs["callbacks"] = sampler_callback + + runner = TrainRunner( + self, + training_plan=training_plan, + data_splitter=data_splitter, + max_epochs=max_epochs, + accelerator=accelerator, + devices=devices, + check_val_every_n_epoch=check_val_every_n_epoch, + **trainer_kwargs, + ) + return runner() + + def get_ranked_genes( + self, adata: AnnOrMuData | None = None, attrs: np.ndarray | None = None + ) -> pd.DataFrame: + """Get the ranked gene list based on highest attributions. + + Parameters + ---------- + attr: numpy.ndarray + Attributions matrix. + + Returns + ------- + pandas.DataFrame + A pandas dataframe containing the ranked attributions for each gene + + Examples + -------- + >>> attrs_df = interpreter.get_ranked_genes(attrs) + """ + if attrs is None: + Warning("Missing Attributions matrix") + return + + adata = self._validate_anndata(adata) + + mean_attrs = attrs.mean(axis=0) + idx = mean_attrs.argsort()[::-1] + df = { + "gene": np.array(adata.var_names)[idx], + "gene_idx": idx, + "attribution_mean": mean_attrs[idx], + "attribution_std": attrs.std(axis=0)[idx], + "cells": attrs.shape[0], + } + return pd.DataFrame(df) + + def shap_adata_predict( + self, + X, + ): + adata = self._validate_anndata() + + # we need to adjust adata to the shap random selection .. + if len(X) > len(adata): + # Repeat the data to expand to a larger size + n_repeats = len(X) / len(adata) # how many times you want to repeat the data + adata_to_pred = adata[adata.obs.index.repeat(n_repeats), :] + if len(X) > len(adata_to_pred): + adata_to_pred = anndata.concat( + [adata_to_pred, adata[0 : (len(X) - len(adata_to_pred))]] + ) + else: + adata_to_pred = adata[0 : len(X)] + adata_to_pred.X = X + + return self.predict(adata_to_pred, soft=True) + + def shap_predict(self, adata: AnnOrMuData | None = None, max_size: int = 100): + missing_modules = [] + try: + importlib.import_module("shap") + except ImportError: + missing_modules.append("shap") + if len(missing_modules) > 0: + raise ModuleNotFoundError("Please install shap to use this functionality.") + import shap + + adata_orig = self._validate_anndata() + adata = self._validate_anndata(adata) + + if type(adata_orig.X).__name__ == "csr_matrix": + feature_matrix_background = pd.DataFrame.sparse.from_spmatrix( + adata_orig.X, columns=adata_orig.var_names + ) + else: + feature_matrix_background = pd.DataFrame(adata_orig.X, columns=adata_orig.var_names) + if type(adata.X).__name__ == "csr_matrix": + feature_matrix = pd.DataFrame.sparse.from_spmatrix( + adata.X, columns=adata_orig.var_names + ) + else: + feature_matrix = pd.DataFrame(adata.X, columns=adata_orig.var_names) + feature_matrix_background = shap.sample(feature_matrix_background, max_size) + feature_matrix = shap.sample(feature_matrix, max_size) + explainer = shap.KernelExplainer(self.shap_adata_predict, feature_matrix_background) + shap_values = explainer.shap_values(feature_matrix) + return shap_values diff --git a/tests/autotune/test_tune.py b/tests/autotune/test_tune.py index ca1f934440..9c6a1c0497 100644 --- a/tests/autotune/test_tune.py +++ b/tests/autotune/test_tune.py @@ -9,6 +9,7 @@ from scvi.model import SCANVI, SCVI +@pytest.mark.autotune def test_run_autotune_scvi_basic(save_path: str): settings.logging_dir = save_path adata = synthetic_iid() @@ -38,6 +39,7 @@ def test_run_autotune_scvi_basic(save_path: str): assert isinstance(experiment.result_grid, ResultGrid) +@pytest.mark.autotune def test_run_autotune_scvi_no_anndata(save_path: str, n_batches: int = 3): settings.logging_dir = save_path adata = synthetic_iid(n_batches=n_batches) @@ -72,6 +74,7 @@ def test_run_autotune_scvi_no_anndata(save_path: str, n_batches: int = 3): assert isinstance(experiment.result_grid, ResultGrid) +@pytest.mark.autotune @pytest.mark.parametrize("metric", ["Total", "Bio conservation", "iLISI"]) @pytest.mark.parametrize("model_cls", [SCVI, SCANVI]) def test_run_autotune_scvi_with_scib(model_cls, metric: str, save_path: str): @@ -117,6 +120,7 @@ def test_run_autotune_scvi_with_scib(model_cls, metric: str, save_path: str): assert isinstance(experiment.result_grid, ResultGrid) +@pytest.mark.autotune def test_run_autotune_scvi_with_scib_ext_indices(save_path: str, metric: str = "iLISI"): settings.logging_dir = save_path adata = synthetic_iid() diff --git a/tests/conftest.py b/tests/conftest.py index 6ef9467efc..78f3934b82 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,13 @@ def pytest_addoption(parser): "--multigpu-tests", action="store_true", default=False, - help="Run tests that are desinged for multiGPU.", + help="Run tests that are designed for multiGPU.", + ) + parser.addoption( + "--autotune-tests", + action="store_true", + default=False, + help="Run tests that are designed for Ray Autotune.", ) parser.addoption( "--optional", @@ -108,6 +114,18 @@ def pytest_collection_modifyitems(config, items): elif run_multigpu and ("multigpu" not in item.keywords): item.add_marker(skip_non_multigpu) + run_autotune = config.getoption("--autotune-tests") + skip_autotune = pytest.mark.skip(reason="need --autotune-tests option to run") + skip_non_autotune = pytest.mark.skip(reason="test not having a pytest.mark.autotune decorator") + for item in items: + # All tests marked with `pytest.mark.autotune` get skipped unless + # `--autotune-tests` passed + if not run_autotune and ("autotune" in item.keywords): + item.add_marker(skip_autotune) + # Skip all tests not marked with `pytest.mark.autotune` if `--autotune-tests` passed + elif run_autotune and ("autotune" not in item.keywords): + item.add_marker(skip_non_autotune) + @pytest.fixture(scope="session") def save_path(tmp_path_factory): diff --git a/tests/model/test_scanvi.py b/tests/model/test_scanvi.py index cb6e965574..632d1eace1 100644 --- a/tests/model/test_scanvi.py +++ b/tests/model/test_scanvi.py @@ -577,3 +577,78 @@ def check_no_logits_and_softmax(model: SCANVI): model = SCANVI.load(resave_model_path, adata) check_no_logits_and_softmax(model) + + +@pytest.mark.parametrize("unlabeled_cat", ["label_0"]) +def test_scanvi_interpertability_ig(unlabeled_cat: str): + adata = synthetic_iid(batch_size=50) + adata.obs["cont1"] = np.random.normal(size=(adata.shape[0],)) + adata.obs["cont2"] = np.random.normal(size=(adata.shape[0],)) + adata.obs["cat1"] = np.random.randint(0, 5, size=(adata.shape[0],)) + adata.obs["cat2"] = np.random.randint(0, 5, size=(adata.shape[0],)) + SCANVI.setup_anndata( + adata, + labels_key="labels", + unlabeled_category=unlabeled_cat, + batch_key="batch", + continuous_covariate_keys=["cont1", "cont2"], + categorical_covariate_keys=["cat1", "cat2"], + ) + model = SCANVI(adata, n_latent=10) + model.train(1, train_size=0.5, check_val_every_n_epoch=1) + + # get the IG for all data + predictions, attributions = model.predict(ig_interpretability=True) # orignal predictions + # let's see an avg of score of top 5 genes for all samples put together + ig_top_features = attributions.head(5) + print(ig_top_features) + + # new data ig prediction specific for samples, top 5 genes + adata2 = synthetic_iid(batch_size=10) + adata2.obs["cont1"] = np.random.normal(size=(adata2.shape[0],)) + adata2.obs["cont2"] = np.random.normal(size=(adata2.shape[0],)) + adata2.obs["cat1"] = np.random.randint(0, 5, size=(adata2.shape[0],)) + adata2.obs["cat2"] = np.random.randint(0, 5, size=(adata2.shape[0],)) + predictions, attributions = model.predict(adata2, indices=[1, 2, 3], ig_interpretability=True) + ig_top_features_3_samples = attributions.head(5) + print(ig_top_features_3_samples) + + +@pytest.mark.parametrize("unlabeled_cat", ["label_0"]) +def test_scanvi_interpertability_shap(unlabeled_cat: str): + adata = synthetic_iid(batch_size=50) + adata.obs["cont1"] = np.random.normal(size=(adata.shape[0],)) + adata.obs["cont2"] = np.random.normal(size=(adata.shape[0],)) + adata.obs["cat1"] = np.random.randint(0, 5, size=(adata.shape[0],)) + adata.obs["cat2"] = np.random.randint(0, 5, size=(adata.shape[0],)) + SCANVI.setup_anndata( + adata, + labels_key="labels", + unlabeled_category=unlabeled_cat, + batch_key="batch", + continuous_covariate_keys=["cont1", "cont2"], + categorical_covariate_keys=["cat1", "cat2"], + ) + model = SCANVI(adata, n_latent=10) + model.train(1, train_size=0.5, check_val_every_n_epoch=1) + + # new data ig prediction specific for samples, top 5 genes + adata2 = synthetic_iid(batch_size=10) + adata2.obs["cont1"] = np.random.normal(size=(adata2.shape[0],)) + adata2.obs["cont2"] = np.random.normal(size=(adata2.shape[0],)) + adata2.obs["cat1"] = np.random.randint(0, 5, size=(adata2.shape[0],)) + adata2.obs["cat2"] = np.random.randint(0, 5, size=(adata2.shape[0],)) + + # now run shap values and compare to previous results + # (here, the more labels the more time it will take to run) + shap_values = model.shap_predict() + # select the label we want to understand (usually the '1' class) + shap_top_features = model.get_ranked_genes(attrs=shap_values[:, :, 1]).head(5) + print(shap_top_features) + + # now run shap values for the test set + # (here, the more labels the more time it will take to run) + shap_values_test = model.shap_predict(adata2) + # # select the label we want to understand (usually the '1' class) + shap_top_features_test = model.get_ranked_genes(attrs=shap_values_test[:, :, 1]).head(5) + print(shap_top_features_test)