Skip to content

Commit 3004ff7

Browse files
committed
Add tests for inplace plans
1 parent 315b9ae commit 3004ff7

File tree

1 file changed

+82
-59
lines changed

1 file changed

+82
-59
lines changed

src/TestUtils.jl

+82-59
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,22 @@ using LinearAlgebra
1111
using Test
1212

1313
"""
14-
test_fft_backend(array_constructor)
14+
test_fft_backend(array_constructor; test_real=true, test_inplace=true)
1515
16-
Run tests to verify correctness of all FFT functions based on a particular
16+
Run tests to verify correctness of FFT functions using a particular
1717
backend plan implementation. The backend implementation is assumed to be loaded
1818
prior to calling this function.
1919
20-
The input `array_constructor` determines the `AbstractArray` implementation for
20+
# Arguments
21+
22+
- `array_constructor`: determines the `AbstractArray` implementation for
2123
which the correctness tests are run. It is assumed to be a callable object that
2224
takes in input arrays of type `Array` and return arrays of the desired type for
2325
testing: this would most commonly be a constructor such as `Array` or `CuArray`.
26+
- `test_real=true`: whether to test real-to-complex and complex-to-real FFTs.
27+
- `test_inplace=true`: whether to test in-place plans.
2428
"""
25-
function test_fft_backend(array_constructor)
29+
function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
2630
@testset "fft correctness" begin
2731
# DFT along last dimension, results computed using FFTW
2832
for (_x, _fftw_fft) in (
@@ -51,82 +55,101 @@ function test_fft_backend(array_constructor)
5155
15.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im;
5256
18.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im]),
5357
)
54-
x = array_constructor(_x)
55-
xcopy_float = array_constructor(copy(float.(x)))
56-
xcopy_complex = array_constructor(copy(complex.(xcopy_float)))
58+
x = array_constructor(_x) # dummy array that will be passed to plans
59+
x_real = float.(x) # for testing real FFTs
60+
x_complex = complex.(x_real) # for testing complex FFTs
5761
fftw_fft = array_constructor(_fftw_fft)
5862

63+
dims = ndims(x) # TODO: this is a single dimension, should check multidimensional FFTs too
64+
5965
# FFT
60-
dims = ndims(x)
61-
y = AbstractFFTs.fft(x, dims)
62-
ycopy = array_constructor(copy(y))
66+
y = AbstractFFTs.fft(x_complex, dims)
6367
@test y fftw_fft
68+
test_inplace && (@test AbstractFFTs.fft!(copy(x_complex), dims) fftw_fft)
6469
# test plan_fft and also inv and plan_inv of plan_ifft, which should all give
6570
# functionally identical plans
66-
for P in [plan_fft(x, dims), inv(plan_ifft(x, dims)),
67-
AbstractFFTs.plan_inv(plan_ifft(x, dims))]
71+
plans_to_test = [plan_fft(x, dims), inv(plan_ifft(x, dims)),
72+
AbstractFFTs.plan_inv(plan_ifft(x, dims))]
73+
for P in plans_to_test
74+
@test mul!(similar(y), P, x_complex) fftw_fft
75+
end
76+
test_inplace && (plans_to_test = vcat(plans_to_test, plan_fft!(similar(x_complex), dims)))
77+
for P in plans_to_test
6878
@test eltype(P) <: Complex
69-
@test P * x fftw_fft
70-
@test mul!(ycopy, P, x) fftw_fft
71-
@test P \ (P * x) x
79+
@test P * copy(x_complex) fftw_fft
80+
@test P \ (P * copy(x_complex)) x_complex
7281
@test fftdims(P) == dims
7382
end
7483

7584
# BFFT
76-
fftw_bfft = complex.(size(x, dims) .* x)
85+
fftw_bfft = size(x_complex, dims) .* x_complex
7786
@test AbstractFFTs.bfft(y, dims) fftw_bfft
78-
P = plan_bfft(x, dims)
79-
@test P * y fftw_bfft
80-
@test P \ (P * y) y
81-
@test mul!(xcopy_complex, P, y) fftw_bfft
82-
@test fftdims(P) == dims
87+
test_inplace && (@test AbstractFFTs.bfft!(copy(y), dims) fftw_bfft)
88+
plans_to_test = [plan_bfft(similar(y), dims)]
89+
for P in plans_to_test
90+
@test mul!(similar(x_complex), P, y) fftw_bfft
91+
end
92+
test_inplace && (plans_to_test = vcat(plans_to_test, plan_bfft!(similar(y), dims)))
93+
for P in plans_to_test
94+
@test eltype(P) <: Complex
95+
@test P * copy(y) fftw_bfft
96+
@test P \ (P * copy(y)) y
97+
@test fftdims(P) == dims
98+
end
8399

