Skip to content

Commit 45dc7aa

Browse files
authored
Add PyTorch 1.7 FFT compatibility (#11)
1 parent 178ba4b commit 45dc7aa

File tree

3 files changed

+57
-8
lines changed

3 files changed

+57
-8
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import torch
2+
from torch import Tensor
3+
from packaging import version
4+
5+
if version.parse(torch.__version__) >= version.parse("1.7.0"):
6+
import torch.fft # type: ignore
7+
8+
9+
def fft_old(image: Tensor, ndim: int, normalized: bool = False) -> Tensor:
10+
return torch.fft(image, ndim, normalized)
11+
12+
13+
def ifft_old(image: Tensor, ndim: int, normalized: bool = False) -> Tensor:
14+
return torch.ifft(image, ndim, normalized)
15+
16+
17+
def fft_new(image: Tensor, ndim: int, normalized: bool = False) -> Tensor:
18+
norm = "ortho" if normalized else None
19+
dims = tuple(range(-ndim, 0))
20+
21+
image = torch.view_as_real(
22+
torch.fft.fftn( # type: ignore
23+
torch.view_as_complex(image.contiguous()), dim=dims, norm=norm
24+
)
25+
)
26+
27+
return image
28+
29+
30+
def ifft_new(image: Tensor, ndim: int, normalized: bool = False) -> Tensor:
31+
norm = "ortho" if normalized else None
32+
dims = tuple(range(-ndim, 0))
33+
image = torch.view_as_real(
34+
torch.fft.ifftn( # type: ignore
35+
torch.view_as_complex(image.contiguous()), dim=dims, norm=norm
36+
)
37+
)
38+
39+
return image

torchkbnufft/nufft/fft_functions.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
import numpy as np
22
import torch
33
import torch.nn.functional as F
4+
from packaging import version
45

56
from ..math import complex_mult, conj_complex_mult
67

8+
if version.parse(torch.__version__) >= version.parse("1.7.0"):
9+
from .fft_compatibility import fft_new as fft_fn
10+
from .fft_compatibility import ifft_new as ifft_fn
11+
else:
12+
from .fft_compatibility import fft_old as fft_fn
13+
from .fft_compatibility import ifft_old as ifft_fn
14+
715

816
def scale_and_fft_on_image_volume(x, scaling_coef, grid_size, im_size, norm):
917
"""Applies the FFT and any relevant scaling factors to x.
@@ -43,7 +51,7 @@ def scale_and_fft_on_image_volume(x, scaling_coef, grid_size, im_size, norm):
4351
# zero pad and fft
4452
x = F.pad(x, pad_sizes)
4553
x = x.permute(permute_dims)
46-
x = torch.fft(x, grid_size.numel())
54+
x = fft_fn(x, grid_size.numel())
4755
if norm == "ortho":
4856
x = x / torch.sqrt(torch.prod(grid_size))
4957
x = x.permute(inv_permute_dims)
@@ -78,7 +86,7 @@ def ifft_and_scale_on_gridded_data(x, scaling_coef, grid_size, im_size, norm):
7886

7987
# do the inverse fft
8088
x = x.permute(permute_dims)
81-
x = torch.ifft(x, grid_size.numel())
89+
x = ifft_fn(x, grid_size.numel())
8290
x = x.permute(inv_permute_dims)
8391

8492
# crop to output size
@@ -140,7 +148,7 @@ def fft_filter(x, kern, norm=None):
140148
# zero pad and fft
141149
x = F.pad(x, pad_sizes)
142150
x = x.permute(permute_dims)
143-
x = torch.fft(x, grid_size.numel())
151+
x = fft_fn(x, grid_size.numel())
144152
if norm == "ortho":
145153
x = x / torch.sqrt(torch.prod(grid_size.to(torch.double)))
146154
x = x.permute(inv_permute_dims)
@@ -150,7 +158,7 @@ def fft_filter(x, kern, norm=None):
150158

151159
# inverse fft
152160
x = x.permute(permute_dims)
153-
x = torch.ifft(x, grid_size.numel())
161+
x = ifft_fn(x, grid_size.numel())
154162
x = x.permute(inv_permute_dims)
155163

156164
# crop to input size

torchkbnufft/nufft/toep_functions.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33

44
import numpy as np
55
import torch
6+
from packaging import version
67

7-
from ..math import absolute
8+
if version.parse(torch.__version__) >= version.parse("1.7.0"):
9+
from .fft_compatibility import fft_new as fft_fn
10+
else:
11+
from .fft_compatibility import fft_old as fft_fn
812

913

1014
def calc_toep_kernel(adj_ob, om, weights=None):
@@ -117,9 +121,7 @@ def _get_kern(om, weights, flip_list, base_flip, adj_ob):
117121
inv_permute_dims = (0, 1, kern.ndim - 1) + tuple(range(2, kern.ndim - 1))
118122

119123
# put the kernel in fft space
120-
kern = torch.fft(kern.permute(permute_dims), kern.ndim - 3).permute(
121-
inv_permute_dims
122-
)
124+
kern = fft_fn(kern.permute(permute_dims), kern.ndim - 3).permute(inv_permute_dims)
123125

124126
if adj_ob.norm == "ortho":
125127
kern = kern / torch.sqrt(torch.prod(torch.tensor(kern.shape[3:], dtype=dtype)))

0 commit comments

Comments
 (0)