From c538dbcf59be6fc1a2ab5504c4211d74241ac400 Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Sun, 9 Mar 2025 15:46:23 +0200 Subject: [PATCH] feat: get_norm function property (#3238) added get_norm property and check for all models + also the criticism guide --- CHANGELOG.md | 12 +++-- docs/user_guide/index.md | 1 + docs/user_guide/models/multivi.md | 4 +- docs/user_guide/models/peakvi.md | 2 +- .../user_guide/use_case/custom_dataloaders.md | 4 +- docs/user_guide/use_case/scvi_criticism.md | 32 +++++++++++ src/scvi/external/cellassign/_model.py | 2 +- src/scvi/external/gimvi/_model.py | 2 +- .../external/methylvi/_methylanvi_model.py | 3 +- src/scvi/external/methylvi/_methylvi_model.py | 3 +- src/scvi/external/poissonvi/_model.py | 5 +- src/scvi/external/sysvi/_model.py | 4 +- src/scvi/hub/_template.py | 16 +++--- src/scvi/model/_autozi.py | 2 +- src/scvi/model/_destvi.py | 7 --- src/scvi/model/_multivi.py | 5 +- src/scvi/model/_peakvi.py | 5 +- src/scvi/model/base/_base_model.py | 16 ++++++ tests/external/poissonvi/test_poissonvi.py | 2 +- tests/model/test_multivi.py | 54 +++++++++---------- tests/model/test_peakvi.py | 6 +-- tests/model/test_scanvi.py | 1 + 22 files changed, 119 insertions(+), 69 deletions(-) create mode 100644 docs/user_guide/use_case/scvi_criticism.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 24a13aee60..3a663777d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,9 @@ to [Semantic Versioning]. Full commit history is available in the - Add {class}`scvi.external.METHYLANVI` for modeling methylation labelled data from single-cell bisulfite sequencing (scBS-seq) {pr}`3066`. -- Add supervised module class {class}`scvi.model.base.SupervisedModuleClass`. {pr}`3237`. +- Add supervised module class {class}`scvi.module.base.SupervisedModuleClass`. {pr}`3237`. +- Add get normalized function model property for any generative model {pr}`3238` and changed + get_accessibility_estimates to get_normalized_accessibility, where needed. #### Fixed @@ -101,7 +103,7 @@ to [Semantic Versioning]. Full commit history is available in the - Added adaptive handling for last training minibatch of 1-2 cells in case of `datasplitter_kwargs={"drop_last": False}` and `train_size = None` by moving them into validation set, if available. {pr}`3036`. -- Add `batch_key` and `labels_key` to `scvi.external.SCAR.setup_anndata`. {pr}`3045`. +- Add `batch_key` and `labels_key` to {meth}`scvi.external.SCAR.setup_anndata`. {pr}`3045`. - Implemented variance of ZINB distribution. {pr}`3044`. - Support for minified mode while retaining counts to skip the encoder. - New Trainingplan argument `update_only_decoder` to use stored latent codes and skip training of @@ -115,7 +117,7 @@ to [Semantic Versioning]. Full commit history is available in the - Breaking Change: Fix `get_outlier_cell_sample_pairs` function in {class}`scvi.external.MRVI` to correctly compute the maxmimum log-density across in-sample cells rather than the aggregated posterior log-density {pr}`3007`. -- Fix references to `scvi.external` in `scvi.external.SCAR.setup_anndata`. +- Fix references to `scvi.external` in {meth}`scvi.external.SCAR.setup_anndata`. - Fix gimVI to append mini batches first into CPU during get_imputed and get_latent operations {pr}`3058`. #### Changed @@ -127,9 +129,9 @@ to [Semantic Versioning]. Full commit history is available in the #### Added - Add support for Python 3.12 {pr}`2966`. -- Add support for categorial covariates in scArches in `scvi.model.archesmixin` {pr}`2936`. +- Add support for categorial covariates in scArches in {meth}`scvi.model.archesmixin` {pr}`2936`. - Add assertion error in cellAssign for checking duplicates in celltype markers {pr}`2951`. -- Add `scvi.external.poissonvi.get_region_factors` {pr}`2940`. +- Add {meth}`scvi.external.poissonvi.get_region_factors` {pr}`2940`. - {attr}`scvi.settings.dl_persistent_workers` allows using persistent workers in {class}`scvi.dataloaders.AnnDataLoader` {pr}`2924`. - Add option for using external indexes in data splitting classes that are under `scvi.dataloaders` diff --git a/docs/user_guide/index.md b/docs/user_guide/index.md index 06e18aabe7..1e6cb10200 100644 --- a/docs/user_guide/index.md +++ b/docs/user_guide/index.md @@ -187,6 +187,7 @@ scvi-tools is composed of models that can perform one or many analysis tasks. In - {doc}`/user_guide/use_case/hyper_parameters_tuning` - {doc}`/user_guide/use_case/multi_gpu_training` - {doc}`/user_guide/use_case/custom_dataloaders` +- {doc}`/user_guide/use_case/scvi_criticism` ## Glossary diff --git a/docs/user_guide/models/multivi.md b/docs/user_guide/models/multivi.md index a552622706..2cb230cca2 100644 --- a/docs/user_guide/models/multivi.md +++ b/docs/user_guide/models/multivi.md @@ -183,7 +183,7 @@ When gene expression data is passed, MultiVI computes denoised gene expression d ### Denoising/imputation of accessibility -In {func}`~scvi.model.MULTIVI.get_accessibility_estimates` MultiVI returns the expected value of $y_i$ under the +In {func}`~scvi.model.MULTIVI.get_normalized_accessibility` MultiVI returns the expected value of $y_i$ under the approximate posterior. For one cell $i$, this can be written as: ```{math} @@ -199,7 +199,7 @@ then, using this value to decode the accessibility probability estimate $p_r$. A the variational mean, a number of samples can be passed as an argument in the code: ``` ->>> model.get_accessibility_estimates(n_samples_overall=10) +>>> model.get_normalized_accessibility(n_samples_overall=10) ``` This value is used to compute the mean of the latent variable over these samples. Notably, this function also has diff --git a/docs/user_guide/models/peakvi.md b/docs/user_guide/models/peakvi.md index 44411cb661..3816364260 100644 --- a/docs/user_guide/models/peakvi.md +++ b/docs/user_guide/models/peakvi.md @@ -132,7 +132,7 @@ A PeakVI model can be pre-trained on reference data and updated with query data ### Estimation of accessibility -In {func}`~scvi.model.PEAKVI.get_accessibility_estimates` PeakVI returns the expected value of $y_i$ under the approximate posterior. For one cell $i$, this can be written as: +In {func}`~scvi.model.PEAKVI.get_normalized_accessibility` PeakVI returns the expected value of $y_i$ under the approximate posterior. For one cell $i$, this can be written as: ```{math} :nowrap: true diff --git a/docs/user_guide/use_case/custom_dataloaders.md b/docs/user_guide/use_case/custom_dataloaders.md index 841b68b3ae..086b8d42ec 100644 --- a/docs/user_guide/use_case/custom_dataloaders.md +++ b/docs/user_guide/use_case/custom_dataloaders.md @@ -10,7 +10,7 @@ For example, we offer custom dataloaders that do not necessarily store the data Without dataloader large data can be on disk but inefficient. We increase efficiency for data on disk and enable data on cloud storage. In SCVI, we work with several collaborators in order to construct efficient custom dataloaders: -1. [lamin.ai]("https://lamin.ai/") custom dataloader is based on MappedCollectionDataModule and can run a collection of adata based on lamindb backend. +1. [lamin.ai](https://lamin.ai/) custom dataloader is based on MappedCollectionDataModule and can run a collection of adata based on lamindb backend. LamindDB is a key-value store designed specifically for machine learning, particularly focusing on making it easier to manage large-scale training datasets. It is optimized for storing and querying tabular data, providing fast read and write access. LamindDB’s design allows for efficient handling of large datasets with a focus on machine learning workflows, such as those used in SCVI. @@ -39,7 +39,7 @@ model = scvi.model.SCVI(adata=None, registry=datamodule.registry) ``` LamindDB may not be as efficient or flexible as TileDB for handling complex multi-dimensional data -2. [CZI]("https://chanzuckerberg.com/") based [tiledb]("https://tiledb.com/") custom dataloader is based on CensusSCVIDataModule and can run a large multi-dimensional datasets that are stored in TileDB’s format. +2. [CZI](https://chanzuckerberg.com/) based [tiledb](https://tiledb.com/) custom dataloader is based on CensusSCVIDataModule and can run a large multi-dimensional datasets that are stored in TileDB’s format. TileDB is a general-purpose, multi-dimensional array storage engine designed for high-performance, scalable data access. It supports various data types, including dense and sparse arrays, and is optimized for handling large datasets efficiently. TileDB’s strength lies in its ability to store and query data across multiple dimensions and scale efficiently with large volumes of data. diff --git a/docs/user_guide/use_case/scvi_criticism.md b/docs/user_guide/use_case/scvi_criticism.md new file mode 100644 index 0000000000..c498b2dbdb --- /dev/null +++ b/docs/user_guide/use_case/scvi_criticism.md @@ -0,0 +1,32 @@ +# SCVI Criticism + +:::{note} +This page is under construction. +::: + +SCVI-Criticism is a tool in the scvi-tools ecosystem that helps assess the performance and quality of single-cell RNA sequencing (scRNA-seq) data analysis using generative models like scVI (single-cell Variational Inference). It provides a framework for evaluating the models' predictions and helps researchers understand the strengths and weaknesses of their models. One of its main advantages is that it allows for robust evaluation using real or simulated datasets, offering insight into model robustness, overfitting, and predictive performance. + +Underneath the hood it uses posterior predictive checks or PPC ({class}`scvi.criticism.PosteriorPredictiveCheck`) for comparing scRNA-seq generative models. + +The method works well for any SCVI model out of the box to provide insights into the quality of predictions and model evaluation. + +### There are few metrics we calculate in order to achieve that: +- **Cell Wise Coefficient of variation:** The cell-wise coefficient of variation summarizes how well variation between different cells is preserved by the generated model expression. Below a squared Pearson correlation coefficient of 0.4 , we would recommend not to use generated data for downstream analysis, while the generated latent space might still be useful for analysis. +- **Gene Wise Coefficient of variation:** The gene-wise coefficient of variation summarizes how well variation between different genes is preserved by the generated model expression. This value is usually quite high. +- **Differential expression metric:** The differential expression metric provides a summary of the differential expression analysis between cell types or input clusters. We provide the F1-score, Pearson Correlation Coefficient of Log-Foldchanges, Spearman Correlation Coefficient, and Area Under the Precision Recall Curve (AUPRC) for the differential expression analysis using Wilcoxon Rank Sum test for each cell-type. + +### Example of use: +We can compute and compare 2 models PPC's by: +```python +models_dict = {"model1": model1, "model2": model2} +ppc = PPC(adata, models_dict) +``` + +### Creating Report: +A good practice will be to save this report together with the model, just like we do for weights and adata. +For example, for all our [SCVI-Hub](https://huggingface.co/scvi-tools) models we attached the SCVI criticism reports. + +We can create the criticism report simply by: +```python +create_criticism_report(model) +``` diff --git a/src/scvi/external/cellassign/_model.py b/src/scvi/external/cellassign/_model.py index dacd7dcbbd..5fba75eb2c 100644 --- a/src/scvi/external/cellassign/_model.py +++ b/src/scvi/external/cellassign/_model.py @@ -33,7 +33,7 @@ B = 10 -class CellAssign(UnsupervisedTrainingMixin, BaseModelClass, RNASeqMixin): +class CellAssign(UnsupervisedTrainingMixin, RNASeqMixin, BaseModelClass): """Reimplementation of CellAssign for reference-based annotation :cite:p:`Zhang19`. Original implementation: https://github.com/irrationone/cellassign. diff --git a/src/scvi/external/gimvi/_model.py b/src/scvi/external/gimvi/_model.py index 608e1de938..ab187b7cdc 100644 --- a/src/scvi/external/gimvi/_model.py +++ b/src/scvi/external/gimvi/_model.py @@ -41,7 +41,7 @@ def _unpack_tensors(tensors): return x, batch_index, y -class GIMVI(VAEMixin, BaseModelClass, RNASeqMixin): +class GIMVI(VAEMixin, RNASeqMixin, BaseModelClass): """Joint VAE for imputing missing genes in spatial data :cite:p:`Lopez19`. Parameters diff --git a/src/scvi/external/methylvi/_methylanvi_model.py b/src/scvi/external/methylvi/_methylanvi_model.py index 64a5e6fd0a..b35ac32001 100644 --- a/src/scvi/external/methylvi/_methylanvi_model.py +++ b/src/scvi/external/methylvi/_methylanvi_model.py @@ -109,6 +109,7 @@ def __init__( "methylation_contexts" ] self.num_features_per_context = [mdata[context].shape[1] for context in self.contexts] + self.get_normalized_function_name = "get_normalized_methylation" n_input = np.sum(self.num_features_per_context) @@ -196,7 +197,7 @@ def setup_mudata( (specified by `mc_layer`) and total number of counts (specified by `cov_layer`) for each genomic region feature. %(param_batch_key)s - %(param_categorical_covariate_keys)s + %(param_cat_cov_keys)s %(param_modalities)s Examples diff --git a/src/scvi/external/methylvi/_methylvi_model.py b/src/scvi/external/methylvi/_methylvi_model.py index 2e44b78666..dd80d827a6 100644 --- a/src/scvi/external/methylvi/_methylvi_model.py +++ b/src/scvi/external/methylvi/_methylvi_model.py @@ -81,6 +81,7 @@ def __init__( "methylation_contexts" ] self.num_features_per_context = [mdata[context].shape[1] for context in self.contexts] + self.get_normalized_function_name = "get_normalized_methylation" n_input = np.sum(self.num_features_per_context) @@ -151,7 +152,7 @@ def setup_mudata( (specified by `mc_layer`) and total number of counts (specified by `cov_layer`) for each genomic region feature. %(param_batch_key)s - %(param_categorical_covariate_keys)s + %(param_cat_cov_keys)s %(param_modalities)s Examples diff --git a/src/scvi/external/poissonvi/_model.py b/src/scvi/external/poissonvi/_model.py index 82bbcdcc5b..62186a2ac1 100644 --- a/src/scvi/external/poissonvi/_model.py +++ b/src/scvi/external/poissonvi/_model.py @@ -96,6 +96,7 @@ def __init__( library_log_means, library_log_vars = _init_library_size(self.adata_manager, n_batch) self._module_cls = VAE + self.get_normalized_function_name = "get_normalized_accessibility" self.module = self._module_cls( n_input=self.summary_stats.n_vars, @@ -135,7 +136,7 @@ def __init__( self.init_params_ = self._get_init_params(locals()) @torch.inference_mode() - def get_accessibility_estimates( + def get_normalized_accessibility( self, adata: AnnData | None = None, indices: Sequence[int] = None, @@ -334,7 +335,7 @@ def differential_accessibility( col_names = adata.var_names importance_weighting_kwargs = importance_weighting_kwargs or {} model_fn = partial( - self.get_accessibility_estimates, + self.get_normalized_accessibility, return_numpy=True, n_samples=1, batch_size=batch_size, diff --git a/src/scvi/external/sysvi/_model.py b/src/scvi/external/sysvi/_model.py index 352c3635a8..978fd5c4c0 100644 --- a/src/scvi/external/sysvi/_model.py +++ b/src/scvi/external/sysvi/_model.py @@ -183,8 +183,8 @@ def setup_anndata( layer AnnData layer to use, default is X. Should contain normalized and log1p transformed expression. - %(param_categorical_covariate_keys)s - %(param_continuous_covariate_keys)s + %(param_cat_cov_keys)s + %(param_cont_cov_keys)s """ setup_method_args = cls._get_setup_method_args(**locals()) diff --git a/src/scvi/hub/_template.py b/src/scvi/hub/_template.py index 372b224a3f..f768ad5211 100644 --- a/src/scvi/hub/_template.py +++ b/src/scvi/hub/_template.py @@ -5,7 +5,7 @@ clustering. scVI takes as input a scRNA-seq gene expression matrix with cells and genes. -We provide an extensive [user guide](https://docs.scvi-tools.org/en/1.2.0/user_guide/models/scvi.html). +We provide an extensive [user guide](https://docs.scvi-tools.org/en/stable/user_guide/models/scvi.html). - See our original manuscript for further details of the model: [scVI manuscript](https://www.nature.com/articles/s41592-018-0229-2). @@ -13,7 +13,7 @@ to leverage pre-trained models. This model can be used for fine tuning on new data using our Arches framework: -[Arches tutorial](https://docs.scvi-tools.org/en/1.0.0/tutorials/notebooks/scarches_scvi_tools.html). +[Arches tutorial](https://docs.scvi-tools.org/en/stable/tutorials/notebooks/scrna/scarches_scvi_tools.html). """ scanvi_pretext = """ @@ -26,7 +26,7 @@ scANVI takes as input a scRNA-seq gene expression matrix with cells and genes as well as a cell-type annotation for a subset of cells. -We provide an extensive [user guide](https://docs.scvi-tools.org/en/1.2.0/user_guide/models/scanvi.html). +We provide an extensive [user guide](https://docs.scvi-tools.org/en/stable/user_guide/models/scanvi.html). - See our original manuscript for further details of the model: [scANVI manuscript](https://www.embopress.org/doi/full/10.15252/msb.20209620). @@ -34,7 +34,7 @@ how to leverage pre-trained models. This model can be used for fine tuning on new data using our Arches framework: -[Arches tutorial](https://docs.scvi-tools.org/en/1.0.0/tutorials/notebooks/scarches_scvi_tools.html). +[Arches tutorial](https://docs.scvi-tools.org/en/stable/tutorials/notebooks/scrna/scarches_scvi_tools.html). """ condscvi_pretext = """ @@ -46,7 +46,7 @@ CondSCVI takes as input a scRNA-seq gene expression matrix with cells and genes as well as a cell-type annotation for all cells. -We provide an extensive [user guide](https://docs.scvi-tools.org/en/1.2.0/user_guide/models/destvi.html) +We provide an extensive [user guide](https://docs.scvi-tools.org/en/stable/user_guide/models/destvi.html) for DestVI including a description of CondSCVI. - See our original manuscript for further details of the model: @@ -64,7 +64,7 @@ Stereoscope takes as input a scRNA-seq gene expression matrix with cells and genes as well as a cell-type annotation for all cells. We provide an extensive for DestVI including a description of CondSCVI -[user guide](https://docs.scvi-tools.org/en/1.2.0/user_guide/models/destvi.html). +[user guide](https://docs.scvi-tools.org/en/stable/user_guide/models/destvi.html). - See our original manuscript for further details of the model: [Stereoscope manuscript](https://www.nature.com/articles/s42003-020-01247-y) as well as the @@ -84,7 +84,7 @@ TotalVI takes as input a scRNA-seq gene expression and protein expression matrix with cells and genes. -We provide an extensive [user guide](https://docs.scvi-tools.org/en/1.2.0/user_guide/models/totalvi.html). +We provide an extensive [user guide](https://docs.scvi-tools.org/en/stable/user_guide/models/totalvi.html). - See our original manuscript for further details of the model: [TotalVI manuscript](https://www.nature.com/articles/s41592-020-01050-x). @@ -92,7 +92,7 @@ how to leverage pre-trained models. This model can be used for fine tuning on new data using our Arches framework: -[Arches tutorial](https://docs.scvi-tools.org/en/1.0.0/tutorials/notebooks/scarches_scvi_tools.html). +[Arches tutorial](https://docs.scvi-tools.org/en/stable/tutorials/notebooks/scrna/scarches_scvi_tools.html). """ diff --git a/src/scvi/model/_autozi.py b/src/scvi/model/_autozi.py index e6ad3499cd..8e3bb7ed3e 100644 --- a/src/scvi/model/_autozi.py +++ b/src/scvi/model/_autozi.py @@ -29,7 +29,7 @@ # register buffer -class AUTOZI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, RNASeqMixin): +class AUTOZI(VAEMixin, UnsupervisedTrainingMixin, RNASeqMixin, BaseModelClass): """Automatic identification of zero-inflated genes :cite:p:`Clivio19`. Parameters diff --git a/src/scvi/model/_destvi.py b/src/scvi/model/_destvi.py index 0640e052c6..fb5f5a9add 100644 --- a/src/scvi/model/_destvi.py +++ b/src/scvi/model/_destvi.py @@ -220,13 +220,6 @@ def get_proportions( index=index_names, ) - def get_normalized_expression( - self, - ): - # Refer to function get_accessibility_estimates - msg = f"get_normalized_expression is not implemented for {self.__class__.__name__}." - raise NotImplementedError(msg) - def get_gamma( self, indices: Sequence[int] | None = None, diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 3a657b1eac..4c69dad9a0 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -235,6 +235,7 @@ def __init__( self.n_genes = n_genes self.n_regions = n_regions self.n_proteins = n_proteins + self.get_normalized_function_name = "get_normalized_accessibility" @devices_dsp.dedent def train( @@ -489,7 +490,7 @@ def get_latent_representation( return torch.cat(latent).numpy() @torch.inference_mode() - def get_accessibility_estimates( + def get_normalized_accessibility( self, adata: AnnOrMuData | None = None, indices: Sequence[int] = None, @@ -801,7 +802,7 @@ def differential_accessibility( adata = self._validate_anndata(adata) col_names = adata.var_names[: self.n_genes] model_fn = partial( - self.get_accessibility_estimates, use_z_mean=False, batch_size=batch_size + self.get_normalized_accessibility, use_z_mean=False, batch_size=batch_size ) all_stats_fn = partial( diff --git a/src/scvi/model/_peakvi.py b/src/scvi/model/_peakvi.py index 1d72636ed0..ddf49f11e4 100644 --- a/src/scvi/model/_peakvi.py +++ b/src/scvi/model/_peakvi.py @@ -114,6 +114,7 @@ def __init__( if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry else [] ) + self.get_normalized_function_name = "get_normalized_accessibility" self.module = self._module_cls( n_input_regions=self.summary_stats.n_vars, @@ -306,7 +307,7 @@ def get_region_factors(self): return torch.sigmoid(self.module.region_factors).cpu().numpy() @torch.inference_mode() - def get_accessibility_estimates( + def get_normalized_accessibility( self, adata: AnnData | None = None, indices: Sequence[int] = None, @@ -501,7 +502,7 @@ def differential_accessibility( adata = self._validate_anndata(adata) col_names = adata.var_names model_fn = partial( - self.get_accessibility_estimates, use_z_mean=False, batch_size=batch_size + self.get_normalized_accessibility, use_z_mean=False, batch_size=batch_size ) result = _de_core( diff --git a/src/scvi/model/base/_base_model.py b/src/scvi/model/base/_base_model.py index 06c1b9a4c4..125ba19aea 100644 --- a/src/scvi/model/base/_base_model.py +++ b/src/scvi/model/base/_base_model.py @@ -120,6 +120,7 @@ def __init__(self, adata: AnnOrMuData | None = None): self.test_indices_ = None self.validation_indices_ = None self.history_ = None + self.get_normalized_function_name_ = "get_normalized_expression" @property def adata(self) -> AnnOrMuData: @@ -808,6 +809,15 @@ def summary_string(self): ) return summary_string + @property + def get_normalized_function_name(self) -> str: + """What the get normalized functions name is""" + return self.get_normalized_function_name_ + + @get_normalized_function_name.setter + def get_normalized_function_name(self, value): + self.get_normalized_function_name_ = value + def __repr__(self): rich.print(self.summary_string) return "" @@ -893,6 +903,12 @@ def view_anndata_setup( ) from err adata_manager.view_registry(hide_state_registries=hide_state_registries) + def get_normalized_expression( + self, + ): + msg = f"get_normalized_expression is not implemented for {self.__class__.__name__}." + raise NotImplementedError(msg) + class BaseMinifiedModeModelClass(BaseModelClass): """Abstract base class for scvi-tools models that can handle minified data.""" diff --git a/tests/external/poissonvi/test_poissonvi.py b/tests/external/poissonvi/test_poissonvi.py index 257b4c2dfe..d498efa15e 100644 --- a/tests/external/poissonvi/test_poissonvi.py +++ b/tests/external/poissonvi/test_poissonvi.py @@ -10,7 +10,7 @@ def test_poissonvi(): model = POISSONVI(adata) model.train(max_epochs=1) model.get_latent_representation() - model.get_accessibility_estimates() + model.get_normalized_accessibility() model.get_region_factors() model.get_normalized_expression() model.get_normalized_expression(transform_batch="batch_1") diff --git a/tests/model/test_multivi.py b/tests/model/test_multivi.py index 5c81e44084..5640bdf538 100644 --- a/tests/model/test_multivi.py +++ b/tests/model/test_multivi.py @@ -28,9 +28,9 @@ def test_multivi(): vae.train(1, adversarial_mixing=False) vae.train(3) vae.get_elbo(indices=vae.validation_indices) - vae.get_accessibility_estimates() - vae.get_accessibility_estimates(normalize_cells=True) - vae.get_accessibility_estimates(normalize_regions=True) + vae.get_normalized_accessibility() + vae.get_normalized_accessibility(normalize_cells=True) + vae.get_normalized_accessibility(normalize_regions=True) vae.get_normalized_expression() vae.get_library_size_factors() vae.get_region_factors() @@ -167,17 +167,17 @@ def test_multivi_mudata_trimodal_external(): model.get_elbo() model.get_reconstruction_error() model.get_normalized_expression() - model.get_accessibility_estimates() - model.get_accessibility_estimates(normalize_cells=True) - model.get_accessibility_estimates(normalize_regions=True) + model.get_normalized_accessibility() + model.get_normalized_accessibility(normalize_cells=True) + model.get_normalized_accessibility(normalize_regions=True) model.get_library_size_factors() model.get_region_factors() model.get_elbo(indices=model.validation_indices) model.get_reconstruction_error(indices=model.validation_indices) - model.get_accessibility_estimates() - model.get_accessibility_estimates(normalize_cells=True) - model.get_accessibility_estimates(normalize_regions=True) + model.get_normalized_accessibility() + model.get_normalized_accessibility(normalize_cells=True) + model.get_normalized_accessibility(normalize_regions=True) model.get_library_size_factors() model.get_region_factors() @@ -209,17 +209,17 @@ def test_multivi_mudata(n_genes: int, n_regions: int): model.get_reconstruction_error() model.get_normalized_expression() model.get_normalized_expression(transform_batch=["batch_0", "batch_1"]) - model.get_accessibility_estimates() - model.get_accessibility_estimates(normalize_cells=True) - model.get_accessibility_estimates(normalize_regions=True) + model.get_normalized_accessibility() + model.get_normalized_accessibility(normalize_cells=True) + model.get_normalized_accessibility(normalize_regions=True) model.get_library_size_factors() model.get_region_factors() model.get_elbo(indices=model.validation_indices) model.get_reconstruction_error(indices=model.validation_indices) - model.get_accessibility_estimates() - model.get_accessibility_estimates(normalize_cells=True) - model.get_accessibility_estimates(normalize_regions=True) + model.get_normalized_accessibility() + model.get_normalized_accessibility(normalize_cells=True) + model.get_normalized_accessibility(normalize_regions=True) model.get_library_size_factors() model.get_region_factors() @@ -236,9 +236,9 @@ def test_multivi_mudata(n_genes: int, n_regions: int): mdata3 = synthetic_iid(return_mudata=True) mdata3.obs["_indices"] = np.arange(mdata3.n_obs) model.get_elbo(mdata3[:10]) - model.get_accessibility_estimates() - model.get_accessibility_estimates(normalize_cells=True) - model.get_accessibility_estimates(normalize_regions=True) + model.get_normalized_accessibility() + model.get_normalized_accessibility(normalize_cells=True) + model.get_normalized_accessibility(normalize_regions=True) model.get_library_size_factors() model.get_region_factors() @@ -259,9 +259,9 @@ def test_multivi_auto_transfer_mudata(): mdata2 = MuData({"rna": adata2, "protein": protein_adata2}) mdata2.obs["_indices"] = np.arange(mdata2.n_obs) model.get_elbo(mdata2) - model.get_accessibility_estimates() - model.get_accessibility_estimates(normalize_cells=True) - model.get_accessibility_estimates(normalize_regions=True) + model.get_normalized_accessibility() + model.get_normalized_accessibility(normalize_cells=True) + model.get_normalized_accessibility(normalize_regions=True) model.get_library_size_factors() model.get_region_factors() @@ -302,9 +302,9 @@ def test_multivi_reordered_mapping_mudata(): adata2.obs.batch = adata2.obs.batch.cat.rename_categories(["batch_1", "batch_0"]) mdata2.obs["_indices"] = np.arange(mdata2.n_obs) model.get_elbo(mdata2) - model.get_accessibility_estimates() - model.get_accessibility_estimates(normalize_cells=True) - model.get_accessibility_estimates(normalize_regions=True) + model.get_normalized_accessibility() + model.get_normalized_accessibility(normalize_cells=True) + model.get_normalized_accessibility(normalize_regions=True) model.get_library_size_factors() model.get_region_factors() @@ -324,9 +324,9 @@ def test_multivi_model_library_size_mudata(): model.train(1, train_size=0.5) assert model.is_trained is True model.get_elbo() - model.get_accessibility_estimates() - model.get_accessibility_estimates(normalize_cells=True) - model.get_accessibility_estimates(normalize_regions=True) + model.get_normalized_accessibility() + model.get_normalized_accessibility(normalize_cells=True) + model.get_normalized_accessibility(normalize_regions=True) model.get_library_size_factors() model.get_region_factors() diff --git a/tests/model/test_peakvi.py b/tests/model/test_peakvi.py index 5402d5c64e..9aef1f4260 100644 --- a/tests/model/test_peakvi.py +++ b/tests/model/test_peakvi.py @@ -131,9 +131,9 @@ def test_peakvi(): ) vae.train(3) vae.get_elbo(indices=vae.validation_indices) - vae.get_accessibility_estimates() - vae.get_accessibility_estimates(normalize_cells=True) - vae.get_accessibility_estimates(normalize_regions=True) + vae.get_normalized_accessibility() + vae.get_normalized_accessibility(normalize_cells=True) + vae.get_normalized_accessibility(normalize_regions=True) vae.get_library_size_factors() vae.get_region_factors() vae.get_reconstruction_error(indices=vae.validation_indices) diff --git a/tests/model/test_scanvi.py b/tests/model/test_scanvi.py index 632d1eace1..aa478189c1 100644 --- a/tests/model/test_scanvi.py +++ b/tests/model/test_scanvi.py @@ -614,6 +614,7 @@ def test_scanvi_interpertability_ig(unlabeled_cat: str): print(ig_top_features_3_samples) +@pytest.mark.optional @pytest.mark.parametrize("unlabeled_cat", ["label_0"]) def test_scanvi_interpertability_shap(unlabeled_cat: str): adata = synthetic_iid(batch_size=50)