diff --git a/CHANGELOG.md b/CHANGELOG.md index 6fa825537a..dccaf6ea20 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ to [Semantic Versioning]. Full commit history is available in the - Add {class}`scvi.external.Decipher` for dimensionality reduction and interpretable representation learning in single-cell RNA sequencing data {pr}`3015`, {pr}`3091`. +- Add {meth}`~scvi.model.SCVI.get_normalized_expression` for models: PeakVI, PoissonVI, CondSCVI, + AutoZI, CellAssign and GimVI. {pr}`3121` - Add {class}`scvi.external.RESOLVI` for bias correction in single-cell resolved spatial transcriptomics {pr}`3144`. diff --git a/src/scvi/external/cellassign/_model.py b/src/scvi/external/cellassign/_model.py index 263df9dfce..dacd7dcbbd 100644 --- a/src/scvi/external/cellassign/_model.py +++ b/src/scvi/external/cellassign/_model.py @@ -20,7 +20,7 @@ from scvi.dataloaders import DataSplitter from scvi.external.cellassign._module import CellAssignModule from scvi.model._utils import get_max_epochs_heuristic -from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin +from scvi.model.base import BaseModelClass, RNASeqMixin, UnsupervisedTrainingMixin from scvi.train import LoudEarlyStopping, TrainingPlan, TrainRunner from scvi.utils import setup_anndata_dsp from scvi.utils._docstrings import devices_dsp @@ -33,7 +33,7 @@ B = 10 -class CellAssign(UnsupervisedTrainingMixin, BaseModelClass): +class CellAssign(UnsupervisedTrainingMixin, BaseModelClass, RNASeqMixin): """Reimplementation of CellAssign for reference-based annotation :cite:p:`Zhang19`. Original implementation: https://github.com/irrationone/cellassign. diff --git a/src/scvi/external/cellassign/_module.py b/src/scvi/external/cellassign/_module.py index ee68e97a31..775957bed8 100644 --- a/src/scvi/external/cellassign/_module.py +++ b/src/scvi/external/cellassign/_module.py @@ -6,14 +6,15 @@ from scvi import REGISTRY_KEYS from scvi.distributions import NegativeBinomial -from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data +from scvi.module import VAE +from scvi.module.base import LossOutput, auto_move_data LOWER_BOUND = 1e-10 THETA_LOWER_BOUND = 1e-20 B = 10 -class CellAssignModule(BaseModuleClass): +class CellAssignModule(VAE): """Model for CellAssign. Parameters @@ -51,7 +52,7 @@ def __init__( n_cats_per_cov: Iterable[int] | None = None, n_continuous_cov: int = 0, ): - super().__init__() + super().__init__(n_genes) self.n_genes = n_genes self.n_labels = rho.shape[1] self.n_batch = n_batch @@ -103,10 +104,7 @@ def __init__( self.register_buffer("basis_means", torch.tensor(basis_means, dtype=torch.float32)) - def _get_inference_input(self, tensors): - return {} - - def _get_generative_input(self, tensors, inference_outputs): + def _get_generative_input(self, tensors, inference_outputs, transform_batch=None): x = tensors[REGISTRY_KEYS.X_KEY] size_factor = tensors[REGISTRY_KEYS.SIZE_FACTOR_KEY] @@ -127,19 +125,27 @@ def _get_generative_input(self, tensors, inference_outputs): design_matrix = torch.cat(to_cat, dim=1) if len(to_cat) > 0 else None + batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] + if transform_batch is not None: + batch_index = torch.ones_like(batch_index) * transform_batch + input_dict = { "x": x, "size_factor": size_factor, "design_matrix": design_matrix, + "batch_index": batch_index, } return input_dict @auto_move_data - def inference(self): - return {} - - @auto_move_data - def generative(self, x, size_factor, design_matrix=None): + def generative( + self, + x, + size_factor, + batch_index, + design_matrix=None, + transform_batch: torch.Tensor | None = None, + ): """Run the generative model.""" # x has shape (n, g) delta = torch.exp(self.delta_log) # (g, c) @@ -193,12 +199,22 @@ def generative(self, x, size_factor, design_matrix=None): normalizer_over_c = normalizer_over_c.unsqueeze(-1).expand(n_cells, self.n_labels) gamma = torch.exp(p_x_c - normalizer_over_c) # (n, c) + px = torch.sum(x_log_prob_raw, -1) + normalizer_over_c2 = torch.logsumexp(px, 1) + normalizer_over_c2 = normalizer_over_c2.unsqueeze(-1).expand(n_cells, self.n_genes) + gamma2 = torch.exp(px - normalizer_over_c2) # (n, g) + + if transform_batch is not None: + batch_index = torch.ones_like(batch_index) * transform_batch + return { "mu": mu_ngc, "phi": phi, "gamma": gamma, "p_x_c": p_x_c, + "px": gamma2, "s": size_factor, + "batch_index": batch_index, } def loss( diff --git a/src/scvi/external/gimvi/_model.py b/src/scvi/external/gimvi/_model.py index 8bbde7326b..608e1de938 100644 --- a/src/scvi/external/gimvi/_model.py +++ b/src/scvi/external/gimvi/_model.py @@ -17,7 +17,7 @@ from scvi.data.fields import CategoricalObsField, LayerField from scvi.dataloaders import DataSplitter from scvi.model._utils import _init_library_size, parse_device_args -from scvi.model.base import BaseModelClass, VAEMixin +from scvi.model.base import BaseModelClass, RNASeqMixin, VAEMixin from scvi.train import Trainer from scvi.utils import setup_anndata_dsp from scvi.utils._docstrings import devices_dsp @@ -41,7 +41,7 @@ def _unpack_tensors(tensors): return x, batch_index, y -class GIMVI(VAEMixin, BaseModelClass): +class GIMVI(VAEMixin, BaseModelClass, RNASeqMixin): """Joint VAE for imputing missing genes in spatial data :cite:p:`Lopez19`. Parameters diff --git a/src/scvi/external/gimvi/_module.py b/src/scvi/external/gimvi/_module.py index a2db2637c3..8969946aba 100644 --- a/src/scvi/external/gimvi/_module.py +++ b/src/scvi/external/gimvi/_module.py @@ -364,20 +364,31 @@ def reconstruction_loss( reconstruction_loss = -Poisson(px_rate).log_prob(x).sum(dim=1) return reconstruction_loss - def _get_inference_input(self, tensors): + def _get_inference_input(self, tensors) -> dict[str, torch.Tensor | None]: """Get the input for the inference model.""" - return {"x": tensors[REGISTRY_KEYS.X_KEY]} + return { + "x": tensors[REGISTRY_KEYS.X_KEY], + "batch_index": tensors.get(REGISTRY_KEYS.BATCH_KEY, None), + } - def _get_generative_input(self, tensors, inference_outputs): + def _get_generative_input(self, tensors, inference_outputs, transform_batch=None): """Get the input for the generative model.""" z = inference_outputs["z"] library = inference_outputs["library"] batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] y = tensors[REGISTRY_KEYS.LABELS_KEY] + if transform_batch is not None: + batch_index = torch.ones_like(batch_index) * transform_batch return {"z": z, "library": library, "batch_index": batch_index, "y": y} @auto_move_data - def inference(self, x: torch.Tensor, mode: int | None = None) -> dict: + def inference( + self, + x: torch.Tensor, + mode: int | None = 0, + n_samples: int | None = 1, + batch_index: torch.Tensor | None = None, + ) -> dict: """Run the inference model.""" x_ = x if self.log_variational: @@ -390,6 +401,11 @@ def inference(self, x: torch.Tensor, mode: int | None = None) -> dict: else: library = torch.log(torch.sum(x, dim=1)).view(-1, 1) + if n_samples > 1: + # when z is normal, untran_z == z + untran_z = qz.sample((n_samples,)) + z = self.z_encoder.z_transformation(untran_z) + return {"qz": qz, "z": z, "ql": ql, "library": library} @auto_move_data @@ -399,7 +415,8 @@ def generative( library: torch.Tensor, batch_index: torch.Tensor | None = None, y: torch.Tensor | None = None, - mode: int | None = None, + mode: int | None = 0, + transform_batch: torch.Tensor | None = None, ) -> dict: """Run the generative model.""" px_scale, px_r, px_rate, px_dropout = self.decoder( @@ -418,11 +435,18 @@ def generative( ) px_rate = px_scale * torch.exp(library) + if transform_batch is not None: + batch_index = torch.ones_like(batch_index) * transform_batch + + px = NegativeBinomial(mu=px_rate, theta=px_r, scale=px_scale) + return { "px_scale": px_scale, + "px": px, "px_r": px_r, "px_rate": px_rate, "px_dropout": px_dropout, + "batch_index": batch_index, } def loss( diff --git a/src/scvi/external/poissonvi/_model.py b/src/scvi/external/poissonvi/_model.py index c7b7bc61c0..82bbcdcc5b 100644 --- a/src/scvi/external/poissonvi/_model.py +++ b/src/scvi/external/poissonvi/_model.py @@ -244,16 +244,6 @@ def get_region_factors(self): raise RuntimeError("region factors were not included in this model") return region_factors - def get_normalized_expression( - self, - ): - # Refer to function get_accessibility_estimates - msg = ( - f"differential_expression is not implemented for {self.__class__.__name__}, please " - f"use {self.__class__.__name__}.get_accessibility_estimates" - ) - raise NotImplementedError(msg) - @de_dsp.dedent def differential_accessibility( self, diff --git a/src/scvi/model/_autozi.py b/src/scvi/model/_autozi.py index 08b4e35131..e6ad3499cd 100644 --- a/src/scvi/model/_autozi.py +++ b/src/scvi/model/_autozi.py @@ -16,7 +16,7 @@ from scvi.module import AutoZIVAE from scvi.utils import setup_anndata_dsp -from .base import BaseModelClass, VAEMixin +from .base import BaseModelClass, RNASeqMixin, VAEMixin if TYPE_CHECKING: from collections.abc import Sequence @@ -29,7 +29,7 @@ # register buffer -class AUTOZI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): +class AUTOZI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, RNASeqMixin): """Automatic identification of zero-inflated genes :cite:p:`Clivio19`. Parameters diff --git a/src/scvi/model/_condscvi.py b/src/scvi/model/_condscvi.py index b5e49a711d..45af751ff2 100644 --- a/src/scvi/model/_condscvi.py +++ b/src/scvi/model/_condscvi.py @@ -297,9 +297,8 @@ def setup_anndata( anndata_fields = [ fields.LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), fields.CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), + fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), ] - if batch_key is not None: - anndata_fields.append(fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key)) adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/src/scvi/model/_destvi.py b/src/scvi/model/_destvi.py index fb5f5a9add..0640e052c6 100644 --- a/src/scvi/model/_destvi.py +++ b/src/scvi/model/_destvi.py @@ -220,6 +220,13 @@ def get_proportions( index=index_names, ) + def get_normalized_expression( + self, + ): + # Refer to function get_accessibility_estimates + msg = f"get_normalized_expression is not implemented for {self.__class__.__name__}." + raise NotImplementedError(msg) + def get_gamma( self, indices: Sequence[int] | None = None, diff --git a/src/scvi/model/_peakvi.py b/src/scvi/model/_peakvi.py index f10b39c771..1d72636ed0 100644 --- a/src/scvi/model/_peakvi.py +++ b/src/scvi/model/_peakvi.py @@ -29,7 +29,7 @@ from scvi.train._callbacks import SaveBestState from scvi.utils._docstrings import de_dsp, devices_dsp, setup_anndata_dsp -from .base import ArchesMixin, BaseModelClass, VAEMixin +from .base import ArchesMixin, BaseModelClass, RNASeqMixin, VAEMixin if TYPE_CHECKING: from collections.abc import Iterable, Sequence @@ -40,7 +40,7 @@ logger = logging.getLogger(__name__) -class PEAKVI(ArchesMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): +class PEAKVI(ArchesMixin, RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): """Peak Variational Inference for chromatin accessilibity analysis :cite:p:`Ashuach22`. Parameters @@ -393,7 +393,7 @@ def get_accessibility_estimates( generative_kwargs=generative_kwargs, compute_loss=False, ) - p = generative_outputs["p"].cpu() + p = generative_outputs["px"].cpu() if normalize_cells: p *= inference_outputs["d"].cpu() diff --git a/src/scvi/model/base/_rnamixin.py b/src/scvi/model/base/_rnamixin.py index fb40d954a1..eadae9ccb0 100644 --- a/src/scvi/model/base/_rnamixin.py +++ b/src/scvi/model/base/_rnamixin.py @@ -269,7 +269,11 @@ def get_normalized_expression( generative_kwargs=generative_kwargs, compute_loss=False, ) - exp_ = generative_outputs["px"].get_normalized(generative_output_key) + px_generative = generative_outputs["px"] + if isinstance(px_generative, torch.Tensor): + exp_ = px_generative + else: + exp_ = px_generative.get_normalized(generative_output_key) exp_ = exp_[..., gene_mask] exp_ *= scaling per_batch_exprs.append(exp_[None].cpu()) diff --git a/src/scvi/module/_autozivae.py b/src/scvi/module/_autozivae.py index 63579e8480..50dcf8645b 100644 --- a/src/scvi/module/_autozivae.py +++ b/src/scvi/module/_autozivae.py @@ -283,6 +283,7 @@ def generative( cat_covs=None, n_samples: int = 1, eps_log: float = 1e-8, + transform_batch: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: """Run the generative model.""" outputs = super().generative( @@ -293,6 +294,7 @@ def generative( cat_covs=cat_covs, y=y, size_factor=size_factor, + transform_batch=transform_batch, ) # Rescale dropout rescaled_dropout = self.rescale_dropout(outputs["px"].zi_logits, eps_log=eps_log) diff --git a/src/scvi/module/_mrdeconv.py b/src/scvi/module/_mrdeconv.py index 5dd98d76aa..e0d6c40e3f 100644 --- a/src/scvi/module/_mrdeconv.py +++ b/src/scvi/module/_mrdeconv.py @@ -1,7 +1,7 @@ -from collections import OrderedDict -from typing import Literal +from __future__ import annotations + +from typing import TYPE_CHECKING -import numpy as np import torch from torch.distributions import Normal @@ -10,6 +10,12 @@ from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data from scvi.nn import FCLayers +if TYPE_CHECKING: + from collections import OrderedDict + from typing import Literal + + import numpy as np + def identity(x): """Identity function.""" @@ -187,7 +193,9 @@ def _get_generative_input(self, tensors, inference_outputs): x = tensors[REGISTRY_KEYS.X_KEY] ind_x = tensors[REGISTRY_KEYS.INDICES_KEY].long().ravel() - input_dict = {"x": x, "ind_x": ind_x} + batch_index = None # tensors[REGISTRY_KEYS.BATCH_KEY] + + input_dict = {"x": x, "ind_x": ind_x, "batch_index": batch_index} return input_dict @auto_move_data @@ -196,7 +204,7 @@ def inference(self): return {} @auto_move_data - def generative(self, x, ind_x): + def generative(self, x, ind_x, batch_index=None, transform_batch: torch.Tensor | None = None): """Build the deconvolution model for every cell in the minibatch.""" m = x.shape[0] library = torch.sum(x, dim=1, keepdim=True) @@ -206,6 +214,9 @@ def generative(self, x, ind_x): x_ = torch.log(1 + x) # subsample parameters + # if transform_batch is not None: + # batch_index = torch.ones_like(batch_index) * transform_batch + if self.amortization in ["both", "latent"]: gamma_ind = torch.transpose(self.gamma_encoder(x_), 0, 1).reshape( (self.n_latent, self.n_labels, -1) @@ -249,6 +260,7 @@ def generative(self, x, ind_x): "px_scale": px_scale, "gamma": gamma_ind, "v": v_ind, + "batch_index": batch_index, } def loss( diff --git a/src/scvi/module/_peakvae.py b/src/scvi/module/_peakvae.py index 10f4898354..5d5d140a33 100644 --- a/src/scvi/module/_peakvae.py +++ b/src/scvi/module/_peakvae.py @@ -303,6 +303,7 @@ def generative( cont_covs=None, cat_covs=None, use_z_mean=False, + transform_batch: torch.Tensor | None = None, ): """Runs the generative model.""" if cat_covs is not None: @@ -310,6 +311,9 @@ def generative( else: categorical_input = () + if transform_batch is not None: + batch_index = torch.ones_like(batch_index) * transform_batch + latent = z if not use_z_mean else qz_m if cont_covs is None: decoder_input = latent @@ -320,16 +324,16 @@ def generative( else: decoder_input = torch.cat([latent, cont_covs], dim=-1) - p = self.z_decoder(decoder_input, batch_index, *categorical_input) + px = self.z_decoder(decoder_input, batch_index, *categorical_input) - return {"p": p} + return {"px": px} def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0): """Compute the loss.""" x = tensors[REGISTRY_KEYS.X_KEY] qz = inference_outputs["qz"] d = inference_outputs["d"] - p = generative_outputs["p"] + px = generative_outputs["px"] kld = kl_divergence( qz, @@ -337,7 +341,7 @@ def loss(self, tensors, inference_outputs, generative_outputs, kl_weight: float ).sum(dim=1) f = torch.sigmoid(self.region_factors) if self.region_factors is not None else 1 - rl = self.get_reconstruction_loss(p, d, f, x) + rl = self.get_reconstruction_loss(px, d, f, x) loss = (rl.sum() + kld * kl_weight).sum() diff --git a/src/scvi/module/_vaec.py b/src/scvi/module/_vaec.py index 89a4720f85..e20b58d3e6 100644 --- a/src/scvi/module/_vaec.py +++ b/src/scvi/module/_vaec.py @@ -158,7 +158,7 @@ def inference( if self.log_variational: x_ = torch.log1p(x_) - encoder_input = [x, y] + encoder_input = [x_, y] if batch_index is not None and self.encode_covariates: encoder_input.append(batch_index) diff --git a/src/scvi/nn/_base_components.py b/src/scvi/nn/_base_components.py index e967fcf726..24e47b96a4 100644 --- a/src/scvi/nn/_base_components.py +++ b/src/scvi/nn/_base_components.py @@ -565,6 +565,7 @@ def __init__( n_layers_individual: int = 1, n_layers_shared: int = 2, n_cat_list: Iterable[int] = None, + distribution: str = "normal", dropout_rate: float = 0.1, return_dist: bool = False, **kwargs, @@ -597,6 +598,11 @@ def __init__( **kwargs, ) + if distribution == "ln": + self.z_transformation = nn.Softmax(dim=-1) + else: + self.z_transformation = _identity + self.mean_encoder = nn.Linear(n_hidden, n_output) self.var_encoder = nn.Linear(n_hidden, n_output) self.return_dist = return_dist diff --git a/tests/external/cellassign/test_model_cellassign.py b/tests/external/cellassign/test_model_cellassign.py index f468cee7e6..f4d9309e10 100644 --- a/tests/external/cellassign/test_model_cellassign.py +++ b/tests/external/cellassign/test_model_cellassign.py @@ -28,6 +28,9 @@ def test_cellassign(): model = CellAssign(adata, marker_df) model.train(max_epochs=1) model.predict() + model.get_normalized_expression() + model.get_normalized_expression(transform_batch="batch_1") + # model.get_normalized_expression(n_samples=2) def test_cellassign_error_duplicates(): diff --git a/tests/external/gimvi/test_gimvi.py b/tests/external/gimvi/test_gimvi.py index c97d2ed535..ceb6a3acbb 100644 --- a/tests/external/gimvi/test_gimvi.py +++ b/tests/external/gimvi/test_gimvi.py @@ -127,6 +127,12 @@ def test_gimvi(): model.get_latent_representation() model.get_imputed_values() model.get_imputed_values(normalized=False) + model.get_normalized_expression(adata=adata_seq) + model.get_normalized_expression(adata=adata_spatial) + model.get_normalized_expression(adata=adata_seq, transform_batch="batch_1") + model.get_normalized_expression(adata=adata_spatial, transform_batch="batch_1") + # model.get_normalized_expression(adata=adata_seq, n_samples=2) + # model.get_normalized_expression(adata=adata_spatial, n_samples=2) adata_spatial.var_names += "asdf" GIMVI.setup_anndata( diff --git a/tests/external/poissonvi/test_poissonvi.py b/tests/external/poissonvi/test_poissonvi.py index 27862b345f..257b4c2dfe 100644 --- a/tests/external/poissonvi/test_poissonvi.py +++ b/tests/external/poissonvi/test_poissonvi.py @@ -6,12 +6,15 @@ def test_poissonvi(): adata = synthetic_iid(batch_size=100) - POISSONVI.setup_anndata(adata) + POISSONVI.setup_anndata(adata, batch_key="batch") model = POISSONVI(adata) model.train(max_epochs=1) model.get_latent_representation() model.get_accessibility_estimates() model.get_region_factors() + model.get_normalized_expression() + model.get_normalized_expression(transform_batch="batch_1") + model.get_normalized_expression(n_samples=2) def test_poissonvi_default_params(): diff --git a/tests/model/test_autozi.py b/tests/model/test_autozi.py index fcf2b82cf5..68fec0fa3e 100644 --- a/tests/model/test_autozi.py +++ b/tests/model/test_autozi.py @@ -91,7 +91,7 @@ def legacy_save( def test_autozi(): data = synthetic_iid( - n_batches=1, + n_batches=2, ) AUTOZI.setup_anndata( data, @@ -112,6 +112,9 @@ def test_autozi(): autozivae.get_reconstruction_error(indices=autozivae.validation_indices) autozivae.get_marginal_ll(indices=autozivae.validation_indices, n_mc_samples=3) autozivae.get_alphas_betas() + autozivae.get_normalized_expression() + autozivae.get_normalized_expression(transform_batch="batch_1") + autozivae.get_normalized_expression(n_samples=2) # Model library size. for disp_zi in ["gene", "gene-label"]: @@ -130,3 +133,6 @@ def test_autozi(): autozivae.get_reconstruction_error(indices=autozivae.validation_indices) autozivae.get_marginal_ll(indices=autozivae.validation_indices, n_mc_samples=3) autozivae.get_alphas_betas() + autozivae.get_normalized_expression() + autozivae.get_normalized_expression(transform_batch="batch_1") + autozivae.get_normalized_expression(n_samples=2) diff --git a/tests/model/test_condscvi.py b/tests/model/test_condscvi.py index 6fed090147..3b90653898 100644 --- a/tests/model/test_condscvi.py +++ b/tests/model/test_condscvi.py @@ -32,10 +32,10 @@ def test_condscvi_batch_key( def test_condscvi_batch_key_compat_load(save_path: str): - adata = synthetic_iid() + adata = synthetic_iid(n_batches=1, n_labels=5) model = CondSCVI.load("tests/test_data/condscvi_pre_batch", adata=adata) - assert not hasattr(model.summary_stats, "n_batch") + # assert not hasattr(model.summary_stats, "n_batch") _ = model.get_latent_representation() _ = model.get_vamp_prior(adata) @@ -45,16 +45,16 @@ def test_condscvi_batch_key_compat_load(save_path: str): @pytest.mark.parametrize("weight_obs", [True, False]) -def test_condscvi_no_batch_key(save_path: str, weight_obs: bool): +def test_condscvi_no_batch_key(weight_obs: bool, save_path: str): adata = synthetic_iid() CondSCVI.setup_anndata(adata, labels_key="labels") - with pytest.raises(ValueError): - _ = CondSCVI(adata, encode_covariates=True) + # with pytest.raises(ValueError): + # _ = CondSCVI(adata, encode_covariates=True) model = CondSCVI(adata, weight_obs=weight_obs) model.train(max_epochs=1) - assert not hasattr(model.summary_stats, "n_batch") + # assert not hasattr(model.summary_stats, "n_batch") _ = model.get_elbo() _ = model.get_reconstruction_error() _ = model.get_latent_representation() diff --git a/tests/model/test_destvi.py b/tests/model/test_destvi.py index 99cb715d0e..0bcff17ab8 100644 --- a/tests/model/test_destvi.py +++ b/tests/model/test_destvi.py @@ -16,6 +16,12 @@ def test_destvi(): sc_model = CondSCVI(dataset, n_latent=n_latent, n_layers=n_layers) sc_model.train(1, train_size=1) + sc_model.get_normalized_expression(dataset) + sc_model.get_elbo() + sc_model.get_reconstruction_error() + sc_model.get_latent_representation() + sc_model.get_vamp_prior(dataset, p=100) + # step 2 Check model setup DestVI.setup_anndata(dataset, layer=None) @@ -52,3 +58,6 @@ def test_destvi(): 50, dataset.n_vars, ) + + with pytest.raises(NotImplementedError): + spatial_model.get_normalized_expression() diff --git a/tests/model/test_peakvi.py b/tests/model/test_peakvi.py index 01f93dc7c3..5402d5c64e 100644 --- a/tests/model/test_peakvi.py +++ b/tests/model/test_peakvi.py @@ -139,6 +139,9 @@ def test_peakvi(): vae.get_reconstruction_error(indices=vae.validation_indices) vae.get_latent_representation() vae.differential_accessibility(groupby="labels", group1="label_1") + vae.get_normalized_expression() + vae.get_normalized_expression(transform_batch="batch_1") + vae.get_normalized_expression(n_samples=2) def single_pass_for_online_update(model): diff --git a/tests/test_data/condscvi_pre_batch/model.pt b/tests/test_data/condscvi_pre_batch/model.pt index bfa7175a75..5ca2f04dd2 100644 Binary files a/tests/test_data/condscvi_pre_batch/model.pt and b/tests/test_data/condscvi_pre_batch/model.pt differ