-
Notifications
You must be signed in to change notification settings - Fork 30
Add permutation test #726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add permutation test #726
Changes from all commits
f99e5ac
66f228e
74e3ec1
4fdd140
64b7114
af7ca5d
325f1db
25d1689
1b6c65e
3e81976
14736b3
5873b87
676b4f0
442b603
8ae69ce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,14 +2,16 @@ | |
|
||
import warnings | ||
from abc import abstractmethod | ||
from collections.abc import Mapping, Sequence | ||
from collections.abc import Callable, Mapping, Sequence | ||
from inspect import signature | ||
from types import MappingProxyType | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import scipy.stats | ||
import statsmodels | ||
from anndata import AnnData | ||
from lamin_utils import logger | ||
from pandas.core.api import DataFrame as DataFrame | ||
from scipy.sparse import diags, issparse | ||
from tqdm.auto import tqdm | ||
|
@@ -33,7 +35,7 @@ def fdr_correction( | |
class SimpleComparisonBase(MethodBase): | ||
@staticmethod | ||
@abstractmethod | ||
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float: | ||
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> dict[str, float]: | ||
"""Perform a statistical test between values in x0 and x1. | ||
|
||
If `paired` is True, x0 and x1 must be of the same length and ordered such that | ||
|
@@ -71,16 +73,16 @@ def _compare_single_group( | |
x0 = x0.tocsc() | ||
x1 = x1.tocsc() | ||
|
||
res = [] | ||
res: list[dict[str, float | np.ndarray]] = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How could this be a ndarray? Doesn't |
||
for var in tqdm(self.adata.var_names): | ||
tmp_x0 = x0[:, self.adata.var_names == var] | ||
tmp_x0 = np.asarray(tmp_x0.todense()).flatten() if issparse(tmp_x0) else tmp_x0.flatten() | ||
tmp_x1 = x1[:, self.adata.var_names == var] | ||
tmp_x1 = np.asarray(tmp_x1.todense()).flatten() if issparse(tmp_x1) else tmp_x1.flatten() | ||
pval = self._test(tmp_x0, tmp_x1, paired, **kwargs) | ||
test_result = self._test(tmp_x0, tmp_x1, paired, **kwargs) | ||
mean_x0 = np.mean(tmp_x0) | ||
mean_x1 = np.mean(tmp_x1) | ||
res.append({"variable": var, "p_value": pval, "log_fc": np.log2(mean_x1) - np.log2(mean_x0)}) | ||
res.append({"variable": var, "log_fc": np.log2(mean_x1) - np.log2(mean_x0), **test_result}) | ||
return pd.DataFrame(res).sort_values("p_value") | ||
|
||
@classmethod | ||
|
@@ -94,9 +96,28 @@ def compare_groups( | |
paired_by: str | None = None, | ||
mask: str | None = None, | ||
layer: str | None = None, | ||
n_permutations: int = 1000, | ||
permutation_test_statistic: type["SimpleComparisonBase"] | None = None, | ||
Comment on lines
+99
to
+100
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this should be in the function signature of the base class. I remember that we had a discussion about this with respect to the usability, i.e. you didn't want this to be hidden in
|
||
fit_kwargs: Mapping = MappingProxyType({}), | ||
test_kwargs: Mapping = MappingProxyType({}), | ||
) -> DataFrame: | ||
"""Perform a comparison between groups. | ||
|
||
Args: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove all typing from the docstring here. We only type the function signature which is enough for Sphinx. |
||
adata: Data with observations to compare. | ||
column: Column in `adata.obs` that contains the groups to compare. | ||
baseline: Reference group. | ||
groups_to_compare: Groups to compare against the baseline. If None, all other groups | ||
are compared. | ||
paired_by: Column in `adata.obs` to use for pairing. If None, an unpaired test is performed. | ||
mask: Mask to apply to the data. | ||
layer: Layer to use for the comparison. | ||
n_permutations: Number of permutations to perform if a permutation test is used. | ||
permutation_test_statistic: The statistic to use if performing a permutation test. If None, the default | ||
t-statistic from `TTest` is used. | ||
fit_kwargs: Unused argument for compatibility with the `MethodBase` interface, do not specify. | ||
test_kwargs: Additional kwargs passed to the test function. | ||
""" | ||
if len(fit_kwargs): | ||
warnings.warn("fit_kwargs not used for simple tests.", UserWarning, stacklevel=2) | ||
paired = paired_by is not None | ||
|
@@ -125,6 +146,12 @@ def _get_idx(column, value): | |
else: | ||
return np.where(mask)[0] | ||
|
||
if permutation_test_statistic: | ||
test_kwargs = dict(test_kwargs) | ||
test_kwargs.update({"test_statistic": permutation_test_statistic, "n_permutations": n_permutations}) | ||
elif permutation_test_statistic is None and cls.__name__ == "PermutationTest": | ||
logger.warning("No permutation test statistic specified. Using TTest statistic as default.") | ||
|
||
Comment on lines
+149
to
+154
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With either solution mentioned above, this wouldn't be necessary here. |
||
res_dfs = [] | ||
baseline_idx = _get_idx(column, baseline) | ||
for group_to_compare in groups_to_compare: | ||
|
@@ -144,19 +171,118 @@ class WilcoxonTest(SimpleComparisonBase): | |
""" | ||
|
||
@staticmethod | ||
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float: | ||
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> dict[str, float]: | ||
"""Perform an unpaired or paired Wilcoxon/Mann-Whitney-U test.""" | ||
|
||
if paired: | ||
return scipy.stats.wilcoxon(x0, x1, **kwargs).pvalue | ||
test_result = scipy.stats.wilcoxon(x0, x1, **kwargs) | ||
else: | ||
return scipy.stats.mannwhitneyu(x0, x1, **kwargs).pvalue | ||
test_result = scipy.stats.mannwhitneyu(x0, x1, **kwargs) | ||
|
||
return { | ||
"p_value": test_result.pvalue, | ||
"statistic": test_result.statistic, | ||
} | ||
|
||
|
||
class TTest(SimpleComparisonBase): | ||
"""Perform a unpaired or paired T-test""" | ||
"""Perform a unpaired or paired T-test.""" | ||
|
||
@staticmethod | ||
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float: | ||
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> dict[str, float]: | ||
if paired: | ||
return scipy.stats.ttest_rel(x0, x1, **kwargs).pvalue | ||
test_result = scipy.stats.ttest_rel(x0, x1, **kwargs) | ||
else: | ||
return scipy.stats.ttest_ind(x0, x1, **kwargs).pvalue | ||
test_result = scipy.stats.ttest_ind(x0, x1, **kwargs) | ||
|
||
return { | ||
"p_value": test_result.pvalue, | ||
"statistic": test_result.statistic, | ||
} | ||
|
||
|
||
class PermutationTest(SimpleComparisonBase): | ||
"""Perform a permutation test. | ||
|
||
The permutation test relies on another test statistic (e.g. t-statistic or your own) to obtain a p-value through | ||
random permutations of the data and repeated generation of the test statistic. | ||
|
||
For paired tests, each paired observation is permuted together and distributed randomly between the two groups. For | ||
unpaired tests, all observations are permuted independently. | ||
|
||
The null hypothesis for the unpaired test is that all observations come from the same underlying distribution and | ||
have been randomly assigned to one of the samples. | ||
|
||
The null hypothesis for the paired permutation test is that the observations within each pair are drawn from the | ||
same underlying distribution and that their assignment to a sample is random. | ||
""" | ||
|
||
@staticmethod | ||
def _test( | ||
x0: np.ndarray, | ||
Zethson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
x1: np.ndarray, | ||
paired: bool, | ||
Zethson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
test_statistic: type["SimpleComparisonBase"] | Callable = WilcoxonTest, | ||
n_permutations: int = 1000, | ||
**kwargs, | ||
) -> dict[str, float]: | ||
"""Perform a permutation test. | ||
|
||
This function relies on another test (e.g. WilcoxonTest) to generate a test statistic for each permutation. | ||
|
||
Args: | ||
x0: Array with baseline values. | ||
x1: Array with values to compare. | ||
paired: Whether to perform a paired test | ||
test_statistic: The class or function to generate the test statistic from permuted data. If a function is | ||
passed, it must have the signature `test_statistic(x0, x1, paired[, axis], **kwargs)`. If it accepts the | ||
parameter axis, vectorization will be used. | ||
n_permutations: Number of permutations to perform. | ||
**kwargs: kwargs passed to the permutation test function, not the test function after permutation. | ||
|
||
Examples: | ||
You can use the `PermutationTest` class to perform a permutation test with a custom test statistic or an | ||
existing test statistic like `TTest`. The test statistic must be a class that implements the `_test` method | ||
or a function that takes the arguments `x0`, `x1`, `paired` and `**kwargs`. | ||
|
||
>>> from pertpy.tools import PermutationTest, TTest | ||
>>> # Perform a permutation test with a t-statistic | ||
>>> p_value = PermutationTest._test(x0, x1, paired=True, test=TTest, n_permutations=1000, rng=0) | ||
>>> # Perform a permutation test with a custom test statistic | ||
>>> p_value = PermutationTest._test(x0, x1, paired=False, test=your_custom_test_statistic) | ||
""" | ||
if test_statistic is PermutationTest: | ||
raise ValueError( | ||
"The `test_statistic` argument cannot be `PermutationTest`. Use a base test like `TTest` or a custom test." | ||
) | ||
|
||
vectorized = hasattr(test_statistic, "_test") or "axis" in signature(test_statistic).parameters | ||
|
||
def call_test(data_baseline, data_comparison, axis: int | None = None, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you really need another test here? Essentially the statistic we want to test is the fold change. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The advantage of permutation tests is their generalizability without (almost) any assumptions, as long as the test statistic is related to the null hypothesis we want to test. I would thus view the ability to use any statistic, either those that are already implemented in pertpy or any callable that accepts two NDArrays and **kwargs as core functionality. With the latest update, this PR supports any statistic, e.g., it would also be trivial to use a comparison of means, of medians or any other function that you can implement in < 5 lines of numpy code with the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @grst is this resolved? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We changed this to the t-statistic after in person discussion, as the wilcoxon statistic would just be the wilcoxon test and from what I understood from the in person discussion the rationale for keeping this function was agreed to, but perhaps @grst can confirm again. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still don't get why you would want to use a permutation test with a test statistics. If I'm interested in the difference in means, I would use the difference in means as statistics. In my understanding, the whole point of using Wilcoxon or T test is that one can compare against a theoretical distribution, avoiding permutations in the first place. In any case, I would prefer passing a simple lambda over accepting another pertpy test class. This could be a simple test_statistic: Callable[[np.ndarray, np.ndarray], float] = lambda x,y: np.log2(np.mean(x)) - np.log2(np.mean(y)) or if you really want to use a t-statistics test_statistic = lambda x, y: scipy.stats.ttest_ind(x, y).statistic |
||
"""Perform the actual test.""" | ||
if not hasattr(test_statistic, "_test"): | ||
if vectorized: | ||
return test_statistic(data_baseline, data_comparison, paired=paired, axis=axis, **kwargs)[ | ||
"statistic" | ||
] | ||
|
||
return test_statistic(data_baseline, data_comparison, paired=paired, **kwargs)["statistic"] | ||
|
||
if vectorized: | ||
kwargs.update({"axis": axis}) | ||
|
||
return test_statistic._test(data_baseline, data_comparison, paired, **kwargs)["statistic"] | ||
|
||
test_result = scipy.stats.permutation_test( | ||
[x0, x1], | ||
statistic=call_test, | ||
n_resamples=n_permutations, | ||
permutation_type=("samples" if paired else "independent"), | ||
vectorized=vectorized, | ||
**kwargs, | ||
) | ||
|
||
return { | ||
"p_value": test_result.pvalue, | ||
"statistic": test_result.statistic, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please document the return type, e.g.