Skip to content

Commit

Permalink
fix: MrVI compatible with large sample sizes (#3146)
Browse files Browse the repository at this point in the history
Fixes #3145, #3166

---------

Co-authored-by: Ori Kronfeld <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Justin Hong <[email protected]>
Co-authored-by: Can Ergen <[email protected]>
  • Loading branch information
5 people authored Feb 20, 2025
1 parent 8e4ec96 commit c3926eb
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 32 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ to [Semantic Versioning]. Full commit history is available in the

- Fixed bug in distributed {class}`scvi.dataloaders.ConcatDataLoader` {pr}`3053`.
- Fixed bug when loading Pyro-based models and scArches support for Pyro {pr}`3138`
- Fixed disable vmap in {class}`scvi.external.MRVI` for large sample sizes to avoid
out-of-memory errors. Store distance matrices as numpy array in xarray to reduce
memory usage {pr}`3146`.

#### Changed

Expand Down
116 changes: 84 additions & 32 deletions src/scvi/external/mrvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ def __init__(self, adata: AnnData, **model_kwargs):
n_batch = self.summary_stats.n_batch
n_labels = self.summary_stats.n_labels

obs_df = adata.obs.copy()
obs_df = obs_df.loc[~obs_df._scvi_sample.duplicated("first")]
self.sample_info = obs_df.set_index("_scvi_sample").sort_index()
self.update_sample_info(adata)
self.sample_key = self.adata_manager.get_state_registry(
REGISTRY_KEYS.SAMPLE_KEY
).original_key
Expand Down Expand Up @@ -309,7 +307,7 @@ def compute_local_statistics(
adata: AnnData | None = None,
indices: npt.ArrayLike | None = None,
batch_size: int | None = None,
use_vmap: bool = True,
use_vmap: Literal["auto", True, False] = "auto",
norm: str = "l2",
mc_samples: int = 10,
) -> xr.Dataset:
Expand All @@ -330,7 +328,8 @@ def compute_local_statistics(
batch_size
Batch size to use for computing the local statistics.
use_vmap
Whether to use vmap to compute the local statistics.
Whether to use vmap to compute the local statistics. If "auto", vmap will be used if
the number of samples is less than 500.
norm
Norm to use for computing the distances.
mc_samples
Expand All @@ -341,6 +340,8 @@ def compute_local_statistics(

from scvi.external.mrvi._utils import _parse_local_statistics_requirements

use_vmap = use_vmap if use_vmap != "auto" else self.summary_stats.n_sample < 500

if not reductions or len(reductions) == 0:
raise ValueError("At least one reduction must be provided.")

Expand Down Expand Up @@ -418,15 +419,24 @@ def per_sample_inference_fn(pair):

# OK to use stacked rngs here since there is no stochasticity for mean rep.
if reqs.needs_mean_representations:
mean_zs_ = mapped_inference_fn(
stacked_rngs=stacked_rngs,
x=jnp.array(inf_inputs["x"]),
sample_index=jnp.array(inf_inputs["sample_index"]),
cf_sample=jnp.array(cf_sample),
use_mean=True,
)
try:
mean_zs_ = mapped_inference_fn(
stacked_rngs=stacked_rngs,
x=jnp.array(inf_inputs["x"]),
sample_index=jnp.array(inf_inputs["sample_index"]),
cf_sample=jnp.array(cf_sample),
use_mean=True,
)
except jax.errors.JaxRuntimeError as e:
if use_vmap:
raise RuntimeError(
"JAX ran out of memory. Try setting use_vmap=False."
) from e
else:
raise e

mean_zs = xr.DataArray(
mean_zs_,
np.array(mean_zs_),
dims=["cell_name", "sample", "latent_dim"],
coords={
"cell_name": self.adata.obs_names[indices].values,
Expand All @@ -445,7 +455,7 @@ def per_sample_inference_fn(pair):
) # (n_mc_samples, n_cells, n_samples, n_latent)
sampled_zs_ = sampled_zs_.transpose((1, 0, 2, 3))
sampled_zs = xr.DataArray(
sampled_zs_,
np.array(sampled_zs_),
dims=["cell_name", "mc_sample", "sample", "latent_dim"],
coords={
"cell_name": self.adata.obs_names[indices].values,
Expand All @@ -456,12 +466,12 @@ def per_sample_inference_fn(pair):

if reqs.needs_mean_distances:
mean_dists = self._compute_distances_from_representations(
mean_zs_, indices, norm=norm
mean_zs_, indices, norm=norm, return_numpy=True
)

if reqs.needs_sampled_distances or reqs.needs_normalized_distances:
sampled_dists = self._compute_distances_from_representations(
sampled_zs_, indices, norm=norm
sampled_zs_, indices, norm=norm, return_numpy=True
)

if reqs.needs_normalized_distances:
Expand Down Expand Up @@ -570,6 +580,7 @@ def _compute_distances_from_representations(
reps: jax.typing.ArrayLike,
indices: jax.typing.ArrayLike,
norm: Literal["l2", "l1", "linf"] = "l2",
return_numpy: bool = True,
) -> xr.DataArray:
if norm not in ("l2", "l1", "linf"):
raise ValueError(f"`norm` {norm} not supported")
Expand All @@ -588,6 +599,8 @@ def _compute_distance(rep: jax.typing.ArrayLike):

if reps.ndim == 3:
dists = jax.vmap(_compute_distance)(reps)
if return_numpy:
dists = np.array(dists)
return xr.DataArray(
dists,
dims=["cell_name", "sample_x", "sample_y"],
Expand All @@ -601,6 +614,8 @@ def _compute_distance(rep: jax.typing.ArrayLike):
else:
# Case with sampled representations
dists = jax.vmap(jax.vmap(_compute_distance))(reps)
if return_numpy:
dists = np.array(dists)
return xr.DataArray(
dists,
dims=["cell_name", "mc_sample", "sample_x", "sample_y"],
Expand All @@ -619,7 +634,7 @@ def get_local_sample_representation(
indices: npt.ArrayLike | None = None,
batch_size: int = 256,
use_mean: bool = True,
use_vmap: bool = True,
use_vmap: Literal["auto", True, False] = "auto",
) -> xr.DataArray:
"""Compute the local sample representation of the cells in the ``adata`` object.
Expand Down Expand Up @@ -660,7 +675,7 @@ def get_local_sample_distances(
batch_size: int = 256,
use_mean: bool = True,
normalize_distances: bool = False,
use_vmap: bool = True,
use_vmap: Literal["auto", True, False] = "auto",
groupby: list[str] | str | None = None,
keep_cell: bool = True,
norm: str = "l2",
Expand Down Expand Up @@ -698,6 +713,8 @@ def get_local_sample_distances(
Number of Monte Carlo samples to use for computing the local sample distances. Only
relevant if ``use_mean=False``.
"""
use_vmap = "auto" if use_vmap == "auto" else use_vmap

input = "mean_distances" if use_mean else "sampled_distances"
if normalize_distances:
if use_mean:
Expand Down Expand Up @@ -1053,7 +1070,7 @@ def differential_expression(
sample_cov_keys: list[str] | None = None,
sample_subset: list[str] | None = None,
batch_size: int = 128,
use_vmap: bool = True,
use_vmap: Literal["auto", True, False] = "auto",
normalize_design_matrix: bool = True,
add_batch_specific_offsets: bool = False,
mc_samples: int = 100,
Expand Down Expand Up @@ -1142,6 +1159,8 @@ def differential_expression(

from scipy.stats import false_discovery_control

use_vmap = use_vmap if use_vmap != "auto" else self.summary_stats.n_sample < 500

if sample_cov_keys is None:
# Hack: kept as kwarg to maintain order of arguments.
raise ValueError("Must assign `sample_cov_keys`")
Expand Down Expand Up @@ -1371,19 +1390,26 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps):
Amat = jax.device_put(Amat, self.device)
prefactor = jax.device_put(prefactor, self.device)

res = mapped_inference_fn(
stacked_rngs=stacked_rngs,
x=jnp.array(inf_inputs["x"]),
sample_index=jnp.array(inf_inputs["sample_index"]),
cf_sample=jnp.array(cf_sample),
Amat=Amat,
prefactor=prefactor,
n_samples_per_cell=n_samples_per_cell,
admissible_samples_mat=admissible_samples_mat,
use_mean=False,
rngs_de=rngs_de,
mc_samples=mc_samples,
)
try:
res = mapped_inference_fn(
stacked_rngs=stacked_rngs,
x=jnp.array(inf_inputs["x"]),
sample_index=jnp.array(inf_inputs["sample_index"]),
cf_sample=jnp.array(cf_sample),
Amat=Amat,
prefactor=prefactor,
n_samples_per_cell=n_samples_per_cell,
admissible_samples_mat=admissible_samples_mat,
use_mean=False,
rngs_de=rngs_de,
mc_samples=mc_samples,
)
except jax.errors.JaxRuntimeError as e:
if use_vmap:
raise RuntimeError("JAX ran out of memory. Try setting use_vmap=False.") from e
else:
raise e

beta.append(np.array(res["beta"]))
effect_size.append(np.array(res["effect_size"]))
pvalue.append(np.array(res["pvalue"]))
Expand Down Expand Up @@ -1557,3 +1583,29 @@ def _construct_design_matrix(
covariates_require_lfc = jnp.array(covariates_require_lfc)

return Xmat, Xmat_names, covariates_require_lfc, offset_indices

def update_sample_info(self, adata):
"""Initialize/update metadata in the case where additional covariates are added.
Parameters
----------
adata
AnnData object to update the sample info with. Typically, this corresponds to the
working dataset, where additional sample-specific covariates have been added.
Examples
--------
>>> import scanpy as sc
>>> from scvi.external import MRVI
>>> MRVI.setup_anndata(adata, sample_key="sample_id")
>>> model = MRVI(adata)
>>> model.train()
>>> # Update sample info with new covariates
>>> sample_mapper = {"sample_1": "healthy", "sample_2": "disease"}
>>> adata.obs["disease_status"] = adata.obs["sample_id"].map(sample_mapper)
>>> model.update_sample_info(adata)
"""
adata = self._validate_anndata(adata)
obs_df = adata.obs.copy()
obs_df = obs_df.loc[~obs_df._scvi_sample.duplicated("first")]
self.sample_info = obs_df.set_index("_scvi_sample").sort_index()

0 comments on commit c3926eb

Please sign in to comment.