27
27
from collections import namedtuple
28
28
from pathlib import Path
29
29
from tempfile import mkdtemp
30
- from typing import Any
30
+ from typing import Any , Generic , TypeVarTuple
31
31
32
32
import attr
33
33
import h5py
34
34
import nibabel as nb
35
35
import numpy as np
36
+ from nibabel .spatialimages import SpatialHeader , SpatialImage
36
37
from nitransforms .linear import Affine
37
38
39
+ from nifreeze .utils .ndimage import load_api
40
+
38
41
NFDH5_EXT = ".h5"
39
42
40
43
44
+ Ts = TypeVarTuple ("Ts" )
45
+
46
+
41
47
def _data_repr (value : np .ndarray | None ) -> str :
42
48
if value is None :
43
49
return "None"
@@ -52,7 +58,7 @@ def _cmp(lh: Any, rh: Any) -> bool:
52
58
53
59
54
60
@attr .s (slots = True )
55
- class BaseDataset :
61
+ class BaseDataset ( Generic [ * Ts ]) :
56
62
"""
57
63
Base dataset representation structure.
58
64
@@ -68,15 +74,15 @@ class BaseDataset:
68
74
69
75
"""
70
76
71
- dataobj = attr .ib (default = None , repr = _data_repr , eq = attr .cmp_using (eq = _cmp ))
77
+ dataobj : np . ndarray = attr .ib (default = None , repr = _data_repr , eq = attr .cmp_using (eq = _cmp ))
72
78
"""A :obj:`~numpy.ndarray` object for the data array."""
73
- affine = attr .ib (default = None , repr = _data_repr , eq = attr .cmp_using (eq = _cmp ))
79
+ affine : np . ndarray = attr .ib (default = None , repr = _data_repr , eq = attr .cmp_using (eq = _cmp ))
74
80
"""Best affine for RAS-to-voxel conversion of coordinates (NIfTI header)."""
75
- brainmask = attr .ib (default = None , repr = _data_repr , eq = attr .cmp_using (eq = _cmp ))
81
+ brainmask : np . ndarray = attr .ib (default = None , repr = _data_repr , eq = attr .cmp_using (eq = _cmp ))
76
82
"""A boolean ndarray object containing a corresponding brainmask."""
77
- motion_affines = attr .ib (default = None , eq = attr .cmp_using (eq = _cmp ))
83
+ motion_affines : np . ndarray = attr .ib (default = None , eq = attr .cmp_using (eq = _cmp ))
78
84
"""List of :obj:`~nitransforms.linear.Affine` realigning the dataset."""
79
- datahdr = attr .ib (default = None )
85
+ datahdr : SpatialHeader = attr .ib (default = None )
80
86
"""A :obj:`~nibabel.spatialimages.SpatialHeader` header corresponding to the data."""
81
87
82
88
_filepath = attr .ib (
@@ -93,9 +99,13 @@ def __len__(self) -> int:
93
99
94
100
return self .dataobj .shape [- 1 ]
95
101
102
+ def _getextra (self , idx : int | slice | tuple | np .ndarray ) -> tuple [* Ts ]:
103
+ # PY312: Default values for TypeVarTuples are not yet supported
104
+ return () # type: ignore[return-value]
105
+
96
106
def __getitem__ (
97
107
self , idx : int | slice | tuple | np .ndarray
98
- ) -> tuple [np .ndarray , np .ndarray | None ]:
108
+ ) -> tuple [np .ndarray , np .ndarray | None , * Ts ]:
99
109
"""
100
110
Returns volume(s) and corresponding affine(s) through fancy indexing.
101
111
@@ -118,7 +128,7 @@ def __getitem__(
118
128
raise ValueError ("No data available (dataobj is None)." )
119
129
120
130
affine = self .motion_affines [idx ] if self .motion_affines is not None else None
121
- return self .dataobj [..., idx ], affine
131
+ return self .dataobj [..., idx ], affine , * self . _getextra ( idx )
122
132
123
133
@classmethod
124
134
def from_filename (cls , filename : Path | str ) -> BaseDataset :
@@ -159,9 +169,8 @@ def set_transform(self, index: int, affine: np.ndarray, order: int = 3) -> None:
159
169
The order of the spline interpolation.
160
170
161
171
"""
162
- reference = namedtuple ("ImageGrid" , ("shape" , "affine" ))(
163
- shape = self .dataobj .shape [:3 ], affine = self .affine
164
- )
172
+ ImageGrid = namedtuple ("ImageGrid" , ("shape" , "affine" ))
173
+ reference = ImageGrid (shape = self .dataobj .shape [:3 ], affine = self .affine )
165
174
166
175
xform = Affine (matrix = affine , reference = reference )
167
176
@@ -227,7 +236,7 @@ def to_filename(
227
236
compression_opts = compression_opts ,
228
237
)
229
238
230
- def to_nifti (self , filename : Path ) -> None :
239
+ def to_nifti (self , filename : Path | str ) -> None :
231
240
"""
232
241
Write a NIfTI file to disk.
233
242
@@ -247,7 +256,7 @@ def load(
247
256
filename : Path | str ,
248
257
brainmask_file : Path | str | None = None ,
249
258
motion_file : Path | str | None = None ,
250
- ) -> BaseDataset :
259
+ ) -> BaseDataset [()] :
251
260
"""
252
261
Load 4D data from a filename or an HDF5 file.
253
262
@@ -279,11 +288,11 @@ def load(
279
288
if filename .name .endswith (NFDH5_EXT ):
280
289
return BaseDataset .from_filename (filename )
281
290
282
- img = nb . load (filename )
283
- retval = BaseDataset (dataobj = img .dataobj , affine = img .affine )
291
+ img = load_api (filename , SpatialImage )
292
+ retval : BaseDataset [()] = BaseDataset (dataobj = np . asanyarray ( img .dataobj ) , affine = img .affine )
284
293
285
294
if brainmask_file :
286
- mask = nb . load (brainmask_file )
295
+ mask = load_api (brainmask_file , SpatialImage )
287
296
retval .brainmask = np .asanyarray (mask .dataobj )
288
297
289
298
return retval
0 commit comments