Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ants read warp transform #638

Merged
merged 3 commits into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions ants/core/ants_transform_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ def transform_from_displacement_field(field):
"""
if not isinstance(field, iio.ANTsImage):
raise ValueError("field must be ANTsImage type")
if field.dimension < 2 or field.dimension > 3:
raise ValueError("Unsupported displacement field dimension: %i" % field.dimension)
if field.components != field.dimension:
raise ValueError("Displacement field must have same number of components as the image dimension")
libfn = utils.get_lib_fn("antsTransformFromDisplacementField")
field = field.clone("float")
txptr = libfn(field.pointer)
Expand All @@ -263,6 +267,7 @@ def transform_from_displacement_field(field):
pointer=txptr,
)


def transform_to_displacement_field(xfrm, ref):
"""
Convert displacement field ANTsTransform to displacement field
Expand Down Expand Up @@ -299,6 +304,7 @@ def transform_to_displacement_field(xfrm, ref):
field_ptr = libfn(xfrm.pointer, ref.pointer)
return iio2.from_pointer(field_ptr)


def read_transform(filename, precision="float"):
"""
Read a transform from file
Expand Down Expand Up @@ -329,7 +335,9 @@ def read_transform(filename, precision="float"):
if not os.path.exists(filename):
raise ValueError("filename does not exist!")

# intentionally ignore dimension
if filename.endswith('.nii') or filename.endswith('.nii.gz'):
return transform_from_displacement_field(iio2.image_read(filename))

libfn1 = utils.get_lib_fn("getTransformDimensionFromFile")
dimensionUse = libfn1(filename)

Expand Down Expand Up @@ -377,7 +385,7 @@ def write_transform(transform, filename):
"""
if not isinstance(transform, tio.ANTsTransform):
raise Exception('Only ANTsTransform instances can be written to file. Check that you are not passing in a filepath to a saved transform.')

filename = os.path.expanduser(filename)
libfn = utils.get_lib_fn("writeTransform")
libfn(transform.pointer, filename)
31 changes: 21 additions & 10 deletions tests/test_core_ants_transform_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@ def setUp(self):
self.txs = [tx2d, tx3d]
self.pixeltypes = ['unsigned char', 'unsigned int', 'float']

self.matrix_offset_types = ['AffineTransform',
self.matrix_offset_types = ['AffineTransform',
'CenteredAffineTransform',
'Euler2DTransform',
'Euler3DTransform',
'Rigid2DTransform',
'QuaternionRigidTransform',
'Similarity2DTransform',
'CenteredSimilarity2DTransform',
'Similarity3DTransform',
'Similarity3DTransform',
'CenteredRigid2DTransform',
'CenteredEuler3DTransform',
'CenteredEuler3DTransform',
'Rigid3DTransform']

def tearDown(self):
Expand Down Expand Up @@ -95,16 +95,27 @@ def test_read_write_transform(self):
# file doesnt exist
with self.assertRaises(Exception):
ants.read_transform('blah-blah.mat')


def test_from_displacement_components(self):
vec_np = np.ndarray((2,2,3), dtype=np.float32)
vec = ants.from_numpy(vec_np, origin=(0,0), spacing=(1,1), has_components=True)
# should get ValueError here because the 2D vector field has 3 components
with self.assertRaises(ValueError):
ants.transform_from_displacement_field(vec)
vec_np = np.ndarray((2,2,2,3), dtype=np.float32)
vec = ants.from_numpy(vec_np, origin=(0,0,0), spacing=(1,1,1), has_components=True)
# should work here because the 3D vector field has 3 components
tx = ants.transform_from_displacement_field(vec)

def test_from_displacement(self):
fi = ants.image_read(ants.get_ants_data('r16') )
mi = ants.image_read(ants.get_ants_data('r64') )
fi = ants.resample_image(fi,(60,60),1,0)
mi = ants.resample_image(mi,(60,60),1,0) # speed up
mytx = ants.registration(fixed=fi, moving=mi, type_of_transform = ('SyN') )
vec = ants.image_read( mytx['fwdtransforms'][0] )
atx = ants.transform_from_displacement_field( vec )
# read transform, which calls transform_from_displacement_field
atx = ants.read_transform( mytx['fwdtransforms'][0] )

def test_to_displacement(self):
fi = ants.image_read(ants.get_ants_data('r16') )
mi = ants.image_read(ants.get_ants_data('r64') )
Expand All @@ -113,11 +124,11 @@ def test_to_displacement(self):
mytx = ants.registration(fixed=fi, moving=mi, type_of_transform = ('SyN') )
vec = ants.image_read( mytx['fwdtransforms'][0] )
atx = ants.transform_from_displacement_field( vec )
field = ants.transform_to_displacement_field( atx, fi )
field = ants.transform_to_displacement_field( atx, fi )

def test_catch_error(self):
with self.assertRaises(Exception):
ants.write_transform(123, 'test.mat')
ants.write_transform(123, 'test.mat')


if __name__ == '__main__':
Expand Down