Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .cursorignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Add directories or file patterns to ignore during indexing (e.g. foo/ or *.csv)
**/*.pkl
**/*tar.gz
**/*zip
**/*whl
**/*whl.part
**/*.tar
**/*.tar.gz
**/*.tar.bz2
**/*.whl
**/*.whl.part
**/*.whl.part.tar
45 changes: 45 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: Run Pytest Tests

on:
workflow_dispatch: # This enables manual triggering

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.9']

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .
pip install pytest requests # Add requests for downloading

- name: Create test data directory
run: mkdir -p tests/test_data

- name: Download test data from Google Drive
env:
GDRIVE_FILE_ID: ${{ secrets.GDRIVE_FILE_ID }}
GDRIVE_OUTPUT_PATH: tests/test_data/pytest.zip
run: |
if [ -z "$GDRIVE_FILE_ID" ]; then
echo "GDRIVE_FILE_ID secret is not set. Skipping data download."
else
echo "Downloading data from Google Drive..."
# Requires gdown (install with pip install gdown)
pip install gdown
gdown --id $GDRIVE_FILE_ID -O $GDRIVE_OUTPUT_PATH
echo "Data download complete."
tar -xvf tests/test_data/pytest.zip -C tests/test_data
fi

- name: Run pytest
run: pytest tests/
9 changes: 8 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
*pdf
*svg
dist/
build/
**/build/
**pkl
**/*.so
**/*.egg-info
**/__pycache__
**/baselines
**/*.pkl
Expand All @@ -22,3 +24,8 @@ docs/site
.vscode
fireants/scripts/template/evaltemplate/
fireants/scripts/template/saved*/
**/saved_results
notepad
**/*pkl
tests/test_results/
**/*.zip
125 changes: 121 additions & 4 deletions fireants/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from time import time
from fireants.types import devicetype
from fireants.utils.imageutils import integer_to_onehot
from fireants.utils.util import check_and_raise_cond
from fireants.utils.util import check_and_raise_cond, augment_filenames
import logging
from copy import deepcopy
import os
from fireants.utils.globals import PERMITTED_ANTS_WARP_EXT
# logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -69,11 +71,19 @@ def __init__(self, itk_image: sitk.SimpleITK.Image,
# if `is_segmentation` is False, then just treat this as an image with given dtype
if not is_segmentation:
self.array = torch.from_numpy(sitk.GetArrayFromImage(itk_image).astype(float)).to(device, dtype)
self.array.unsqueeze_(0).unsqueeze_(0)
# self.array = self.array[None, None] # TODO: Change it to support multichannel images, right now just batchify and add a dummy channel to it
channels = itk_image.GetNumberOfComponentsPerPixel()
self.channels = channels
assert channels == 1, "Only single channel images supported"
# assert channels == 1, "Only single channel images supported"
if channels > 1:
logger.warning("Image has multiple channels, make sure its not a spatial dimension")
# permute the channel dimension to the front
ndim = self.array.ndim
self.array = self.array.permute([ndim-1] + list(range(ndim-1))).contiguous() # permute to [C, H, W, D] from [H, W, D, C]
else:
self.array.unsqueeze_(0)
# add batch dimension
self.array.unsqueeze_(0)
else:
array = torch.from_numpy(sitk.GetArrayFromImage(itk_image).astype(int)).to(device).long()
# preprocess segmentation if provided by user
Expand Down Expand Up @@ -167,7 +177,6 @@ def concatenate(self, *others, optimize_memory: bool = True):
t2 = Image.load_file(t2_path)
flair = Image.load_file(flair_path)
t1.concatenate(t2, flair, optimize_memory=True) # deletes the arrays of t2 and flair after concatenation

'''
check_and_raise_cond(self.is_array_present, "Image must have a PyTorch tensor representation to concatenate", ValueError)
if isinstance(others[0], list) and len(others) == 1:
Expand Down Expand Up @@ -314,6 +323,114 @@ def get_torch2phy(self):
def get_phy2torch(self):
return self.phy2torch

class FakeBatchedImages:
'''
A class to handle fake batches of images.
This is used to handle the case where the user passes a tensor to the registration class
instead of a BatchedImages object.

