Skip to content

Commit abca3f2

Browse files
authored
Batched nufft (#24)
* First batched nufft implementation * Performance improvements * Fixes for toep * Type fix * Bug fix, doc updates * Fix dcomp it num * Update batch docs * Add new docs * Change the performance tips name * Try to change list * Increment version * Fix perf tips * Code quality, remove test script * Update performance doc * Update doc * Update docs
1 parent b5f6581 commit abca3f2

File tree

15 files changed

+967
-136
lines changed

15 files changed

+967
-136
lines changed

docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ References
6969
:caption: User Guide
7070

7171
basic
72+
performance
7273

7374
.. toctree::
7475
:hidden:

docs/source/performance.rst

+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
Performance Tips
2+
================
3+
4+
:py:mod:`torchkbnufft` is primarily written for the goal of scaling parallelism within
5+
the PyTorch framework. The performance bottleneck of the package comes from two sources:
6+
1) advanced indexing and 2) multiplications. Multiplications are handled in a way that
7+
scales well, but advanced indexing is not due to
8+
`limitations with PyTorch <https://github.com/pytorch/pytorch/issues/29973>`_.
9+
As a result, growth in problem size that is independent of the indexing bottleneck is
10+
handled very well by the package, such as:
11+
12+
1. Scaling the batch dimension.
13+
2. Scaling the coil dimension.
14+
15+
Generally, you can just add to these dimensions and the package will perform well
16+
without adding much compute time. If you're chasing more speed, some strategies that
17+
might be helpful are listed below.
18+
19+
Using Batched K-space Trajectories
20+
----------------------------------
21+
22+
As of version ``1.1.0``, :py:mod:`torchkbnufft` can use batched k-space trajectories.
23+
If you pass in a variable for ``omega`` with dimensions
24+
``(N, length(im_size), klength)``, the package will parallelize the execution of all
25+
trajectories in the ``N`` dimension. This is useful when ``N`` is very large, as might
26+
occur in dynamic imaging settings. The following shows an example:
27+
28+
.. code-block:: python
29+
30+
import torch
31+
import torchkbnufft as tkbn
32+
import numpy as np
33+
from skimage.data import shepp_logan_phantom
34+
35+
batch_size = 12
36+
37+
x = shepp_logan_phantom().astype(np.complex)
38+
im_size = x.shape
39+
# convert to tensor, unsqueeze batch and coil dimension
40+
# output size: (batch_size, 1, ny, nx)
41+
x = torch.tensor(x).unsqueeze(0).unsqueeze(0).to(torch.complex64)
42+
x = x.repeat(batch_size, 1, 1, 1)
43+
44+
klength = 64
45+
ktraj = np.stack(
46+
(np.zeros(64), np.linspace(-np.pi, np.pi, klength))
47+
)
48+
# convert to tensor, unsqueeze batch dimension
49+
# output size: (batch_size, 2, klength)
50+
ktraj = torch.tensor(ktraj).to(torch.float)
51+
ktraj = ktraj.unsqueeze(0).repeat(batch_size, 1, 1)
52+
53+
nufft_ob = tkbn.KbNufft(im_size=im_size)
54+
# outputs a (batch_size, 1, klength) vector of k-space data
55+
kdata = nufft_ob(x, ktraj)
56+
57+
This code will then compute the 12 different radial spokes while parallelizing as much
58+
as possible.
59+
60+
Lowering the Precision
61+
----------------------
62+
63+
A simple way to save both memory and compute time is to decrease the precision. PyTorch
64+
normally operates at a default 32-bit floating point precision, but if you're converting
65+
data from NumPy then you might have some data at 64-bit floating precision. To use
66+
32-bit precision, simply do the following:
67+
68+
.. code-block:: python
69+
70+
image = image.to(dtype=torch.complex64)
71+
ktraj = ktraj.to(dtype=torch.float32)
72+
forw_ob = forw_ob.to(image)
73+
74+
data = forw_ob(image, ktraj)
75+
76+
The ``forw_ob.to(image)`` command will automagically determine the type for both real
77+
and complex tensors registered as buffers under ``forw_ob``, so you should be able to
78+
do this safely in your code.
79+
80+
In many cases, the tradeoff for going from 64-bit to 32-bit is not severe, so you can
81+
securely use 32-bit precision.
82+
83+
Lowering the Oversampling Ratio
84+
-------------------------------
85+
86+
If you create a :py:class:`~torchkbnufft.KbNufft` object using the following code:
87+
88+
.. code-block:: python
89+
90+
forw_ob = tkbn.KbNufft(im_size=im_size)
91+
92+
then by default it will use a 2-factor oversampled grid. For some applications, this can
93+
be overkill. If you can sacrifice some accuracy for your application, you can use a
94+
smaller grid with 1.25-factor oversampling by altering how you initialize NUFFT objects
95+
like :py:class:`~torchkbnufft.KbNufft`:
96+
97+
.. code-block:: python
98+
99+
grid_size = tuple([int(el * 1.25) for el in im_size])
100+
forw_ob = tkbn.KbNufft(im_size=im_size, grid_size=grid_size)
101+
102+
Using Fewer Interpolation Neighbors
103+
-----------------------------------
104+
105+
Another major speed factor is how many neighbors you use for interpolation. By default,
106+
:py:mod:`torchkbnufft` uses 6 nearest neighbors in each dimension. If you can sacrifice
107+
accuracy, you can get more speed by using fewer neighbors by altering how you initialize
108+
NUFFT objects like :py:class:`~torchkbnufft.KbNufft`:
109+
110+
.. code-block:: python
111+
112+
forw_ob = tkbn.KbNufft(im_size=im_size, numpoints=4)
113+
114+
If you know that you can be less accurate in one dimension (e.g., the z-dimension), then
115+
you can use less neighbors in only that dimension:
116+
117+
.. code-block:: python
118+
119+
forw_ob = tkbn.KbNufft(im_size=im_size, numpoints=(4, 6, 6))
120+
121+
Package Limitations
122+
-------------------
123+
124+
As mentioned earlier, batches and coils scale well, primarily due to the fact that they
125+
don't impact the bottlenecks of the package around advanced indexing. Where
126+
:py:mod:`torchkbnufft` does not scale well is:
127+
128+
1. Very long k-space trajectories.
129+
2. More imaging dimensions (e.g., 3D).
130+
131+
For these settings, you can first try to use some of the strategies here (lowering
132+
precision, fewer neighbors, smaller grid). In some cases, lowering the precision a bit
133+
and using a GPU can still give strong performance. If you're still waiting too long for
134+
compute after trying all of these, you may be running into the limits of the package.

