Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Indexing images returns cropped images #626

Closed
wants to merge 17 commits into from
140 changes: 102 additions & 38 deletions ants/core/ants_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,47 +44,74 @@

class ANTsImage(object):

def __init__(self, pixeltype='float', dimension=3, components=1, pointer=None, is_rgb=False):
def __init__(self, pointer):
"""
Initialize an ANTsImage
Initialize an ANTsImage.

Creating an ANTsImage requires a pointer to an underlying ITK image that
is stored via a nanobind class wrapping a AntsImage struct.

Arguments
---------
pixeltype : string
ITK pixeltype of image

dimension : integer
number of image dimension. Does NOT include components dimension

components : integer
number of pixel components in the image

pointer : py::capsule (optional)
pybind11 capsule holding the pointer to the underlying ITK image object
pointer : nb::class
nanobind class wrapping the struct holding the pointer to the underlying ITK image object

"""
## Attributes which cant change without creating a new ANTsImage object
self.pointer = pointer
self.pixeltype = pixeltype
self.dimension = dimension
self.components = components
self.has_components = self.components > 1
self.dtype = _itk_to_npy_map[self.pixeltype]
self.is_rgb = is_rgb

self._pixelclass = 'vector' if self.has_components else 'scalar'
self._shortpclass = 'V' if self._pixelclass == 'vector' else ''
if is_rgb:
self._pixelclass = 'rgb'
self._shortpclass = 'RGB'

self._libsuffix = '%s%s%i' % (self._shortpclass, utils.short_ptype(self.pixeltype), self.dimension)

self.shape = tuple(utils.get_lib_fn('getShape')(self.pointer))
self.physical_shape = tuple([round(sh*sp,3) for sh,sp in zip(self.shape, self.spacing)])

self._array = None

@property
def _libsuffix(self):
return str(type(self.pointer)).split('AntsImage')[-1].split("'")[0]

@property
def shape(self):
return tuple(utils.get_lib_fn('getShape')(self.pointer))

@property
def physical_shape(self):
return tuple([round(sh*sp,3) for sh,sp in zip(self.shape, self.spacing)])

@property
def is_rgb(self):
return 'RGB' in self._libsuffix

@property
def has_components(self):
suffix = self._libsuffix
return suffix.startswith('V') or suffix.startswith('RGB')

@property
def components(self):
if not self.has_components:
return 1

libfn = utils.get_lib_fn('getComponents')
return libfn(self.pointer)

@property
def pixeltype(self):
ptype = self._libsuffix[:-1]
if self.has_components:
if self.is_rgb:
ptype = ptype[3:]
else:
ptype = ptype[1:]

ptype_map = {'UC': 'unsigned char',
'UI': 'unsigned int',
'F': 'float',
'D': 'double'}
return ptype_map[ptype]

@property
def dtype(self):
return _itk_to_npy_map[self.pixeltype]

@property
def dimension(self):
return int(self._libsuffix[-1])

