Skip to content
8 changes: 8 additions & 0 deletions fireants/registration/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from functools import partial
from fireants.utils.imageutils import is_torch_float_type
from fireants.interpolator import fireants_interpolator
from fireants.utils.util import get_min_dim
import logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -157,6 +158,13 @@ def __init__(self,
if hasattr(self.loss_fn, 'set_scales'):
logger.info("Setting scales for loss function")
self.loss_fn.set_scales(self.scales)

# set min dim for img_size
fixed_arrays = self.fixed_images()
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to get fixed_arrays and moving_arrays? There is a shape attribute of the BatchedImages object that you can use directly

moving_arrays = self.moving_images()
fixed_size = fixed_arrays.shape[2:]
moving_size = moving_arrays.shape[2:]
self.min_dim = get_min_dim(fixed_size + moving_size)

self.print_init_msg()

Expand Down
4 changes: 2 additions & 2 deletions fireants/registration/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,8 @@ def optimize(self):
self.convergence_monitor.reset()
prev_loss = np.inf
# downsample fixed array and retrieve coords
size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in fixed_size]
mov_size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in moving_arrays.shape[2:]]
size_down = [max(int(s / scale), self.min_dim) for s in fixed_size]
mov_size_down = [max(int(s / scale), self.min_dim) for s in moving_arrays.shape[2:]]
# downsample
if self.blur and scale > 1:
sigmas = 0.5 * torch.tensor([sz/szdown for sz, szdown in zip(fixed_size, size_down)], device=fixed_arrays.device, dtype=moving_arrays.dtype)
Expand Down
9 changes: 5 additions & 4 deletions fireants/registration/deformation/compositive.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
from fireants.utils.imageutils import scaling_and_squaring, _find_integrator_n
from fireants.types import devicetype
from fireants.losses.cc import gaussian_1d, separable_filtering
from fireants.utils.util import grad_smoothing_hook
from fireants.utils.util import grad_smoothing_hook, get_min_dim
from fireants.utils.imageutils import jacobian
from fireants.registration.optimizers.sgd import WarpSGD
from fireants.registration.optimizers.adam import WarpAdam
from fireants.utils.globals import MIN_IMG_SIZE
from fireants.utils.globals import MIN_IMG_SIZE, MIN_IMG_SHARDED_SIZE

