1
+ import numpy as np
2
+ import nibabel as nb
3
+ from nipype .utils .filemanip import fname_presuffix
4
+
5
+
6
+ def extract_b0 (in_file , b0_ixs , newpath = None ):
7
+ """Extract the *b0* volumes from a DWI dataset."""
8
+ out_file = fname_presuffix (in_file , suffix = "_b0" , newpath = newpath )
9
+
10
+ img = nb .load (in_file )
11
+ data = img .get_fdata (dtype = "float32" )
12
+
13
+ b0 = data [..., b0_ixs ]
14
+
15
+ hdr = img .header .copy ()
16
+ hdr .set_data_shape (b0 .shape )
17
+ hdr .set_xyzt_units ("mm" )
18
+ hdr .set_data_dtype (np .float32 )
19
+ nb .Nifti1Image (b0 , img .affine , hdr ).to_filename (out_file )
20
+ return out_file
21
+
22
+
23
+ def rescale_b0 (in_file , mask_file , newpath = None ):
24
+ """Rescale the input volumes using the median signal intensity."""
25
+ out_file = fname_presuffix (in_file , suffix = "_rescaled_b0" , newpath = newpath )
26
+
27
+ img = nb .load (in_file )
28
+ if img .dataobj .ndim == 3 :
29
+ return in_file
30
+
31
+ data = img .get_fdata (dtype = "float32" )
32
+ mask_img = nb .load (mask_file )
33
+ mask_data = mask_img .get_fdata (dtype = "float32" )
34
+
35
+ median_signal = np .median (data [mask_data > 0 , ...], axis = 0 )
36
+ rescaled_data = 1000 * data / median_signal
37
+ hdr = img .header .copy ()
38
+ nb .Nifti1Image (rescaled_data , img .affine , hdr ).to_filename (out_file )
39
+ return out_file
40
+
41
+
42
+ def median (in_file , newpath = None ):
43
+ """Average a 4D dataset across the last dimension using median."""
44
+ out_file = fname_presuffix (in_file , suffix = "_b0ref" , newpath = newpath )
45
+
46
+ img = nb .load (in_file )
47
+ if img .dataobj .ndim == 3 :
48
+ return in_file
49
+ if img .shape [- 1 ] == 1 :
50
+ nb .squeeze_image (img ).to_filename (out_file )
51
+ return out_file
52
+
53
+ median_data = np .median (img .get_fdata (dtype = "float32" ), axis = - 1 )
54
+
55
+ hdr = img .header .copy ()
56
+ hdr .set_xyzt_units ("mm" )
57
+ hdr .set_data_dtype (np .float32 )
58
+ nb .Nifti1Image (median_data , img .affine , hdr ).to_filename (out_file )
59
+ return out_file
60
+
61
+
62
+ def average_images (images ):
63
+ """Average the voxel-wise signal intensity across a list of 3D image files to produce a 3D mean output image."""
64
+ from nilearn .image import mean_img
65
+
66
+ average_img = mean_img ([nb .load (img ) for img in images ])
67
+ output_average_image = fname_presuffix (
68
+ images [0 ], use_ext = False , suffix = "_mean.nii.gz"
69
+ )
70
+ average_img .to_filename (output_average_image )
71
+ return output_average_image
72
+
73
+
74
+ def quick_load_images (image_list , dtype = np .float32 ):
75
+ """Iteratively loads 3D dwi volume files from a list of filepaths directly into a 4d array to use for signal
76
+ prediction. A helper function for EMC."""
77
+ example_img = nb .load (image_list [0 ])
78
+ num_images = len (image_list )
79
+ output_matrix = np .zeros (tuple (example_img .shape ) + (num_images ,), dtype = dtype )
80
+ for image_num , image_path in enumerate (image_list ):
81
+ output_matrix [..., image_num ] = nb .load (image_path ).get_fdata (dtype = dtype )
82
+ return output_matrix
83
+
84
+
85
+ def match_transforms (dwi_files , transforms , b0_indices ):
86
+ """Arranges the order of a list of affine transforms to correspond with that of each individual dwi volume file,
87
+ accounting for the indices of B0s. A helper function for EMC."""
88
+ original_b0_indices = np .array (b0_indices )
89
+ num_dwis = len (dwi_files )
90
+ num_transforms = len (transforms )
91
+
92
+ if num_dwis == num_transforms :
93
+ return transforms
94
+
95
+ # Do sanity checks
96
+ if not len (transforms ) == len (b0_indices ):
97
+ raise Exception ("number of transforms does not match number of b0 images" )
98
+
99
+ # Create a list of which emc affines go with each of the split images
100
+ nearest_affines = []
101
+ for index in range (num_dwis ):
102
+ nearest_b0_num = np .argmin (np .abs (index - original_b0_indices ))
103
+ this_transform = transforms [nearest_b0_num ]
104
+ nearest_affines .append (this_transform )
105
+
106
+ return nearest_affines
107
+
108
+
109
+ def save_4d_to_3d (in_file ):
110
+ """Loads a 4D input file and splits it in the 4th dimension to produce a list of 3D output files."""
111
+ files_3d = nb .four_to_three (nb .load (in_file ))
112
+ out_files = []
113
+ for i , file_3d in enumerate (files_3d ):
114
+ out_file = fname_presuffix (in_file , suffix = "_tmp_{}" .format (i ))
115
+ file_3d .to_filename (out_file )
116
+ out_files .append (out_file )
117
+ del files_3d
118
+ return out_files
119
+
120
+
121
+ def prune_b0s_from_dwis (in_files , b0_ixs ):
122
+ """Removes B0 volume files from a complete list of dwi volume files."""
123
+ if in_files [0 ].endswith ("_warped.nii.gz" ):
124
+ out_files = [
125
+ i
126
+ for j , i in enumerate (
127
+ sorted (
128
+ in_files , key = lambda x : int (x .split ("_" )[- 2 ].split (".nii.gz" )[0 ])
129
+ )
130
+ )
131
+ if j not in b0_ixs
132
+ ]
133
+ else :
134
+ out_files = [
135
+ i
136
+ for j , i in enumerate (
137
+ sorted (
138
+ in_files , key = lambda x : int (x .split ("_" )[- 1 ].split (".nii.gz" )[0 ])
139
+ )
140
+ )
141
+ if j not in b0_ixs
142
+ ]
143
+ return out_files
144
+
145
+
146
+ def save_3d_to_4d (in_files ):
147
+ """Loads a list of 3D input files and concatenates it to produce a 4D output file."""
148
+ img_4d = nb .funcs .concat_images ([nb .load (img_3d ) for img_3d in in_files ])
149
+ out_file = fname_presuffix (in_files [0 ], suffix = "_merged" )
150
+ img_4d .to_filename (out_file )
151
+ del img_4d
152
+ return out_file
153
+
154
+
155
+ def get_params (A ):
156
+ """This is a copy of spm's spm_imatrix where
157
+ we already know the rotations and translations matrix,
158
+ shears and zooms (as outputs from fsl FLIRT/avscale)
159
+ Let A = the 4x4 rotation and translation matrix
160
+ R = [ c5*c6, c5*s6, s5]
161
+ [-s4*s5*c6-c4*s6, -s4*s5*s6+c4*c6, s4*c5]
162
+ [-c4*s5*c6+s4*s6, -c4*s5*s6-s4*c6, c4*c5]
163
+ """
164
+
165
+ def rang (b ):
166
+ a = min (max (b , - 1 ), 1 )
167
+ return a
168
+
169
+ Ry = np .arcsin (A [0 , 2 ])
170
+ # Rx = np.arcsin(A[1, 2] / np.cos(Ry))
171
+ # Rz = np.arccos(A[0, 1] / np.sin(Ry))
172
+
173
+ if (abs (Ry ) - np .pi / 2 ) ** 2 < 1e-9 :
174
+ Rx = 0
175
+ Rz = np .arctan2 (- rang (A [1 , 0 ]), rang (- A [2 , 0 ] / A [0 , 2 ]))
176
+ else :
177
+ c = np .cos (Ry )
178
+ Rx = np .arctan2 (rang (A [1 , 2 ] / c ), rang (A [2 , 2 ] / c ))
179
+ Rz = np .arctan2 (rang (A [0 , 1 ] / c ), rang (A [0 , 0 ] / c ))
180
+
181
+ rotations = [Rx , Ry , Rz ]
182
+ translations = [A [0 , 3 ], A [1 , 3 ], A [2 , 3 ]]
183
+
184
+ return rotations , translations
0 commit comments