diff --git a/pyem/utils/stats.py b/pyem/utils/stats.py index 2d38d85..9ef8d37 100644 --- a/pyem/utils/stats.py +++ b/pyem/utils/stats.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd from joblib import Parallel, delayed +from scipy.special import logsumexp from scipy.stats import norm def calc_LME(inv_h: np.ndarray, NPL: np.ndarray) -> tuple[np.ndarray, float, np.ndarray]: @@ -31,17 +32,23 @@ def calc_BICint( all_data, param_names, mu, sigma, fit_func, nsamples: int = 2000, func_output: str = "all", nll_key: str = "nll" ) -> float: npar = len(param_names) - # count trials of the first subject - if isinstance(all_data[0], pd.DataFrame): - total_trials = len(all_data[0]) - else: - first = all_data[0] - if isinstance(first, (list, tuple)) and hasattr(first[0], "size"): - total_trials = int(np.sum([x.size for x in first if hasattr(x, "size")])) - else: - raise ValueError("Unrecognized data structure in all_data") + + def subject_trials(beh) -> int: + if isinstance(beh, pd.DataFrame): + return len(beh) + if isinstance(beh, np.ndarray): + return int(beh.size) + if isinstance(beh, (list, tuple)): + for item in beh: + if hasattr(item, "size"): + return int(item.size) + raise ValueError("Unrecognized data structure in all_data") + + total_trials = int(np.sum([subject_trials(beh) for beh in all_data])) + sigmasqrt = np.sqrt(np.asarray(sigma).reshape(-1)) mu = np.asarray(mu).reshape(-1) + def subj_iLog(beh): G = norm.rvs(loc=mu[:, None], scale=sigmasqrt[:, None], size=(len(mu), nsamples)) subnll = [] @@ -50,8 +57,9 @@ def subj_iLog(beh): # fit_func expected to return dict when output="all" info = fit_func(pars, *beh, output=func_output) subnll.append(info[nll_key]) - iLog = np.log(np.sum(np.exp(-np.asarray(subnll))) / nsamples) + iLog = logsumexp(-np.asarray(subnll)) - np.log(nsamples) return iLog + iLogs = Parallel(n_jobs=-1)(delayed(subj_iLog)(beh) for beh in all_data) iLogs = np.asarray(iLogs) finite = np.isfinite(iLogs) @@ -61,6 +69,9 @@ def subj_iLog(beh): return float(bicint) def pseudo_r2_from_nll(nll: np.ndarray, ntrials_total: int, noptions: int, metric: str = 'median') -> float: + if metric not in {'median', 'mean'}: + raise ValueError("metric must be 'median' or 'mean'") + if metric == 'median': median_nll = float(np.median(nll)) random_baseline = float(np.median(-np.log(1.0 / noptions) * ntrials_total)) @@ -70,6 +81,54 @@ def pseudo_r2_from_nll(nll: np.ndarray, ntrials_total: int, noptions: int, metri random_baseline = float(np.mean(-np.log(1.0 / noptions) * ntrials_total)) return 1.0 - (mean_nll / random_baseline) + +def overall_predictive_probability_from_nll( + nll: np.ndarray, + nchoices_total: int, + *, + return_log: bool = False, +) -> float: + """Compute geometric-mean predictive probability from summed NLL values. + + For per-subject summed negative log-likelihood values ``nll``, the overall + predictive probability over all modeled choices is: + + p = exp(-sum(nll) / nchoices_total) + + This corresponds to the geometric mean of trial-level predictive + probabilities across all subjects and trials. + + Parameters + ---------- + nll : np.ndarray + 1D array containing per-subject summed negative log-likelihood values. + nchoices_total : int + Total number of modeled choices across all subjects. + return_log : bool, default False + If True, return ``log(p)`` instead of ``p``. + + Returns + ------- + float + Geometric mean predictive probability (or its log if ``return_log`` is + True). + """ + nll = np.asarray(nll, dtype=float) + if nll.ndim != 1: + raise ValueError("nll must be a 1D array of shape (nsubjects,)") + if nchoices_total <= 0: + raise ValueError("nchoices_total must be a positive integer") + + finite_nll = nll[np.isfinite(nll)] + if finite_nll.size == 0: + return float("nan") + + log_gmean = -float(np.sum(finite_nll)) / float(nchoices_total) + if return_log: + return log_gmean + return float(np.exp(log_gmean)) + + def likelihood_r2(nll: np.ndarray, metric: str = 'median') -> float: """ R^2-style score from per-subject summed negative log-likelihoods. @@ -103,4 +162,4 @@ def likelihood_r2(nll: np.ndarray, metric: str = 'median') -> float: likelihoods = np.exp(-nll) # per-subject joint likelihood in (0, 1], NaN-preserving agg = np.nanmedian(likelihoods) if metric == 'median' else np.nanmean(likelihoods) - return float(agg ** 2) \ No newline at end of file + return float(agg ** 2) diff --git a/tests/test_stats.py b/tests/test_stats.py new file mode 100644 index 0000000..d4832e3 --- /dev/null +++ b/tests/test_stats.py @@ -0,0 +1,82 @@ +import numpy as np +import pytest + +from pyem.utils.stats import ( + calc_BICint, + overall_predictive_probability_from_nll, + pseudo_r2_from_nll, +) + + +def _dummy_fit_func(pars, choices, rewards, output="all"): + del pars, output + return {"nll": float(choices.size), "rewards_sum": float(np.sum(rewards))} + + +def test_calc_bicint_uses_total_trials_across_subjects(): + all_data = [ + (np.array([0, 1, 0]), np.array([1, 0, 1])), + (np.array([1, 1]), np.array([0, 1])), + ] + mu = np.array([0.1, -0.1]) + sigma = np.zeros(2) + + bicint = calc_BICint( + all_data=all_data, + param_names=["beta", "alpha"], + mu=mu, + sigma=sigma, + fit_func=_dummy_fit_func, + nsamples=5, + func_output="all", + nll_key="nll", + ) + + total_trials = 5 + expected = -2 * (-(3 + 2)) + 2 * np.log(total_trials) + assert np.isclose(bicint, expected) + + +def test_calc_bicint_is_numerically_stable_with_large_nll(): + def huge_nll_fit(pars, choices, rewards, output="all"): + del pars, choices, rewards, output + return {"nll": 1_000.0} + + all_data = [ + (np.array([0, 1, 0]), np.array([1, 0, 1])), + ] + + bicint = calc_BICint( + all_data=all_data, + param_names=["beta"], + mu=np.array([0.0]), + sigma=np.zeros(1), + fit_func=huge_nll_fit, + nsamples=10, + func_output="all", + nll_key="nll", + ) + + assert np.isfinite(bicint) + + +def test_pseudo_r2_rejects_invalid_metric(): + with pytest.raises(ValueError, match="metric must be 'median' or 'mean'"): + pseudo_r2_from_nll(np.array([2.0, 3.0]), ntrials_total=10, noptions=2, metric="mode") + + +def test_overall_predictive_probability_from_nll_matches_formula(): + nll = np.array([2.0, 3.0]) + nchoices_total = 10 + expected = np.exp(-(2.0 + 3.0) / nchoices_total) + got = overall_predictive_probability_from_nll(nll, nchoices_total) + assert np.isclose(got, expected) + + +def test_overall_predictive_probability_from_nll_return_log_and_validation(): + nll = np.array([1.0, 2.0, np.nan]) + log_val = overall_predictive_probability_from_nll(nll, nchoices_total=6, return_log=True) + assert np.isclose(log_val, -0.5) + + with pytest.raises(ValueError, match="nchoices_total must be a positive integer"): + overall_predictive_probability_from_nll(np.array([1.0]), nchoices_total=0)