diff --git a/CHANGELOG.md b/CHANGELOG.md index 8561302ac7..1e03717f92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,9 @@ to [Semantic Versioning]. Full commit history is available in the - Fixed bug in distributed {class}`scvi.dataloaders.ConcatDataLoader` {pr}`3053`. - Fixed bug when loading Pyro-based models and scArches support for Pyro {pr}`3138` +- Fixed disable vmap in {class}`scvi.external.MRVI` for large sample sizes to avoid + out-of-memory errors. Store distance matrices as numpy array in xarray to reduce + memory usage {pr}`3146`. #### Changed diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index e7e7b7c2fa..baad88751d 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -105,9 +105,7 @@ def __init__(self, adata: AnnData, **model_kwargs): n_batch = self.summary_stats.n_batch n_labels = self.summary_stats.n_labels - obs_df = adata.obs.copy() - obs_df = obs_df.loc[~obs_df._scvi_sample.duplicated("first")] - self.sample_info = obs_df.set_index("_scvi_sample").sort_index() + self.update_sample_info(adata) self.sample_key = self.adata_manager.get_state_registry( REGISTRY_KEYS.SAMPLE_KEY ).original_key @@ -309,7 +307,7 @@ def compute_local_statistics( adata: AnnData | None = None, indices: npt.ArrayLike | None = None, batch_size: int | None = None, - use_vmap: bool = True, + use_vmap: Literal["auto", True, False] = "auto", norm: str = "l2", mc_samples: int = 10, ) -> xr.Dataset: @@ -330,7 +328,8 @@ def compute_local_statistics( batch_size Batch size to use for computing the local statistics. use_vmap - Whether to use vmap to compute the local statistics. + Whether to use vmap to compute the local statistics. If "auto", vmap will be used if + the number of samples is less than 500. norm Norm to use for computing the distances. mc_samples @@ -341,6 +340,8 @@ def compute_local_statistics( from scvi.external.mrvi._utils import _parse_local_statistics_requirements + use_vmap = use_vmap if use_vmap != "auto" else self.summary_stats.n_sample < 500 + if not reductions or len(reductions) == 0: raise ValueError("At least one reduction must be provided.") @@ -418,15 +419,24 @@ def per_sample_inference_fn(pair): # OK to use stacked rngs here since there is no stochasticity for mean rep. if reqs.needs_mean_representations: - mean_zs_ = mapped_inference_fn( - stacked_rngs=stacked_rngs, - x=jnp.array(inf_inputs["x"]), - sample_index=jnp.array(inf_inputs["sample_index"]), - cf_sample=jnp.array(cf_sample), - use_mean=True, - ) + try: + mean_zs_ = mapped_inference_fn( + stacked_rngs=stacked_rngs, + x=jnp.array(inf_inputs["x"]), + sample_index=jnp.array(inf_inputs["sample_index"]), + cf_sample=jnp.array(cf_sample), + use_mean=True, + ) + except jax.errors.JaxRuntimeError as e: + if use_vmap: + raise RuntimeError( + "JAX ran out of memory. Try setting use_vmap=False." + ) from e + else: + raise e + mean_zs = xr.DataArray( - mean_zs_, + np.array(mean_zs_), dims=["cell_name", "sample", "latent_dim"], coords={ "cell_name": self.adata.obs_names[indices].values, @@ -445,7 +455,7 @@ def per_sample_inference_fn(pair): ) # (n_mc_samples, n_cells, n_samples, n_latent) sampled_zs_ = sampled_zs_.transpose((1, 0, 2, 3)) sampled_zs = xr.DataArray( - sampled_zs_, + np.array(sampled_zs_), dims=["cell_name", "mc_sample", "sample", "latent_dim"], coords={ "cell_name": self.adata.obs_names[indices].values, @@ -456,12 +466,12 @@ def per_sample_inference_fn(pair): if reqs.needs_mean_distances: mean_dists = self._compute_distances_from_representations( - mean_zs_, indices, norm=norm + mean_zs_, indices, norm=norm, return_numpy=True ) if reqs.needs_sampled_distances or reqs.needs_normalized_distances: sampled_dists = self._compute_distances_from_representations( - sampled_zs_, indices, norm=norm + sampled_zs_, indices, norm=norm, return_numpy=True ) if reqs.needs_normalized_distances: @@ -570,6 +580,7 @@ def _compute_distances_from_representations( reps: jax.typing.ArrayLike, indices: jax.typing.ArrayLike, norm: Literal["l2", "l1", "linf"] = "l2", + return_numpy: bool = True, ) -> xr.DataArray: if norm not in ("l2", "l1", "linf"): raise ValueError(f"`norm` {norm} not supported") @@ -588,6 +599,8 @@ def _compute_distance(rep: jax.typing.ArrayLike): if reps.ndim == 3: dists = jax.vmap(_compute_distance)(reps) + if return_numpy: + dists = np.array(dists) return xr.DataArray( dists, dims=["cell_name", "sample_x", "sample_y"], @@ -601,6 +614,8 @@ def _compute_distance(rep: jax.typing.ArrayLike): else: # Case with sampled representations dists = jax.vmap(jax.vmap(_compute_distance))(reps) + if return_numpy: + dists = np.array(dists) return xr.DataArray( dists, dims=["cell_name", "mc_sample", "sample_x", "sample_y"], @@ -619,7 +634,7 @@ def get_local_sample_representation( indices: npt.ArrayLike | None = None, batch_size: int = 256, use_mean: bool = True, - use_vmap: bool = True, + use_vmap: Literal["auto", True, False] = "auto", ) -> xr.DataArray: """Compute the local sample representation of the cells in the ``adata`` object. @@ -660,7 +675,7 @@ def get_local_sample_distances( batch_size: int = 256, use_mean: bool = True, normalize_distances: bool = False, - use_vmap: bool = True, + use_vmap: Literal["auto", True, False] = "auto", groupby: list[str] | str | None = None, keep_cell: bool = True, norm: str = "l2", @@ -698,6 +713,8 @@ def get_local_sample_distances( Number of Monte Carlo samples to use for computing the local sample distances. Only relevant if ``use_mean=False``. """ + use_vmap = "auto" if use_vmap == "auto" else use_vmap + input = "mean_distances" if use_mean else "sampled_distances" if normalize_distances: if use_mean: @@ -1053,7 +1070,7 @@ def differential_expression( sample_cov_keys: list[str] | None = None, sample_subset: list[str] | None = None, batch_size: int = 128, - use_vmap: bool = True, + use_vmap: Literal["auto", True, False] = "auto", normalize_design_matrix: bool = True, add_batch_specific_offsets: bool = False, mc_samples: int = 100, @@ -1142,6 +1159,8 @@ def differential_expression( from scipy.stats import false_discovery_control + use_vmap = use_vmap if use_vmap != "auto" else self.summary_stats.n_sample < 500 + if sample_cov_keys is None: # Hack: kept as kwarg to maintain order of arguments. raise ValueError("Must assign `sample_cov_keys`") @@ -1371,19 +1390,26 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): Amat = jax.device_put(Amat, self.device) prefactor = jax.device_put(prefactor, self.device) - res = mapped_inference_fn( - stacked_rngs=stacked_rngs, - x=jnp.array(inf_inputs["x"]), - sample_index=jnp.array(inf_inputs["sample_index"]), - cf_sample=jnp.array(cf_sample), - Amat=Amat, - prefactor=prefactor, - n_samples_per_cell=n_samples_per_cell, - admissible_samples_mat=admissible_samples_mat, - use_mean=False, - rngs_de=rngs_de, - mc_samples=mc_samples, - ) + try: + res = mapped_inference_fn( + stacked_rngs=stacked_rngs, + x=jnp.array(inf_inputs["x"]), + sample_index=jnp.array(inf_inputs["sample_index"]), + cf_sample=jnp.array(cf_sample), + Amat=Amat, + prefactor=prefactor, + n_samples_per_cell=n_samples_per_cell, + admissible_samples_mat=admissible_samples_mat, + use_mean=False, + rngs_de=rngs_de, + mc_samples=mc_samples, + ) + except jax.errors.JaxRuntimeError as e: + if use_vmap: + raise RuntimeError("JAX ran out of memory. Try setting use_vmap=False.") from e + else: + raise e + beta.append(np.array(res["beta"])) effect_size.append(np.array(res["effect_size"])) pvalue.append(np.array(res["pvalue"])) @@ -1557,3 +1583,29 @@ def _construct_design_matrix( covariates_require_lfc = jnp.array(covariates_require_lfc) return Xmat, Xmat_names, covariates_require_lfc, offset_indices + + def update_sample_info(self, adata): + """Initialize/update metadata in the case where additional covariates are added. + + Parameters + ---------- + adata + AnnData object to update the sample info with. Typically, this corresponds to the + working dataset, where additional sample-specific covariates have been added. + + Examples + -------- + >>> import scanpy as sc + >>> from scvi.external import MRVI + >>> MRVI.setup_anndata(adata, sample_key="sample_id") + >>> model = MRVI(adata) + >>> model.train() + >>> # Update sample info with new covariates + >>> sample_mapper = {"sample_1": "healthy", "sample_2": "disease"} + >>> adata.obs["disease_status"] = adata.obs["sample_id"].map(sample_mapper) + >>> model.update_sample_info(adata) + """ + adata = self._validate_anndata(adata) + obs_df = adata.obs.copy() + obs_df = obs_df.loc[~obs_df._scvi_sample.duplicated("first")] + self.sample_info = obs_df.set_index("_scvi_sample").sort_index()