Skip to content

Commit

Permalink
feat: supervised module class to be used with semisupervised models. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ori-kron-wis authored Mar 5, 2025
1 parent 3ff1826 commit 7e91540
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 186 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ to [Semantic Versioning]. Full commit history is available in the

- Add {class}`scvi.external.METHYLANVI` for modeling methylation labelled data from single-cell
bisulfite sequencing (scBS-seq) {pr}`3066`.
-
- Add supervised module class {class}`scvi.model.base.SupervisedModuleClass`. {pr}`3237`.

#### Fixed

Expand Down
1 change: 1 addition & 0 deletions docs/api/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ These classes should be used to construct module classes that define generative
module.base.BaseModuleClass
module.base.BaseMinifiedModeModuleClass
module.base.SupervisedModuleClass
module.base.PyroBaseModuleClass
module.base.JaxBaseModuleClass
module.base.EmbeddingModuleMixin
Expand Down
2 changes: 1 addition & 1 deletion docs/user_guide/models/resolvi.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ The limitations of resolVI include:

```{topic} Tutorials:
- {doc}`/tutorials/notebooks/spatial/resolVI_tutorial.ipynb`
- {doc}`/tutorials/notebooks/spatial/resolVI_tutorial`
```

## Preliminaries
Expand Down
10 changes: 5 additions & 5 deletions docs/user_guide/use_case/downstream_analysis_tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ SCVI provides useful tools for exploring and understanding the learned latent re

