Skip to content

Commit

Permalink
feat: get_norm function property (#3238)
Browse files Browse the repository at this point in the history
added get_norm property and check for all models + also the criticism
guide
  • Loading branch information
ori-kron-wis authored Mar 9, 2025
1 parent bf8daf4 commit c538dbc
Show file tree
Hide file tree
Showing 22 changed files with 119 additions and 69 deletions.
12 changes: 7 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ to [Semantic Versioning]. Full commit history is available in the

- Add {class}`scvi.external.METHYLANVI` for modeling methylation labelled data from single-cell
bisulfite sequencing (scBS-seq) {pr}`3066`.
- Add supervised module class {class}`scvi.model.base.SupervisedModuleClass`. {pr}`3237`.
- Add supervised module class {class}`scvi.module.base.SupervisedModuleClass`. {pr}`3237`.
- Add get normalized function model property for any generative model {pr}`3238` and changed
get_accessibility_estimates to get_normalized_accessibility, where needed.

#### Fixed

Expand Down Expand Up @@ -101,7 +103,7 @@ to [Semantic Versioning]. Full commit history is available in the
- Added adaptive handling for last training minibatch of 1-2 cells in case of
`datasplitter_kwargs={"drop_last": False}` and `train_size = None` by moving them into
validation set, if available. {pr}`3036`.
- Add `batch_key` and `labels_key` to `scvi.external.SCAR.setup_anndata`. {pr}`3045`.
- Add `batch_key` and `labels_key` to {meth}`scvi.external.SCAR.setup_anndata`. {pr}`3045`.
- Implemented variance of ZINB distribution. {pr}`3044`.
- Support for minified mode while retaining counts to skip the encoder.
- New Trainingplan argument `update_only_decoder` to use stored latent codes and skip training of
Expand All @@ -115,7 +117,7 @@ to [Semantic Versioning]. Full commit history is available in the
- Breaking Change: Fix `get_outlier_cell_sample_pairs` function in {class}`scvi.external.MRVI`
to correctly compute the maxmimum log-density across in-sample cells rather than the
aggregated posterior log-density {pr}`3007`.
- Fix references to `scvi.external` in `scvi.external.SCAR.setup_anndata`.
- Fix references to `scvi.external` in {meth}`scvi.external.SCAR.setup_anndata`.
- Fix gimVI to append mini batches first into CPU during get_imputed and get_latent operations {pr}`3058`.

#### Changed
Expand All @@ -127,9 +129,9 @@ to [Semantic Versioning]. Full commit history is available in the
#### Added

- Add support for Python 3.12 {pr}`2966`.
- Add support for categorial covariates in scArches in `scvi.model.archesmixin` {pr}`2936`.
- Add support for categorial covariates in scArches in {meth}`scvi.model.archesmixin` {pr}`2936`.
- Add assertion error in cellAssign for checking duplicates in celltype markers {pr}`2951`.
- Add `scvi.external.poissonvi.get_region_factors` {pr}`2940`.
- Add {meth}`scvi.external.poissonvi.get_region_factors` {pr}`2940`.
- {attr}`scvi.settings.dl_persistent_workers` allows using persistent workers in
{class}`scvi.dataloaders.AnnDataLoader` {pr}`2924`.
- Add option for using external indexes in data splitting classes that are under `scvi.dataloaders`
Expand Down
1 change: 1 addition & 0 deletions docs/user_guide/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ scvi-tools is composed of models that can perform one or many analysis tasks. In
- {doc}`/user_guide/use_case/hyper_parameters_tuning`
- {doc}`/user_guide/use_case/multi_gpu_training`
- {doc}`/user_guide/use_case/custom_dataloaders`
- {doc}`/user_guide/use_case/scvi_criticism`

## Glossary

Expand Down
4 changes: 2 additions & 2 deletions docs/user_guide/models/multivi.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ When gene expression data is passed, MultiVI computes denoised gene expression d

### Denoising/imputation of accessibility

In {func}`~scvi.model.MULTIVI.get_accessibility_estimates` MultiVI returns the expected value of $y_i$ under the
In {func}`~scvi.model.MULTIVI.get_normalized_accessibility` MultiVI returns the expected value of $y_i$ under the
approximate posterior. For one cell $i$, this can be written as:

```{math}
Expand All @@ -199,7 +199,7 @@ then, using this value to decode the accessibility probability estimate $p_r$. A
the variational mean, a number of samples can be passed as an argument in the code:

```
>>> model.get_accessibility_estimates(n_samples_overall=10)
>>> model.get_normalized_accessibility(n_samples_overall=10)
```

This value is used to compute the mean of the latent variable over these samples. Notably, this function also has
Expand Down
2 changes: 1 addition & 1 deletion docs/user_guide/models/peakvi.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ A PeakVI model can be pre-trained on reference data and updated with query data

### Estimation of accessibility

In {func}`~scvi.model.PEAKVI.get_accessibility_estimates` PeakVI returns the expected value of $y_i$ under the approximate posterior. For one cell $i$, this can be written as:
In {func}`~scvi.model.PEAKVI.get_normalized_accessibility` PeakVI returns the expected value of $y_i$ under the approximate posterior. For one cell $i$, this can be written as:

```{math}
:nowrap: true
Expand Down
4 changes: 2 additions & 2 deletions docs/user_guide/use_case/custom_dataloaders.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ For example, we offer custom dataloaders that do not necessarily store the data
Without dataloader large data can be on disk but inefficient. We increase efficiency for data on disk and enable data on cloud storage.

In SCVI, we work with several collaborators in order to construct efficient custom dataloaders:
1. [lamin.ai]("https://lamin.ai/") custom dataloader is based on MappedCollectionDataModule and can run a collection of adata based on lamindb backend.
1. [lamin.ai](https://lamin.ai/) custom dataloader is based on MappedCollectionDataModule and can run a collection of adata based on lamindb backend.

LamindDB is a key-value store designed specifically for machine learning, particularly focusing on making it easier to manage large-scale training datasets. It is optimized for storing and querying tabular data, providing fast read and write access. LamindDB’s design allows for efficient handling of large datasets with a focus on machine learning workflows, such as those used in SCVI.

Expand Down Expand Up @@ -39,7 +39,7 @@ model = scvi.model.SCVI(adata=None, registry=datamodule.registry)
```
LamindDB may not be as efficient or flexible as TileDB for handling complex multi-dimensional data

2. [CZI]("https://chanzuckerberg.com/") based [tiledb]("https://tiledb.com/") custom dataloader is based on CensusSCVIDataModule and can run a large multi-dimensional datasets that are stored in TileDB’s format.
2. [CZI](https://chanzuckerberg.com/) based [tiledb](https://tiledb.com/) custom dataloader is based on CensusSCVIDataModule and can run a large multi-dimensional datasets that are stored in TileDB’s format.

TileDB is a general-purpose, multi-dimensional array storage engine designed for high-performance, scalable data access. It supports various data types, including dense and sparse arrays, and is optimized for handling large datasets efficiently. TileDB’s strength lies in its ability to store and query data across multiple dimensions and scale efficiently with large volumes of data.

Expand Down
32 changes: 32 additions & 0 deletions docs/user_guide/use_case/scvi_criticism.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SCVI Criticism

:::{note}
This page is under construction.
:::

SCVI-Criticism is a tool in the scvi-tools ecosystem that helps assess the performance and quality of single-cell RNA sequencing (scRNA-seq) data analysis using generative models like scVI (single-cell Variational Inference). It provides a framework for evaluating the models' predictions and helps researchers understand the strengths and weaknesses of their models. One of its main advantages is that it allows for robust evaluation using real or simulated datasets, offering insight into model robustness, overfitting, and predictive performance.

Underneath the hood it uses posterior predictive checks or PPC ({class}`scvi.criticism.PosteriorPredictiveCheck`) for comparing scRNA-seq generative models.

The method works well for any SCVI model out of the box to provide insights into the quality of predictions and model evaluation.

### There are few metrics we calculate in order to achieve that:
- **Cell Wise Coefficient of variation:** The cell-wise coefficient of variation summarizes how well variation between different cells is preserved by the generated model expression. Below a squared Pearson correlation coefficient of 0.4 , we would recommend not to use generated data for downstream analysis, while the generated latent space might still be useful for analysis.
- **Gene Wise Coefficient of variation:** The gene-wise coefficient of variation summarizes how well variation between different genes is preserved by the generated model expression. This value is usually quite high.
- **Differential expression metric:** The differential expression metric provides a summary of the differential expression analysis between cell types or input clusters. We provide the F1-score, Pearson Correlation Coefficient of Log-Foldchanges, Spearman Correlation Coefficient, and Area Under the Precision Recall Curve (AUPRC) for the differential expression analysis using Wilcoxon Rank Sum test for each cell-type.

### Example of use:
We can compute and compare 2 models PPC's by:
```python
models_dict = {"model1": model1, "model2": model2}
ppc = PPC(adata, models_dict)
```

### Creating Report:
A good practice will be to save this report together with the model, just like we do for weights and adata.
For example, for all our [SCVI-Hub](https://huggingface.co/scvi-tools) models we attached the SCVI criticism reports.

We can create the criticism report simply by:
```python
create_criticism_report(model)
```
2 changes: 1 addition & 1 deletion src/scvi/external/cellassign/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
B = 10


class CellAssign(UnsupervisedTrainingMixin, BaseModelClass, RNASeqMixin):
class CellAssign(UnsupervisedTrainingMixin, RNASeqMixin, BaseModelClass):
"""Reimplementation of CellAssign for reference-based annotation :cite:p:`Zhang19`.
Original implementation: https://github.com/irrationone/cellassign.
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/external/gimvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _unpack_tensors(tensors):
return x, batch_index, y


class GIMVI(VAEMixin, BaseModelClass, RNASeqMixin):
class GIMVI(VAEMixin, RNASeqMixin, BaseModelClass):
"""Joint VAE for imputing missing genes in spatial data :cite:p:`Lopez19`.
Parameters
Expand Down
3 changes: 2 additions & 1 deletion src/scvi/external/methylvi/_methylanvi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
"methylation_contexts"
]
self.num_features_per_context = [mdata[context].shape[1] for context in self.contexts]
self.get_normalized_function_name = "get_normalized_methylation"

n_input = np.sum(self.num_features_per_context)

Expand Down Expand Up @@ -196,7 +197,7 @@ def setup_mudata(
(specified by `mc_layer`) and total number of counts (specified by `cov_layer`) for
each genomic region feature.
%(param_batch_key)s
%(param_categorical_covariate_keys)s
%(param_cat_cov_keys)s
%(param_modalities)s
Examples
Expand Down
3 changes: 2 additions & 1 deletion src/scvi/external/methylvi/_methylvi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
"methylation_contexts"
]
self.num_features_per_context = [mdata[context].shape[1] for context in self.contexts]
self.get_normalized_function_name = "get_normalized_methylation"

n_input = np.sum(self.num_features_per_context)

Expand Down Expand Up @@ -151,7 +152,7 @@ def setup_mudata(
(specified by `mc_layer`) and total number of counts (specified by `cov_layer`) for
each genomic region feature.
%(param_batch_key)s
%(param_categorical_covariate_keys)s
%(param_cat_cov_keys)s
%(param_modalities)s
Examples
Expand Down
5 changes: 3 additions & 2 deletions src/scvi/external/poissonvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch)

self._module_cls = VAE
self.get_normalized_function_name = "get_normalized_accessibility"

self.module = self._module_cls(
n_input=self.summary_stats.n_vars,
Expand Down Expand Up @@ -135,7 +136,7 @@ def __init__(
self.init_params_ = self._get_init_params(locals())

@torch.inference_mode()
def get_accessibility_estimates(
def get_normalized_accessibility(
self,
adata: AnnData | None = None,
indices: Sequence[int] = None,
Expand Down Expand Up @@ -334,7 +335,7 @@ def differential_accessibility(
col_names = adata.var_names
importance_weighting_kwargs = importance_weighting_kwargs or {}
model_fn = partial(
self.get_accessibility_estimates,
self.get_normalized_accessibility,
return_numpy=True,
n_samples=1,
batch_size=batch_size,
Expand Down
4 changes: 2 additions & 2 deletions src/scvi/external/sysvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ def setup_anndata(
layer
AnnData layer to use, default is X.
Should contain normalized and log1p transformed expression.
%(param_categorical_covariate_keys)s
%(param_continuous_covariate_keys)s
%(param_cat_cov_keys)s
%(param_cont_cov_keys)s
"""
setup_method_args = cls._get_setup_method_args(**locals())

Expand Down
16 changes: 8 additions & 8 deletions src/scvi/hub/_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
clustering.
scVI takes as input a scRNA-seq gene expression matrix with cells and genes.
We provide an extensive [user guide](https://docs.scvi-tools.org/en/1.2.0/user_guide/models/scvi.html).
We provide an extensive [user guide](https://docs.scvi-tools.org/en/stable/user_guide/models/scvi.html).
- See our original manuscript for further details of the model:
[scVI manuscript](https://www.nature.com/articles/s41592-018-0229-2).
- See our manuscript on [scvi-hub](https://www.biorxiv.org/content/10.1101/2024.03.01.582887v2) how
to leverage pre-trained models.
This model can be used for fine tuning on new data using our Arches framework:
[Arches tutorial](https://docs.scvi-tools.org/en/1.0.0/tutorials/notebooks/scarches_scvi_tools.html).
[Arches tutorial](https://docs.scvi-tools.org/en/stable/tutorials/notebooks/scrna/scarches_scvi_tools.html).
"""

scanvi_pretext = """
Expand All @@ -26,15 +26,15 @@
scANVI takes as input a scRNA-seq gene expression matrix with cells and genes as well as a
cell-type annotation for a subset of cells.
We provide an extensive [user guide](https://docs.scvi-tools.org/en/1.2.0/user_guide/models/scanvi.html).
We provide an extensive [user guide](https://docs.scvi-tools.org/en/stable/user_guide/models/scanvi.html).
- See our original manuscript for further details of the model:
[scANVI manuscript](https://www.embopress.org/doi/full/10.15252/msb.20209620).
- See our manuscript on [scvi-hub](https://www.biorxiv.org/content/10.1101/2024.03.01.582887v2)
how to leverage pre-trained models.
This model can be used for fine tuning on new data using our Arches framework:
[Arches tutorial](https://docs.scvi-tools.org/en/1.0.0/tutorials/notebooks/scarches_scvi_tools.html).
[Arches tutorial](https://docs.scvi-tools.org/en/stable/tutorials/notebooks/scrna/scarches_scvi_tools.html).
"""

condscvi_pretext = """
Expand All @@ -46,7 +46,7 @@
CondSCVI takes as input a scRNA-seq gene expression matrix with cells and genes as well as a
cell-type annotation for all cells.
We provide an extensive [user guide](https://docs.scvi-tools.org/en/1.2.0/user_guide/models/destvi.html)
We provide an extensive [user guide](https://docs.scvi-tools.org/en/stable/user_guide/models/destvi.html)
for DestVI including a description of CondSCVI.
- See our original manuscript for further details of the model:
Expand All @@ -64,7 +64,7 @@
Stereoscope takes as input a scRNA-seq gene expression matrix with cells and genes as well as a
cell-type annotation for all cells.
We provide an extensive for DestVI including a description of CondSCVI
[user guide](https://docs.scvi-tools.org/en/1.2.0/user_guide/models/destvi.html).
[user guide](https://docs.scvi-tools.org/en/stable/user_guide/models/destvi.html).
- See our original manuscript for further details of the model:
[Stereoscope manuscript](https://www.nature.com/articles/s42003-020-01247-y) as well as the
Expand All @@ -84,15 +84,15 @@
TotalVI takes as input a scRNA-seq gene expression and protein expression matrix with cells and
genes.
We provide an extensive [user guide](https://docs.scvi-tools.org/en/1.2.0/user_guide/models/totalvi.html).
We provide an extensive [user guide](https://docs.scvi-tools.org/en/stable/user_guide/models/totalvi.html).
- See our original manuscript for further details of the model:
[TotalVI manuscript](https://www.nature.com/articles/s41592-020-01050-x).
- See our manuscript on [scvi-hub](https://www.biorxiv.org/content/10.1101/2024.03.01.582887v2)
how to leverage pre-trained models.
This model can be used for fine tuning on new data using our Arches framework:
[Arches tutorial](https://docs.scvi-tools.org/en/1.0.0/tutorials/notebooks/scarches_scvi_tools.html).
[Arches tutorial](https://docs.scvi-tools.org/en/stable/tutorials/notebooks/scrna/scarches_scvi_tools.html).
"""


Expand Down
2 changes: 1 addition & 1 deletion src/scvi/model/_autozi.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# register buffer


class AUTOZI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, RNASeqMixin):
class AUTOZI(VAEMixin, UnsupervisedTrainingMixin, RNASeqMixin, BaseModelClass):
"""Automatic identification of zero-inflated genes :cite:p:`Clivio19`.
Parameters
Expand Down
7 changes: 0 additions & 7 deletions src/scvi/model/_destvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,6 @@ def get_proportions(
index=index_names,
)

def get_normalized_expression(
self,
):
# Refer to function get_accessibility_estimates
msg = f"get_normalized_expression is not implemented for {self.__class__.__name__}."
raise NotImplementedError(msg)

def get_gamma(
self,
indices: Sequence[int] | None = None,
Expand Down
5 changes: 3 additions & 2 deletions src/scvi/model/_multivi.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def __init__(
self.n_genes = n_genes
self.n_regions = n_regions
self.n_proteins = n_proteins
self.get_normalized_function_name = "get_normalized_accessibility"

@devices_dsp.dedent
def train(
Expand Down Expand Up @@ -489,7 +490,7 @@ def get_latent_representation(
return torch.cat(latent).numpy()

@torch.inference_mode()
def get_accessibility_estimates(
def get_normalized_accessibility(
self,
adata: AnnOrMuData | None = None,
indices: Sequence[int] = None,
Expand Down Expand Up @@ -801,7 +802,7 @@ def differential_accessibility(
adata = self._validate_anndata(adata)
col_names = adata.var_names[: self.n_genes]
model_fn = partial(
self.get_accessibility_estimates, use_z_mean=False, batch_size=batch_size
self.get_normalized_accessibility, use_z_mean=False, batch_size=batch_size
)

all_stats_fn = partial(
Expand Down
5 changes: 3 additions & 2 deletions src/scvi/model/_peakvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry
else []
)
self.get_normalized_function_name = "get_normalized_accessibility"

self.module = self._module_cls(
n_input_regions=self.summary_stats.n_vars,
Expand Down Expand Up @@ -306,7 +307,7 @@ def get_region_factors(self):
return torch.sigmoid(self.module.region_factors).cpu().numpy()

@torch.inference_mode()
def get_accessibility_estimates(
def get_normalized_accessibility(
self,
adata: AnnData | None = None,
indices: Sequence[int] = None,
Expand Down Expand Up @@ -501,7 +502,7 @@ def differential_accessibility(
adata = self._validate_anndata(adata)
col_names = adata.var_names
model_fn = partial(
self.get_accessibility_estimates, use_z_mean=False, batch_size=batch_size
self.get_normalized_accessibility, use_z_mean=False, batch_size=batch_size
)

result = _de_core(
Expand Down
16 changes: 16 additions & 0 deletions src/scvi/model/base/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(self, adata: AnnOrMuData | None = None):
self.test_indices_ = None
self.validation_indices_ = None
self.history_ = None
self.get_normalized_function_name_ = "get_normalized_expression"

@property
def adata(self) -> AnnOrMuData:
Expand Down Expand Up @@ -808,6 +809,15 @@ def summary_string(self):
)
return summary_string

@property
def get_normalized_function_name(self) -> str:
"""What the get normalized functions name is"""
return self.get_normalized_function_name_

@get_normalized_function_name.setter
def get_normalized_function_name(self, value):
self.get_normalized_function_name_ = value

def __repr__(self):
rich.print(self.summary_string)
return ""
Expand Down Expand Up @@ -893,6 +903,12 @@ def view_anndata_setup(
) from err
adata_manager.view_registry(hide_state_registries=hide_state_registries)

def get_normalized_expression(
self,
):
msg = f"get_normalized_expression is not implemented for {self.__class__.__name__}."
raise NotImplementedError(msg)


class BaseMinifiedModeModelClass(BaseModelClass):
"""Abstract base class for scvi-tools models that can handle minified data."""
Expand Down
Loading

0 comments on commit c538dbc

Please sign in to comment.