Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: vmap False by default #3146

Merged
merged 18 commits into from
Feb 20, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
more informative tracebacks + auto vmap
PierreBoyeau committed Jan 14, 2025
commit d685dbb8336f91d9629346c37b9f7acc84f4aef6
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
73 changes: 48 additions & 25 deletions src/scvi/external/mrvi/_model.py
Original file line number Diff line number Diff line change
@@ -309,7 +309,7 @@ def compute_local_statistics(
adata: AnnData | None = None,
indices: npt.ArrayLike | None = None,
batch_size: int | None = None,
use_vmap: bool = False,
use_vmap: Literal["auto", True, False] = "auto",
norm: str = "l2",
mc_samples: int = 10,
) -> xr.Dataset:
@@ -330,7 +330,8 @@ def compute_local_statistics(
batch_size
Batch size to use for computing the local statistics.
use_vmap
Whether to use vmap to compute the local statistics.
Whether to use vmap to compute the local statistics. If "auto", vmap will be used if
the number of samples is less than 500.
norm
Norm to use for computing the distances.
mc_samples
@@ -341,6 +342,8 @@ def compute_local_statistics(

from scvi.external.mrvi._utils import _parse_local_statistics_requirements

use_vmap = use_vmap if use_vmap != "auto" else self.summary_stats.n_sample < 500

if not reductions or len(reductions) == 0:
raise ValueError("At least one reduction must be provided.")

@@ -418,13 +421,22 @@ def per_sample_inference_fn(pair):

# OK to use stacked rngs here since there is no stochasticity for mean rep.
if reqs.needs_mean_representations:
mean_zs_ = mapped_inference_fn(
stacked_rngs=stacked_rngs,
x=jnp.array(inf_inputs["x"]),
sample_index=jnp.array(inf_inputs["sample_index"]),
cf_sample=jnp.array(cf_sample),
use_mean=True,
)
try:
mean_zs_ = mapped_inference_fn(
stacked_rngs=stacked_rngs,
x=jnp.array(inf_inputs["x"]),
sample_index=jnp.array(inf_inputs["sample_index"]),
cf_sample=jnp.array(cf_sample),
use_mean=True,
)
except jax.errors.JaxRuntimeError as e:
if use_vmap:
raise RuntimeError(
"JAX ran out of memory. Try setting use_vmap=False."
) from e
else:
raise e

mean_zs = xr.DataArray(
mean_zs_,
dims=["cell_name", "sample", "latent_dim"],
@@ -619,7 +631,7 @@ def get_local_sample_representation(
indices: npt.ArrayLike | None = None,
batch_size: int = 256,
use_mean: bool = True,
use_vmap: bool = False,
use_vmap: Literal["auto", True, False] = "auto",
) -> xr.DataArray:
"""Compute the local sample representation of the cells in the ``adata`` object.

@@ -660,7 +672,7 @@ def get_local_sample_distances(
batch_size: int = 256,
use_mean: bool = True,
normalize_distances: bool = False,
use_vmap: bool = False,
use_vmap: Literal["auto", True, False] = "auto",
groupby: list[str] | str | None = None,
keep_cell: bool = True,
norm: str = "l2",
@@ -698,6 +710,8 @@ def get_local_sample_distances(
Number of Monte Carlo samples to use for computing the local sample distances. Only
relevant if ``use_mean=False``.
"""
use_vmap = "auto" if use_vmap == "auto" else use_vmap

input = "mean_distances" if use_mean else "sampled_distances"
if normalize_distances:
if use_mean:
@@ -1053,7 +1067,7 @@ def differential_expression(
sample_cov_keys: list[str] | None = None,
sample_subset: list[str] | None = None,
batch_size: int = 128,
use_vmap: bool = False,
use_vmap: Literal["auto", True, False] = "auto",
normalize_design_matrix: bool = True,
add_batch_specific_offsets: bool = False,
mc_samples: int = 100,
@@ -1142,6 +1156,8 @@ def differential_expression(

from scipy.stats import false_discovery_control

use_vmap = use_vmap if use_vmap != "auto" else self.summary_stats.n_sample < 500

if sample_cov_keys is None:
# Hack: kept as kwarg to maintain order of arguments.
raise ValueError("Must assign `sample_cov_keys`")
@@ -1371,19 +1387,26 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps):
Amat = jax.device_put(Amat, self.device)
prefactor = jax.device_put(prefactor, self.device)

res = mapped_inference_fn(
stacked_rngs=stacked_rngs,
x=jnp.array(inf_inputs["x"]),
sample_index=jnp.array(inf_inputs["sample_index"]),
cf_sample=jnp.array(cf_sample),
Amat=Amat,
prefactor=prefactor,
n_samples_per_cell=n_samples_per_cell,
admissible_samples_mat=admissible_samples_mat,
use_mean=False,
rngs_de=rngs_de,
mc_samples=mc_samples,
)
try:
res = mapped_inference_fn(
stacked_rngs=stacked_rngs,
x=jnp.array(inf_inputs["x"]),
sample_index=jnp.array(inf_inputs["sample_index"]),
cf_sample=jnp.array(cf_sample),
Amat=Amat,
prefactor=prefactor,
n_samples_per_cell=n_samples_per_cell,
admissible_samples_mat=admissible_samples_mat,
use_mean=False,
rngs_de=rngs_de,
mc_samples=mc_samples,
)
except jax.errors.JaxRuntimeError as e:
if use_vmap:
raise RuntimeError("JAX ran out of memory. Try setting use_vmap=False.") from e
else:
raise e

beta.append(np.array(res["beta"]))
effect_size.append(np.array(res["effect_size"]))
pvalue.append(np.array(res["pvalue"]))