Skip to content

Commit 5f5ebb1

Browse files
author
dPys
committed
[ENH] Add a series of general-purpose and emc-related image-handling helper functions in a new module utils/images.py, and relocate dangling image helper functions that were previously in interfaces/images.py into this module
1 parent 7e194c3 commit 5f5ebb1

File tree

2 files changed

+205
-36
lines changed

2 files changed

+205
-36
lines changed

dmriprep/interfaces/images.py

+21-36
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""Image tools interfaces."""
22
import numpy as np
33
import nibabel as nb
4+
from dmriprep.utils.images import rescale_b0, median, match_transforms
45
from nipype.utils.filemanip import fname_presuffix
56
from nipype import logging
67
from nipype.interfaces.base import (
7-
traits, TraitedSpec, BaseInterfaceInputSpec, SimpleInterface, File
8+
traits, TraitedSpec, BaseInterfaceInputSpec, SimpleInterface, File,
9+
InputMultiObject, OutputMultiObject
810
)
911

1012
LOGGER = logging.getLogger('nipype.interface')
@@ -103,43 +105,26 @@ def _run_interface(self, runtime):
103105
return runtime
104106

105107

106-
def rescale_b0(in_file, mask_file, newpath=None):
107-
"""Rescale the input volumes using the median signal intensity."""
108-
out_file = fname_presuffix(
109-
in_file, suffix='_rescaled_b0', newpath=newpath)
108+
class MatchTransformsInputSpec(BaseInterfaceInputSpec):
109+
b0_indices = traits.List(mandatory=True)
110+
dwi_files = InputMultiObject(File(exists=True), mandatory=True)
111+
transforms = InputMultiObject(File(exists=True), mandatory=True)
110112

111-
img = nb.load(in_file)
112-
if img.dataobj.ndim == 3:
113-
return in_file
114113

115-
data = img.get_fdata(dtype='float32')
116-
mask_img = nb.load(mask_file)
117-
mask_data = mask_img.get_fdata(dtype='float32')
114+
class MatchTransformsOutputSpec(TraitedSpec):
115+
transforms = OutputMultiObject(File(exists=True), mandatory=True)
118116

119-
median_signal = np.median(data[mask_data > 0, ...], axis=0)
120-
rescaled_data = 1000 * data / median_signal
121-
hdr = img.header.copy()
122-
nb.Nifti1Image(rescaled_data, img.affine, hdr).to_filename(out_file)
123-
return out_file
124117

118+
class MatchTransforms(SimpleInterface):
119+
input_spec = MatchTransformsInputSpec
120+
output_spec = MatchTransformsOutputSpec
125121

126-
def median(in_file, newpath=None):
127-
"""Average a 4D dataset across the last dimension using median."""
128-
out_file = fname_presuffix(
129-
in_file, suffix='_b0ref', newpath=newpath)
130-
131-
img = nb.load(in_file)
132-
if img.dataobj.ndim == 3:
133-
return in_file
134-
if img.shape[-1] == 1:
135-
nb.squeeze_image(img).to_filename(out_file)
136-
return out_file
137-
138-
median_data = np.median(img.get_fdata(dtype='float32'),
139-
axis=-1)
140-
141-
hdr = img.header.copy()
142-
hdr.set_xyzt_units('mm')
143-
hdr.set_data_dtype(np.float32)
144-
nb.Nifti1Image(median_data, img.affine, hdr).to_filename(out_file)
145-
return out_file
122+
def _run_interface(self, runtime):
123+
"""
124+
Interface for mapping the `match_transforms` function across lists of inputs.
125+
"""
126+
self._results["transforms"] = match_transforms(
127+
self.inputs.dwi_files, self.inputs.transforms, self.inputs.b0_indices
128+
)
129+
return runtime
130+

dmriprep/utils/images.py

+184
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import numpy as np
2+
import nibabel as nb
3+
from nipype.utils.filemanip import fname_presuffix
4+
5+
6+
def extract_b0(in_file, b0_ixs, newpath=None):
7+
"""Extract the *b0* volumes from a DWI dataset."""
8+
out_file = fname_presuffix(in_file, suffix="_b0", newpath=newpath)
9+
10+
img = nb.load(in_file)
11+
data = img.get_fdata(dtype="float32")
12+
13+
b0 = data[..., b0_ixs]
14+
15+
hdr = img.header.copy()
16+
hdr.set_data_shape(b0.shape)
17+
hdr.set_xyzt_units("mm")
18+
hdr.set_data_dtype(np.float32)
19+
nb.Nifti1Image(b0, img.affine, hdr).to_filename(out_file)
20+
return out_file
21+
22+
23+
def rescale_b0(in_file, mask_file, newpath=None):
24+
"""Rescale the input volumes using the median signal intensity."""
25+
out_file = fname_presuffix(in_file, suffix="_rescaled_b0", newpath=newpath)
26+
27+
img = nb.load(in_file)
28+
if img.dataobj.ndim == 3:
29+
return in_file
30+
31+
data = img.get_fdata(dtype="float32")
32+
mask_img = nb.load(mask_file)
33+
mask_data = mask_img.get_fdata(dtype="float32")
34+
35+
median_signal = np.median(data[mask_data > 0, ...], axis=0)
36+
rescaled_data = 1000 * data / median_signal
37+
hdr = img.header.copy()
38+
nb.Nifti1Image(rescaled_data, img.affine, hdr).to_filename(out_file)
39+
return out_file
40+
41+
42+
def median(in_file, newpath=None):
43+
"""Average a 4D dataset across the last dimension using median."""
44+
out_file = fname_presuffix(in_file, suffix="_b0ref", newpath=newpath)
45+
46+
img = nb.load(in_file)
47+
if img.dataobj.ndim == 3:
48+
return in_file
49+
if img.shape[-1] == 1:
50+
nb.squeeze_image(img).to_filename(out_file)
51+
return out_file
52+
53+
median_data = np.median(img.get_fdata(dtype="float32"), axis=-1)
54+
55+
hdr = img.header.copy()
56+
hdr.set_xyzt_units("mm")
57+
hdr.set_data_dtype(np.float32)
58+
nb.Nifti1Image(median_data, img.affine, hdr).to_filename(out_file)
59+
return out_file
60+
61+
62+
def average_images(images):
63+
"""Average the voxel-wise signal intensity across a list of 3D image files to produce a 3D mean output image."""
64+
from nilearn.image import mean_img
65+
66+
average_img = mean_img([nb.load(img) for img in images])
67+
output_average_image = fname_presuffix(
68+
images[0], use_ext=False, suffix="_mean.nii.gz"
69+
)
70+
average_img.to_filename(output_average_image)
71+
return output_average_image
72+
73+
74+
def quick_load_images(image_list, dtype=np.float32):
75+
"""Iteratively loads 3D dwi volume files from a list of filepaths directly into a 4d array to use for signal
76+
prediction. A helper function for EMC."""
77+
example_img = nb.load(image_list[0])
78+
num_images = len(image_list)
79+
output_matrix = np.zeros(tuple(example_img.shape) + (num_images,), dtype=dtype)
80+
for image_num, image_path in enumerate(image_list):
81+
output_matrix[..., image_num] = nb.load(image_path).get_fdata(dtype=dtype)
82+
return output_matrix
83+
84+
85+
def match_transforms(dwi_files, transforms, b0_indices):
86+
"""Arranges the order of a list of affine transforms to correspond with that of each individual dwi volume file,
87+
accounting for the indices of B0s. A helper function for EMC."""
88+
original_b0_indices = np.array(b0_indices)
89+
num_dwis = len(dwi_files)
90+
num_transforms = len(transforms)
91+
92+
if num_dwis == num_transforms:
93+
return transforms
94+
95+
# Do sanity checks
96+
if not len(transforms) == len(b0_indices):
97+
raise Exception("number of transforms does not match number of b0 images")
98+
99+
# Create a list of which emc affines go with each of the split images
100+
nearest_affines = []
101+
for index in range(num_dwis):
102+
nearest_b0_num = np.argmin(np.abs(index - original_b0_indices))
103+
this_transform = transforms[nearest_b0_num]
104+
nearest_affines.append(this_transform)
105+
106+
return nearest_affines
107+
108+
109+
def save_4d_to_3d(in_file):
110+
"""Loads a 4D input file and splits it in the 4th dimension to produce a list of 3D output files."""
111+
files_3d = nb.four_to_three(nb.load(in_file))
112+
out_files = []
113+
for i, file_3d in enumerate(files_3d):
114+
out_file = fname_presuffix(in_file, suffix="_tmp_{}".format(i))
115+
file_3d.to_filename(out_file)
116+
out_files.append(out_file)
117+
del files_3d
118+
return out_files
119+
120+
121+
def prune_b0s_from_dwis(in_files, b0_ixs):
122+
"""Removes B0 volume files from a complete list of dwi volume files."""
123+
if in_files[0].endswith("_warped.nii.gz"):
124+
out_files = [
125+
i
126+
for j, i in enumerate(
127+
sorted(
128+
in_files, key=lambda x: int(x.split("_")[-2].split(".nii.gz")[0])
129+
)
130+
)
131+
if j not in b0_ixs
132+
]
133+
else:
134+
out_files = [
135+
i
136+
for j, i in enumerate(
137+
sorted(
138+
in_files, key=lambda x: int(x.split("_")[-1].split(".nii.gz")[0])
139+
)
140+
)
141+
if j not in b0_ixs
142+
]
143+
return out_files
144+
145+
146+
def save_3d_to_4d(in_files):
147+
"""Loads a list of 3D input files and concatenates it to produce a 4D output file."""
148+
img_4d = nb.funcs.concat_images([nb.load(img_3d) for img_3d in in_files])
149+
out_file = fname_presuffix(in_files[0], suffix="_merged")
150+
img_4d.to_filename(out_file)
151+
del img_4d
152+
return out_file
153+
154+
155+
def get_params(A):
156+
"""This is a copy of spm's spm_imatrix where
157+
we already know the rotations and translations matrix,
158+
shears and zooms (as outputs from fsl FLIRT/avscale)
159+
Let A = the 4x4 rotation and translation matrix
160+
R = [ c5*c6, c5*s6, s5]
161+
[-s4*s5*c6-c4*s6, -s4*s5*s6+c4*c6, s4*c5]
162+
[-c4*s5*c6+s4*s6, -c4*s5*s6-s4*c6, c4*c5]
163+
"""
164+
165+
def rang(b):
166+
a = min(max(b, -1), 1)
167+
return a
168+
169+
Ry = np.arcsin(A[0, 2])
170+
# Rx = np.arcsin(A[1, 2] / np.cos(Ry))
171+
# Rz = np.arccos(A[0, 1] / np.sin(Ry))
172+
173+
if (abs(Ry) - np.pi / 2) ** 2 < 1e-9:
174+
Rx = 0
175+
Rz = np.arctan2(-rang(A[1, 0]), rang(-A[2, 0] / A[0, 2]))
176+
else:
177+
c = np.cos(Ry)
178+
Rx = np.arctan2(rang(A[1, 2] / c), rang(A[2, 2] / c))
179+
Rz = np.arctan2(rang(A[0, 1] / c), rang(A[0, 0] / c))
180+
181+
rotations = [Rx, Ry, Rz]
182+
translations = [A[0, 3], A[1, 3], A[2, 3]]
183+
184+
return rotations, translations

0 commit comments

Comments
 (0)