tests/test_dcomp.py

+38
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import pytest
23
import torch
34
import torchkbnufft as tkbn
@@ -37,3 +38,40 @@ def test_dcomp_run(shape, kdata_shape, is_complex):
3738
_ = adj_ob(kdata * dcomp, ktraj)
3839

3940
torch.set_default_dtype(default_dtype)
41+
42+
43+
@pytest.mark.parametrize(
44+
"shape, kdata_shape",
45+
[
46+
([2, 1, 19], [2, 1, 25]),
47+
([3, 1, 13], [3, 1, 18]),
48+
([6, 1, 32, 16], [6, 1, 83]),
49+
([5, 1, 15, 12], [5, 1, 83]),
50+
([3, 2, 13, 18, 12], [3, 2, 112]),
51+
([2, 2, 17, 19, 12], [2, 2, 112]),
52+
],
53+
)
54+
def test_batched_dcomp(shape, kdata_shape):
55+
default_dtype = torch.get_default_dtype()
56+
torch.set_default_dtype(torch.double)
57+
torch.manual_seed(123)
58+
im_size = shape[2:]
59+
60+
ktraj = (
61+
torch.rand(size=(shape[0], len(im_size), kdata_shape[2])) * 2 * np.pi - np.pi
62+
)
63+
64+
forloop_dcomp = []
65+
for ktraj_it in ktraj:
66+
res = tkbn.calc_density_compensation_function(ktraj=ktraj_it, im_size=im_size)
67+
forloop_dcomp.append(
68+
tkbn.calc_density_compensation_function(ktraj=ktraj_it, im_size=im_size)
69+
)
70+
71+
batched_dcomp = tkbn.calc_density_compensation_function(
72+
ktraj=ktraj, im_size=im_size
73+
)
74+
75+
assert torch.allclose(torch.cat(forloop_dcomp), batched_dcomp)
76+
77+
torch.set_default_dtype(default_dtype)

tests/test_interp.py

+49
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pickle
22

