diff --git a/scripts/dwi_gp_estimation_error_analysis.py b/scripts/dwi_gp_estimation_error_analysis.py index eb393b97..a0616493 100644 --- a/scripts/dwi_gp_estimation_error_analysis.py +++ b/scripts/dwi_gp_estimation_error_analysis.py @@ -31,6 +31,7 @@ import argparse from collections import defaultdict from pathlib import Path +from typing import DefaultDict, List import numpy as np import pandas as pd @@ -49,7 +50,7 @@ def cross_validate( cv: int, n_repeats: int, gpr: DiffusionGPR, -) -> dict[int, list[tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]]: +) -> np.ndarray: """ Perform the experiment by estimating the dMRI signal using a Gaussian process model. @@ -68,7 +69,7 @@ def cross_validate( Returns ------- - :obj:`dict` + :obj:`~numpy.ndarray` Data for the predicted signal and its error. """ @@ -202,12 +203,14 @@ def main() -> None: # max_iter=2e5, ) + n_repeats = 10 + if args.kfold: # Use Scikit-learn cross validation - scores = defaultdict(list, {}) + scores: DefaultDict[str, List[float | str]] = defaultdict(list) for n in args.kfold: for i in range(args.repeats): - cv_scores = -1.0 * cross_validate(X, y.T, n, gpr) + cv_scores = -1.0 * cross_validate(X, y.T, n, n_repeats, gpr) scores["rmse"] += cv_scores.tolist() scores["repeat"] += [i] * len(cv_scores) scores["n_folds"] += [n] * len(cv_scores) @@ -217,7 +220,7 @@ def main() -> None: print(f"Finished {n}-fold cross-validation") scores_df = pd.DataFrame(scores) - scores_df.to_csv(args.output_scores, sep="\t", index=None, na_rep="n/a") + scores_df.to_csv(args.output_scores, sep="\t", index=False, na_rep="n/a") grouped = scores_df.groupby(["n_folds"]) print(grouped[["rmse"]].mean())