Skip to content

Commit 254a27e

Browse files
committed
Merge remote-tracking branch 'upstream/main' into enh/estimator-refactor
2 parents 7322e93 + d9b390b commit 254a27e

21 files changed

+175
-98
lines changed

.github/workflows/test.yml

+1-2
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,7 @@ jobs:
9595
continue-on-error: true
9696
strategy:
9797
matrix:
98-
check: ['spellcheck']
99-
98+
check: ['spellcheck', 'typecheck']
10099
steps:
101100
- uses: actions/checkout@v4
102101
- name: Install the latest version of uv

docs/conf.py

+14-15
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
"nipype",
5555
"nitime",
5656
"nitransforms",
57-
"numpy",
5857
"pandas",
5958
"scipy",
6059
"seaborn",
@@ -154,20 +153,20 @@
154153

155154
# -- Options for LaTeX output ------------------------------------------------
156155

157-
latex_elements = {
158-
# The paper size ('letterpaper' or 'a4paper').
159-
#
160-
# 'papersize': 'letterpaper',
161-
# The font size ('10pt', '11pt' or '12pt').
162-
#
163-
# 'pointsize': '10pt',
164-
# Additional stuff for the LaTeX preamble.
165-
#
166-
# 'preamble': '',
167-
# Latex figure (float) alignment
168-
#
169-
# 'figure_align': 'htbp',
170-
}
156+
# latex_elements = {
157+
# # The paper size ('letterpaper' or 'a4paper').
158+
# #
159+
# # 'papersize': 'letterpaper',
160+
# # The font size ('10pt', '11pt' or '12pt').
161+
# #
162+
# # 'pointsize': '10pt',
163+
# # Additional stuff for the LaTeX preamble.
164+
# #
165+
# # 'preamble': '',
166+
# # Latex figure (float) alignment
167+
# #
168+
# # 'figure_align': 'htbp',
169+
# }
171170

172171
# Grouping the document tree into LaTeX files. List of tuples
173172
# (source start file, target name, title,

pyproject.toml

+24
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,15 @@ test = [
7575
"pytest-env",
7676
"pytest-xdist >= 1.28"
7777
]
78+
types = [
79+
"pandas-stubs",
80+
"types-setuptools",
81+
"scipy-stubs",
82+
"types-PyYAML",
83+
"types-tqdm",
84+
"pytest",
85+
"microsoft-python-type-stubs @ git+https://github.com/microsoft/python-type-stubs.git",
86+
]
7887

7988
notebooks = [
8089
"jupyter",
@@ -138,6 +147,21 @@ version-file = "src/nifreeze/_version.py"
138147
# Developer tool configurations
139148
#
140149

150+
[[tool.mypy.overrides]]
151+
module = [
152+
"nipype.*",
153+
"nilearn.*",
154+
"nireports.*",
155+
"nitransforms.*",
156+
"seaborn",
157+
"dipy.*",
158+
"smac.*",
159+
"joblib",
160+
"h5py",
161+
"ConfigSpace",
162+
]
163+
ignore_missing_imports = true
164+
141165
[tool.ruff]
142166
line-length = 99
143167
target-version = "py310"

scripts/dwi_gp_estimation_error_analysis.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def cross_validate(
4949
cv: int,
5050
n_repeats: int,
5151
gpr: DiffusionGPR,
52-
) -> dict[int, list[tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]]:
52+
) -> np.ndarray:
5353
"""
5454
Perform the experiment by estimating the dMRI signal using a Gaussian process model.
5555
@@ -74,7 +74,14 @@ def cross_validate(
7474
"""
7575

7676
rkf = RepeatedKFold(n_splits=cv, n_repeats=n_repeats)
77-
scores = cross_val_score(gpr, X, y, scoring="neg_root_mean_squared_error", cv=rkf)
77+
# scikit-learn stubs do not recognize rkf as a BaseCrossValidator
78+
scores = cross_val_score(
79+
gpr,
80+
X,
81+
y,
82+
scoring="neg_root_mean_squared_error",
83+
cv=rkf, # type: ignore[arg-type]
84+
)
7885
return scores
7986

8087

@@ -204,10 +211,10 @@ def main() -> None:
204211

205212
if args.kfold:
206213
# Use Scikit-learn cross validation
207-
scores = defaultdict(list, {})
214+
scores: dict[str, list] = defaultdict(list, {})
208215
for n in args.kfold:
209216
for i in range(args.repeats):
210-
cv_scores = -1.0 * cross_validate(X, y.T, n, gpr)
217+
cv_scores = -1.0 * cross_validate(X, y.T, n, i, gpr)
211218
scores["rmse"] += cv_scores.tolist()
212219
scores["repeat"] += [i] * len(cv_scores)
213220
scores["n_folds"] += [n] * len(cv_scores)
@@ -217,7 +224,7 @@ def main() -> None:
217224
print(f"Finished {n}-fold cross-validation")
218225

