Skip to content

Commit 6f64876

Browse files
committed
Added support for different Kaiser-Bessel orders.
1 parent 06a759d commit 6f64876

12 files changed

+583
-256
lines changed

docs/source/torchkbnufft.rst

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,3 @@ torchkbnufft.math
1313
:members:
1414
:undoc-members:
1515
:show-inheritance:
16-
17-
torchkbnufft.nufft\_utils
18-
-------------------------
19-
20-
.. automodule:: torchkbnufft.nufft_utils
21-
:members:
22-
:undoc-members:
23-
:show-inheritance:

notebooks/Basic Example.ipynb

Lines changed: 77 additions & 8 deletions
Large diffs are not rendered by default.

notebooks/SENSE Example.ipynb

Lines changed: 51 additions & 5 deletions
Large diffs are not rendered by default.

notebooks/Sparse Matrix Example.ipynb

Lines changed: 55 additions & 9 deletions
Large diffs are not rendered by default.

tests/test_kb_construction.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import torch
2+
import numpy as np
3+
4+
from torchkbnufft import (AdjKbNufft, AdjMriSenseNufft, KbInterpBack,
5+
KbInterpForw, KbNufft, MriSenseNufft)
6+
7+
norm_tol = 1e-10
8+
9+
10+
def test_kb_matching():
11+
def check_tables(table1, table2):
12+
for ind, table in enumerate(table1):
13+
assert np.linalg.norm(table - table2[ind]) < norm_tol
14+
15+
im_szs = [(256, 256), (10, 256, 256)]
16+
17+
kbwidths = [2.34, 5]
18+
orders = [0, 2]
19+
20+
for kbwidth in kbwidths:
21+
for order in orders:
22+
for im_sz in im_szs:
23+
smap = torch.randn(*((1,) + im_sz))
24+
25+
base_table = AdjKbNufft(
26+
im_sz, order=order, kbwidth=kbwidth).table
27+
28+
cur_table = KbNufft(im_sz, order=order, kbwidth=kbwidth).table
29+
check_tables(base_table, cur_table)
30+
31+
cur_table = KbInterpBack(
32+
im_sz, order=order, kbwidth=kbwidth).table
33+
check_tables(base_table, cur_table)
34+
35+
cur_table = KbInterpForw(
36+
im_sz, order=order, kbwidth=kbwidth).table
37+
check_tables(base_table, cur_table)
38+
39+
cur_table = MriSenseNufft(
40+
smap, im_sz, order=order, kbwidth=kbwidth).table
41+
check_tables(base_table, cur_table)
42+
43+
cur_table = AdjMriSenseNufft(
44+
smap, im_sz, order=order, kbwidth=kbwidth).table
45+
check_tables(base_table, cur_table)
46+
47+
48+
def test_2d_init_inputs():
49+
# all object initializations have assertions
50+
# this should result in an error if any dimensions don't match
51+
52+
# test 2d scalar inputs
53+
im_sz = (256, 256)
54+
smap = torch.randn(*((1,) + im_sz))
55+
grid_sz = (512, 512)
56+
n_shift = (128, 128)
57+
numpoints = 6
58+
table_oversamp = 2**10
59+
kbwidth = 2.34
60+
order = 0
61+
norm = 'None'
62+
63+
ob = KbInterpForw(
64+
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
65+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order)
66+
ob = KbInterpBack(
67+
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
68+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order)
69+
70+
ob = KbNufft(
71+
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
72+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
73+
ob = AdjKbNufft(
74+
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
75+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
76+
77+
ob = MriSenseNufft(
78+
smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
79+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
80+
ob = AdjMriSenseNufft(
81+
smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
82+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
83+
84+
# test 2d tuple inputs
85+
im_sz = (256, 256)
86+
smap = torch.randn(*((1,) + im_sz))
87+
grid_sz = (512, 512)
88+
n_shift = (128, 128)
89+
numpoints = (6, 6)
90+
table_oversamp = (2**10, 2**10)
91+
kbwidth = (2.34, 2.34)
92+
order = (0, 0)
93+
norm = 'None'
94+
95+
ob = KbInterpForw(
96+
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
97+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order)
98+
ob = KbInterpBack(
99+
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
100+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order)
101+
102+
ob = KbNufft(
103+
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
104+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
105+
ob = AdjKbNufft(
106+
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
107+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
108+
109+
ob = MriSenseNufft(
110+
smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
111+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
112+
ob = AdjMriSenseNufft(
113+
smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
114+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
115+
116+
117+
def test_3d_init_inputs():
118+
# all object initializations have assertions
119+
# this should result in an error if any dimensions don't match
120+
121+
# test 3d scalar inputs
122+
im_sz = (10, 256, 256)
123+
smap = torch.randn(*((1,) + im_sz))
124+
grid_sz = (10, 512, 512)
125+
n_shift = (5, 128, 128)
126+
numpoints = 6
127+
table_oversamp = 2**10
128+
kbwidth = 2.34
129+
order = 0
130+
norm = 'None'
131+
132+
ob = KbInterpForw(
133+
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
134+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order)
135+
ob = KbInterpBack(
136+
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
137+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order)
138+
139+
ob = KbNufft(
140+
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
141+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
142+
ob = AdjKbNufft(
143+
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
144+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
145+
146+
ob = MriSenseNufft(
147+
smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
148+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
149+
ob = AdjMriSenseNufft(
150+
smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
151+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
152+
153+
# test 3d tuple inputs
154+
im_sz = (10, 256, 256)
155+
smap = torch.randn(*((1,) + im_sz))
156+
grid_sz = (10, 512, 512)
157+
n_shift = (5, 128, 128)
158+
numpoints = (6, 6, 6)
159+
table_oversamp = (2**10, 2**10, 2**10)
160+
kbwidth = (2.34, 2.34, 2.34)
161+
order = (0, 0, 0)
162+
norm = 'None'
163+
164+
ob = KbInterpForw(
165+
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
166+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order)
167+
ob = KbInterpBack(
168+
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
169+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order)
170+
171+
ob = KbNufft(
172+
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
173+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
174+
ob = AdjKbNufft(
175+
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
176+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
177+
178+
ob = MriSenseNufft(
179+
smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
180+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
181+
ob = AdjMriSenseNufft(
182+
smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
183+
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)

torchkbnufft/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Package info"""
22

3-
__version__ = '0.2.0'
3+
__version__ = '0.2.1'
44
__author__ = 'Matthew Muckley'
55
__author_email__ = '[email protected]'
66
__license__ = 'MIT'
@@ -21,6 +21,7 @@
2121
from .kbinterp import KbInterpBack, KbInterpForw
2222
from .kbnufft import KbNufft, AdjKbNufft
2323
from .mrisensenufft import MriSenseNufft, AdjMriSenseNufft
24+
from .nufft import utils as nufft_utils
2425

2526
__all__ = [
2627
'KbInterpForw',

torchkbnufft/functional/kbinterp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class KbInterpFunction(Function):
88
def forward(ctx, x, om, interpob, interp_mats=None):
99
"""Apply table interpolation.
1010
11-
This is a wrapper for mri.interp_functions.kbinterp for PyTorch autograd.
11+
This is a wrapper for nufft.interp_functions.kbinterp for PyTorch autograd.
1212
"""
1313
y = kbinterp(x, om, interpob, interp_mats)
1414

@@ -22,7 +22,7 @@ def forward(ctx, x, om, interpob, interp_mats=None):
2222
def backward(ctx, y):
2323
"""Apply table interpolation adjoint for gradient calculation.
2424
25-
This is a wrapper for mri.interp_functions.adjkbinterp for PyTorch autograd.
25+
This is a wrapper for nufft.interp_functions.adjkbinterp for PyTorch autograd.
2626
"""
2727
om, = ctx.saved_tensors
2828
interpob = ctx.interpob
@@ -38,7 +38,7 @@ class AdjKbInterpFunction(Function):
3838
def forward(ctx, y, om, interpob, interp_mats=None):
3939
"""Apply table interpolation adjoint.
4040
41-
This is a wrapper for mri.interp_functions.adjkbinterp for PyTorch autograd.
41+
This is a wrapper for nufft.interp_functions.adjkbinterp for PyTorch autograd.
4242
"""
4343
x = adjkbinterp(y, om, interpob, interp_mats)
4444

@@ -52,7 +52,7 @@ def forward(ctx, y, om, interpob, interp_mats=None):
5252
def backward(ctx, x):
5353
"""Apply table interpolation for gradient calculation.
5454
55-
This is a wrapper for mri.interp_functions.kbinterp for PyTorch autograd.
55+
This is a wrapper for nufft.interp_functions.kbinterp for PyTorch autograd.
5656
"""
5757
om, = ctx.saved_tensors
5858
interpob = ctx.interpob

0 commit comments

Comments
 (0)