diff --git a/dmriprep/utils/registration.py b/dmriprep/utils/registration.py new file mode 100644 index 00000000..5d1dcd0f --- /dev/null +++ b/dmriprep/utils/registration.py @@ -0,0 +1,169 @@ +""" +Linear affine registration tools for motion correction. +""" +import numpy as np +import nibabel as nb +from dipy.align.metrics import CCMetric, EMMetric, SSDMetric +from dipy.align.imaffine import ( + transform_centers_of_mass, + AffineMap, + MutualInformationMetric, + AffineRegistration, +) +from dipy.align.transforms import ( + TranslationTransform3D, + RigidTransform3D, + AffineTransform3D, +) +from nipype.utils.filemanip import fname_presuffix + +syn_metric_dict = {"CC": CCMetric, "EM": EMMetric, "SSD": SSDMetric} + +__all__ = [ + "c_of_mass", + "translation", + "rigid", + "affine", + "affine_registration", +] + + +def apply_affine(moving, static, transform_affine, invert=False): + """Apply an affine to transform an image from one space to another. + + Parameters + ---------- + moving : array + The image to be resampled + + static : array + + Returns + ------- + warped_img : the moving array warped into the static array's space. + + """ + affine_map = AffineMap( + transform_affine, static.shape, static.affine, moving.shape, moving.affine + ) + if invert is True: + warped_arr = affine_map.transform_inverse(np.asarray(moving.dataobj)) + else: + warped_arr = affine_map.transform(np.asarray(moving.dataobj)) + + return nb.Nifti1Image(warped_arr, static.affine) + + +def average_affines(transforms): + affine_list = [np.load(aff) for aff in transforms] + average_affine_file = fname_presuffix( + transforms[0], use_ext=False, suffix="_average.npy" + ) + np.save(average_affine_file, np.mean(affine_list, axis=0)) + return average_affine_file + + +# Affine registration pipeline: +affine_metric_dict = {"MI": MutualInformationMetric, "CC": CCMetric} + + +def c_of_mass( + moving, static, static_affine, moving_affine, reg, starting_affine, params0=None +): + transform = transform_centers_of_mass(static, static_affine, moving, moving_affine) + transformed = transform.transform(moving) + return transformed, transform.affine + + +def translation( + moving, static, static_affine, moving_affine, reg, starting_affine, params0=None +): + transform = TranslationTransform3D() + translation = reg.optimize( + static, + moving, + transform, + params0, + static_affine, + moving_affine, + starting_affine=starting_affine, + ) + + return translation.transform(moving), translation.affine + + +def rigid( + moving, static, static_affine, moving_affine, reg, starting_affine, params0=None +): + transform = RigidTransform3D() + rigid = reg.optimize( + static, + moving, + transform, + params0, + static_affine, + moving_affine, + starting_affine=starting_affine, + ) + return rigid.transform(moving), rigid.affine + + +def affine( + moving, static, static_affine, moving_affine, reg, starting_affine, params0=None +): + transform = AffineTransform3D() + affine = reg.optimize( + static, + moving, + transform, + params0, + static_affine, + moving_affine, + starting_affine=starting_affine, + ) + + return affine.transform(moving), affine.affine + + +def affine_registration( + moving, + static, + nbins, + sampling_prop, + metric, + pipeline, + level_iters, + sigmas, + factors, + params0, +): + """ + Find the affine transformation between two 3D images. + + Parameters + ---------- + + """ + # Define the Affine registration object we'll use with the chosen metric: + use_metric = affine_metric_dict[metric](nbins, sampling_prop) + affreg = AffineRegistration( + metric=use_metric, level_iters=level_iters, sigmas=sigmas, factors=factors + ) + + if not params0: + starting_affine = np.eye(4) + else: + starting_affine = params0 + + # Go through the selected transformation: + for func in pipeline: + transformed, starting_affine = func( + np.asarray(moving.dataobj), + np.asarray(static.dataobj), + static.affine, + moving.affine, + affreg, + starting_affine, + params0, + ) + return nb.Nifti1Image(np.array(transformed), static.affine), starting_affine diff --git a/dmriprep/utils/vectors.py b/dmriprep/utils/vectors.py index d9284ef3..e409d187 100644 --- a/dmriprep/utils/vectors.py +++ b/dmriprep/utils/vectors.py @@ -243,7 +243,7 @@ def normalize_gradients(bvecs, bvals, b0_threshold=B0_THRESHOLD, # Check for bval-bvec discrepancy. if not np.all(b0s == b0_vecs): - raise ValueError( + raise UserWarning( 'Inconsistent bvals and bvecs (%d, %d low-b, respectively).' % (b0s.sum(), b0_vecs.sum())) @@ -375,3 +375,111 @@ def bvecs2ras(affine, bvecs, norm=True, bvec_norm_epsilon=0.2): rotated_bvecs[~b0s] /= norms_bvecs[~b0s, np.newaxis] rotated_bvecs[b0s] = np.zeros(3) return rotated_bvecs + + +def nonoverlapping_qspace_samples(sample_bval, sample_bvec, all_bvals, + all_bvecs, cutoff=2): + """ + Checks the q-space overlap (within some distance) between a sample + and a collection of q-space points. + + Parameters + ---------- + sample_bval : int + A single b-value sampled along the sphere. + sample_bvec : int + A single b-vector sampled along the sphere. + Should correspond to `sample_bval`. + all_bvals : ndarray + A 1D vector of all b-values from the diffusion series. + all_bvecs: ndarray + A 3 x n vector of all vectors from the diffusion series, + where n is the total number of samples. + cutoff : float + A minimal allowable q-space distance between points on + the sphere. + + Returns + ------- + ok_samples : boolean ndarray + True for q-vectors whose spatial distribution along + the sphere is non-overlapping, else False. + + Examples + -------- + >>> bvec1 = np.array([1, 0, 0]) + >>> bvec2 = np.array([1, 0, 0]) + >>> bvec3 = np.array([0, 1, 0]) + >>> bval1 = 1000 + >>> bval2 = 1000 + >>> bval3 = 1000 + >>> all_bvals = np.array([0, bval2, bval3]) + >>> all_bvecs = np.array([np.zeros(3), bvec2, bvec3]) + >>> # Case 1: overlapping + >>> nonoverlapping_qspace_samples(bval1, bvec1, all_bvals, all_bvecs, cutoff=2) + array([ True, False, True]) + >>> all_bvals = np.array([0, bval1, bval2]) + >>> all_bvecs = np.array([np.zeros(3), bvec1, bvec2]) + >>> # Case 2: non-overlapping + >>> nonoverlapping_qspace_samples(bval3, bvec3, all_bvals, all_bvecs, cutoff=2) + array([ True, True, True]) + """ + min_bval = min(min(all_bvals), sample_bval) + max_bval = max(max(all_bvals), sample_bval) + if min_bval == max_bval: + raise ValueError('All b-values are identical') + + all_qvals = np.sqrt(all_bvals - min_bval) + sample_qval = np.sqrt(sample_bval - min_bval) + + # Convert q values to percent of maximum qval + max_qval = max(max(all_qvals), sample_qval) + all_qvals_scaled = all_qvals / max_qval * 100 + scaled_qvecs = all_bvecs * all_qvals_scaled[:, np.newaxis] + scaled_sample_qvec = sample_bvec * (sample_qval / max_qval * 100) + + # Calculate the distance between all qvecs and the sample qvec + ok_samples = ( + np.linalg.norm(scaled_qvecs - scaled_sample_qvec, axis=1) > cutoff + ) * (np.linalg.norm(scaled_qvecs + scaled_sample_qvec, axis=1) > cutoff) + + return ok_samples + + +def _rasb_to_bvec_list(in_rasb): + """ + Create a list of b-vectors from a rasb gradient table. + + Parameters + ---------- + in_rasb : str or os.pathlike + File path to a RAS-B gradient table. + + Returns + ------- + List of b-vectors as floats. + """ + import numpy as np + + ras_b_mat = np.genfromtxt(in_rasb, delimiter="\t") + bvec = [vec for vec in ras_b_mat[:, 0:3] if not np.isclose(all(vec), 0)] + return list(bvec) + + +def _rasb_to_bval_floats(in_rasb): + """ + Create a list of b-values from a rasb gradient table. + + Parameters + ---------- + in_rasb : str or os.pathlike + File path to a RAS-B gradient table. + + Returns + ------- + List of b-values as floats. + """ + import numpy as np + + ras_b_mat = np.genfromtxt(in_rasb, delimiter="\t") + return [float(bval) for bval in ras_b_mat[:, 3] if bval > 0]