219226
scores_df = pd.DataFrame(scores)
220-
scores_df.to_csv(args.output_scores, sep="\t", index=None, na_rep="n/a")
227+
scores_df.to_csv(args.output_scores, sep="\t", index=False, na_rep="n/a")
221228

222229
grouped = scores_df.groupby(["n_folds"])
223230
print(grouped[["rmse"]].mean())

scripts/dwi_gp_estimation_error_analysis_plot.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,18 @@ def main() -> None:
8989
df = pd.read_csv(args.error_data_fname, sep="\t", keep_default_na=False, na_values="n/a")
9090

9191
# Plot the prediction error
92-
kfolds = sorted(np.unique(df["n_folds"].values))
93-
snr = np.unique(df["snr"].values).item()
94-
bval = np.unique(df["bval"].values).item()
95-
rmse_data = [df.groupby("n_folds").get_group(k)["rmse"].values for k in kfolds]
92+
kfolds = sorted(pd.unique(df["n_folds"]))
93+
snr = pd.unique(df["snr"])
94+
if len(snr) == 1:
95+
snr = snr[0]
96+
else:
97+
raise ValueError(f"More than one unique SNR value: {snr}")
98+
bval = pd.unique(df["bval"])
99+
if len(bval) == 1:
100+
bval = bval[0]
101+
else:
102+
raise ValueError(f"More than one unique bval value: {bval}")
103+
rmse_data = np.asarray([df.groupby("n_folds").get_group(k)["rmse"].values for k in kfolds])
96104
axis = 1
97105
mean = np.mean(rmse_data, axis=axis)
98106
std_dev = np.std(rmse_data, axis=axis)

scripts/dwi_gp_estimation_simulated_signal.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,11 @@ def main() -> None:
132132

133133
# Fit the Gaussian Process regressor and predict on an arbitrary number of
134134
# directions
135-
a = 1.15
136-
lambda_s = 120
135+
beta_a = 1.15
136+
beta_l = 120
137137
alpha = 100
138138
gpr = DiffusionGPR(
139-
kernel=SphericalKriging(a=a, lambda_s=lambda_s),
139+
kernel=SphericalKriging(beta_a=beta_a, beta_l=beta_l),
140140
alpha=alpha,
141141
optimizer=None,
142142
)
@@ -154,6 +154,8 @@ def main() -> None:
154154
X_test = np.vstack([gtab[~gtab.b0s_mask].bvecs, sph.vertices])
155155

156156
predictions = gpr_fit.predict(X_test)
157+
if isinstance(predictions, tuple):
158+
predictions = predictions[0]
157159

158160
# Save the predicted data
159161
testsims.serialize_dwi(predictions.T, args.dwi_pred_data_fname)

