Skip to content

Add permutation test #726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions docs/usage/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ Pertpy provides utilities to conduct differential gene expression tests through
tools.EdgeR
tools.WilcoxonTest
tools.TTest
tools.PermutationTest
tools.Statsmodels
```

Expand Down Expand Up @@ -563,33 +564,33 @@ including cell line annotation, bulk RNA and protein expression data.

Available databases for cell line metadata:

- [The Cancer Dependency Map Project at Broad](https://depmap.org/portal/)
- [The Cancer Dependency Map Project at Sanger](https://depmap.sanger.ac.uk/)
- [Genomics of Drug Sensitivity in Cancer (GDSC)](https://www.cancerrxgene.org/)
- [The Cancer Dependency Map Project at Broad](https://depmap.org/portal/)
- [The Cancer Dependency Map Project at Sanger](https://depmap.sanger.ac.uk/)
- [Genomics of Drug Sensitivity in Cancer (GDSC)](https://www.cancerrxgene.org/)

### Compound

The Compound module enables the retrieval of various types of information related to compounds of interest, including the most common synonym, pubchemID and canonical SMILES.

Available databases for compound metadata:

- [PubChem](https://pubchem.ncbi.nlm.nih.gov/)
- [PubChem](https://pubchem.ncbi.nlm.nih.gov/)

### Mechanism of Action

This module aims to retrieve metadata of mechanism of action studies related to perturbagens of interest, depending on the molecular targets.

Available databases for mechanism of action metadata:

- [CLUE](https://clue.io/)
- [CLUE](https://clue.io/)

### Drug

This module allows for the retrieval of Drug target information.

Available databases for drug metadata:

- [chembl](https://www.ebi.ac.uk/chembl/)
- [chembl](https://www.ebi.ac.uk/chembl/)

```{eval-rst}
.. autosummary::
Expand Down
2 changes: 2 additions & 0 deletions pertpy/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self, *args, **kwargs):

DE_EXTRAS = ["formulaic", "pydeseq2"]
EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS) # edgeR will be imported via rpy2
PermutationTest = lazy_import("pertpy.tools._differential_gene_expression", "PermutationTest", DE_EXTRAS)
PyDESeq2 = lazy_import("pertpy.tools._differential_gene_expression", "PyDESeq2", DE_EXTRAS)
Statsmodels = lazy_import("pertpy.tools._differential_gene_expression", "Statsmodels", DE_EXTRAS + ["statsmodels"])
TTest = lazy_import("pertpy.tools._differential_gene_expression", "TTest", DE_EXTRAS)
Expand All @@ -62,6 +63,7 @@ def __init__(self, *args, **kwargs):
"PyDESeq2",
"WilcoxonTest",
"TTest",
"PermutationTest",
"Statsmodels",
"DistanceTest",
"Distance",
Expand Down
5 changes: 3 additions & 2 deletions pertpy/tools/_differential_gene_expression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ._dge_comparison import DGEEVAL
from ._edger import EdgeR
from ._pydeseq2 import PyDESeq2
from ._simple_tests import SimpleComparisonBase, TTest, WilcoxonTest
from ._simple_tests import PermutationTest, SimpleComparisonBase, TTest, WilcoxonTest
from ._statsmodels import Statsmodels

__all__ = [
Expand All @@ -14,6 +14,7 @@
"SimpleComparisonBase",
"WilcoxonTest",
"TTest",
"PermutationTest",
]

AVAILABLE_METHODS = [Statsmodels, EdgeR, PyDESeq2, WilcoxonTest, TTest]
AVAILABLE_METHODS = [Statsmodels, EdgeR, PyDESeq2, WilcoxonTest, TTest, PermutationTest]
2 changes: 1 addition & 1 deletion pertpy/tools/_differential_gene_expression/_pydeseq2.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def fit(self, **kwargs) -> pd.DataFrame:
**kwargs: Keyword arguments specific to DeseqDataSet(), except for `n_cpus` which will use all available CPUs minus one if the argument is not passed.
"""
try:
usable_cpus = len(os.sched_getaffinity(0))
usable_cpus = len(os.sched_getaffinity(0)) # type: ignore # os.sched_getaffinity is not available on Windows and macOS
except AttributeError:
usable_cpus = os.cpu_count()

Expand Down
150 changes: 138 additions & 12 deletions pertpy/tools/_differential_gene_expression/_simple_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

import warnings
from abc import abstractmethod
from collections.abc import Mapping, Sequence
from collections.abc import Callable, Mapping, Sequence
from inspect import signature
from types import MappingProxyType

import numpy as np
import pandas as pd
import scipy.stats
import statsmodels
from anndata import AnnData
from lamin_utils import logger
from pandas.core.api import DataFrame as DataFrame
from scipy.sparse import diags, issparse
from tqdm.auto import tqdm
Expand All @@ -33,7 +35,7 @@ def fdr_correction(
class SimpleComparisonBase(MethodBase):
@staticmethod
@abstractmethod
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float:
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> dict[str, float]:
"""Perform a statistical test between values in x0 and x1.