84100
# IFFT
85-
fftw_ifft = complex.(x)
101+
fftw_ifft = x_complex
86102
@test AbstractFFTs.ifft(y, dims) fftw_ifft
87-
for P in [plan_ifft(x, dims), inv(plan_fft(x, dims)),
88-
AbstractFFTs.plan_inv(plan_fft(x, dims))]
89-
@test P * y fftw_ifft
90-
@test mul!(xcopy_complex, P, y) fftw_ifft
91-
@test P \ (P * y) y
92-
@test fftdims(P) == dims
103+
test_inplace && (@test AbstractFFTs.ifft!(copy(y), dims) fftw_ifft)
104+
plans_to_test = [plan_ifft(x, dims), inv(plan_fft(x, dims)),
105+
AbstractFFTs.plan_inv(plan_fft(x, dims))]
106+
for P in plans_to_test
107+
@test mul!(similar(x_complex), P, y) fftw_ifft
93108
end
94-
95-
# RFFT
96-
fftw_rfft = fftw_fft[
97-
(Colon() for _ in 1:(ndims(fftw_fft) - 1))...,
98-
1:(size(fftw_fft, ndims(fftw_fft)) ÷ 2 + 1)
99-
]
100-
ry = AbstractFFTs.rfft(x, dims)
101-
rycopy = array_constructor(copy(ry))
102-
@test ry fftw_rfft
103-
for P in [plan_rfft(x, dims), inv(plan_irfft(ry, size(x, dims), dims)),
104-
AbstractFFTs.plan_inv(plan_irfft(ry, size(x, dims), dims))]
105-
@test eltype(P) <: Real
106-
@test P * x fftw_rfft
107-
@test mul!(rycopy, P, x) fftw_rfft
108-
@test P \ (P * x) x
109+
test_inplace && (plan_to_test = vcat(plans_to_test, plan_ifft!(similar(x_complex), dims)))
110+
for P in plans_to_test
111+
@test eltype(P) <: Complex
112+
@test P * copy(y) fftw_ifft
113+
@test P \ (P * copy(y)) y
109114
@test fftdims(P) == dims
110115
end
111116

112-
# BRFFT
113-
fftw_brfft = complex.(size(x, dims) .* x)
114-
@test AbstractFFTs.brfft(ry, size(x, dims), dims) fftw_brfft
115-
P = plan_brfft(ry, size(x, dims), dims)
116-
@test P * ry fftw_brfft
117-
@test mul!(xcopy_float, P, ry) fftw_brfft
118-
@test P \ (P * ry) ry
119-
@test fftdims(P) == dims
120-
121-
# IRFFT
122-
fftw_irfft = complex.(x)
123-
@test AbstractFFTs.irfft(ry, size(x, dims), dims) fftw_irfft
124-
for P in [plan_irfft(ry, size(x, dims), dims), inv(plan_rfft(x, dims)),
125-
AbstractFFTs.plan_inv(plan_rfft(x, dims))]
126-
@test P * ry fftw_irfft
127-
@test mul!(xcopy_float, P, ry) fftw_irfft
117+
if test_real
118+
# RFFT
119+
fftw_rfft = fftw_fft[
120+
(Colon() for _ in 1:(ndims(fftw_fft) - 1))...,
121+
1:(size(fftw_fft, ndims(fftw_fft)) ÷ 2 + 1)
122+
]
123+
ry = AbstractFFTs.rfft(x_real, dims)
124+
@test ry fftw_rfft
125+
for P in [plan_rfft(x_real, dims), inv(plan_irfft(ry, size(x, dims), dims)),
126+
AbstractFFTs.plan_inv(plan_irfft(ry, size(x, dims), dims))]
127+
@test eltype(P) <: Real
128+
@test P * x_real fftw_rfft
129+
@test mul!(similar(ry), P, x_real) fftw_rfft
130+
@test P \ (P * x_real) x_real
131+
@test fftdims(P) == dims
132+
end
133+
134+
# BRFFT
135+
fftw_brfft = complex.(size(x, dims) .* x_real)
136+
@test AbstractFFTs.brfft(ry, size(x_real, dims), dims) fftw_brfft
137+
P = plan_brfft(ry, size(x_real, dims), dims)
138+
@test P * ry fftw_brfft
139+
@test mul!(similar(x_real), P, ry) fftw_brfft
128140
@test P \ (P * ry) ry
129141
@test fftdims(P) == dims
142+
143+
# IRFFT
144+
fftw_irfft = x_complex
145+
@test AbstractFFTs.irfft(ry, size(x, dims), dims) fftw_irfft
146+
for P in [plan_irfft(ry, size(x, dims), dims), inv(plan_rfft(x_real, dims)),
147+
AbstractFFTs.plan_inv(plan_rfft(x_real, dims))]
148+
@test P * ry fftw_irfft
149+
@test mul!(similar(x_real), P, ry) fftw_irfft
150+
@test P \ (P * ry) ry
151+
@test fftdims(P) == dims
152+
end
130153
end
131154
end
132155
end

0 commit comments

Comments
 (0)