We will use the metadata of the BatchedImages object to create a FakeBatchedImages object.
with the content of the tensor.
'''
def __init__(self, tensor: torch.Tensor, batched_images: BatchedImages) -> None:
batched_size = list(deepcopy(batched_images().shape))
tensor_size = list(deepcopy(tensor.shape))
# ignore channel dimension differences
batched_size[1] = 1
tensor_size[1] = 1
check_and_raise_cond(tuple(batched_size) == tuple(tensor_size), "Tensor size must match the size of the batched images", ValueError)
self.tensor = tensor
self.batched_images = batched_images

def __call__(self):
return self.tensor

def get_torch2phy(self):
return self.batched_images.torch2phy

def get_phy2torch(self):
return self.batched_images.phy2torch

@property
def device(self):
return self.tensor.device

@property
def dims(self):
return self.tensor.ndim - 2

@property
def shape(self):
return self.tensor.shape

def write_image(self, filenames: Union[str, List[str]], permitted_ext: List[str] = PERMITTED_ANTS_WARP_EXT):
"""
Save tensor elements to disk as SimpleITK images.

For each image in the batch:
- If multi-channel, the channel dimension is permuted to the end
- If single-channel, the channel dimension is squeezed
- Metadata is copied from the corresponding BatchedImages itk_image

Args:
filenames (str or List[str]): A single filename or a list of filenames.
- If one filename is provided for multiple images, they will be saved as
filename_img0.ext, filename_img1.ext, etc.
- If the number of filenames equals the number of images, they are mapped one-to-one.
- Otherwise, an error is raised.

Raises:
ValueError: If the number of filenames doesn't match the number of images and is not 1.
"""
batch_size = self.tensor.shape[0]

# Convert single filename to list
if isinstance(filenames, str):
filenames = [filenames]

# Check if number of filenames matches number of images
check_and_raise_cond(len(filenames)==1 or len(filenames)==batch_size, "Number of filenames must match the number of images or be 1", ValueError)
filenames = augment_filenames(filenames, batch_size, permitted_ext)

# Process each image in the batch
for i in range(batch_size):
# Get the corresponding tensor
img_tensor = self.tensor[i]

# Check if multi-channel (channel dimension is at index 0 after batch dimension)
channels = img_tensor.shape[0]
isVector = channels > 1

# If multi-channel, permute the channel to the end
if channels > 1:
# For 2D: [C, H, W] -> [H, W, C]
# For 3D: [C, H, W, D] -> [H, W, D, C]
dims = len(img_tensor.shape)
perm = list(range(1, dims)) + [0]
img_tensor = img_tensor.permute(*perm)
else:
# If single channel, squeeze the channel dimension
img_tensor = img_tensor.squeeze(0)

# Convert tensor to numpy array
np_array = img_tensor.detach().cpu().numpy()

# Create SimpleITK image
itk_image = sitk.GetImageFromArray(np_array, isVector=isVector)

# Get metadata from corresponding BatchedImages object
if hasattr(self.batched_images, 'images') and i < len(self.batched_images.images):
src_itk = self.batched_images.images[i].itk_image
itk_image.SetSpacing(src_itk.GetSpacing())
itk_image.SetDirection(src_itk.GetDirection())
itk_image.SetOrigin(src_itk.GetOrigin())
else:
raise ValueError("No corresponding BatchedImages object found for image {}".format(i))

save_filename = filenames[i]
# Save the image
sitk.WriteImage(itk_image, save_filename)
logger.info(f"Saved image to {save_filename}")

if __name__ == '__main__':
from fireants.utils.util import get_tensor_memory_details
Expand Down
57 changes: 53 additions & 4 deletions fireants/registration/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from fireants.utils.util import _assert_check_scales_decreasing
from fireants.losses import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss, NoOp, MeanSquaredError
from torch.optim import SGD, Adam
from fireants.io.image import BatchedImages
from typing import Optional
from fireants.io.image import BatchedImages, FakeBatchedImages
from typing import Optional, Union
from fireants.utils.util import ConvergenceMonitor
from torch.nn import functional as F
from functools import partial
Expand Down Expand Up @@ -126,7 +126,7 @@ def optimize(self):
pass

@abstractmethod
def get_warped_coordinates(self, fixed_images: BatchedImages, moving_images: BatchedImages, shape=None):
def get_warped_coordinates(self, fixed_images: Union[BatchedImages, FakeBatchedImages], moving_images: Union[BatchedImages, FakeBatchedImages], shape=None):
'''Get the transformed coordinates for warping the moving image.