from logging import getLogger
from copy import deepcopy
Expand Down Expand Up @@ -58,11 +58,12 @@ def __init__(self, fixed_images: BatchedImages, moving_images: BatchedImages,
self.device = fixed_images.device
if optimizer_lr > 1:
getLogger("CompositiveWarp").warning(f'optimizer_lr is {optimizer_lr}, which is very high. Unexpected registration may occur.')

self.min_dim = get_min_dim(spatial_dims)

# define warp and register it as a parameter
# set size
if init_scale > 1:
spatial_dims = [max(int(s / init_scale), MIN_IMG_SIZE) for s in spatial_dims]
spatial_dims = [max(int(s / init_scale), self.min_dim) for s in spatial_dims]
warp = torch.zeros([num_images, *spatial_dims, self.n_dims], dtype=dtype, device=fixed_images.device) # [N, HWD, dims]
self.register_parameter('warp', nn.Parameter(warp))

Expand Down
8 changes: 4 additions & 4 deletions fireants/registration/greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def get_inverse_warp_parameters(self, fixed_images: Union[BatchedImages, FakeBat
if shape is None:
shape = moving_images.shape if use_moving_shape else fixed_images.shape
else:
shape = [moving_arrays.shape[0], 1] + list(shape) if use_moving_shape else [fixed_arrays.shape[0], 1] + list(shape)
shape = [moving_arrays.shape[0], 1] + list(shape) if use_moving_shape else [fixed_images.shape[0], 1] + list(shape)

warp = self.warp.get_warp().detach().clone()
warp_inv = compositive_warp_inverse(moving_images if use_moving_shape else fixed_images, warp, displacement=True)
Expand Down Expand Up @@ -277,9 +277,9 @@ def optimize(self):
# multi-scale optimization
for scale, iters in zip(self.scales, self.iterations):
self.convergence_monitor.reset()
# resize images
size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in fixed_size]
moving_size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in moving_size]
# resize images
size_down = [max(int(s / scale), self.min_dim) for s in fixed_size]
moving_size_down = [max(int(s / scale), self.min_dim) for s in moving_size]
if self.blur and scale > 1:
sigmas = 0.5 * torch.tensor([sz/szdown for sz, szdown in zip(fixed_size, size_down)], device=fixed_arrays.device, dtype=fixed_arrays.dtype)
gaussians = [gaussian_1d(s, truncated=2) for s in sigmas]
Expand Down
4 changes: 2 additions & 2 deletions fireants/registration/moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ def downsample_images(self, fixed_arrays, moving_arrays):
fixed_size = fixed_arrays.shape[2:]
moving_size = moving_arrays.shape[2:]
scale = self.scales[0]
size_down_f = [max(int(s / scale), MIN_IMG_SIZE) for s in fixed_size]
size_down_m = [max(int(s / scale), MIN_IMG_SIZE) for s in moving_size]
size_down_f = [max(int(s / scale), self.min_dim) for s in fixed_size]
size_down_m = [max(int(s / scale), self.min_dim) for s in moving_size]
# downsample
if self.blur:
# blur and downsample for higher scale
Expand Down
4 changes: 2 additions & 2 deletions fireants/registration/rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,8 @@ def optimize(self):
self.convergence_monitor.reset()
prev_loss = np.inf
# downsample fixed array and retrieve coords
size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in fixed_size]
mov_size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in moving_arrays.shape[2:]]
size_down = [max(int(s / scale), self.min_dim) for s in fixed_size]
mov_size_down = [max(int(s / scale), self.min_dim) for s in moving_arrays.shape[2:]]
# downsample
if self.blur and scale > 1:
sigmas = 0.5 * torch.tensor([sz/szdown for sz, szdown in zip(fixed_size, size_down)], device=fixed_arrays.device, dtype=moving_arrays.dtype)
Expand Down
2 changes: 1 addition & 1 deletion fireants/registration/syn.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def optimize(self):
for scale, iters in zip(self.scales, self.iterations):
self.convergence_monitor.reset()
# resize images
size_down = [max(int(s / scale), MIN_IMG_SIZE) for s in fixed_size]
size_down = [max(int(s / scale), self.min_dim) for s in fixed_size]
if self.blur and scale > 1:
sigmas = 0.5 * torch.tensor([sz/szdown for sz, szdown in zip(fixed_size, size_down)], device=fixed_arrays.device, dtype=fixed_arrays.dtype)
gaussians = [gaussian_1d(s, truncated=2) for s in sigmas]
Expand Down
8 changes: 8 additions & 0 deletions fireants/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import logging
from typing import Optional
from fireants.interpolator import fireants_interpolator
from fireants.utils.globals import MIN_IMG_SIZE, MIN_IMG_SHARDED_SIZE
import SimpleITK as sitk
logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -262,3 +263,10 @@ def collate_fireants_fn(batch):
def check_and_raise_cond(cond: bool, msg: str, error_type: Exception = ValueError):
if not cond:
raise error_type(msg)

def get_min_dim(sizes: List[int]):
sizes = [x for x in sizes if x > 0]
minimax_dim = min([2**int((np.log2(x))) for x in sizes])
if minimax_dim < MIN_IMG_SHARDED_SIZE:
raise ValueError(f"One of fixed or moving image dimensions is too small, absolute min dimension size is {MIN_IMG_SHARDED_SIZE}, recommended min dimension size is {MIN_IMG_SIZE}")
return min(MIN_IMG_SIZE, minimax_dim)
119 changes: 119 additions & 0 deletions tests/test_lowdim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import pytest
import torch
import numpy as np
from pathlib import Path
import SimpleITK as sitk

from fireants.registration.affine import AffineRegistration
from fireants.registration.greedy import GreedyRegistration
from fireants.registration.syn import SyNRegistration
from fireants.io.image import Image, BatchedImages
from fireants.io import FakeBatchedImages
from .test_moments_registration import generate_3d_ellipse
try:
from .conftest import dice_loss
except ImportError:
from conftest import dice_loss

