From 060eedc12cc7048c554d477dfb5f62161e2a8e3b Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Tue, 14 Jan 2025 02:24:12 -0800 Subject: [PATCH 01/14] vmap False by default --- src/scvi/external/mrvi/_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index e7e7b7c2fa..d6662adf57 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -309,7 +309,7 @@ def compute_local_statistics( adata: AnnData | None = None, indices: npt.ArrayLike | None = None, batch_size: int | None = None, - use_vmap: bool = True, + use_vmap: bool = False, norm: str = "l2", mc_samples: int = 10, ) -> xr.Dataset: @@ -619,7 +619,7 @@ def get_local_sample_representation( indices: npt.ArrayLike | None = None, batch_size: int = 256, use_mean: bool = True, - use_vmap: bool = True, + use_vmap: bool = False, ) -> xr.DataArray: """Compute the local sample representation of the cells in the ``adata`` object. @@ -660,7 +660,7 @@ def get_local_sample_distances( batch_size: int = 256, use_mean: bool = True, normalize_distances: bool = False, - use_vmap: bool = True, + use_vmap: bool = False, groupby: list[str] | str | None = None, keep_cell: bool = True, norm: str = "l2", @@ -1053,7 +1053,7 @@ def differential_expression( sample_cov_keys: list[str] | None = None, sample_subset: list[str] | None = None, batch_size: int = 128, - use_vmap: bool = True, + use_vmap: bool = False, normalize_design_matrix: bool = True, add_batch_specific_offsets: bool = False, mc_samples: int = 100, From d685dbb8336f91d9629346c37b9f7acc84f4aef6 Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Tue, 14 Jan 2025 10:13:01 -0800 Subject: [PATCH 02/14] more informative tracebacks + auto vmap --- docs/tutorials/notebooks | 2 +- src/scvi/external/mrvi/_model.py | 73 +++++++++++++++++++++----------- 2 files changed, 49 insertions(+), 26 deletions(-) diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index b5890651da..c2fc6d100e 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit b5890651da3ad734cc12e7d54b39395aa6e9137d +Subproject commit c2fc6d100ecc28e716f9ffc96bc68af48a7733b4 diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index d6662adf57..68f4ffcdf3 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -309,7 +309,7 @@ def compute_local_statistics( adata: AnnData | None = None, indices: npt.ArrayLike | None = None, batch_size: int | None = None, - use_vmap: bool = False, + use_vmap: Literal["auto", True, False] = "auto", norm: str = "l2", mc_samples: int = 10, ) -> xr.Dataset: @@ -330,7 +330,8 @@ def compute_local_statistics( batch_size Batch size to use for computing the local statistics. use_vmap - Whether to use vmap to compute the local statistics. + Whether to use vmap to compute the local statistics. If "auto", vmap will be used if + the number of samples is less than 500. norm Norm to use for computing the distances. mc_samples @@ -341,6 +342,8 @@ def compute_local_statistics( from scvi.external.mrvi._utils import _parse_local_statistics_requirements + use_vmap = use_vmap if use_vmap != "auto" else self.summary_stats.n_sample < 500 + if not reductions or len(reductions) == 0: raise ValueError("At least one reduction must be provided.") @@ -418,13 +421,22 @@ def per_sample_inference_fn(pair): # OK to use stacked rngs here since there is no stochasticity for mean rep. if reqs.needs_mean_representations: - mean_zs_ = mapped_inference_fn( - stacked_rngs=stacked_rngs, - x=jnp.array(inf_inputs["x"]), - sample_index=jnp.array(inf_inputs["sample_index"]), - cf_sample=jnp.array(cf_sample), - use_mean=True, - ) + try: + mean_zs_ = mapped_inference_fn( + stacked_rngs=stacked_rngs, + x=jnp.array(inf_inputs["x"]), + sample_index=jnp.array(inf_inputs["sample_index"]), + cf_sample=jnp.array(cf_sample), + use_mean=True, + ) + except jax.errors.JaxRuntimeError as e: + if use_vmap: + raise RuntimeError( + "JAX ran out of memory. Try setting use_vmap=False." + ) from e + else: + raise e + mean_zs = xr.DataArray( mean_zs_, dims=["cell_name", "sample", "latent_dim"], @@ -619,7 +631,7 @@ def get_local_sample_representation( indices: npt.ArrayLike | None = None, batch_size: int = 256, use_mean: bool = True, - use_vmap: bool = False, + use_vmap: Literal["auto", True, False] = "auto", ) -> xr.DataArray: """Compute the local sample representation of the cells in the ``adata`` object. @@ -660,7 +672,7 @@ def get_local_sample_distances( batch_size: int = 256, use_mean: bool = True, normalize_distances: bool = False, - use_vmap: bool = False, + use_vmap: Literal["auto", True, False] = "auto", groupby: list[str] | str | None = None, keep_cell: bool = True, norm: str = "l2", @@ -698,6 +710,8 @@ def get_local_sample_distances( Number of Monte Carlo samples to use for computing the local sample distances. Only relevant if ``use_mean=False``. """ + use_vmap = "auto" if use_vmap == "auto" else use_vmap + input = "mean_distances" if use_mean else "sampled_distances" if normalize_distances: if use_mean: @@ -1053,7 +1067,7 @@ def differential_expression( sample_cov_keys: list[str] | None = None, sample_subset: list[str] | None = None, batch_size: int = 128, - use_vmap: bool = False, + use_vmap: Literal["auto", True, False] = "auto", normalize_design_matrix: bool = True, add_batch_specific_offsets: bool = False, mc_samples: int = 100, @@ -1142,6 +1156,8 @@ def differential_expression( from scipy.stats import false_discovery_control + use_vmap = use_vmap if use_vmap != "auto" else self.summary_stats.n_sample < 500 + if sample_cov_keys is None: # Hack: kept as kwarg to maintain order of arguments. raise ValueError("Must assign `sample_cov_keys`") @@ -1371,19 +1387,26 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): Amat = jax.device_put(Amat, self.device) prefactor = jax.device_put(prefactor, self.device) - res = mapped_inference_fn( - stacked_rngs=stacked_rngs, - x=jnp.array(inf_inputs["x"]), - sample_index=jnp.array(inf_inputs["sample_index"]), - cf_sample=jnp.array(cf_sample), - Amat=Amat, - prefactor=prefactor, - n_samples_per_cell=n_samples_per_cell, - admissible_samples_mat=admissible_samples_mat, - use_mean=False, - rngs_de=rngs_de, - mc_samples=mc_samples, - ) + try: + res = mapped_inference_fn( + stacked_rngs=stacked_rngs, + x=jnp.array(inf_inputs["x"]), + sample_index=jnp.array(inf_inputs["sample_index"]), + cf_sample=jnp.array(cf_sample), + Amat=Amat, + prefactor=prefactor, + n_samples_per_cell=n_samples_per_cell, + admissible_samples_mat=admissible_samples_mat, + use_mean=False, + rngs_de=rngs_de, + mc_samples=mc_samples, + ) + except jax.errors.JaxRuntimeError as e: + if use_vmap: + raise RuntimeError("JAX ran out of memory. Try setting use_vmap=False.") from e + else: + raise e + beta.append(np.array(res["beta"])) effect_size.append(np.array(res["effect_size"])) pvalue.append(np.array(res["pvalue"])) From e3e5ed61a203d5f2a10389816524097a2df61f12 Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Mon, 27 Jan 2025 15:52:25 -0800 Subject: [PATCH 03/14] store dmats as numpy arrays instead of jax. --- src/scvi/external/mrvi/_model.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index 68f4ffcdf3..d8333ceb32 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -438,7 +438,7 @@ def per_sample_inference_fn(pair): raise e mean_zs = xr.DataArray( - mean_zs_, + np.array(mean_zs_), dims=["cell_name", "sample", "latent_dim"], coords={ "cell_name": self.adata.obs_names[indices].values, @@ -457,7 +457,7 @@ def per_sample_inference_fn(pair): ) # (n_mc_samples, n_cells, n_samples, n_latent) sampled_zs_ = sampled_zs_.transpose((1, 0, 2, 3)) sampled_zs = xr.DataArray( - sampled_zs_, + np.array(sampled_zs_), dims=["cell_name", "mc_sample", "sample", "latent_dim"], coords={ "cell_name": self.adata.obs_names[indices].values, @@ -468,12 +468,12 @@ def per_sample_inference_fn(pair): if reqs.needs_mean_distances: mean_dists = self._compute_distances_from_representations( - mean_zs_, indices, norm=norm + mean_zs_, indices, norm=norm, return_numpy=True ) if reqs.needs_sampled_distances or reqs.needs_normalized_distances: sampled_dists = self._compute_distances_from_representations( - sampled_zs_, indices, norm=norm + sampled_zs_, indices, norm=norm, return_numpy=True ) if reqs.needs_normalized_distances: @@ -582,6 +582,7 @@ def _compute_distances_from_representations( reps: jax.typing.ArrayLike, indices: jax.typing.ArrayLike, norm: Literal["l2", "l1", "linf"] = "l2", + return_numpy: bool = True, ) -> xr.DataArray: if norm not in ("l2", "l1", "linf"): raise ValueError(f"`norm` {norm} not supported") @@ -600,6 +601,8 @@ def _compute_distance(rep: jax.typing.ArrayLike): if reps.ndim == 3: dists = jax.vmap(_compute_distance)(reps) + if return_numpy: + dists = np.array(dists) return xr.DataArray( dists, dims=["cell_name", "sample_x", "sample_y"], @@ -613,6 +616,8 @@ def _compute_distance(rep: jax.typing.ArrayLike): else: # Case with sampled representations dists = jax.vmap(jax.vmap(_compute_distance))(reps) + if return_numpy: + dists = np.array(dists) return xr.DataArray( dists, dims=["cell_name", "mc_sample", "sample_x", "sample_y"], From ebc9463926bb928a9d6c4455232544f17a09b2bd Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Fri, 14 Feb 2025 20:10:09 -0800 Subject: [PATCH 04/14] add way to update self.sample_info --- src/scvi/external/mrvi/_model.py | 117 +++++++++++++++++++++++-------- 1 file changed, 87 insertions(+), 30 deletions(-) diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index d8333ceb32..4ac5eb76d1 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -105,9 +105,7 @@ def __init__(self, adata: AnnData, **model_kwargs): n_batch = self.summary_stats.n_batch n_labels = self.summary_stats.n_labels - obs_df = adata.obs.copy() - obs_df = obs_df.loc[~obs_df._scvi_sample.duplicated("first")] - self.sample_info = obs_df.set_index("_scvi_sample").sort_index() + self.update_sample_info(adata) self.sample_key = self.adata_manager.get_state_registry( REGISTRY_KEYS.SAMPLE_KEY ).original_key @@ -133,7 +131,9 @@ def to_device(self, device): # TODO(jhong): remove this once we have a better way to handle device. pass - def _generate_stacked_rngs(self, n_sets: int | tuple) -> dict[str, jax.random.KeyArray]: + def _generate_stacked_rngs( + self, n_sets: int | tuple + ) -> dict[str, jax.random.KeyArray]: return_1d = isinstance(n_sets, int) if return_1d: n_sets_1d = n_sets @@ -191,7 +191,9 @@ def setup_anndata( fields.NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), ] - adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) + adata_manager = AnnDataManager( + fields=anndata_fields, setup_method_args=setup_method_args + ) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -408,12 +410,16 @@ def per_sample_inference_fn(pair): for ur in reqs.ungrouped_reductions: ungrouped_data_arrs[ur.name] = [] for gr in reqs.grouped_reductions: - grouped_data_arrs[gr.name] = {} # Will map group category to running group sum. + grouped_data_arrs[gr.name] = ( + {} + ) # Will map group category to running group sum. for array_dict in tqdm(scdl): indices = array_dict[REGISTRY_KEYS.INDICES_KEY].astype(int).flatten() n_cells = array_dict[REGISTRY_KEYS.X_KEY].shape[0] - cf_sample = np.broadcast_to(np.arange(n_sample)[:, None, None], (n_sample, n_cells, 1)) + cf_sample = np.broadcast_to( + np.arange(n_sample)[:, None, None], (n_sample, n_cells, 1) + ) inf_inputs = self.module._get_inference_input( array_dict, ) @@ -490,8 +496,11 @@ def per_sample_inference_fn(pair): normalization_means = normalization_means.reshape(-1, 1, 1, 1) normalization_vars = normalization_vars.reshape(-1, 1, 1, 1) normalized_dists = ( - (sampled_dists - normalization_means) / (normalization_vars**0.5) - ).mean(dim="mc_sample") # (n_cells, n_samples, n_samples) + (sampled_dists - normalization_means) + / (normalization_vars**0.5) + ).mean( + dim="mc_sample" + ) # (n_cells, n_samples, n_samples) # Compute each reduction for r in reductions: @@ -515,7 +524,9 @@ def per_sample_inference_fn(pair): group_by_cats = group_by.unique() for cat in group_by_cats: cat_summed_outputs = outputs.sel( - cell_name=self.adata.obs_names[indices][group_by == cat].values + cell_name=self.adata.obs_names[indices][ + group_by == cat + ].values ).sum(dim="cell_name") cat_summed_outputs = cat_summed_outputs.assign_coords( {f"{r.group_by}_name": cat} @@ -537,8 +548,12 @@ def per_sample_inference_fn(pair): group_by_counts = group_by.value_counts() averaged_grouped_data_arrs = [] for cat, count in group_by_counts.items(): - averaged_grouped_data_arrs.append(grouped_data_arrs[gr.name][cat] / count) - final_data_arr = xr.concat(averaged_grouped_data_arrs, dim=f"{gr.group_by}_name") + averaged_grouped_data_arrs.append( + grouped_data_arrs[gr.name][cat] / count + ) + final_data_arr = xr.concat( + averaged_grouped_data_arrs, dim=f"{gr.group_by}_name" + ) final_data_arrs[gr.name] = final_data_arr return xr.Dataset(data_vars=final_data_arrs) @@ -731,7 +746,9 @@ def get_local_sample_distances( reductions = [] if not keep_cell and not groupby: - raise ValueError("Undefined computation because not keep_cell and no groupby.") + raise ValueError( + "Undefined computation because not keep_cell and no groupby." + ) if keep_cell: reductions.append( MRVIReduction( @@ -801,7 +818,9 @@ def get_aggregated_posterior( qu_locs = [] qu_scales = [] - jit_inference_fn = self.module.get_jit_inference_fn(inference_kwargs={"use_mean": True}) + jit_inference_fn = self.module.get_jit_inference_fn( + inference_kwargs={"use_mean": True} + ) for array_dict in scdl: outputs = jit_inference_fn(self.module.rngs, array_dict) @@ -885,7 +904,9 @@ def differential_abundance( n_splits = max(adata.n_obs // batch_size, 1) log_probs_ = [] for u_rep in np.array_split(us, n_splits): - log_probs_.append(jax.device_get(ap.log_prob(u_rep).sum(-1, keepdims=True))) + log_probs_.append( + jax.device_get(ap.log_prob(u_rep).sum(-1, keepdims=True)) + ) log_probs.append(np.concatenate(log_probs_, axis=0)) # (n_cells, 1) @@ -911,7 +932,9 @@ def differential_abundance( per_val_log_enrichs = {} for sample_cov_value in sample_cov_unique_values: cov_samples = ( - self.sample_info[self.sample_info[sample_cov_key] == sample_cov_value] + self.sample_info[ + self.sample_info[sample_cov_key] == sample_cov_value + ] )[self.sample_key].to_numpy() if sample_subset is not None: cov_samples = np.intersect1d(cov_samples, np.array(sample_subset)) @@ -919,7 +942,9 @@ def differential_abundance( continue sel_log_probs = log_probs_arr.log_probs.loc[{"sample": cov_samples}] - val_log_probs = logsumexp(sel_log_probs, axis=1) - np.log(sel_log_probs.shape[1]) + val_log_probs = logsumexp(sel_log_probs, axis=1) - np.log( + sel_log_probs.shape[1] + ) per_val_log_probs[sample_cov_value] = val_log_probs if compute_log_enrichment: @@ -932,13 +957,17 @@ def differential_abundance( stacklevel=2, ) continue - rest_log_probs = log_probs_arr.log_probs.loc[{"sample": rest_samples}] + rest_log_probs = log_probs_arr.log_probs.loc[ + {"sample": rest_samples} + ] rest_val_log_probs = logsumexp(rest_log_probs, axis=1) - np.log( rest_log_probs.shape[1] ) enrichment_scores = val_log_probs - rest_val_log_probs per_val_log_enrichs[sample_cov_value] = enrichment_scores - sample_cov_log_probs_map[sample_cov_key] = DataFrame.from_dict(per_val_log_probs) + sample_cov_log_probs_map[sample_cov_key] = DataFrame.from_dict( + per_val_log_probs + ) if compute_log_enrichment and len(per_val_log_enrichs) > 0: sample_cov_log_enrichs_map[sample_cov_key] = DataFrame.from_dict( per_val_log_enrichs @@ -1015,7 +1044,9 @@ def get_outlier_cell_sample_pairs( for sample_name in tqdm(unique_samples): sample_idxs = np.where(adata.obs[self.sample_key] == sample_name)[0] if subsample_size is not None and sample_idxs.shape[0] > subsample_size: - sample_idxs = np.random.choice(sample_idxs, size=subsample_size, replace=False) + sample_idxs = np.random.choice( + sample_idxs, size=subsample_size, replace=False + ) adata_s = adata[sample_idxs] ap = self.get_aggregated_posterior(adata=adata, indices=sample_idxs) @@ -1328,7 +1359,9 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): ) mc_samples, _, n_cells_, n_latent = betas_covariates.shape betas_offset_ = ( - jnp.zeros((mc_samples, self.summary_stats.n_batch, n_cells_, n_latent)) + jnp.zeros( + (mc_samples, self.summary_stats.n_batch, n_cells_, n_latent) + ) + eps_mean_ ) # batch_offset shape (mc_samples, n_batch, n_cells, n_latent) @@ -1336,7 +1369,9 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): f_ = jax.vmap( h_inference_fn, in_axes=(0, None, 0), out_axes=0 ) # fn over MC samples - f_ = jax.vmap(f_, in_axes=(1, None, None), out_axes=1) # fn over covariates + f_ = jax.vmap( + f_, in_axes=(1, None, None), out_axes=1 + ) # fn over covariates f_ = jax.vmap(f_, in_axes=(None, 0, 1), out_axes=0) # fn over batches h_fn = jax.jit(f_) @@ -1346,7 +1381,9 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): lfcs = jnp.log2(x_1 + eps_lfc) - jnp.log2(x_0 + eps_lfc) lfc_mean = jnp.average(lfcs.mean(1), weights=batch_weights, axis=0) if delta is not None: - lfc_std = jnp.sqrt(jnp.average(lfcs.var(1), weights=batch_weights, axis=0)) + lfc_std = jnp.sqrt( + jnp.average(lfcs.var(1), weights=batch_weights, axis=0) + ) pde = (jnp.abs(lfcs) >= delta).mean(1).mean(0) if store_baseline: @@ -1382,7 +1419,9 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): stacked_rngs = self._generate_stacked_rngs(cf_sample.shape[0]) rngs_de = self.module.rngs if store_lfc else None - admissible_samples_mat = jnp.array(admissible_samples[indices]) # (n_cells, n_samples) + admissible_samples_mat = jnp.array( + admissible_samples[indices] + ) # (n_cells, n_samples) n_samples_per_cell = admissible_samples_mat.sum(axis=1) admissible_samples_dmat = jax.vmap(jnp.diag)(admissible_samples_mat).astype( float @@ -1408,7 +1447,9 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): ) except jax.errors.JaxRuntimeError as e: if use_vmap: - raise RuntimeError("JAX ran out of memory. Try setting use_vmap=False.") from e + raise RuntimeError( + "JAX ran out of memory. Try setting use_vmap=False." + ) from e else: raise e @@ -1426,7 +1467,9 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): effect_size = np.concatenate(effect_size, axis=0) pvalue = np.concatenate(pvalue, axis=0) pvalue_shape = pvalue.shape - padj = false_discovery_control(pvalue.flatten(), method="bh").reshape(pvalue_shape) + padj = false_discovery_control(pvalue.flatten(), method="bh").reshape( + pvalue_shape + ) coords = { "cell_name": (("cell_name"), adata.obs_names), @@ -1545,19 +1588,27 @@ def _construct_design_matrix( Xmat_dim_to_key = np.concatenate(Xmat_dim_to_key) if normalize_design_matrix: - Xmat = (Xmat - Xmat.min(axis=0)) / (1e-6 + Xmat.max(axis=0) - Xmat.min(axis=0)) + Xmat = (Xmat - Xmat.min(axis=0)) / ( + 1e-6 + Xmat.max(axis=0) - Xmat.min(axis=0) + ) if add_batch_specific_offsets: cov = sample_info["_scvi_batch"] if cov.nunique() == self.summary_stats.n_batch: - cov = np.eye(self.summary_stats.n_batch)[sample_info["_scvi_batch"].values] - cov_names = ["offset_batch_" + str(i) for i in range(self.summary_stats.n_batch)] + cov = np.eye(self.summary_stats.n_batch)[ + sample_info["_scvi_batch"].values + ] + cov_names = [ + "offset_batch_" + str(i) for i in range(self.summary_stats.n_batch) + ] Xmat = np.concatenate([cov, Xmat], axis=1) Xmat_names = np.concatenate([np.array(cov_names), Xmat_names]) Xmat_dim_to_key = np.concatenate([np.array(cov_names), Xmat_dim_to_key]) # Retrieve indices of offset covariates in the right order offset_indices = ( - Series(np.arange(len(Xmat_names)), index=Xmat_names).loc[cov_names].values + Series(np.arange(len(Xmat_names)), index=Xmat_names) + .loc[cov_names] + .values ) offset_indices = jnp.array(offset_indices) else: @@ -1585,3 +1636,9 @@ def _construct_design_matrix( covariates_require_lfc = jnp.array(covariates_require_lfc) return Xmat, Xmat_names, covariates_require_lfc, offset_indices + + def update_sample_info(self, adata): + """Initialize/update metadata in the case where additional covariates are added.""" + obs_df = adata.obs.copy() + obs_df = obs_df.loc[~obs_df._scvi_sample.duplicated("first")] + self.sample_info = obs_df.set_index("_scvi_sample").sort_index() From 16f3d9ef993651d3cb1b8f02ebf549f583621915 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 15 Feb 2025 04:10:24 +0000 Subject: [PATCH 05/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/external/mrvi/_model.py | 107 ++++++++----------------------- 1 file changed, 27 insertions(+), 80 deletions(-) diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index 4ac5eb76d1..1dc970dc0a 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -131,9 +131,7 @@ def to_device(self, device): # TODO(jhong): remove this once we have a better way to handle device. pass - def _generate_stacked_rngs( - self, n_sets: int | tuple - ) -> dict[str, jax.random.KeyArray]: + def _generate_stacked_rngs(self, n_sets: int | tuple) -> dict[str, jax.random.KeyArray]: return_1d = isinstance(n_sets, int) if return_1d: n_sets_1d = n_sets @@ -191,9 +189,7 @@ def setup_anndata( fields.NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -410,16 +406,12 @@ def per_sample_inference_fn(pair): for ur in reqs.ungrouped_reductions: ungrouped_data_arrs[ur.name] = [] for gr in reqs.grouped_reductions: - grouped_data_arrs[gr.name] = ( - {} - ) # Will map group category to running group sum. + grouped_data_arrs[gr.name] = {} # Will map group category to running group sum. for array_dict in tqdm(scdl): indices = array_dict[REGISTRY_KEYS.INDICES_KEY].astype(int).flatten() n_cells = array_dict[REGISTRY_KEYS.X_KEY].shape[0] - cf_sample = np.broadcast_to( - np.arange(n_sample)[:, None, None], (n_sample, n_cells, 1) - ) + cf_sample = np.broadcast_to(np.arange(n_sample)[:, None, None], (n_sample, n_cells, 1)) inf_inputs = self.module._get_inference_input( array_dict, ) @@ -496,11 +488,8 @@ def per_sample_inference_fn(pair): normalization_means = normalization_means.reshape(-1, 1, 1, 1) normalization_vars = normalization_vars.reshape(-1, 1, 1, 1) normalized_dists = ( - (sampled_dists - normalization_means) - / (normalization_vars**0.5) - ).mean( - dim="mc_sample" - ) # (n_cells, n_samples, n_samples) + (sampled_dists - normalization_means) / (normalization_vars**0.5) + ).mean(dim="mc_sample") # (n_cells, n_samples, n_samples) # Compute each reduction for r in reductions: @@ -524,9 +513,7 @@ def per_sample_inference_fn(pair): group_by_cats = group_by.unique() for cat in group_by_cats: cat_summed_outputs = outputs.sel( - cell_name=self.adata.obs_names[indices][ - group_by == cat - ].values + cell_name=self.adata.obs_names[indices][group_by == cat].values ).sum(dim="cell_name") cat_summed_outputs = cat_summed_outputs.assign_coords( {f"{r.group_by}_name": cat} @@ -548,12 +535,8 @@ def per_sample_inference_fn(pair): group_by_counts = group_by.value_counts() averaged_grouped_data_arrs = [] for cat, count in group_by_counts.items(): - averaged_grouped_data_arrs.append( - grouped_data_arrs[gr.name][cat] / count - ) - final_data_arr = xr.concat( - averaged_grouped_data_arrs, dim=f"{gr.group_by}_name" - ) + averaged_grouped_data_arrs.append(grouped_data_arrs[gr.name][cat] / count) + final_data_arr = xr.concat(averaged_grouped_data_arrs, dim=f"{gr.group_by}_name") final_data_arrs[gr.name] = final_data_arr return xr.Dataset(data_vars=final_data_arrs) @@ -746,9 +729,7 @@ def get_local_sample_distances( reductions = [] if not keep_cell and not groupby: - raise ValueError( - "Undefined computation because not keep_cell and no groupby." - ) + raise ValueError("Undefined computation because not keep_cell and no groupby.") if keep_cell: reductions.append( MRVIReduction( @@ -818,9 +799,7 @@ def get_aggregated_posterior( qu_locs = [] qu_scales = [] - jit_inference_fn = self.module.get_jit_inference_fn( - inference_kwargs={"use_mean": True} - ) + jit_inference_fn = self.module.get_jit_inference_fn(inference_kwargs={"use_mean": True}) for array_dict in scdl: outputs = jit_inference_fn(self.module.rngs, array_dict) @@ -904,9 +883,7 @@ def differential_abundance( n_splits = max(adata.n_obs // batch_size, 1) log_probs_ = [] for u_rep in np.array_split(us, n_splits): - log_probs_.append( - jax.device_get(ap.log_prob(u_rep).sum(-1, keepdims=True)) - ) + log_probs_.append(jax.device_get(ap.log_prob(u_rep).sum(-1, keepdims=True))) log_probs.append(np.concatenate(log_probs_, axis=0)) # (n_cells, 1) @@ -932,9 +909,7 @@ def differential_abundance( per_val_log_enrichs = {} for sample_cov_value in sample_cov_unique_values: cov_samples = ( - self.sample_info[ - self.sample_info[sample_cov_key] == sample_cov_value - ] + self.sample_info[self.sample_info[sample_cov_key] == sample_cov_value] )[self.sample_key].to_numpy() if sample_subset is not None: cov_samples = np.intersect1d(cov_samples, np.array(sample_subset)) @@ -942,9 +917,7 @@ def differential_abundance( continue sel_log_probs = log_probs_arr.log_probs.loc[{"sample": cov_samples}] - val_log_probs = logsumexp(sel_log_probs, axis=1) - np.log( - sel_log_probs.shape[1] - ) + val_log_probs = logsumexp(sel_log_probs, axis=1) - np.log(sel_log_probs.shape[1]) per_val_log_probs[sample_cov_value] = val_log_probs if compute_log_enrichment: @@ -957,17 +930,13 @@ def differential_abundance( stacklevel=2, ) continue - rest_log_probs = log_probs_arr.log_probs.loc[ - {"sample": rest_samples} - ] + rest_log_probs = log_probs_arr.log_probs.loc[{"sample": rest_samples}] rest_val_log_probs = logsumexp(rest_log_probs, axis=1) - np.log( rest_log_probs.shape[1] ) enrichment_scores = val_log_probs - rest_val_log_probs per_val_log_enrichs[sample_cov_value] = enrichment_scores - sample_cov_log_probs_map[sample_cov_key] = DataFrame.from_dict( - per_val_log_probs - ) + sample_cov_log_probs_map[sample_cov_key] = DataFrame.from_dict(per_val_log_probs) if compute_log_enrichment and len(per_val_log_enrichs) > 0: sample_cov_log_enrichs_map[sample_cov_key] = DataFrame.from_dict( per_val_log_enrichs @@ -1044,9 +1013,7 @@ def get_outlier_cell_sample_pairs( for sample_name in tqdm(unique_samples): sample_idxs = np.where(adata.obs[self.sample_key] == sample_name)[0] if subsample_size is not None and sample_idxs.shape[0] > subsample_size: - sample_idxs = np.random.choice( - sample_idxs, size=subsample_size, replace=False - ) + sample_idxs = np.random.choice(sample_idxs, size=subsample_size, replace=False) adata_s = adata[sample_idxs] ap = self.get_aggregated_posterior(adata=adata, indices=sample_idxs) @@ -1359,9 +1326,7 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): ) mc_samples, _, n_cells_, n_latent = betas_covariates.shape betas_offset_ = ( - jnp.zeros( - (mc_samples, self.summary_stats.n_batch, n_cells_, n_latent) - ) + jnp.zeros((mc_samples, self.summary_stats.n_batch, n_cells_, n_latent)) + eps_mean_ ) # batch_offset shape (mc_samples, n_batch, n_cells, n_latent) @@ -1369,9 +1334,7 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): f_ = jax.vmap( h_inference_fn, in_axes=(0, None, 0), out_axes=0 ) # fn over MC samples - f_ = jax.vmap( - f_, in_axes=(1, None, None), out_axes=1 - ) # fn over covariates + f_ = jax.vmap(f_, in_axes=(1, None, None), out_axes=1) # fn over covariates f_ = jax.vmap(f_, in_axes=(None, 0, 1), out_axes=0) # fn over batches h_fn = jax.jit(f_) @@ -1381,9 +1344,7 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): lfcs = jnp.log2(x_1 + eps_lfc) - jnp.log2(x_0 + eps_lfc) lfc_mean = jnp.average(lfcs.mean(1), weights=batch_weights, axis=0) if delta is not None: - lfc_std = jnp.sqrt( - jnp.average(lfcs.var(1), weights=batch_weights, axis=0) - ) + lfc_std = jnp.sqrt(jnp.average(lfcs.var(1), weights=batch_weights, axis=0)) pde = (jnp.abs(lfcs) >= delta).mean(1).mean(0) if store_baseline: @@ -1419,9 +1380,7 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): stacked_rngs = self._generate_stacked_rngs(cf_sample.shape[0]) rngs_de = self.module.rngs if store_lfc else None - admissible_samples_mat = jnp.array( - admissible_samples[indices] - ) # (n_cells, n_samples) + admissible_samples_mat = jnp.array(admissible_samples[indices]) # (n_cells, n_samples) n_samples_per_cell = admissible_samples_mat.sum(axis=1) admissible_samples_dmat = jax.vmap(jnp.diag)(admissible_samples_mat).astype( float @@ -1447,9 +1406,7 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): ) except jax.errors.JaxRuntimeError as e: if use_vmap: - raise RuntimeError( - "JAX ran out of memory. Try setting use_vmap=False." - ) from e + raise RuntimeError("JAX ran out of memory. Try setting use_vmap=False.") from e else: raise e @@ -1467,9 +1424,7 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): effect_size = np.concatenate(effect_size, axis=0) pvalue = np.concatenate(pvalue, axis=0) pvalue_shape = pvalue.shape - padj = false_discovery_control(pvalue.flatten(), method="bh").reshape( - pvalue_shape - ) + padj = false_discovery_control(pvalue.flatten(), method="bh").reshape(pvalue_shape) coords = { "cell_name": (("cell_name"), adata.obs_names), @@ -1588,27 +1543,19 @@ def _construct_design_matrix( Xmat_dim_to_key = np.concatenate(Xmat_dim_to_key) if normalize_design_matrix: - Xmat = (Xmat - Xmat.min(axis=0)) / ( - 1e-6 + Xmat.max(axis=0) - Xmat.min(axis=0) - ) + Xmat = (Xmat - Xmat.min(axis=0)) / (1e-6 + Xmat.max(axis=0) - Xmat.min(axis=0)) if add_batch_specific_offsets: cov = sample_info["_scvi_batch"] if cov.nunique() == self.summary_stats.n_batch: - cov = np.eye(self.summary_stats.n_batch)[ - sample_info["_scvi_batch"].values - ] - cov_names = [ - "offset_batch_" + str(i) for i in range(self.summary_stats.n_batch) - ] + cov = np.eye(self.summary_stats.n_batch)[sample_info["_scvi_batch"].values] + cov_names = ["offset_batch_" + str(i) for i in range(self.summary_stats.n_batch)] Xmat = np.concatenate([cov, Xmat], axis=1) Xmat_names = np.concatenate([np.array(cov_names), Xmat_names]) Xmat_dim_to_key = np.concatenate([np.array(cov_names), Xmat_dim_to_key]) # Retrieve indices of offset covariates in the right order offset_indices = ( - Series(np.arange(len(Xmat_names)), index=Xmat_names) - .loc[cov_names] - .values + Series(np.arange(len(Xmat_names)), index=Xmat_names).loc[cov_names].values ) offset_indices = jnp.array(offset_indices) else: From 6377ee4f11450a214335504ba3a9e5e7cd60e621 Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Wed, 19 Feb 2025 20:25:12 -0800 Subject: [PATCH 06/14] better docstring --- src/scvi/external/mrvi/_model.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index 1dc970dc0a..baad88751d 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -1585,7 +1585,27 @@ def _construct_design_matrix( return Xmat, Xmat_names, covariates_require_lfc, offset_indices def update_sample_info(self, adata): - """Initialize/update metadata in the case where additional covariates are added.""" + """Initialize/update metadata in the case where additional covariates are added. + + Parameters + ---------- + adata + AnnData object to update the sample info with. Typically, this corresponds to the + working dataset, where additional sample-specific covariates have been added. + + Examples + -------- + >>> import scanpy as sc + >>> from scvi.external import MRVI + >>> MRVI.setup_anndata(adata, sample_key="sample_id") + >>> model = MRVI(adata) + >>> model.train() + >>> # Update sample info with new covariates + >>> sample_mapper = {"sample_1": "healthy", "sample_2": "disease"} + >>> adata.obs["disease_status"] = adata.obs["sample_id"].map(sample_mapper) + >>> model.update_sample_info(adata) + """ + adata = self._validate_anndata(adata) obs_df = adata.obs.copy() obs_df = obs_df.loc[~obs_df._scvi_sample.duplicated("first")] self.sample_info = obs_df.set_index("_scvi_sample").sort_index() From f3b12a4ea012212da541d90d8d4779664a17c102 Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Wed, 19 Feb 2025 20:26:59 -0800 Subject: [PATCH 07/14] readd submodule for scvi tutorials notebooks --- src/scvi/external/mrvi/.gitmodules | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 src/scvi/external/mrvi/.gitmodules diff --git a/src/scvi/external/mrvi/.gitmodules b/src/scvi/external/mrvi/.gitmodules new file mode 100644 index 0000000000..0f23124d6b --- /dev/null +++ b/src/scvi/external/mrvi/.gitmodules @@ -0,0 +1,5 @@ +[submodule "docs/tutorials/notebooks"] + path = docs/tutorials/notebooks + url = https://github.com/YosefLab/scvi-tutorials.git + branch = main + From b7be59ec64370b49db0b6c7a33361b742d9e158f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Feb 2025 04:27:18 +0000 Subject: [PATCH 08/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/external/mrvi/.gitmodules | 1 - 1 file changed, 1 deletion(-) diff --git a/src/scvi/external/mrvi/.gitmodules b/src/scvi/external/mrvi/.gitmodules index 0f23124d6b..d2b1a24352 100644 --- a/src/scvi/external/mrvi/.gitmodules +++ b/src/scvi/external/mrvi/.gitmodules @@ -2,4 +2,3 @@ path = docs/tutorials/notebooks url = https://github.com/YosefLab/scvi-tutorials.git branch = main - From c80a8f63747e38c20854c65e82614786d6e413e3 Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Wed, 19 Feb 2025 20:28:04 -0800 Subject: [PATCH 09/14] undo --- src/scvi/external/mrvi/.gitmodules | 5 ----- 1 file changed, 5 deletions(-) delete mode 100644 src/scvi/external/mrvi/.gitmodules diff --git a/src/scvi/external/mrvi/.gitmodules b/src/scvi/external/mrvi/.gitmodules deleted file mode 100644 index 0f23124d6b..0000000000 --- a/src/scvi/external/mrvi/.gitmodules +++ /dev/null @@ -1,5 +0,0 @@ -[submodule "docs/tutorials/notebooks"] - path = docs/tutorials/notebooks - url = https://github.com/YosefLab/scvi-tutorials.git - branch = main - From 1d8192d4981910b92d72911883f672cbe3a95f05 Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Wed, 19 Feb 2025 20:31:34 -0800 Subject: [PATCH 10/14] attempt undo --- .gitmodules | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitmodules b/.gitmodules index e69de29bb2..d2b1a24352 100644 --- a/.gitmodules +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "docs/tutorials/notebooks"] + path = docs/tutorials/notebooks + url = https://github.com/YosefLab/scvi-tutorials.git + branch = main From aeb9670bb5639529c4ada554fe11676415f9b0eb Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Thu, 20 Feb 2025 10:27:37 -0500 Subject: [PATCH 11/14] revert submodule update --- docs/tutorials/notebooks | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index c2fc6d100e..943703f938 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit c2fc6d100ecc28e716f9ffc96bc68af48a7733b4 +Subproject commit 943703f938c43ddc681e01c013d704db37fa3193 From c16f1b4dfc694f4d7139402ff28e8712f5d1586b Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Thu, 20 Feb 2025 16:43:00 +0100 Subject: [PATCH 12/14] Update CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6fa825537a..8dbfde7f36 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,9 @@ to [Semantic Versioning]. Full commit history is available in the - Fixed bug in distributed {class}`scvi.dataloaders.ConcatDataLoader` {pr}`3053`. - Fixed bug when loading Pyro-based models and scArches support for Pyro {pr}`3138` +- Fixed disable vmap in {class}`scvi.external.MRVI` for large sample sizes to avoid + out-of-memory errors. Store distance matrices as numpy array in xarray to reduce + memory usage. #### Changed From 750a4f999161fc3e1566f375497c1d515c3db6ab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Feb 2025 15:43:25 +0000 Subject: [PATCH 13/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8dbfde7f36..4de2b3f885 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,7 @@ to [Semantic Versioning]. Full commit history is available in the - Fixed bug when loading Pyro-based models and scArches support for Pyro {pr}`3138` - Fixed disable vmap in {class}`scvi.external.MRVI` for large sample sizes to avoid out-of-memory errors. Store distance matrices as numpy array in xarray to reduce - memory usage. + memory usage. #### Changed From 437c09da8c2e5e174fbf98b4429c7b36b84b9156 Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Thu, 20 Feb 2025 16:43:57 +0100 Subject: [PATCH 14/14] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4de2b3f885..8ae0421769 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,7 @@ to [Semantic Versioning]. Full commit history is available in the - Fixed bug when loading Pyro-based models and scArches support for Pyro {pr}`3138` - Fixed disable vmap in {class}`scvi.external.MRVI` for large sample sizes to avoid out-of-memory errors. Store distance matrices as numpy array in xarray to reduce - memory usage. + memory usage {pr}`3146`. #### Changed