Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
f99e5ac
Add permutation test
maltekuehl Mar 18, 2025
66f228e
Fix test kwargs update
maltekuehl Mar 18, 2025
74e3ec1
Add n_jobs argument and change test to check for significance agreeme…
maltekuehl Mar 18, 2025
4fdd140
Merge branch 'scverse:main' into main
maltekuehl Mar 18, 2025
64b7114
Simplify and generalize compare_groups by adding most important permu…
maltekuehl Mar 18, 2025
af7ca5d
Merge branch 'main' of https://github.com/complextissue/pertpy
maltekuehl Mar 18, 2025
325f1db
Make permutation_test argument optional but raise warning if not prov…
maltekuehl Mar 18, 2025
25d1689
Merge branch 'main' into main
maltekuehl Mar 18, 2025
1b6c65e
Merge branch 'main' of https://github.com/complextissue/pertpy
maltekuehl Mar 18, 2025
3e81976
Make test case a bit stricter again for significant values, enable re…
maltekuehl Mar 18, 2025
14736b3
Remove unnecessary import
maltekuehl Mar 18, 2025
5873b87
Remove parallelization and return statistic and p-value everywhere
maltekuehl Mar 19, 2025
676b4f0
Remove parallelization and return statistic and p-value everywhere
maltekuehl Mar 19, 2025
442b603
Remove parallelization and return statistic and p-value everywhere
maltekuehl Mar 19, 2025
8ae69ce
Fix docstring and examples of permutation test
maltekuehl Apr 7, 2025
ebc30fb
Merge remote-tracking branch 'origin/remote' into dev
maltekuehl Sep 11, 2025
1336ddb
Simplify permutation test with callable only
maltekuehl Sep 11, 2025
95d7da4
Default on user facing function only
maltekuehl Sep 12, 2025
e2c53fb
Undo last commit
maltekuehl Sep 12, 2025
412ab3b
Merge branch 'main' of https://github.com/complextissue/pertpy
maltekuehl Sep 12, 2025
52d2d58
Actually revert
maltekuehl Sep 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/tools_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Pertpy enables differential gene expression tests through a common interface tha
tools.EdgeR
tools.WilcoxonTest
tools.TTest
tools.PermutationTest
tools.Statsmodels
```

Expand Down
3 changes: 2 additions & 1 deletion pertpy/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __getattr__(name: str):
raise ImportError(
"Extra dependencies required: toytree, ete4. Please install with: pip install toytree ete4"
) from None
elif name in ["EdgeR", "PyDESeq2", "Statsmodels", "TTest", "WilcoxonTest"]:
elif name in ["EdgeR", "PermutationTest", "PyDESeq2", "Statsmodels", "TTest", "WilcoxonTest"]:
module = import_module("pertpy.tools._differential_gene_expression")
return getattr(module, name)
elif name == "Scgen":
Expand Down Expand Up @@ -63,6 +63,7 @@ def __dir__():
"PyDESeq2",
"WilcoxonTest",
"TTest",
"PermutationTest",
"Statsmodels",
"DistanceTest",
"Distance",
Expand Down
4 changes: 2 additions & 2 deletions pertpy/tools/_differential_gene_expression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from importlib.util import find_spec

from ._base import LinearModelBase, MethodBase
from ._dge_comparison import DGEEVAL
from ._edger import EdgeR
from ._simple_tests import SimpleComparisonBase, TTest, WilcoxonTest
from ._simple_tests import PermutationTest, SimpleComparisonBase, TTest, WilcoxonTest


def __getattr__(name: str):
Expand Down Expand Up @@ -57,4 +56,5 @@ def _get_available_methods():
"SimpleComparisonBase",
"WilcoxonTest",
"TTest",
"PermutationTest",
]
2 changes: 1 addition & 1 deletion pertpy/tools/_differential_gene_expression/_pydeseq2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def fit(self, **kwargs) -> pd.DataFrame:
**kwargs: Keyword arguments specific to DeseqDataSet(), except for `n_cpus` which will use all available CPUs minus one if the argument is not passed.
"""
try:
usable_cpus = len(os.sched_getaffinity(0))
usable_cpus = len(os.sched_getaffinity(0)) # type: ignore # os.sched_getaffinity is not available on Windows and macOS
except AttributeError:
usable_cpus = os.cpu_count()

Expand Down
179 changes: 164 additions & 15 deletions pertpy/tools/_differential_gene_expression/_simple_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
from abc import abstractmethod
from collections.abc import Mapping, Sequence
from collections.abc import Callable, Mapping, Sequence
from types import MappingProxyType

import numpy as np
Expand Down Expand Up @@ -33,7 +33,7 @@ def fdr_correction(
class SimpleComparisonBase(MethodBase):
@staticmethod
@abstractmethod
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float:
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> dict[str, float]:
"""Perform a statistical test between values in x0 and x1.

If `paired` is True, x0 and x1 must be of the same length and ordered such that
Expand All @@ -44,6 +44,10 @@ def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float:
x1: Array with values to compare.
paired: Indicates whether to perform a paired test
**kwargs: kwargs passed to the test function

Returns:
A dictionary metric -> value.
This allows to return values for different metrics (e.g. p-value + test statistic).
"""
...

Expand Down Expand Up @@ -71,16 +75,16 @@ def _compare_single_group(
x0 = x0.tocsc()
x1 = x1.tocsc()

res = []
res: list[dict[str, float]] = []
for var in tqdm(self.adata.var_names):
tmp_x0 = x0[:, self.adata.var_names == var]
tmp_x0 = np.asarray(tmp_x0.todense()).flatten() if issparse(tmp_x0) else tmp_x0.flatten()
tmp_x1 = x1[:, self.adata.var_names == var]
tmp_x1 = np.asarray(tmp_x1.todense()).flatten() if issparse(tmp_x1) else tmp_x1.flatten()
pval = self._test(tmp_x0, tmp_x1, paired, **kwargs)
test_result = self._test(tmp_x0, tmp_x1, paired, **kwargs)
mean_x0 = np.mean(tmp_x0)
mean_x1 = np.mean(tmp_x1)
res.append({"variable": var, "p_value": pval, "log_fc": np.log2(mean_x1) - np.log2(mean_x0)})
res.append({"variable": var, "log_fc": np.log2(mean_x1) - np.log2(mean_x0), **test_result})
return pd.DataFrame(res).sort_values("p_value")

@classmethod
Expand All @@ -97,6 +101,20 @@ def compare_groups(
fit_kwargs: Mapping = MappingProxyType({}),
test_kwargs: Mapping = MappingProxyType({}),
) -> DataFrame:
"""Perform a comparison between groups.

Args:
adata: Data with observations to compare.
column: Column in `adata.obs` that contains the groups to compare.
baseline: Reference group.
groups_to_compare: Groups to compare against the baseline. If None, all other groups
are compared.
paired_by: Column in `adata.obs` to use for pairing. If None, an unpaired test is performed.
mask: Mask to apply to the data.
layer: Layer to use for the comparison.
fit_kwargs: Unused argument for compatibility with the `MethodBase` interface, do not specify.
test_kwargs: Additional kwargs passed to the test function.
"""
if len(fit_kwargs):
warnings.warn("fit_kwargs not used for simple tests.", UserWarning, stacklevel=2)
paired = paired_by is not None
Expand Down Expand Up @@ -144,19 +162,150 @@ class WilcoxonTest(SimpleComparisonBase):
"""

@staticmethod
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float:
if paired:
return scipy.stats.wilcoxon(x0, x1, **kwargs).pvalue
else:
return scipy.stats.mannwhitneyu(x0, x1, **kwargs).pvalue
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> dict[str, float]:
"""Perform an unpaired or paired Wilcoxon/Mann-Whitney-U test."""
test_result = scipy.stats.wilcoxon(x0, x1, **kwargs) if paired else scipy.stats.mannwhitneyu(x0, x1, **kwargs)

return {
"p_value": test_result.pvalue,
"statistic": test_result.statistic,
}


class TTest(SimpleComparisonBase):
"""Perform a unpaired or paired T-test."""

@staticmethod
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float:
if paired:
return scipy.stats.ttest_rel(x0, x1, **kwargs).pvalue
else:
return scipy.stats.ttest_ind(x0, x1, **kwargs).pvalue
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> dict[str, float]:
test_result = scipy.stats.ttest_rel(x0, x1, **kwargs) if paired else scipy.stats.ttest_ind(x0, x1, **kwargs)

return {
"p_value": test_result.pvalue,
"statistic": test_result.statistic,
}


class PermutationTest(SimpleComparisonBase):
"""Perform a permutation test.

The permutation test relies on another test statistic (e.g. t-statistic or your own) to obtain a p-value through
random permutations of the data and repeated generation of the test statistic.

For paired tests, each paired observation is permuted together and distributed randomly between the two groups. For
unpaired tests, all observations are permuted independently.

The null hypothesis for the unpaired test is that all observations come from the same underlying distribution and
have been randomly assigned to one of the samples.

The null hypothesis for the paired permutation test is that the observations within each pair are drawn from the
same underlying distribution and that their assignment to a sample is random.
"""

@classmethod
def compare_groups(
cls,
adata: AnnData,
column: str,
baseline: str,
groups_to_compare: str | Sequence[str],
*,
paired_by: str | None = None,
mask: str | None = None,
layer: str | None = None,
n_permutations: int = 1000,
test_statistic: Callable[[np.ndarray, np.ndarray], float] = lambda x, y: np.log2(np.mean(y) + 1e-8)
- np.log2(np.mean(x) + 1e-8),
fit_kwargs: Mapping = MappingProxyType({}),
test_kwargs: Mapping = MappingProxyType({}),
) -> DataFrame:
"""Perform a permutation test comparison between groups.

Args:
adata: Data with observations to compare.
column: Column in `adata.obs` that contains the groups to compare.
baseline: Reference group.
groups_to_compare: Groups to compare against the baseline. If None, all other groups
are compared.
paired_by: Column in `adata.obs` to use for pairing. If None, an unpaired test is performed.
mask: Mask to apply to the data.
layer: Layer to use for the comparison.
n_permutations: Number of permutations to perform.
test_statistic: A callable that takes two arrays (x0, x1) and returns a float statistic.
Defaults to log2 fold change with pseudocount: log2(mean(x1) + 1e-8) - log2(mean(x0) + 1e-8).
The callable should have signature: test_statistic(x0, x1) -> float.
fit_kwargs: Unused argument for compatibility with the `MethodBase` interface, do not specify.
test_kwargs: Additional kwargs passed to the permutation test function (not the test statistic). The
permutation test function is `scipy.stats.permutation_test`, so please refer to its documentation for
available options. Note that `test_statistic` and `n_permutations` are set by this function and should
not be provided here.

Examples:
>>> # Difference in means (log fold change)
>>> PermutationTest.compare_groups(
... adata,
... column="condition",
... baseline="A",
... groups_to_compare="B",
... test_statistic=lambda x, y: np.log2(np.mean(y)) - np.log2(np.mean(x)),
... n_permutations=1000,
... test_kwargs={"rng": 0},
... )
"""
enhanced_test_kwargs = dict(test_kwargs)
enhanced_test_kwargs.update({"test_statistic": test_statistic, "n_permutations": n_permutations})

return super().compare_groups(
adata=adata,
column=column,
baseline=baseline,
groups_to_compare=groups_to_compare,
paired_by=paired_by,
mask=mask,
layer=layer,
fit_kwargs=fit_kwargs,
test_kwargs=enhanced_test_kwargs,
)

@staticmethod
def _test(
x0: np.ndarray,
x1: np.ndarray,
paired: bool,
test_statistic: Callable[[np.ndarray, np.ndarray], float] = lambda x, y: np.log2(np.mean(y) + 1e-8)
- np.log2(np.mean(x) + 1e-8),
n_permutations: int = 1000,
**kwargs,
) -> dict[str, float]:
"""Perform a permutation test.

This function uses a simple test statistic function to compute p-values through permutations.

Args:
x0: Array with baseline values.
x1: Array with values to compare.
paired: Whether to perform a paired test.
test_statistic: A callable that takes two arrays (x0, x1) and returns a float statistic. Please refer to
the examples below for usage. The callable should have signature: test_statistic(x0, x1) -> float.
n_permutations: Number of permutations to perform.
**kwargs: Additional kwargs passed to scipy.stats.permutation_test.

Examples:
>>> # Difference in means (log fold change)
>>> PermutationTest._test(x0, x1, paired=False)
>>>
>>> # Difference in medians
>>> median_diff = lambda x, y: np.median(y) - np.median(x)
>>> PermutationTest._test(x0, x1, paired=False, test_statistic=median_diff)
"""
test_result = scipy.stats.permutation_test(
[x0, x1],
statistic=lambda x0_perm, x1_perm: test_statistic(x0_perm, x1_perm),
n_resamples=n_permutations,
permutation_type=("samples" if paired else "independent"),
**kwargs,
)

return {
"p_value": test_result.pvalue,
"statistic": test_result.statistic,
}
48 changes: 47 additions & 1 deletion tests/tools/_differential_gene_expression/test_simple_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
if find_spec("formulaic_contrasts") is None or find_spec("formulaic") is None:
pytestmark = pytest.mark.skip(reason="formulaic_contrasts and formulaic not available")

from pertpy.tools._differential_gene_expression import SimpleComparisonBase, TTest, WilcoxonTest
from pertpy.tools._differential_gene_expression import PermutationTest, SimpleComparisonBase, TTest, WilcoxonTest


@pytest.mark.parametrize(
Expand Down Expand Up @@ -67,6 +67,52 @@ def test_t(test_adata_minimal, paired_by, expected):
assert actual[gene] == pytest.approx(expected[gene], abs=0.02)


@pytest.mark.parametrize(
"paired_by,expected",
[
pytest.param(
None,
{"gene1": {"p_value": 2.13e-26, "log_fc": -5.14}, "gene2": {"p_value": 0.96, "log_fc": -0.016}},
id="unpaired",
),
pytest.param(
"pairing",
{"gene1": {"p_value": 1.63e-26, "log_fc": -5.14}, "gene2": {"p_value": 0.85, "log_fc": -0.016}},
id="paired",
),
],
)
def test_permutation(test_adata_minimal, paired_by, expected):
"""Test that permutation test gives the correct values.

Reference values have been computed in R using wilcox.test
"""
# Test with different simple test statistics
test_statistics = [
lambda x, y: np.log2(np.mean(y)) - np.log2(np.mean(x)), # log fold change between means (default)
lambda x, y: np.mean(y) - np.mean(x), # mean difference
lambda x, y: np.max(y) - np.max(x), # max difference
]

for test_stat in test_statistics:
res_df = PermutationTest.compare_groups(
adata=test_adata_minimal,
column="condition",
baseline="A",
groups_to_compare="B",
test_statistic=test_stat,
paired_by=paired_by,
n_permutations=1000,
test_kwargs={"rng": 0},
)
assert isinstance(res_df, DataFrame), "PermutationTest.compare_groups should return a DataFrame"
actual = res_df.loc[:, ["variable", "p_value", "log_fc"]].set_index("variable").to_dict(orient="index")
for gene in expected:
assert (actual[gene]["p_value"] < 0.05) == (expected[gene]["p_value"] < 0.05)
if actual[gene]["p_value"] < 0.05:
assert actual[gene] == pytest.approx(expected[gene], abs=0.02)


@pytest.mark.parametrize("seed", range(10))
def test_simple_comparison_pairing(test_adata_minimal, seed):
"""Test that paired samples are properly matched in a paired test"""
Expand Down
Loading