This abstract method must be implemented by all registration classes to define how
Expand Down Expand Up @@ -156,7 +156,51 @@ def get_warped_coordinates(self, fixed_images: BatchedImages, moving_images: Bat
'''
pass

def evaluate(self, fixed_images: BatchedImages, moving_images: BatchedImages, shape=None):
@abstractmethod
def get_inverse_warped_coordinates(self, fixed_images: Union[BatchedImages, FakeBatchedImages], moving_images: Union[BatchedImages, FakeBatchedImages], shape=None):
''' Get inverse warped coordinates for the moving image.

This method is useful to analyse the effect of how the moving coordinates (fixed images) are transformed
'''
pass

def save_moved_images(self, moved_images: Union[BatchedImages, FakeBatchedImages, torch.Tensor], filenames: Union[str, List[str]], moving_to_fixed: bool = True):
'''
Save the moved images to disk.

Args:
moved_images (Union[BatchedImages, FakeBatchedImages, torch.Tensor]): The moved images to save.
filenames (Union[str, List[str]]): The filenames to save the moved images to.
moving_to_fixed (bool, optional): If True, the moving images are saved to the fixed image space. Defaults to True.
if False, we are dealing with an image that is moved from fixed space to moving space
'''
if isinstance(moved_images, BatchedImages):
moved_images_save = FakeBatchedImages(moved_images(), moved_images) # roundabout way to call the fakebatchedimages
elif isinstance(moved_images, torch.Tensor):
moved_images_save = FakeBatchedImages(moved_images, self.fixed_images if moving_to_fixed else self.moving_images)
else:
# if it is already a fakebatchedimages, we can just use it
moved_images_save = moved_images
moved_images_save.write_image(filenames)


def evaluate_inverse(self, fixed_images: Union[BatchedImages, torch.Tensor], moving_images: Union[BatchedImages, torch.Tensor], shape=None, **kwargs):
''' Apply the inverse of the learned transformation to new images.

This method is useful to analyse the effect of how the moving coordinates (fixed images) are transformed
'''
if isinstance(fixed_images, torch.Tensor):
fixed_images = FakeBatchedImages(fixed_images, self.fixed_images)
if isinstance(moving_images, torch.Tensor):
moving_images = FakeBatchedImages(moving_images, self.moving_images)

fixed_arrays = moving_images()
fixed_moved_coords = self.get_inverse_warped_coordinates(fixed_images, moving_images, shape=shape, **kwargs)
fixed_moved_image = F.grid_sample(fixed_arrays, fixed_moved_coords, mode='bilinear', align_corners=True) # [N, C, H, W, [D]]
return fixed_moved_image


def evaluate(self, fixed_images: Union[BatchedImages, torch.Tensor], moving_images: Union[BatchedImages, torch.Tensor], shape=None):
'''Apply the learned transformation to new images.

This method applies the registration transformation learned during optimization
Expand Down Expand Up @@ -185,6 +229,11 @@ def evaluate(self, fixed_images: BatchedImages, moving_images: BatchedImages, sh
The transformation is applied using bilinear interpolation with align_corners=True
to maintain consistency with the optimization process.
'''
if isinstance(fixed_images, torch.Tensor):
fixed_images = FakeBatchedImages(fixed_images, self.fixed_images)
if isinstance(moving_images, torch.Tensor):
moving_images = FakeBatchedImages(moving_images, self.moving_images)

moving_arrays = moving_images()
moved_coords = self.get_warped_coordinates(fixed_images, moving_images, shape=shape)
moved_image = F.grid_sample(moving_arrays, moved_coords, mode='bilinear', align_corners=True) # [N, C, H, W, [D]]
Expand Down
Loading