1. Latent Space Exploration
Visualization: You can visualize the learned latent representation of cells to observe patterns, clusters, and structure in the data.
We usually apply [scanpy's UMAP]("https://scanpy.readthedocs.io/en/1.10.x/tutorials/plotting/core.html"); These dimensionality reduction techniques can be applied to the model’s latent space to visualize how cells are grouped in 2D or 3D.
We usually apply [scanpy's UMAP](https://scanpy.readthedocs.io/en/1.10.x/tutorials/plotting/core.html); These dimensionality reduction techniques can be applied to the model’s latent space to visualize how cells are grouped in 2D or 3D.

Clustering: After visualizing the latent space, you can perform clustering (e.g., leiden, k-means, hierarchical clustering) to identify distinct groups of cells with similar gene expression profiles.

Expand All @@ -17,12 +17,12 @@ For example:
:::
2. Differential Expression (DE) Analysis
Gene Expression Comparisons: SCVI allows you to perform differential expression analysis between different clusters or conditions (e.g., different cell types or experimental conditions).
You can compare the expression of genes between clusters to identify which genes are differentially expressed. See more information [here]("https://decoupler-py.readthedocs.io/en/latest/notebooks/bulk.html#Differential-expression-analysis")
You can compare the expression of genes between clusters to identify which genes are differentially expressed. See more information [here](https://decoupler-py.readthedocs.io/en/latest/notebooks/bulk.html#Differential-expression-analysis)
```python
differential_expression = scvi.model.SCVI().differential_expression()
```
Log-fold Change (LFC) and p-values are typically used to assess which genes have significant expression differences between groups.
Refer to [SCVI-Hub]("https://huggingface.co/scvi-tools") for use cases of DE.
Refer to [SCVI-Hub](https://huggingface.co/scvi-tools) for use cases of DE.
3. Cell Type Identification
Mapping to Known Labels: After training a model with SCVI, you can use the latent space to assign cells to known or predicted cell types. You can compare how well SCVI clusters cells by their latent representations and match them to known biological annotations.
If you have labeled data (e.g., cell types), you can assess how well the model’s clusters correspond to these labels.
Expand All @@ -35,7 +35,7 @@ See {class}`scvi.module.LDVAE`
5. Predictive Modeling and Imputation
Gene Expression/Protein Imputation: SCVI can be used to impute missing or dropout gene expression values. This is useful when dealing with sparse data, as SCVI can recover hidden information based on the patterns it learns in the data.
Prediction: After training, SCVI can also be used to predict gene expression under different conditions or experimental setups.
We can get those with get_normalized_expression function, exists for most models, and plot it, per gene over the already generated latent embedding umap. See {doc}`/user_guide/models/mrvi` as example
We can get those with get_normalized_expression function, exists for most models, and plot it, per gene over the already generated latent embedding umap.
6. Batch Effect Removal
Identifying and Correcting Batch Effects: SCVI allows you to account for and remove batch effects during training. You can use the model to test if there is any batch-related structure in the latent space or gene expression.
You can evaluate how much of the variability in gene expression is due to biological factors versus technical factors like batch.
Expand All @@ -50,7 +50,7 @@ Log-Likelihood: SCVI models the data through a probabilistic framework, so you c
Cross-validation: Perform cross-validation to check for overfitting and assess the model's generalizability.
10. Model Comparison
Comparing SCVI Models: You can compare different SCVI models trained with different configurations or hyperparameters to assess which model performs better. This is especially useful in selecting the best model for downstream tasks like differential expression or clustering.
We usually use to run [scib-metrics]("https://github.com/YosefLab/scib-metrics") to compare different models bio conservation and batch correction metrics.
We usually use to run [scib-metrics](https://github.com/YosefLab/scib-metrics) to compare different models bio conservation and batch correction metrics.

In summary:
SCVI provides a broad set of downstream analysis capabilities, including differential expression analysis, cell type identification, latent factor exploration, trajectory inference, and batch effect correction, among others. By using SCVI’s probabilistic framework, you can explore complex patterns in single-cell RNA-seq data, visualize latent representations, impute missing data, and integrate metadata to gain deeper insights into cellular behaviors. These tools are crucial for understanding biological processes, making SCVI a versatile tool for single-cell genomics analysis.
78 changes: 8 additions & 70 deletions src/scvi/external/methylvi/_methylanvi_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from scvi.external.methylvi._methylvi_module import METHYLVAE
from scvi.module._classifier import Classifier
from scvi.module._utils import broadcast_labels
from scvi.module.base import LossOutput, auto_move_data
from scvi.module.base import LossOutput, SupervisedModuleClass, auto_move_data
from scvi.nn import Decoder, Encoder


class METHYLANVAE(METHYLVAE, BSSeqModuleMixin):
class METHYLANVAE(SupervisedModuleClass, METHYLVAE, BSSeqModuleMixin):
"""Methylation annotation using variational inference.
This is an implementation of the MethylANVI model described in :cite:p:`Weinberger2023a`.
Expand Down Expand Up @@ -51,8 +51,6 @@ class METHYLANVAE(METHYLVAE, BSSeqModuleMixin):
One of the following
* ``'region'`` - dispersion parameter of BetaBinomial is constant per region across cells
* ``'region-cell'`` - dispersion can differ for every region in every cell
log_variational
Log(data+1) prior to encoding for numerical stability. Not normalization.
y_prior
If None, initialized to uniform probability over cell types
labels_groups
Expand Down Expand Up @@ -196,76 +194,16 @@ def classify(
cat_covs=None,
use_posterior_mean: bool = True,
) -> torch.Tensor:
"""Forward pass through the encoder and classifier.
Parameters
----------
x
Tensor of shape ``(n_obs, n_vars)``.
batch_index
Tensor of shape ``(n_obs,)`` denoting batch indices.
cont_covs
Tensor of shape ``(n_obs, n_continuous_covariates)``.
cat_covs
Tensor of shape ``(n_obs, n_categorical_covariates)``.
use_posterior_mean
Whether to use the posterior mean of the latent distribution for
classification.
Returns
-------
Tensor of shape ``(n_obs, n_labels)`` denoting logit scores per label.
Before v1.1, this method by default returned probabilities per label,
see #2301 for more details.
"""
# log the inputs to the variational distribution for numerical stability
mc_ = torch.log(1 + mc)
cov_ = torch.log(1 + cov)

"""Forward pass through the encoder and classifier of methylANVI."""
# get variational parameters via the encoder networks
# we input both the methylated reads (mc) and coverage (cov)
encoder_input = torch.cat((mc_, cov_), dim=-1)
if cont_covs is not None and self.encode_covariates:
encoder_input = torch.cat((encoder_input, cont_covs), dim=-1)
if cat_covs is not None and self.encode_covariates:
categorical_input = torch.split(cat_covs, 1, dim=1)
else:
categorical_input = ()

qz, z = self.z_encoder(encoder_input, batch_index, *categorical_input)
z = qz.loc if use_posterior_mean else z

if self.use_labels_groups:
w_g = self.classifier_groups(z)
unw_y = self.classifier(z)
w_y = torch.zeros_like(unw_y)
for i, group_index in enumerate(self.groups_index):
unw_y_g = unw_y[:, group_index]
w_y[:, group_index] = unw_y_g / (unw_y_g.sum(dim=-1, keepdim=True) + 1e-8)
w_y[:, group_index] *= w_g[:, [i]]
else:
w_y = self.classifier(z)
return w_y

@auto_move_data
def classification_loss(self, labelled_dataset):
"""Computes scANVI-style classification loss."""
inference_inputs = self._get_inference_input(labelled_dataset) # (n_obs, n_vars)
data_inputs = {key: inference_inputs[key] for key in self.data_input_keys}
y = labelled_dataset[REGISTRY_KEYS.LABELS_KEY] # (n_obs, 1)
batch_idx = labelled_dataset[REGISTRY_KEYS.BATCH_KEY]
cat_covs = inference_inputs["cat_covs"]

logits = self.classify(
**data_inputs,
batch_index=batch_idx,
return super().classify(
x=torch.cat((mc, cov), dim=-1),
batch_index=batch_index,
cont_covs=cont_covs,
cat_covs=cat_covs,
) # (n_obs, n_labels)
ce_loss = F.cross_entropy(
logits,
y.view(-1).long(),
use_posterior_mean=use_posterior_mean,
)
return ce_loss, y, logits

def loss(
self,
Expand Down
4 changes: 4 additions & 0 deletions src/scvi/external/methylvi/_methylvi_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class METHYLVAE(BaseModuleClass, BSSeqModuleMixin):
Number of hidden layers used for encoder and decoder NNs
dropout_rate
Dropout rate for neural networks
log_variational
Log(data+1) prior to encoding for numerical stability. Not normalization.
likelihood
One of
* ``'betabinomial'`` - BetaBinomial distribution
Expand All @@ -63,6 +65,7 @@ def __init__(
n_latent: int = 10,
n_layers: int = 1,
dropout_rate: float = 0.1,
log_variational: bool = True,
likelihood: Literal["betabinomial", "binomial"] = "betabinomial",
dispersion: Literal["region", "region-cell"] = "region",
):
Expand All @@ -74,6 +77,7 @@ def __init__(
self.dispersion = dispersion
self.likelihood = likelihood
self.contexts = contexts
self.log_variational = log_variational
self.num_features_per_context = num_features_per_context

cat_list = [n_batch] + list([] if n_cats_per_cov is None else n_cats_per_cov)
Expand Down
109 changes: 2 additions & 107 deletions src/scvi/module/_scanvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from torch.nn import functional as F

from scvi import REGISTRY_KEYS
from scvi.data import _constants
from scvi.module.base import LossOutput, auto_move_data
from scvi.module.base import LossOutput, SupervisedModuleClass
from scvi.nn import Decoder, Encoder

from ._classifier import Classifier
Expand All @@ -23,10 +22,8 @@

from torch.distributions import Distribution

from scvi.model.base import BaseModelClass


class SCANVAE(VAE):
class SCANVAE(SupervisedModuleClass, VAE):
"""Single-cell annotation using variational inference.
This is an implementation of the scANVI model described in :cite:p:`Xu21`,
Expand Down Expand Up @@ -202,87 +199,6 @@ def __init__(
]
)

@auto_move_data
def classify(
self,
x: torch.Tensor,
batch_index: torch.Tensor | None = None,
cont_covs: torch.Tensor | None = None,
cat_covs: torch.Tensor | None = None,
use_posterior_mean: bool = True,
) -> torch.Tensor:
"""Forward pass through the encoder and classifier.
Parameters
----------
x
Tensor of shape ``(n_obs, n_vars)``.
batch_index
Tensor of shape ``(n_obs,)`` denoting batch indices.
cont_covs
Tensor of shape ``(n_obs, n_continuous_covariates)``.
cat_covs
Tensor of shape ``(n_obs, n_categorical_covariates)``.
use_posterior_mean
Whether to use the posterior mean of the latent distribution for
classification.
Returns
-------
Tensor of shape ``(n_obs, n_labels)`` denoting logit scores per label.
Before v1.1, this method by default returned probabilities per label,
see #2301 for more details.
"""
if self.log_variational:
x = torch.log1p(x)

if cont_covs is not None and self.encode_covariates:
encoder_input = torch.cat((x, cont_covs), dim=-1)
else:
encoder_input = x
if cat_covs is not None and self.encode_covariates:
categorical_input = torch.split(cat_covs, 1, dim=1)
else:
categorical_input = ()

qz, z = self.z_encoder(encoder_input, batch_index, *categorical_input)
z = qz.loc if use_posterior_mean else z

if self.use_labels_groups:
w_g = self.classifier_groups(z)
unw_y = self.classifier(z)
w_y = torch.zeros_like(unw_y)
for i, group_index in enumerate(self.groups_index):
unw_y_g = unw_y[:, group_index]
w_y[:, group_index] = unw_y_g / (unw_y_g.sum(dim=-1, keepdim=True) + 1e-8)
w_y[:, group_index] *= w_g[:, [i]]
else:
w_y = self.classifier(z)
return w_y

@auto_move_data
def classification_loss(
self, labelled_dataset: dict[str, torch.Tensor]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
x = labelled_dataset[REGISTRY_KEYS.X_KEY] # (n_obs, n_vars)
y = labelled_dataset[REGISTRY_KEYS.LABELS_KEY] # (n_obs, 1)
batch_idx = labelled_dataset[REGISTRY_KEYS.BATCH_KEY]
cont_key = REGISTRY_KEYS.CONT_COVS_KEY
cont_covs = labelled_dataset[cont_key] if cont_key in labelled_dataset.keys() else None

cat_key = REGISTRY_KEYS.CAT_COVS_KEY
cat_covs = labelled_dataset[cat_key] if cat_key in labelled_dataset.keys() else None
# NOTE: prior to v1.1, this method returned probabilities per label by
# default, see #2301 for more details
logits = self.classify(
x, batch_index=batch_idx, cat_covs=cat_covs, cont_covs=cont_covs
) # (n_obs, n_labels)
ce_loss = F.cross_entropy(
logits,
y.view(-1).long(),
)
return ce_loss, y, logits

def loss(
self,
tensors: dict[str, torch.Tensor],
Expand Down Expand Up @@ -373,24 +289,3 @@ def loss(
},
)
return LossOutput(loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_divergence)

def on_load(self, model: BaseModelClass, **kwargs):
manager = model.get_anndata_manager(model.adata, required=True)
source_version = manager._source_registry[_constants._SCVI_VERSION_KEY]
version_split = source_version.split(".")

if int(version_split[0]) >= 1 and int(version_split[1]) >= 1:
return

# need this if <1.1 model is resaved with >=1.1 as new registry is
# updated on setup
manager.registry[_constants._SCVI_VERSION_KEY] = source_version

# pre 1.1 logits fix
model_kwargs = model.init_params_.get("model_kwargs", {})
cls_params = model_kwargs.get("classifier_parameters", {})
user_logits = cls_params.get("logits", False)

if not user_logits:
self.classifier.logits = False
self.classifier.classifier.append(torch.nn.Softmax(dim=-1))
2 changes: 2 additions & 0 deletions src/scvi/module/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
JaxBaseModuleClass,
LossOutput,
PyroBaseModuleClass,
SupervisedModuleClass,
TrainStateWithState,
)
from ._decorators import auto_move_data, flax_configure
Expand All @@ -19,6 +20,7 @@
"JaxBaseModuleClass",
"TrainStateWithState",
"BaseMinifiedModeModuleClass",
"SupervisedModuleClass",
"EmbeddingModuleMixin",
"GaussianPrior",
"MogPrior",
Expand Down
Loading

0 comments on commit 7e91540

Please sign in to comment.