3+
import numpy as np
34
import pytest
45
import torch
56
import torchkbnufft as tkbn
@@ -277,3 +278,51 @@ def test_interp_autograd_gpu(shape, kdata_shape, is_complex):
277278
nufft_autograd_test(image, kdata, ktraj, forw_ob, adj_ob, spmat)
278279

279280
torch.set_default_dtype(default_dtype)
281+
282+
283+
@pytest.mark.parametrize(
284+
"shape, kdata_shape, is_complex",
285+
[
286+
([3, 1, 19], [3, 1, 25], True),
287+
([3, 1, 13, 2], [3, 1, 18, 2], False),
288+
([4, 1, 32, 16], [4, 1, 83], True),
289+
([5, 1, 15, 12, 2], [5, 1, 83, 2], False),
290+
([3, 2, 13, 18, 12], [3, 2, 112], True),
291+
([2, 2, 17, 19, 12, 2], [2, 2, 112, 2], False),
292+
],
293+
)
294+
def test_interp_batches(shape, kdata_shape, is_complex):
295+
default_dtype = torch.get_default_dtype()
296+
torch.set_default_dtype(torch.double)
297+
torch.manual_seed(123)
298+
if is_complex:
299+
im_size = shape[2:]
300+
else:
301+
im_size = shape[2:-1]
302+
303+
image = create_input_plus_noise(shape, is_complex)
304+
kdata = create_input_plus_noise(kdata_shape, is_complex)
305+
ktraj = (
306+
torch.rand(size=(shape[0], len(im_size), kdata_shape[2])) * 2 * np.pi - np.pi
307+
)
308+
309+
forw_ob = tkbn.KbInterp(im_size=im_size, grid_size=im_size)
310+
adj_ob = tkbn.KbInterpAdjoint(im_size=im_size, grid_size=im_size)
311+
312+
forloop_test_forw = []
313+
for image_it, ktraj_it in zip(image, ktraj):
314+
forloop_test_forw.append(forw_ob(image_it.unsqueeze(0), ktraj_it))
315+
316+
batched_test_forw = forw_ob(image, ktraj)
317+
318+
assert torch.allclose(torch.cat(forloop_test_forw), batched_test_forw)
319+
320+
forloop_test_adj = []
321+
for data_it, ktraj_it in zip(kdata, ktraj):
322+
forloop_test_adj.append(adj_ob(data_it.unsqueeze(0), ktraj_it))
323+
324+
batched_test_adj = adj_ob(kdata, ktraj)
325+
326+
assert torch.allclose(torch.cat(forloop_test_adj), batched_test_adj)
327+
328+
torch.set_default_dtype(default_dtype)

tests/test_toep.py

+59
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import pytest
23
import torch
34
import torchkbnufft as tkbn
@@ -37,6 +38,64 @@ def test_toeplitz_nufft_accuracy(shape, kdata_shape, is_complex):
3738
toep_ob = tkbn.ToepNufft()
3839

3940
kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size, norm="ortho")
41+
if not is_complex:
42+
kernel = torch.view_as_real(kernel)
43+
44+
fbn = adj_ob(
45+
forw_ob(image, ktraj, smaps=smaps, norm="ortho"),
46+
ktraj,
47+
smaps=smaps,
48+
norm="ortho",
49+
)
50+
fbt = toep_ob(image, kernel, smaps=smaps, norm="ortho")
51+
52+
if is_complex:
53+
fbn = torch.view_as_real(fbn)
54+
fbt = torch.view_as_real(fbt)
55+
56+
norm_diff = torch.norm(fbn - fbt) / torch.norm(fbn)
57+
58+
assert norm_diff < norm_diff_tol
59+
60+
torch.set_default_dtype(default_dtype)
61+
62+
63+
@pytest.mark.parametrize(
64+
"shape, kdata_shape, is_complex",
65+
[
66+
([4, 3, 19], [4, 3, 25], True),
67+
([3, 5, 13, 2], [3, 5, 18, 2], False),
68+
([2, 4, 32, 16], [2, 4, 83], True),
69+
([5, 8, 15, 12, 2], [5, 8, 83, 2], False),
70+
([3, 10, 13, 18, 12], [3, 10, 112], True),
71+
([2, 12, 17, 19, 12, 2], [2, 12, 112, 2], False),
72+
],
73+
)
74+
def test_batched_toeplitz_nufft_accuracy(shape, kdata_shape, is_complex):
75+
norm_diff_tol = 1e-4 # toeplitz is only approximate
76+
default_dtype = torch.get_default_dtype()
77+
torch.set_default_dtype(torch.double)
78+
torch.manual_seed(123)
79+
if is_complex:
80+
im_size = shape[2:]
81+
else:
82+
im_size = shape[2:-1]
83+
im_shape = [s for s in shape]
84+
im_shape[1] = 1
85+
86+
image = create_input_plus_noise(im_shape, is_complex)
87+
smaps = create_input_plus_noise(shape, is_complex)
88+
ktraj = (
89+
torch.rand(size=(shape[0], len(im_size), kdata_shape[2])) * 2 * np.pi - np.pi
90+
)
91+
92+
forw_ob = tkbn.KbNufft(im_size=im_size)
93+
adj_ob = tkbn.KbNufftAdjoint(im_size=im_size)
94+
toep_ob = tkbn.ToepNufft()
95+
96+
kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size, norm="ortho")
97+
if not is_complex:
98+
kernel = torch.view_as_real(kernel)
4099

