3030
3131from eddymotion .data .dmri import DWI
3232from eddymotion .estimator import EddyMotionEstimator
33+ from eddymotion .registration .utils import displacements_within_mask
3334
3435
35- def test_proximity_estimator_trivial_model (datadir ):
36+ def test_proximity_estimator_trivial_model (datadir , tmp_path ):
3637 """Check the proximity of transforms estimated by the estimator with a trivial B0 model."""
3738
3839 dwdata = DWI .from_filename (datadir / "dwi.h5" )
3940 b0nii = nb .Nifti1Image (dwdata .bzero , dwdata .affine , None )
41+ masknii = nb .Nifti1Image (dwdata .brainmask .astype (np .uint8 ), dwdata .affine , None )
4042
4143 # Generate a list of large-yet-plausible bulk-head motion.
4244 xfms = nt .linear .LinearTransformsMapping (
@@ -56,8 +58,8 @@ def test_proximity_estimator_trivial_model(datadir):
5658 moved_nii = (~ xfms ).apply (b0nii , reference = b0nii )
5759
5860 # Uncomment to see the moved dataset
59- # moved_nii.to_filename(tmp_path / "test.nii.gz")
60- # xfms.apply(moved_nii).to_filename(tmp_path / "ground_truth.nii.gz")
61+ moved_nii .to_filename (tmp_path / "test.nii.gz" )
62+ xfms .apply (moved_nii ).to_filename (tmp_path / "ground_truth.nii.gz" )
6163
6264 # Wrap into dataset object
6365 dwi_motion = DWI (
@@ -70,7 +72,7 @@ def test_proximity_estimator_trivial_model(datadir):
7072
7173 estimator = EddyMotionEstimator ()
7274 em_affines = estimator .estimate (
73- dwdata = dwi_motion ,
75+ data = dwi_motion ,
7476 models = ("b0" ,),
7577 seed = None ,
7678 align_kwargs = {
@@ -81,14 +83,16 @@ def test_proximity_estimator_trivial_model(datadir):
8183 )
8284
8385 # Uncomment to see the realigned dataset
84- # nt.linear.LinearTransformsMapping(
85- # em_affines,
86- # reference=b0nii,
87- # ).apply(moved_nii).to_filename(tmp_path / "realigned.nii.gz")
86+ nt .linear .LinearTransformsMapping (
87+ em_affines ,
88+ reference = b0nii ,
89+ ).apply (moved_nii ).to_filename (tmp_path / "realigned.nii.gz" )
8890
8991 # For each moved b0 volume
9092 coords = xfms .reference .ndcoords .T
9193 for i , est in enumerate (em_affines ):
92- xfm = nt .linear .Affine (xfms .matrix [i ], reference = b0nii )
93- est = nt .linear .Affine (est , reference = b0nii )
94- assert np .sqrt (((xfm .map (coords ) - est .map (coords )) ** 2 ).sum (1 )).mean () < 0.2
94+ assert displacements_within_mask (
95+ masknii ,
96+ nt .linear .Affine (est ),
97+ xfms [i ],
98+ ).max () < 0.2
0 commit comments