@property
def spacing(self):
"""
Expand Down Expand Up @@ -560,19 +587,56 @@ def __ne__(self, other):
return self.new_image_like(new_array.astype('uint8'))

def __getitem__(self, idx):
if self._array is None:
self._array = self.numpy()

if self.has_components:
return utils.merge_channels([
img[idx] for img in utils.split_channels(self)
])

if isinstance(idx, ANTsImage):
if not image_physical_space_consistency(self, idx):
raise ValueError('images do not occupy same physical space')
return self._array.__getitem__(idx.numpy().astype('bool'))
else:
return self._array.__getitem__(idx)
return self.numpy().__getitem__(idx.numpy().astype('bool'))

ndim = len(idx)
sizes = list(self.shape)
starts = [0] * ndim

for i in range(ndim):
ti = idx[i]
if isinstance(ti, slice):
if ti.start:
starts[i] = ti.start
if ti.stop:
sizes[i] = ti.stop - starts[i]
else:
sizes[i] = self.shape[i] - starts[i]

if ti.stop and ti.start:
if ti.stop < ti.start:
raise Exception('Reverse indexing is not supported.')

elif isinstance(ti, int):
starts[i] = ti
sizes[i] = 0

if sizes[i] == 0:
ndim -= 1

if ndim < 2:
return self.numpy().__getitem__(idx)

libfn = utils.get_lib_fn('getItem%i' % ndim)
new_ptr = libfn(self.pointer, starts, sizes)
new_image = ANTsImage(pixeltype=self.pixeltype, dimension=ndim,
components=self.components, pointer=new_ptr)
return new_image


def __setitem__(self, idx, value):
arr = self.view()
if isinstance(value, ANTsImage):
value = value.numpy()

if isinstance(idx, ANTsImage):
if not image_physical_space_consistency(self, idx):
raise ValueError('images do not occupy same physical space')
Expand Down
2 changes: 1 addition & 1 deletion ants/core/ants_image_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,4 +646,4 @@ def image_write(image, filename, ri=False):
image.to_file(filename)

if ri:
return image
return image
2 changes: 1 addition & 1 deletion ants/utils/weingarten_image_curvature.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def weingarten_image_curvature(image, sigma=1.0, opt='mean'):
temp = np.zeros(list(d)+[10])
for k in range(1,7):
voxvals = image[:d[0],:d[1]]
temp[:d[0],:d[1],k] = voxvals
temp[:d[0],:d[1],k] = voxvals.numpy()
temp = core.from_numpy(temp)
myspc = image.spacing
myspc = list(myspc) + [min(myspc)]
Expand Down
12 changes: 6 additions & 6 deletions ants/viz/plot_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,17 +158,17 @@ def reorient_slice(x, axis):

def slice_image(img, axis, idx):
if axis == 0:
return img[idx, :, :]
return img[idx, :, :].numpy()
elif axis == 1:
return img[:, idx, :]
return img[:, idx, :].numpy()
elif axis == 2:
return img[:, :, idx]
return img[:, :, idx].numpy()
elif axis == -1:
return img[:, :, idx]
return img[:, :, idx].numpy()
elif axis == -2:
return img[:, idx, :]
return img[:, idx, :].numpy()
elif axis == -3:
return img[idx, :, :]
return img[idx, :, :].numpy()
else:
raise ValueError("axis %i not valid" % axis)

Expand Down
78 changes: 78 additions & 0 deletions src/antsGetItem.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#include <nanobind/nanobind.h>
#include <nanobind/stl/vector.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/tuple.h>
#include <nanobind/stl/list.h>
#include <nanobind/ndarray.h>
#include <nanobind/stl/shared_ptr.h>

#include "itkImage.h"
#include <itkExtractImageFilter.h>

#include "antsImage.h"

namespace nb = nanobind;
using namespace nb::literals;

template <typename ImageType, class PixelType, unsigned int ndim>
AntsImage<itk::Image<PixelType, ndim>> getItem( AntsImage<ImageType> & antsImage,
std::vector<unsigned long> starts,
std::vector<unsigned long> sizes )
{
typename ImageType::Pointer image = antsImage.ptr;

using OutImageType = itk::Image<PixelType, ndim>;

typename ImageType::IndexType desiredStart;
typename ImageType::SizeType desiredSize;

for( int i = 0 ; i < starts.size(); ++i )
{
desiredStart[i] = starts[i];
desiredSize[i] = sizes[i];
}

typename ImageType::RegionType desiredRegion(desiredStart, desiredSize);

using FilterType = itk::ExtractImageFilter<ImageType, OutImageType>;
typename FilterType::Pointer filter = FilterType::New();
filter->SetExtractionRegion(desiredRegion);
filter->SetInput(image);
filter->SetDirectionCollapseToIdentity(); // This is required.
filter->Update();

FixNonZeroIndex<OutImageType>( filter->GetOutput() );
AntsImage<OutImageType> outImage = { filter->GetOutput() };
return outImage;
}


void local_antsGetItem(nb::module_ &m) {
m.def("getItem2", &getItem<itk::Image<float,2>, float, 2>);
m.def("getItem2", &getItem<itk::Image<float,3>, float, 2>);
m.def("getItem2", &getItem<itk::Image<float,4>, float, 2>);
m.def("getItem3", &getItem<itk::Image<float,3>, float, 3>);
m.def("getItem3", &getItem<itk::Image<float,4>, float, 3>);
m.def("getItem4", &getItem<itk::Image<float,4>, float, 4>);

m.def("getItem2", &getItem<itk::Image<unsigned char,2>, unsigned char, 2>);
m.def("getItem2", &getItem<itk::Image<unsigned char,3>, unsigned char, 2>);
m.def("getItem2", &getItem<itk::Image<unsigned char,4>, unsigned char, 2>);
m.def("getItem3", &getItem<itk::Image<unsigned char,3>, unsigned char, 3>);
m.def("getItem3", &getItem<itk::Image<unsigned char,4>, unsigned char, 3>);
m.def("getItem4", &getItem<itk::Image<unsigned char,4>, unsigned char, 4>);

m.def("getItem2", &getItem<itk::Image<unsigned int,2>, unsigned int, 2>);
m.def("getItem2", &getItem<itk::Image<unsigned int,3>, unsigned int, 2>);
m.def("getItem2", &getItem<itk::Image<unsigned int,4>, unsigned int, 2>);
m.def("getItem3", &getItem<itk::Image<unsigned int,3>, unsigned int, 3>);
m.def("getItem3", &getItem<itk::Image<unsigned int,4>, unsigned int, 3>);
m.def("getItem4", &getItem<itk::Image<unsigned int,4>, unsigned int, 4>);

m.def("getItem2", &getItem<itk::Image<double,2>, double, 2>);
m.def("getItem2", &getItem<itk::Image<double,3>, double, 2>);
m.def("getItem2", &getItem<itk::Image<double,4>, double, 2>);
m.def("getItem3", &getItem<itk::Image<double,3>, double, 3>);
m.def("getItem3", &getItem<itk::Image<double,4>, double, 3>);
m.def("getItem4", &getItem<itk::Image<double,4>, double, 4>);
}
3 changes: 3 additions & 0 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "antiAlias.cxx"
#include "antsImage.cxx"
#include "antsImageClone.cxx"
#include "antsGetItem.cxx"
#include "antsImageHeaderInfo.cxx"
#include "antsImageMutualInformation.cxx"
#include "antsImageToImageMetric.cxx"
Expand Down Expand Up @@ -62,6 +63,7 @@ void local_addNoiseToImage(nb::module_ &);
void local_antiAlias(nb::module_ &);
void local_antsImage(nb::module_ &);
void local_antsImageClone(nb::module_ &);
void local_antsGetItem(nb::module_ &);
void local_antsImageHeaderInfo(nb::module_ &);
void local_antsImageMutualInformation(nb::module_ &);
void local_antsImageToImageMetric(nb::module_ &);
Expand Down Expand Up @@ -119,6 +121,7 @@ NB_MODULE(lib, m) {
local_antiAlias(m);
local_antsImage(m);
local_antsImageClone(m);
local_antsGetItem(m);
local_antsImageHeaderInfo(m);
local_antsImageMutualInformation(m);
local_antsImageToImageMetric(m);
Expand Down
1 change: 1 addition & 0 deletions tests/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pushd "$(dirname "$0")"

echo "Running core tests"
$PYCMD test_core_ants_image.py $@
$PYCMD test_core_ants_image_indexing.py $@
$PYCMD test_core_ants_image_io.py $@
$PYCMD test_core_ants_transform.py $@
$PYCMD test_core_ants_transform_io.py $@
Expand Down
36 changes: 17 additions & 19 deletions tests/test_core_ants_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import ants


class TestClass_ANTsImage(unittest.TestCase):
"""
Test ants.ANTsImage class
Expand Down Expand Up @@ -484,26 +484,24 @@ def test__ne__(self):
img2.set_spacing([2.31]*img.dimension)
img3 = img != img2

