Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: vmap False by default #3146

Merged
merged 18 commits into from
Feb 20, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -26,6 +26,9 @@ to [Semantic Versioning]. Full commit history is available in the

- Fixed bug in distributed {class}`scvi.dataloaders.ConcatDataLoader` {pr}`3053`.
- Fixed bug when loading Pyro-based models and scArches support for Pyro {pr}`3138`
- Fixed disable vmap in {class}`scvi.external.MRVI` for large sample sizes to avoid
out-of-memory errors. Store distance matrices as numpy array in xarray to reduce
memory usage {pr}`3146`.

#### Changed

116 changes: 84 additions & 32 deletions src/scvi/external/mrvi/_model.py
Original file line number Diff line number Diff line change
@@ -105,9 +105,7 @@ def __init__(self, adata: AnnData, **model_kwargs):
n_batch = self.summary_stats.n_batch
n_labels = self.summary_stats.n_labels

obs_df = adata.obs.copy()
obs_df = obs_df.loc[~obs_df._scvi_sample.duplicated("first")]
self.sample_info = obs_df.set_index("_scvi_sample").sort_index()
self.update_sample_info(adata)
self.sample_key = self.adata_manager.get_state_registry(
REGISTRY_KEYS.SAMPLE_KEY
).original_key
@@ -309,7 +307,7 @@ def compute_local_statistics(
adata: AnnData | None = None,
indices: npt.ArrayLike | None = None,
batch_size: int | None = None,
use_vmap: bool = True,
use_vmap: Literal["auto", True, False] = "auto",
norm: str = "l2",
mc_samples: int = 10,
) -> xr.Dataset:
@@ -330,7 +328,8 @@ def compute_local_statistics(
batch_size
Batch size to use for computing the local statistics.
use_vmap
Whether to use vmap to compute the local statistics.
Whether to use vmap to compute the local statistics. If "auto", vmap will be used if
the number of samples is less than 500.
norm
Norm to use for computing the distances.
mc_samples
@@ -341,6 +340,8 @@ def compute_local_statistics(

from scvi.external.mrvi._utils import _parse_local_statistics_requirements

use_vmap = use_vmap if use_vmap != "auto" else self.summary_stats.n_sample < 500

if not reductions or len(reductions) == 0:
raise ValueError("At least one reduction must be provided.")

@@ -418,15 +419,24 @@ def per_sample_inference_fn(pair):

# OK to use stacked rngs here since there is no stochasticity for mean rep.
if reqs.needs_mean_representations:
mean_zs_ = mapped_inference_fn(
stacked_rngs=stacked_rngs,
x=jnp.array(inf_inputs["x"]),
sample_index=jnp.array(inf_inputs["sample_index"]),
cf_sample=jnp.array(cf_sample),
use_mean=True,
)
try:
mean_zs_ = mapped_inference_fn(
stacked_rngs=stacked_rngs,
x=jnp.array(inf_inputs["x"]),
sample_index=jnp.array(inf_inputs["sample_index"]),
cf_sample=jnp.array(cf_sample),
use_mean=True,
)
except jax.errors.JaxRuntimeError as e:
if use_vmap:
raise RuntimeError(
"JAX ran out of memory. Try setting use_vmap=False."
) from e
else:
raise e

mean_zs = xr.DataArray(
mean_zs_,
np.array(mean_zs_),
dims=["cell_name", "sample", "latent_dim"],
coords={
"cell_name": self.adata.obs_names[indices].values,
@@ -445,7 +455,7 @@ def per_sample_inference_fn(pair):
) # (n_mc_samples, n_cells, n_samples, n_latent)
sampled_zs_ = sampled_zs_.transpose((1, 0, 2, 3))
sampled_zs = xr.DataArray(
sampled_zs_,
np.array(sampled_zs_),
dims=["cell_name", "mc_sample", "sample", "latent_dim"],
coords={
"cell_name": self.adata.obs_names[indices].values,
@@ -456,12 +466,12 @@ def per_sample_inference_fn(pair):

if reqs.needs_mean_distances:
mean_dists = self._compute_distances_from_representations(
mean_zs_, indices, norm=norm
mean_zs_, indices, norm=norm, return_numpy=True
)

if reqs.needs_sampled_distances or reqs.needs_normalized_distances:
sampled_dists = self._compute_distances_from_representations(
sampled_zs_, indices, norm=norm
sampled_zs_, indices, norm=norm, return_numpy=True
)

if reqs.needs_normalized_distances:
@@ -570,6 +580,7 @@ def _compute_distances_from_representations(
reps: jax.typing.ArrayLike,
indices: jax.typing.ArrayLike,
norm: Literal["l2", "l1", "linf"] = "l2",
return_numpy: bool = True,
) -> xr.DataArray:
if norm not in ("l2", "l1", "linf"):
raise ValueError(f"`norm` {norm} not supported")
@@ -588,6 +599,8 @@ def _compute_distance(rep: jax.typing.ArrayLike):

if reps.ndim == 3:
dists = jax.vmap(_compute_distance)(reps)
if return_numpy:
dists = np.array(dists)
return xr.DataArray(
dists,
dims=["cell_name", "sample_x", "sample_y"],
@@ -601,6 +614,8 @@ def _compute_distance(rep: jax.typing.ArrayLike):
else:
# Case with sampled representations
dists = jax.vmap(jax.vmap(_compute_distance))(reps)
if return_numpy:
dists = np.array(dists)
return xr.DataArray(
dists,
dims=["cell_name", "mc_sample", "sample_x", "sample_y"],
@@ -619,7 +634,7 @@ def get_local_sample_representation(
indices: npt.ArrayLike | None = None,
batch_size: int = 256,
use_mean: bool = True,
use_vmap: bool = True,
use_vmap: Literal["auto", True, False] = "auto",
) -> xr.DataArray:
"""Compute the local sample representation of the cells in the ``adata`` object.

@@ -660,7 +675,7 @@ def get_local_sample_distances(
batch_size: int = 256,
use_mean: bool = True,
normalize_distances: bool = False,
use_vmap: bool = True,
use_vmap: Literal["auto", True, False] = "auto",
groupby: list[str] | str | None = None,
keep_cell: bool = True,
norm: str = "l2",
@@ -698,6 +713,8 @@ def get_local_sample_distances(
Number of Monte Carlo samples to use for computing the local sample distances. Only
relevant if ``use_mean=False``.
"""
use_vmap = "auto" if use_vmap == "auto" else use_vmap

input = "mean_distances" if use_mean else "sampled_distances"
if normalize_distances:
if use_mean:
@@ -1053,7 +1070,7 @@ def differential_expression(
sample_cov_keys: list[str] | None = None,
sample_subset: list[str] | None = None,
batch_size: int = 128,
use_vmap: bool = True,
use_vmap: Literal["auto", True, False] = "auto",
normalize_design_matrix: bool = True,
add_batch_specific_offsets: bool = False,
mc_samples: int = 100,
@@ -1142,6 +1159,8 @@ def differential_expression(

from scipy.stats import false_discovery_control

use_vmap = use_vmap if use_vmap != "auto" else self.summary_stats.n_sample < 500

if sample_cov_keys is None:
# Hack: kept as kwarg to maintain order of arguments.
raise ValueError("Must assign `sample_cov_keys`")
@@ -1371,19 +1390,26 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps):
Amat = jax.device_put(Amat, self.device)
prefactor = jax.device_put(prefactor, self.device)

res = mapped_inference_fn(
stacked_rngs=stacked_rngs,
x=jnp.array(inf_inputs["x"]),
sample_index=jnp.array(inf_inputs["sample_index"]),
cf_sample=jnp.array(cf_sample),
Amat=Amat,
prefactor=prefactor,
n_samples_per_cell=n_samples_per_cell,
admissible_samples_mat=admissible_samples_mat,
use_mean=False,
rngs_de=rngs_de,
mc_samples=mc_samples,
)
try:
res = mapped_inference_fn(
stacked_rngs=stacked_rngs,
x=jnp.array(inf_inputs["x"]),
sample_index=jnp.array(inf_inputs["sample_index"]),
cf_sample=jnp.array(cf_sample),
Amat=Amat,
prefactor=prefactor,
n_samples_per_cell=n_samples_per_cell,
admissible_samples_mat=admissible_samples_mat,
use_mean=False,
rngs_de=rngs_de,
mc_samples=mc_samples,
)
except jax.errors.JaxRuntimeError as e:
if use_vmap:
raise RuntimeError("JAX ran out of memory. Try setting use_vmap=False.") from e
else:
raise e

beta.append(np.array(res["beta"]))
effect_size.append(np.array(res["effect_size"]))
pvalue.append(np.array(res["pvalue"]))
@@ -1557,3 +1583,29 @@ def _construct_design_matrix(
covariates_require_lfc = jnp.array(covariates_require_lfc)

return Xmat, Xmat_names, covariates_require_lfc, offset_indices

def update_sample_info(self, adata):
"""Initialize/update metadata in the case where additional covariates are added.

Parameters
----------
adata
AnnData object to update the sample info with. Typically, this corresponds to the
working dataset, where additional sample-specific covariates have been added.

Examples
--------
>>> import scanpy as sc
>>> from scvi.external import MRVI
>>> MRVI.setup_anndata(adata, sample_key="sample_id")
>>> model = MRVI(adata)
>>> model.train()
>>> # Update sample info with new covariates
>>> sample_mapper = {"sample_1": "healthy", "sample_2": "disease"}
>>> adata.obs["disease_status"] = adata.obs["sample_id"].map(sample_mapper)
>>> model.update_sample_info(adata)
"""
adata = self._validate_anndata(adata)
obs_df = adata.obs.copy()
obs_df = obs_df.loc[~obs_df._scvi_sample.duplicated("first")]
self.sample_info = obs_df.set_index("_scvi_sample").sort_index()