Summary
discrimination_score (src/cell_eval/metrics/_anndata.py) loops over every
perturbation and calls sklearn.metrics.pairwise_distances once per
perturbation to produce a single row of an n_pert x n_pert distance matrix:
for p_idx, p in enumerate(data.perts):
...
distances = skm.pairwise_distances(
real_effects[:, include_mask],
pred_effects[p_idx, include_mask].reshape(1, -1),
metric=metric,
).flatten()
That is n_pert separate sklearn dispatches that together compute the full
matrix. For l2/cosine the per-row work is a BLAS matrix-vector product that
could be a single matrix-matrix product; the loop forfeits that. With three
registered variants (discrimination_score_l1/l2/cosine) and datasets that
have thousands of perturbations, this is a large part of the anndata-metric
runtime.
Why it isn't a one-line pairwise_distances(real, pred)
When exclude_target_gene=True on expression data (the default), each
perturbation excludes a different feature column (the gene named like the
perturbation), so a single unmasked call does not reproduce the per-row masked
distances. The fix has to compute the full matrix once and then remove each
row's target-gene contribution exactly.
Proposal
Compute the full distance matrix once per metric and derive ranks via
argsort. Apply an exact, vectorized rank-1 column correction for the
target-gene-exclusion path (l1/l2/cosine), falling back to exact per-row masked
distances for other metrics. Output stays numerically identical (same ranks).
Local numbers (Apple M2 Pro, Python 3.12, scikit-learn 1.8): bit-identical
ranks across a large synthetic sweep, with speedups that grow with n_pert —
at n_pert=10000, ~8x (l1), ~51x (l2), ~61x (cosine).
I have the implementation, tests, and a benchmark ready and will open a PR.
Summary
discrimination_score(src/cell_eval/metrics/_anndata.py) loops over everyperturbation and calls
sklearn.metrics.pairwise_distancesonce perperturbation to produce a single row of an
n_pert x n_pertdistance matrix:That is
n_pertseparate sklearn dispatches that together compute the fullmatrix. For l2/cosine the per-row work is a BLAS matrix-vector product that
could be a single matrix-matrix product; the loop forfeits that. With three
registered variants (
discrimination_score_l1/l2/cosine) and datasets thathave thousands of perturbations, this is a large part of the anndata-metric
runtime.
Why it isn't a one-line
pairwise_distances(real, pred)When
exclude_target_gene=Trueon expression data (the default), eachperturbation excludes a different feature column (the gene named like the
perturbation), so a single unmasked call does not reproduce the per-row masked
distances. The fix has to compute the full matrix once and then remove each
row's target-gene contribution exactly.
Proposal
Compute the full distance matrix once per metric and derive ranks via
argsort. Apply an exact, vectorized rank-1 column correction for thetarget-gene-exclusion path (l1/l2/cosine), falling back to exact per-row masked
distances for other metrics. Output stays numerically identical (same ranks).
Local numbers (Apple M2 Pro, Python 3.12, scikit-learn 1.8): bit-identical
ranks across a large synthetic sweep, with speedups that grow with
n_pert—at
n_pert=10000, ~8x (l1), ~51x (l2), ~61x (cosine).I have the implementation, tests, and a benchmark ready and will open a PR.