def create_synthetic_data_np(size):
rng = np.random.RandomState(42)

# Generate random axes lengths
a = rng.uniform(size // 4, 3*size // 8)
b = rng.uniform(3*size // 16, a-5)
c = rng.uniform(2*size // 16, b-5)
axes = (a, b, c)

# Generate fixed image
fixed_center = (2 * size // 32, -3*size // 32, size // 32)
fixed_angles = (np.pi/6, np.pi/4, -np.pi/3)
fixed_arr = generate_3d_ellipse(size=size, axes=axes, center=fixed_center,
angles=fixed_angles, rng=rng)

# Generate moving image with different center and rotation
moving_center = (-4*size // 32, size // 32, -2*size // 32)
moving_angles = (-np.pi/3, -np.pi/6, np.pi/2)
moving_arr = generate_3d_ellipse(size=size, axes=axes, center=moving_center,
angles=moving_angles, rng=rng)

return fixed_arr, moving_arr


def test_lowdim():
test_data_dir = Path(__file__).parent / "test_data"
fixed_image_path = str(test_data_dir / "deformable_image_1.nii.gz")
moving_image_path = str(test_data_dir / "deformable_image_2.nii.gz")

downscales = [24, 12, 6]
expected_zdims = [128 // (128 // zdim) for zdim in downscales]
for downscale, expected_zdim in zip(downscales, expected_zdims):
if any(not Path(f).exists() for f in [fixed_image_path, moving_image_path]):
fixed_np, moving_np = create_synthetic_data_np(128)
else:
fixed_img = Image.load_file(fixed_image_path)
moving_img = Image.load_file(moving_image_path)
fixed_np = sitk.GetArrayFromImage(fixed_img.itk_image)
moving_np = sitk.GetArrayFromImage(moving_img.itk_image)
expected_zdim = fixed_np.shape[2] // (fixed_np.shape[2] // downscale)

# Scale down
fixed_np = fixed_np[:,:,::128//downscale]
moving_np = moving_np[:,:,::128//downscale]

fixed_dims = fixed_np.shape
moving_dims = moving_np.shape

fixed_itk = sitk.GetImageFromArray(fixed_np)
moving_itk = sitk.GetImageFromArray(moving_np)

fixed_itk.SetSpacing((1.0, 1.0, 128//downscale))
moving_itk.SetSpacing((1.0, 1.0, 128//downscale))

fixed_img = Image(fixed_itk, device='cuda')
moving_img = Image(moving_itk, device='cuda')

fixed_batch = BatchedImages([fixed_img])
moving_batch = BatchedImages([moving_img])

# Test AffineRegistration
reg = AffineRegistration(
scales=[4, 2, 1],
iterations=[200, 100, 50],
fixed_images=fixed_batch,
moving_images=moving_batch,
loss_type='mse',
optimizer='Adam',
optimizer_lr=3e-2,
)
assert reg.min_dim <= min(fixed_dims), f"Min dimension is {reg.min_dim}, expected {min(fixed_dims)}"
reg.optimize()

# Test GreedyRegistration
reg = GreedyRegistration(
scales=[4, 2, 1],
iterations=[200, 100, 50],
fixed_images=fixed_batch,
moving_images=moving_batch,
loss_type='cc',
optimizer='Adam',
optimizer_lr=0.2,
smooth_warp_sigma=0.25,
smooth_grad_sigma=0.5
)
assert reg.min_dim <= min(fixed_dims), f"Min dimension is {reg.min_dim}, expected {min(fixed_dims)}"
reg.optimize()

# Test SynRegistration
reg = SyNRegistration(
scales=[4, 2, 1],
iterations=[200, 100, 50],
fixed_images=fixed_batch,
moving_images=moving_batch,
loss_type='cc',
optimizer='Adam',
optimizer_lr=0.2,
smooth_warp_sigma=0.25,
smooth_grad_sigma=0.5
)
assert reg.min_dim <= min(fixed_dims), f"Min dimension is {reg.min_dim}, expected {min(fixed_dims)}"
reg.optimize()