Skip to content

Commit 3bee02d

Browse files
author
dPys
committed
[ENH] Add a series of general-purpose and emc-related image-handling helper functions in a new module utils/images.py, and relocate dangling image helper functions that were previously in interfaces/images.py into this module
1 parent 7e194c3 commit 3bee02d

File tree

2 files changed

+198
-57
lines changed

2 files changed

+198
-57
lines changed

dmriprep/interfaces/images.py

+20-57
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
"""Image tools interfaces."""
2-
import numpy as np
3-
import nibabel as nb
4-
from nipype.utils.filemanip import fname_presuffix
2+
from dmriprep.utils.images import rescale_b0, median, match_transforms, extract_b0
53
from nipype import logging
64
from nipype.interfaces.base import (
7-
traits, TraitedSpec, BaseInterfaceInputSpec, SimpleInterface, File
5+
traits, TraitedSpec, BaseInterfaceInputSpec, SimpleInterface, File,
6+
InputMultiObject, OutputMultiObject
87
)
98

109
LOGGER = logging.getLogger('nipype.interface')
@@ -45,24 +44,6 @@ def _run_interface(self, runtime):
4544
return runtime
4645

4746

48-
def extract_b0(in_file, b0_ixs, newpath=None):
49-
"""Extract the *b0* volumes from a DWI dataset."""
50-
out_file = fname_presuffix(
51-
in_file, suffix='_b0', newpath=newpath)
52-
53-
img = nb.load(in_file)
54-
data = img.get_fdata(dtype='float32')
55-
56-
b0 = data[..., b0_ixs]
57-
58-
hdr = img.header.copy()
59-
hdr.set_data_shape(b0.shape)
60-
hdr.set_xyzt_units('mm')
61-
hdr.set_data_dtype(np.float32)
62-
nb.Nifti1Image(b0, img.affine, hdr).to_filename(out_file)
63-
return out_file
64-
65-
6647
class _RescaleB0InputSpec(BaseInterfaceInputSpec):
6748
in_file = File(exists=True, mandatory=True, desc='b0s file')
6849
mask_file = File(exists=True, mandatory=True, desc='mask file')
@@ -103,43 +84,25 @@ def _run_interface(self, runtime):
10384
return runtime
10485

10586

106-
def rescale_b0(in_file, mask_file, newpath=None):
107-
"""Rescale the input volumes using the median signal intensity."""
108-
out_file = fname_presuffix(
109-
in_file, suffix='_rescaled_b0', newpath=newpath)
110-
111-
img = nb.load(in_file)
112-
if img.dataobj.ndim == 3:
113-
return in_file
114-
115-
data = img.get_fdata(dtype='float32')
116-
mask_img = nb.load(mask_file)
117-
mask_data = mask_img.get_fdata(dtype='float32')
87+
class MatchTransformsInputSpec(BaseInterfaceInputSpec):
88+
b0_indices = traits.List(mandatory=True)
89+
dwi_files = InputMultiObject(File(exists=True), mandatory=True)
90+
transforms = InputMultiObject(File(exists=True), mandatory=True)
11891

119-
median_signal = np.median(data[mask_data > 0, ...], axis=0)
120-
rescaled_data = 1000 * data / median_signal
121-
hdr = img.header.copy()
122-
nb.Nifti1Image(rescaled_data, img.affine, hdr).to_filename(out_file)
123-
return out_file
12492

93+
class MatchTransformsOutputSpec(TraitedSpec):
94+
transforms = OutputMultiObject(File(exists=True), mandatory=True)
12595

126-
def median(in_file, newpath=None):
127-
"""Average a 4D dataset across the last dimension using median."""
128-
out_file = fname_presuffix(
129-
in_file, suffix='_b0ref', newpath=newpath)
13096