41100
fbn = adj_ob(
42101
forw_ob(image, ktraj, smaps=smaps, norm="ortho"),

torchkbnufft/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Package info"""
22

3-
__version__ = "1.0.1"
3+
__version__ = "1.1.0"
44
__author__ = "Matthew Muckley"
55
__author_email__ = "[email protected]"
66
__license__ = "MIT"

torchkbnufft/_nufft/dcomp.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ def calc_density_compensation_function(
2323
This function has optional parameters for initializing a NUFFT object. See
2424
:py:class:`~torchkbnufft.KbInterp` for details.
2525
26-
* :attr:`ktraj` should be of size ``(len(im_size), klength)``,
27-
where ``klength`` is the length of the k-space trajectory.
26+
* :attr:`ktraj` should be of size ``(len(grid_size), klength)`` or
27+
``(N, len(grid_size), klength)``, where ``klength`` is the length of the
28+
k-space trajectory.
2829
2930
Based on the `method of Pipe
3031
<https://doi.org/10.1002/(SICI)1522-2594(199901)41:1%3C179::AID-MRM25%3E3.0.CO;2-V>`_.
@@ -56,6 +57,16 @@ def calc_density_compensation_function(
5657
>>> image = adjkb_ob(data * dcomp, omega)
5758
"""
5859
device = ktraj.device
60+
batch_size = 1
61+
62+
if ktraj.ndim not in (2, 3):
63+
raise ValueError("ktraj must have 2 or 3 dimensions")
64+
65+
if ktraj.ndim == 3:
66+
if ktraj.shape[0] == 1:
67+
ktraj = ktraj[0]
68+
else:
69+
batch_size = ktraj.shape[0]
5970

6071
# init nufft variables
6172
(
@@ -80,10 +91,12 @@ def calc_density_compensation_function(
8091
device=device,
8192
)
8293

83-
test_sig = torch.ones([1, 1, ktraj.shape[-1]], dtype=tables[0].dtype, device=device)
94+
test_sig = torch.ones(
95+
[batch_size, 1, ktraj.shape[-1]], dtype=tables[0].dtype, device=device
96+
)
8497
for _ in range(num_iterations):
8598
new_sig = tkbnF.kb_table_interp(
86-
tkbnF.kb_table_interp_adjoint(
99+
image=tkbnF.kb_table_interp_adjoint(
87100
data=test_sig,
88101
omega=ktraj,
89102
tables=tables,
@@ -101,7 +114,6 @@ def calc_density_compensation_function(
101114
offsets=offsets_t,
102115
)
103116

104-
norm_new_sig = torch.abs(new_sig)
105-
test_sig = test_sig / norm_new_sig
117+
test_sig = test_sig / torch.abs(new_sig)
106118

107119
return test_sig

torchkbnufft/_nufft/fft.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,9 @@ def fft_filter(image: Tensor, kernel: Tensor, norm: Optional[str] = "ortho") ->
146146
raise ValueError("Only option for norm is 'ortho'.")
147147

148148
im_size = torch.tensor(image.shape[2:], dtype=torch.long, device=image.device)
149-
grid_size = torch.tensor(kernel.shape[2:], dtype=torch.long, device=image.device)
149+
grid_size = torch.tensor(
150+
kernel.shape[-len(image.shape[2:]) :], dtype=torch.long, device=image.device
151+
)
150152

151153
# set up n-dimensional zero pad
152154
# zero pad for oversampled nufft

0 commit comments

Comments
 (0)