scripts/optimize_registration.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,13 @@ async def train_coro(
127127
moving_path = tmp_folder / f"test-{index:04d}.nii.gz"
128128
(~xfm).apply(refnii, reference=refnii).to_filename(moving_path)
129129

130+
_kwargs = {"output_transform_prefix": f"conversion-{index:04d}", **align_kwargs}
131+
130132
cmdline = erants.generate_command(
131133
fixed_path,
132134
moving_path,
133135
fixedmask_path=brainmask_path,
134-
output_transform_prefix=f"conversion-{index:04d}",
135-
**align_kwargs,
136+
**_kwargs,
136137
).cmdline
137138

138139
tasks.append(

src/nifreeze/cli/parser.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@
2929
import yaml
3030

3131

32-
def _parse_yaml_config(file_path: Path) -> dict:
32+
def _parse_yaml_config(file_path: str) -> dict:
3333
"""
3434
Parse YAML configuration file.
3535
3636
Parameters
3737
----------
38-
file_path : Path
38+
file_path : str
3939
Path to the YAML configuration file.
4040
4141
Returns

src/nifreeze/data/base.py

+26-17
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,23 @@
2727
from collections import namedtuple
2828
from pathlib import Path
2929
from tempfile import mkdtemp
30-
from typing import Any
30+
from typing import Any, Generic, TypeVarTuple
3131

3232
import attr
3333
import h5py
3434
import nibabel as nb
3535
import numpy as np
36+
from nibabel.spatialimages import SpatialHeader, SpatialImage
3637
from nitransforms.linear import Affine
3738

39+
from nifreeze.utils.ndimage import load_api
40+
3841
NFDH5_EXT = ".h5"
3942

4043

44+
Ts = TypeVarTuple("Ts")
45+
46+
4147
def _data_repr(value: np.ndarray | None) -> str:
4248
if value is None:
4349
return "None"
@@ -52,7 +58,7 @@ def _cmp(lh: Any, rh: Any) -> bool:
5258

5359

5460
@attr.s(slots=True)
55-
class BaseDataset:
61+
class BaseDataset(Generic[*Ts]):
5662
"""
5763
Base dataset representation structure.
5864
@@ -68,15 +74,15 @@ class BaseDataset:
6874
6975
"""
7076

71-
dataobj = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp))
77+
dataobj: np.ndarray = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp))
7278
"""A :obj:`~numpy.ndarray` object for the data array."""
73-
affine = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp))
79+
affine: np.ndarray = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp))
7480
"""Best affine for RAS-to-voxel conversion of coordinates (NIfTI header)."""
75-
brainmask = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp))
81+
brainmask: np.ndarray = attr.ib(default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp))
7682
"""A boolean ndarray object containing a corresponding brainmask."""
77-
motion_affines = attr.ib(default=None, eq=attr.cmp_using(eq=_cmp))
83+
motion_affines: np.ndarray = attr.ib(default=None, eq=attr.cmp_using(eq=_cmp))
7884
"""List of :obj:`~nitransforms.linear.Affine` realigning the dataset."""
79-
datahdr = attr.ib(default=None)
85+
datahdr: SpatialHeader = attr.ib(default=None)
8086
"""A :obj:`~nibabel.spatialimages.SpatialHeader` header corresponding to the data."""
8187

8288
_filepath = attr.ib(
@@ -93,9 +99,13 @@ def __len__(self) -> int:
9399

94100
return self.dataobj.shape[-1]
95101

102+
def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[*Ts]:
103+
# PY312: Default values for TypeVarTuples are not yet supported
104+
return () # type: ignore[return-value]
105+
96106
def __getitem__(
97107
self, idx: int | slice | tuple | np.ndarray
98-
) -> tuple[np.ndarray, np.ndarray | None]:
108+
) -> tuple[np.ndarray, np.ndarray | None, *Ts]:
99109
"""
100110
Returns volume(s) and corresponding affine(s) through fancy indexing.
101111
@@ -118,7 +128,7 @@ def __getitem__(
118128
raise ValueError("No data available (dataobj is None).")
119129

120130
affine = self.motion_affines[idx] if self.motion_affines is not None else None
121-
return self.dataobj[..., idx], affine
131+
return self.dataobj[..., idx], affine, *self._getextra(idx)
122132

123133
@classmethod
124134
def from_filename(cls, filename: Path | str) -> BaseDataset:
@@ -159,9 +169,8 @@ def set_transform(self, index: int, affine: np.ndarray, order: int = 3) -> None:
159169
The order of the spline interpolation.
160170
161171
"""
162-
reference = namedtuple("ImageGrid", ("shape", "affine"))(
163-
shape=self.dataobj.shape[:3], affine=self.affine
164-
)
172+
ImageGrid = namedtuple("ImageGrid", ("shape", "affine"))
173+
reference = ImageGrid(shape=self.dataobj.shape[:3], affine=self.affine)
165174

166175
xform = Affine(matrix=affine, reference=reference)
167176

@@ -227,7 +236,7 @@ def to_filename(
227236
compression_opts=compression_opts,
228237
)
229238

230-
def to_nifti(self, filename: Path) -> None:
239+
def to_nifti(self, filename: Path | str) -> None:
231240
"""
232241
Write a NIfTI file to disk.
233242
@@ -247,7 +256,7 @@ def load(
247256
filename: Path | str,
248257
brainmask_file: Path | str | None = None,
249258
motion_file: Path | str | None = None,
250-
) -> BaseDataset:
259+
) -> BaseDataset[()]:
251260
"""
252261
Load 4D data from a filename or an HDF5 file.
253262
@@ -279,11 +288,11 @@ def load(
279288
if filename.name.endswith(NFDH5_EXT):
280289
return BaseDataset.from_filename(filename)
281290

282-
img = nb.load(filename)
283-
retval = BaseDataset(dataobj=img.dataobj, affine=img.affine)
291+
img = load_api(filename, SpatialImage)
292+
retval: BaseDataset[()] = BaseDataset(dataobj=np.asanyarray(img.dataobj), affine=img.affine)
284293

285294
if brainmask_file:
286-
mask = nb.load(brainmask_file)
295+
mask = load_api(brainmask_file, SpatialImage)
287296
retval.brainmask = np.asanyarray(mask.dataobj)
288297

289298
return retval

0 commit comments

Comments
 (0)