131-
img = nb.load(in_file)
132-
if img.dataobj.ndim == 3:
133-
return in_file
134-
if img.shape[-1] == 1:
135-
nb.squeeze_image(img).to_filename(out_file)
136-
return out_file
137-
138-
median_data = np.median(img.get_fdata(dtype='float32'),
139-
axis=-1)
97+
class MatchTransforms(SimpleInterface):
98+
"""
99+
Interface for mapping the `match_transforms` function across lists of inputs.
100+
"""
101+
input_spec = MatchTransformsInputSpec
102+
output_spec = MatchTransformsOutputSpec
140103

141-
hdr = img.header.copy()
142-
hdr.set_xyzt_units('mm')
143-
hdr.set_data_dtype(np.float32)
144-
nb.Nifti1Image(median_data, img.affine, hdr).to_filename(out_file)
145-
return out_file
104+
def _run_interface(self, runtime):
105+
self._results["transforms"] = match_transforms(
106+
self.inputs.dwi_files, self.inputs.transforms, self.inputs.b0_indices
107+
)
108+
return runtime

dmriprep/utils/images.py

+178
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import numpy as np
2+
import nibabel as nb
3+
from nipype.utils.filemanip import fname_presuffix
4+
5+
6+
def extract_b0(in_file, b0_ixs, newpath=None):
7+
"""Extract the *b0* volumes from a DWI dataset."""
8+
out_file = fname_presuffix(in_file, suffix="_b0", newpath=newpath)
9+
10+
img = nb.load(in_file)
11+
data = img.get_fdata(dtype="float32")
12+
13+
b0 = data[..., b0_ixs]
14+
15+
hdr = img.header.copy()
16+
hdr.set_data_shape(b0.shape)
17+
hdr.set_xyzt_units("mm")
18+
hdr.set_data_dtype(np.float32)
19+
nb.Nifti1Image(b0, img.affine, hdr).to_filename(out_file)
20+
return out_file
21+
22+
23+
def rescale_b0(in_file, mask_file, newpath=None):
24+
"""Rescale the input volumes using the median signal intensity."""
25+
out_file = fname_presuffix(in_file, suffix="_rescaled_b0", newpath=newpath)
26+
27+
img = nb.load(in_file)
28+
if img.dataobj.ndim == 3:
29+
return in_file
30+
31+
data = img.get_fdata(dtype="float32")
32+
mask_img = nb.load(mask_file)
33+
mask_data = mask_img.get_fdata(dtype="float32")
34+
35+
median_signal = np.median(data[mask_data > 0, ...], axis=0)
36+
rescaled_data = 1000 * data / median_signal
37+
hdr = img.header.copy()
38+
nb.Nifti1Image(rescaled_data, img.affine, hdr).to_filename(out_file)
39+
return out_file
40+
41+
42+
def median(in_file, newpath=None):
43+
"""Average a 4D dataset across the last dimension using median."""
44+
out_file = fname_presuffix(in_file, suffix="_b0ref", newpath=newpath)
45+
46+
img = nb.load(in_file)
47+
if img.dataobj.ndim == 3:
48+
return in_file
49+
if img.shape[-1] == 1:
50+
nb.squeeze_image(img).to_filename(out_file)
51+
return out_file
52+
53+
median_data = np.median(img.get_fdata(dtype="float32"), axis=-1)
54+
55+
hdr = img.header.copy()
56+
hdr.set_xyzt_units("mm")
57+
hdr.set_data_dtype(np.float32)
58+
nb.Nifti1Image(median_data, img.affine, hdr).to_filename(out_file)
59+
return out_file
60+
61+
62+
def average_images(images, out_path=None):
63+
"""Average a 4D dataset across the last dimension using mean."""
64+
from nilearn.image import mean_img
65+
66+
average_img = mean_img([nb.load(img) for img in images])
67+
if out_path is None:
68+
out_path = fname_presuffix(
69+
images[0], use_ext=False, suffix="_mean.nii.gz"
70+
)
71+
average_img.to_filename(out_path)
72+
return out_path
73+
74+
75+
def quick_load_images(image_list, dtype=np.float32):
76+
"""Load 3D volumes from a list of file paths into a 4D array."""
77+
example_img = nb.load(image_list[0])
78+
num_images = len(image_list)
79+
output_matrix = np.zeros(tuple(example_img.shape) + (num_images,), dtype=dtype)
80+
for image_num, image_path in enumerate(image_list):
81+
output_matrix[..., image_num] = nb.load(image_path).get_fdata(dtype=dtype)
82+
return output_matrix
83+
84+
85+
def match_transforms(dwi_files, transforms, b0_indices):
86+
"""Arranges the order of a list of affine transforms to correspond with that of
87+
each individual dwi volume file, accounting for the indices of B0s. A helper
88+
function for EMC."""
89+
num_dwis = len(dwi_files)
90+
num_transforms = len(transforms)
91+
92+
if num_dwis == num_transforms:
93+
return transforms
94+
95+
# Do sanity checks
96+
if not len(transforms) == len(b0_indices):
97+
raise Exception("number of transforms does not match number of b0 images")
98+
99+
# Create a list of which emc affines go with each of the split images
100+
nearest_affines = []
101+
for index in range(num_dwis):
102+
nearest_b0_num = np.argmin(np.abs(index - np.array(b0_indices)))
103+
this_transform = transforms[nearest_b0_num]
104+
nearest_affines.append(this_transform)
105+
106+
return nearest_affines
107+
108+
109+
def save_4d_to_3d(in_file):
110+
"""Split a 4D dataset along the last dimension into multiple 3D volumes."""
111+
files_3d = nb.four_to_three(nb.load(in_file))
112+
out_files = []
113+
for i, file_3d in enumerate(files_3d):
114+
out_file = fname_presuffix(in_file, suffix="_tmp_{}".format(i))
115+
file_3d.to_filename(out_file)
116+
out_files.append(out_file)
117+
del files_3d
118+
return out_files
119+
120+
121+
def prune_b0s_from_dwis(in_files, b0_ixs):
122+
"""Remove *b0* volume files from a complete list of DWI volume files."""
123+
if in_files[0].endswith("_warped.nii.gz"):
124+
out_files = [
125+
i
126+
for j, i in enumerate(
127+
sorted(
128+
in_files, key=lambda x: int(x.split("_")[-2].split(".nii.gz")[0])
129+
)
130+
)
131+
if j not in b0_ixs
132+
]
133+
else:
134+
out_files = [
135+
i
136+
for j, i in enumerate(
137+
sorted(
138+
in_files, key=lambda x: int(x.split("_")[-1].split(".nii.gz")[0])
139+
)
140+
)
141+
if j not in b0_ixs
142+
]
143+
return out_files
144+
145+
146+
def save_3d_to_4d(in_files):
147+
"""Concatenate a list of 3D volumes into a 4D output."""
148+
img_4d = nb.funcs.concat_images([nb.load(img_3d) for img_3d in in_files])
149+
out_file = fname_presuffix(in_files[0], suffix="_merged")
150+
img_4d.to_filename(out_file)
151+
del img_4d
152+
return out_file
153+
154+
155+
def get_params(A):
156+
"""Takes an transformation affine matrix A and determines
157+
rotations and translations."""
158+
159+
def rang(b):
160+
a = min(max(b, -1), 1)
161+
return a
162+
163+
Ry = np.arcsin(A[0, 2])
164+
# Rx = np.arcsin(A[1, 2] / np.cos(Ry))
165+
# Rz = np.arccos(A[0, 1] / np.sin(Ry))
166+
167+
if (abs(Ry) - np.pi / 2) ** 2 < 1e-9:
168+
Rx = 0
169+
Rz = np.arctan2(-rang(A[1, 0]), rang(-A[2, 0] / A[0, 2]))
170+
else:
171+
c = np.cos(Ry)
172+
Rx = np.arctan2(rang(A[1, 2] / c), rang(A[2, 2] / c))
173+
Rz = np.arctan2(rang(A[0, 1] / c), rang(A[0, 0] / c))
174+
175+
rotations = [Rx, Ry, Rz]
176+
translations = [A[0, 3], A[1, 3], A[2, 3]]
177+
178+
return rotations, translations

0 commit comments

Comments
 (0)