Skip to content

Commit

Permalink
Save all files as per brainreg, save the forward and inverse transfor…
Browse files Browse the repository at this point in the history
…mation parameters for elastix
  • Loading branch information
IgorTatarnikov committed Dec 13, 2024
1 parent e8f745a commit 2b96d46
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 3 deletions.
99 changes: 97 additions & 2 deletions brainglobe_registration/elastix/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def run_registration(

elastix_object.SetParameterObject(parameter_object)

if output_directory:
elastix_object.SetOutputDirectory(str(output_directory))
# if output_directory:
# elastix_object.SetOutputDirectory(str(output_directory))

# update filter object
elastix_object.UpdateLargestPossibleRegion()
Expand All @@ -81,6 +81,15 @@ def run_registration(
result_image = elastix_object.GetOutput()
result_transform_parameters = elastix_object.GetTransformParameterObject()

file_names = [
f"{output_directory}/TransformParameters.{i}.txt"
for i in range(len(parameter_lists))
]

itk.ParameterObject.WriteParameterFile(
result_transform_parameters, file_names
)

return (
np.asarray(result_image),
result_transform_parameters,
Expand Down Expand Up @@ -139,6 +148,36 @@ def transform_annotation_image(
return transformed_annotation_array


def transform_image(
image: npt.NDArray,
transform_parameters: itk.ParameterObject,
) -> npt.NDArray:
"""
Transform the image using the given transform parameters.
Parameters
----------
image: npt.NDArray
The image to transform.
transform_parameters: itk.ParameterObject
The transform parameters.
Returns
-------
npt.NDArray
The transformed image.
"""
image = itk.GetImageViewFromArray(image).astype(itk.F)

transformix_object = itk.TransformixFilter.New(image)
transformix_object.SetTransformParameterObject(transform_parameters)
transformix_object.UpdateLargestPossibleRegion()

transformed_image = transformix_object.GetOutput()

return np.asarray(transformed_image)


def calculate_deformation_field(
moving_image: npt.NDArray,
transform_parameters: itk.ParameterObject,
Expand Down Expand Up @@ -172,9 +211,65 @@ def calculate_deformation_field(
transformix_object.GetOutputDeformationField()
)[..., ::-1]

# Cleanup files generated by elastix
(Path.cwd() / "DeformationField.tiff").unlink(missing_ok=True)

return deformation_field


def invert_transformation(
fixed_image: npt.NDArray,
parameter_list: List[Tuple[str, dict]],
transform_parameters: itk.ParameterObject,
output_directory: Optional[Path] = None,
) -> itk.ParameterObject:

fixed_image = itk.GetImageFromArray(fixed_image).astype(itk.F)

elastix_object = itk.ElastixRegistrationMethod.New(
fixed_image, fixed_image
)

parameter_object_inverse = setup_parameter_object(parameter_list)

elastix_object.SetInitialTransformParameterObject(transform_parameters)

elastix_object.SetParameterObject(parameter_object_inverse)

elastix_object.UpdateLargestPossibleRegion()

num_initial_transforms = transform_parameters.GetNumberOfParameterMaps()

result_image = elastix_object.GetOutput()
out_parameters = elastix_object.GetTransformParameterObject()
result_transform_parameters = itk.ParameterObject.New()

for i in range(
num_initial_transforms, out_parameters.GetNumberOfParameterMaps()
):
result_transform_parameters.AddParameterMap(
out_parameters.GetParameterMap(i)
)

result_transform_parameters.SetParameter(
0, "InitialTransformParameterFileName", "NoInitialTransform"
)

file_names = [
f"{output_directory}/InverseTransformParameters.{i}.txt"
for i in range(len(parameter_list))
]

itk.ParameterObject.WriteParameterFiles(
result_transform_parameters, file_names
)

return (
np.asarray(result_image),
result_transform_parameters,
)


def setup_parameter_object(parameter_lists: List[tuple[str, dict]]):
"""
Set up the parameter object for the registration process.
Expand Down
18 changes: 17 additions & 1 deletion brainglobe_registration/registration_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,10 @@ def _on_open_file_dialog_clicked(self) -> None:
def _on_run_button_click(self):
from brainglobe_registration.elastix.register import (
calculate_deformation_field,
invert_transformation,
run_registration,
transform_annotation_image,
transform_image,
)

if self._atlas_data_layer is None:
Expand Down Expand Up @@ -417,6 +419,15 @@ def _on_run_button_click(self):
self.output_directory,
)

inverse_result, inverse_parameters = invert_transformation(
atlas_layer,
self.transform_selections,
parameters,
self.output_directory,
)

data_in_atlas_space = transform_image(moving_image, inverse_parameters)

registered_annotation_image = transform_annotation_image(
self._atlas_annotations_layer.data[current_atlas_slice, :, :],
parameters,
Expand Down Expand Up @@ -453,6 +464,11 @@ def _on_run_button_click(self):
blending="additive",
opacity=0.8,
)
self._viewer.add_image(
data_in_atlas_space,
name="Inverse Registered Image",
visible=False,
)

self._viewer.grid.enabled = False

Expand All @@ -461,7 +477,7 @@ def _on_run_button_click(self):
boundaries,
deformation_field,
moving_image,
result,
data_in_atlas_space,
result,
registered_annotation_image,
registered_hemisphere,
Expand Down

0 comments on commit 2b96d46

Please sign in to comment.