def test__getitem__(self):
#self.setUp()
for img in self.imgs:
if img.dimension == 2:
img2 = img[6:9,6:9]
nptest.assert_allclose(img2, img.numpy()[6:9,6:9])
elif img.dimension == 3:
img2 = img[6:9,6:9,6:9]
nptest.assert_allclose(img2, img.numpy()[6:9,6:9,6:9])

# get from another image
img2 = img.clone()
xx = img[img2]
with self.assertRaises(Exception):
# different physical space
img2.set_direction(img.direction*2)
xx = img[img2]
#def test__getitem__(self):
# for img in self.imgs:
# if img.dimension == 2:
# img2 = img[6:9,6:9]
# nptest.assert_allclose(img2, img.numpy()[6:9,6:9])
# elif img.dimension == 3:
# img2 = img[6:9,6:9,6:9]
# nptest.assert_allclose(img2, img.numpy()[6:9,6:9,6:9])
#
# # get from another image
# img2 = img.clone()
# xx = img[img2]
# with self.assertRaises(Exception):
# # different physical space
# img2.set_direction(img.direction*2)
# xx = img[img2]

def test__setitem__(self):
#self.setUp()
for img in self.imgs:
if img.dimension == 2:
img[6:9,6:9] = 6.9
Expand Down
Loading
Loading