If `paired` is True, x0 and x1 must be of the same length and ordered such that
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please document the return type, e.g.

Returns:
            A dictionary metric -> value.
            This allows to return values for different metrics (e.g. p-value + test statistic).

Expand Down Expand Up @@ -71,16 +73,16 @@ def _compare_single_group(
x0 = x0.tocsc()
x1 = x1.tocsc()

res = []
res: list[dict[str, float | np.ndarray]] = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How could this be a ndarray? Doesn't _test always return a float?

for var in tqdm(self.adata.var_names):
tmp_x0 = x0[:, self.adata.var_names == var]
tmp_x0 = np.asarray(tmp_x0.todense()).flatten() if issparse(tmp_x0) else tmp_x0.flatten()
tmp_x1 = x1[:, self.adata.var_names == var]
tmp_x1 = np.asarray(tmp_x1.todense()).flatten() if issparse(tmp_x1) else tmp_x1.flatten()
pval = self._test(tmp_x0, tmp_x1, paired, **kwargs)
test_result = self._test(tmp_x0, tmp_x1, paired, **kwargs)
mean_x0 = np.mean(tmp_x0)
mean_x1 = np.mean(tmp_x1)
res.append({"variable": var, "p_value": pval, "log_fc": np.log2(mean_x1) - np.log2(mean_x0)})
res.append({"variable": var, "log_fc": np.log2(mean_x1) - np.log2(mean_x0), **test_result})
return pd.DataFrame(res).sort_values("p_value")

@classmethod
Expand All @@ -94,9 +96,28 @@ def compare_groups(
paired_by: str | None = None,
mask: str | None = None,
layer: str | None = None,
n_permutations: int = 1000,
permutation_test_statistic: type["SimpleComparisonBase"] | None = None,
Comment on lines +99 to +100
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be in the function signature of the base class.

I remember that we had a discussion about this with respect to the usability, i.e. you didn't want this to be hidden in test_kwargs. This is fair, but here are two alternative ways this could be achieved:

  • instead of test_kwargs: Mapping, pass a **kwargs to compare_groups which will be passed on to the test function
  • Instead of using classmethods/staticmethods, go with an OOP approach. Test-specific parameters to into the constructor. The usage would become
    -res = TTest.compare_groups(adata, "A", "B")
    +res = TTest().compare_groups(adata, "A", "B")
    
    -res = PermuationTest.compare_groups(adata, "A", "B", n_permutations=1000)
    +res = PermuationTest(n_permutations=1000).compare_groups(adata, "A", "B")

fit_kwargs: Mapping = MappingProxyType({}),
test_kwargs: Mapping = MappingProxyType({}),
) -> DataFrame:
"""Perform a comparison between groups.

Args:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove all typing from the docstring here. We only type the function signature which is enough for Sphinx.

adata: Data with observations to compare.
column: Column in `adata.obs` that contains the groups to compare.
baseline: Reference group.
groups_to_compare: Groups to compare against the baseline. If None, all other groups
are compared.
paired_by: Column in `adata.obs` to use for pairing. If None, an unpaired test is performed.
mask: Mask to apply to the data.
layer: Layer to use for the comparison.
n_permutations: Number of permutations to perform if a permutation test is used.
permutation_test_statistic: The statistic to use if performing a permutation test. If None, the default
t-statistic from `TTest` is used.
fit_kwargs: Unused argument for compatibility with the `MethodBase` interface, do not specify.
test_kwargs: Additional kwargs passed to the test function.
"""
if len(fit_kwargs):
warnings.warn("fit_kwargs not used for simple tests.", UserWarning, stacklevel=2)
paired = paired_by is not None
Expand Down Expand Up @@ -125,6 +146,12 @@ def _get_idx(column, value):
else:
return np.where(mask)[0]

if permutation_test_statistic:
test_kwargs = dict(test_kwargs)
test_kwargs.update({"test_statistic": permutation_test_statistic, "n_permutations": n_permutations})
elif permutation_test_statistic is None and cls.__name__ == "PermutationTest":
logger.warning("No permutation test statistic specified. Using TTest statistic as default.")

Comment on lines +149 to +154
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With either solution mentioned above, this wouldn't be necessary here.

res_dfs = []
baseline_idx = _get_idx(column, baseline)
for group_to_compare in groups_to_compare:
Expand All @@ -144,19 +171,118 @@ class WilcoxonTest(SimpleComparisonBase):
"""

@staticmethod
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float:
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> dict[str, float]:
"""Perform an unpaired or paired Wilcoxon/Mann-Whitney-U test."""

if paired:
return scipy.stats.wilcoxon(x0, x1, **kwargs).pvalue
test_result = scipy.stats.wilcoxon(x0, x1, **kwargs)
else:
return scipy.stats.mannwhitneyu(x0, x1, **kwargs).pvalue
test_result = scipy.stats.mannwhitneyu(x0, x1, **kwargs)

return {
"p_value": test_result.pvalue,
"statistic": test_result.statistic,
}


class TTest(SimpleComparisonBase):
"""Perform a unpaired or paired T-test"""
"""Perform a unpaired or paired T-test."""

@staticmethod
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float:
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> dict[str, float]:
if paired:
return scipy.stats.ttest_rel(x0, x1, **kwargs).pvalue
test_result = scipy.stats.ttest_rel(x0, x1, **kwargs)
else:
return scipy.stats.ttest_ind(x0, x1, **kwargs).pvalue
test_result = scipy.stats.ttest_ind(x0, x1, **kwargs)

return {
"p_value": test_result.pvalue,
"statistic": test_result.statistic,
}


class PermutationTest(SimpleComparisonBase):
"""Perform a permutation test.

The permutation test relies on another test statistic (e.g. t-statistic or your own) to obtain a p-value through
random permutations of the data and repeated generation of the test statistic.

For paired tests, each paired observation is permuted together and distributed randomly between the two groups. For
unpaired tests, all observations are permuted independently.

The null hypothesis for the unpaired test is that all observations come from the same underlying distribution and
have been randomly assigned to one of the samples.

The null hypothesis for the paired permutation test is that the observations within each pair are drawn from the
same underlying distribution and that their assignment to a sample is random.
"""

@staticmethod
def _test(
x0: np.ndarray,
x1: np.ndarray,
paired: bool,
test_statistic: type["SimpleComparisonBase"] | Callable = WilcoxonTest,
n_permutations: int = 1000,
**kwargs,
) -> dict[str, float]:
"""Perform a permutation test.

This function relies on another test (e.g. WilcoxonTest) to generate a test statistic for each permutation.

Args:
x0: Array with baseline values.
x1: Array with values to compare.
paired: Whether to perform a paired test
test_statistic: The class or function to generate the test statistic from permuted data. If a function is
passed, it must have the signature `test_statistic(x0, x1, paired[, axis], **kwargs)`. If it accepts the
parameter axis, vectorization will be used.
n_permutations: Number of permutations to perform.
**kwargs: kwargs passed to the permutation test function, not the test function after permutation.

Examples:
You can use the `PermutationTest` class to perform a permutation test with a custom test statistic or an
existing test statistic like `TTest`. The test statistic must be a class that implements the `_test` method
or a function that takes the arguments `x0`, `x1`, `paired` and `**kwargs`.

>>> from pertpy.tools import PermutationTest, TTest
>>> # Perform a permutation test with a t-statistic
>>> p_value = PermutationTest._test(x0, x1, paired=True, test=TTest, n_permutations=1000, rng=0)
>>> # Perform a permutation test with a custom test statistic
>>> p_value = PermutationTest._test(x0, x1, paired=False, test=your_custom_test_statistic)
"""
if test_statistic is PermutationTest:
raise ValueError(
"The `test_statistic` argument cannot be `PermutationTest`. Use a base test like `TTest` or a custom test."
)

vectorized = hasattr(test_statistic, "_test") or "axis" in signature(test_statistic).parameters

def call_test(data_baseline, data_comparison, axis: int | None = None, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you really need another test here? Essentially the statistic we want to test is the fold change.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The advantage of permutation tests is their generalizability without (almost) any assumptions, as long as the test statistic is related to the null hypothesis we want to test. I would thus view the ability to use any statistic, either those that are already implemented in pertpy or any callable that accepts two NDArrays and **kwargs as core functionality. With the latest update, this PR supports any statistic, e.g., it would also be trivial to use a comparison of means, of medians or any other function that you can implement in < 5 lines of numpy code with the PermutationTest. I opted for Wilcoxon statistic as a default because the ranksum is fairly general and it's something that's already implemented in pertpy. Of course, we could also add an explicit collection of other statistics but it could never cover all use cases and defining this statistic should be part of the thought process when a user uses the permutation test, so I'm not convinced of the value and necessity of covering this as part of the library itself.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@grst is this resolved?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We changed this to the t-statistic after in person discussion, as the wilcoxon statistic would just be the wilcoxon test and from what I understood from the in person discussion the rationale for keeping this function was agreed to, but perhaps @grst can confirm again.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't get why you would want to use a permutation test with a test statistics. If I'm interested in the difference in means, I would use the difference in means as statistics.

In my understanding, the whole point of using Wilcoxon or T test is that one can compare against a theoretical distribution, avoiding permutations in the first place.

In any case, I would prefer passing a simple lambda over accepting another pertpy test class. This could be a simple

test_statistic: Callable[[np.ndarray, np.ndarray], float] = lambda x,y: np.log2(np.mean(x)) - np.log2(np.mean(y))

or if you really want to use a t-statistics

test_statistic = lambda x, y: scipy.stats.ttest_ind(x, y).statistic

"""Perform the actual test."""
if not hasattr(test_statistic, "_test"):
if vectorized:
return test_statistic(data_baseline, data_comparison, paired=paired, axis=axis, **kwargs)[
"statistic"
]

return test_statistic(data_baseline, data_comparison, paired=paired, **kwargs)["statistic"]

if vectorized:
kwargs.update({"axis": axis})

return test_statistic._test(data_baseline, data_comparison, paired, **kwargs)["statistic"]

test_result = scipy.stats.permutation_test(
[x0, x1],
statistic=call_test,
n_resamples=n_permutations,
permutation_type=("samples" if paired else "independent"),
vectorized=vectorized,
**kwargs,
)

return {
"p_value": test_result.pvalue,
"statistic": test_result.statistic,
}
41 changes: 40 additions & 1 deletion tests/tools/_differential_gene_expression/test_simple_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pandas as pd
import pytest
from pandas.core.api import DataFrame as DataFrame
from pertpy.tools._differential_gene_expression import SimpleComparisonBase, TTest, WilcoxonTest
from pertpy.tools._differential_gene_expression import PermutationTest, SimpleComparisonBase, TTest, WilcoxonTest


@pytest.mark.parametrize(
Expand Down Expand Up @@ -61,6 +61,45 @@ def test_t(test_adata_minimal, paired_by, expected):
assert actual[gene] == pytest.approx(expected[gene], abs=0.02)


@pytest.mark.parametrize(
"paired_by,expected",
[
pytest.param(
None,
{"gene1": {"p_value": 2.13e-26, "log_fc": -5.14}, "gene2": {"p_value": 0.96, "log_fc": -0.016}},
id="unpaired",
),
pytest.param(
"pairing",
{"gene1": {"p_value": 1.63e-26, "log_fc": -5.14}, "gene2": {"p_value": 0.85, "log_fc": -0.016}},
id="paired",
),
],
)
def test_permutation(test_adata_minimal, paired_by, expected):
"""Test that permutation test gives the correct values.

Reference values have been computed in R using wilcox.test
"""
for statistic in [TTest, WilcoxonTest]:
res_df = PermutationTest.compare_groups(
adata=test_adata_minimal,
column="condition",
baseline="A",
groups_to_compare="B",
paired_by=paired_by,
n_permutations=1000,
permutation_test_statistic=statistic,
test_kwargs={"rng": 0},
)
assert isinstance(res_df, DataFrame), "PermutationTest.compare_groups should return a DataFrame"
actual = res_df.loc[:, ["variable", "p_value", "log_fc"]].set_index("variable").to_dict(orient="index")
for gene in expected:
assert (actual[gene]["p_value"] < 0.05) == (expected[gene]["p_value"] < 0.05)
if actual[gene]["p_value"] < 0.05:
assert actual[gene] == pytest.approx(expected[gene], abs=0.02)


@pytest.mark.parametrize("seed", range(10))
def test_simple_comparison_pairing(test_adata_minimal, seed):
"""Test that paired samples are properly matched in a paired test"""
Expand Down
Loading