Skip to content

Commit 313a180

Browse files
Put tests of FFT backends into TestUtils submodule (#78)
* Add TestUtils submodule/extension * Fix typo * Support Julia 1.0 * Add missing test deps * Add adjoint testing to test utilities * Remove mul! method from inplace test plan (consistent with fftw) * Fix typo * Document test utilities * Apply code review suggestions and refactor TestUtils * Support Julia 1.0 * Reorder kwargs in doc string * Also explicitly test AbstractFFTs.plan_inv * Lift isdefined checks out of __init__ * Update src/definitions.jl Co-authored-by: David Widmann <[email protected]> * Note TestUtils is a weak extension * Update function names in error handler * Add missing test_adjoint's for BRFFT, IRFFT * Collect x_rfft so as to not hit #112 --------- Co-authored-by: David Widmann <[email protected]>
1 parent 5c23f4b commit 313a180

8 files changed

+384
-199
lines changed

Project.toml

+3
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@ version = "1.4.0"
55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
89

910
[weakdeps]
1011
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
12+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1113

1214
[extensions]
1315
AbstractFFTsChainRulesCoreExt = "ChainRulesCore"
16+
AbstractFFTsTestExt = "Test"
1417

1518
[compat]
1619
ChainRulesCore = "1"

docs/src/implementations.md

+15
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,18 @@ To define a new FFT implementation in your own module, you should
3939

4040
The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of
4141
length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``.
42+
43+
## Testing implementations
44+
45+
`AbstractFFTs.jl` provides an experimental `TestUtils` module to help with testing downstream implementations,
46+
available as a [weak extension](https://pkgdocs.julialang.org/v1.9/creating-packages/#Conditional-loading-of-code-in-packages-(Extensions)) of `Test`.
47+
The following functions test that all FFT functionality has been correctly implemented:
48+
```@docs
49+
AbstractFFTs.TestUtils.test_complex_ffts
50+
AbstractFFTs.TestUtils.test_real_ffts
51+
```
52+
`TestUtils` also exposes lower level functions for generically testing particular plans:
53+
```@docs
54+
AbstractFFTs.TestUtils.test_plan
55+
AbstractFFTs.TestUtils.test_plan_adjoint
56+
```

ext/AbstractFFTsTestExt.jl

+232
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license
2+
3+
module AbstractFFTsTestExt
4+
5+
using AbstractFFTs
6+
using AbstractFFTs: TestUtils
7+
using AbstractFFTs.LinearAlgebra
8+
using Test
9+
10+
# Ground truth x_fft computed using FFTW library
11+
const TEST_CASES = (
12+
(; x = collect(1:7), dims = 1,
13+
x_fft = [28.0 + 0.0im,
14+
-3.5 + 7.267824888003178im,
15+
-3.5 + 2.7911568610884143im,
16+
-3.5 + 0.7988521603655248im,
17+
-3.5 - 0.7988521603655248im,
18+
-3.5 - 2.7911568610884143im,
19+
-3.5 - 7.267824888003178im]),
20+
(; x = collect(1:8), dims = 1,
21+
x_fft = [36.0 + 0.0im,
22+
-4.0 + 9.65685424949238im,
23+
-4.0 + 4.0im,
24+
-4.0 + 1.6568542494923806im,
25+
-4.0 + 0.0im,
26+
-4.0 - 1.6568542494923806im,
27+
-4.0 - 4.0im,
28+
-4.0 - 9.65685424949238im]),
29+
(; x = collect(reshape(1:8, 2, 4)), dims = 2,
30+
x_fft = [16.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im;
31+
20.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im]),
32+
(; x = collect(reshape(1:9, 3, 3)), dims = 2,
33+
x_fft = [12.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im;
34+
15.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im;
35+
18.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im]),
36+
(; x = collect(reshape(1:8, 2, 2, 2)), dims = 1:2,
37+
x_fft = cat([10.0 + 0.0im -4.0 + 0.0im; -2.0 + 0.0im 0.0 + 0.0im],
38+
[26.0 + 0.0im -4.0 + 0.0im; -2.0 + 0.0im 0.0 + 0.0im],
39+
dims=3)),
40+
(; x = collect(1:7) + im * collect(8:14), dims = 1,
41+
x_fft = [28.0 + 77.0im,
42+
-10.76782488800318 + 3.767824888003175im,
43+
-6.291156861088416 - 0.7088431389115883im,
44+
-4.298852160365525 - 2.7011478396344746im,
45+
-2.7011478396344764 - 4.298852160365524im,
46+
-0.7088431389115866 - 6.291156861088417im,
47+
3.767824888003177 - 10.76782488800318im]),
48+
(; x = collect(reshape(1:8, 2, 2, 2)) + im * reshape(9:16, 2, 2, 2), dims = 1:2,
49+
x_fft = cat([10.0 + 42.0im -4.0 - 4.0im; -2.0 - 2.0im 0.0 + 0.0im],
50+
[26.0 + 58.0im -4.0 - 4.0im; -2.0 - 2.0im 0.0 + 0.0im],
51+
dims=3)),
52+
)
53+
54+
55+
function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transformed::AbstractArray; inplace_plan=false, copy_input=false)
56+
_copy = copy_input ? copy : identity
57+
if !inplace_plan
58+
@test P * _copy(x) x_transformed
59+
@test P \ (P * _copy(x)) x
60+
_x_out = similar(P * _copy(x))
61+
@test mul!(_x_out, P, _copy(x)) x_transformed
62+
@test _x_out x_transformed
63+
else
64+
_x = copy(x)
65+
@test P * _copy(_x) x_transformed
66+
@test _x x_transformed
67+
@test P \ _copy(_x) x
68+
@test _x x
69+
end
70+
end
71+
72+
function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; real_plan=false, copy_input=false)
73+
_copy = copy_input ? copy : identity
74+
y = rand(eltype(P * _copy(x)), size(P * _copy(x)))
75+
# test basic properties
76+
@test_skip eltype(P') === typeof(y) # (AbstractFFTs.jl#110)
77+
@test (P')' === P # test adjoint of adjoint
78+
@test size(P') == AbstractFFTs.output_size(P) # test size of adjoint
79+
# test correctness of adjoint and its inverse via the dot test
80+
if !real_plan
81+
@test dot(y, P * _copy(x)) dot(P' * _copy(y), x)
82+
@test dot(y, P \ _copy(x)) dot(P' \ _copy(y), x)
83+
else
84+
_component_dot(x, y) = dot(real.(x), real.(y)) + dot(imag.(x), imag.(y))
85+
@test _component_dot(y, P * _copy(x)) _component_dot(P' * _copy(y), x)
86+
@test _component_dot(x, P \ _copy(y)) _component_dot(P' \ _copy(x), y)
87+
end
88+
@test_throws MethodError mul!(x, P', y)
89+
end
90+
91+
function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_adjoint=true)
92+
@testset "correctness of fft, bfft, ifft" begin
93+
for test_case in TEST_CASES
94+
_x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft)
95+
x = convert(ArrayType, _x) # dummy array that will be passed to plans
96+
x_complexf = convert(ArrayType, complex.(float.(x))) # for testing mutating complex FFTs
97+
x_fft = convert(ArrayType, _x_fft)
98+
99+
# FFT
100+
@test fft(x, dims) x_fft
101+
if test_inplace
102+
_x_complexf = copy(x_complexf)
103+
@test fft!(_x_complexf, dims) x_fft
104+
@test _x_complexf x_fft
105+
end
106+
# test OOP plans, checking plan_fft and also inv and plan_inv of plan_ifft,
107+
# which should give functionally identical plans
108+
for P in (plan_fft(similar(x_complexf), dims),
109+
(_inv(plan_ifft(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
110+
@test eltype(P) <: Complex
111+
@test fftdims(P) == dims
112+
TestUtils.test_plan(P, x_complexf, x_fft)
113+
if test_adjoint
114+
@test fftdims(P') == fftdims(P)
115+
TestUtils.test_plan_adjoint(P, x_complexf)
116+
end
117+
end
118+
if test_inplace
119+
# test IIP plans
120+
for P in (plan_fft!(similar(x_complexf), dims),
121+
(_inv(plan_ifft!(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
122+
TestUtils.test_plan(P, x_complexf, x_fft; inplace_plan=true)
123+
end
124+
end
125+
126+
# BFFT
127+
x_scaled = prod(size(x, d) for d in dims) .* x
128+
@test bfft(x_fft, dims) x_scaled
129+
if test_inplace
130+
_x_fft = copy(x_fft)
131+
@test bfft!(_x_fft, dims) x_scaled
132+
@test _x_fft x_scaled
133+
end
134+
# test OOP plans. Just 1 plan to test, but we use a for loop for consistent style
135+
for P in (plan_bfft(similar(x_fft), dims),)
136+
@test eltype(P) <: Complex
137+
@test fftdims(P) == dims
138+
TestUtils.test_plan(P, x_fft, x_scaled)
139+
if test_adjoint
140+
TestUtils.test_plan_adjoint(P, x_fft)
141+
end
142+
end
143+
# test IIP plans
144+
for P in (plan_bfft!(similar(x_fft), dims),)
145+
@test eltype(P) <: Complex
146+
@test fftdims(P) == dims
147+
TestUtils.test_plan(P, x_fft, x_scaled; inplace_plan=true)
148+
end
149+
150+
# IFFT
151+
@test ifft(x_fft, dims) x
152+
if test_inplace
153+
_x_fft = copy(x_fft)
154+
@test ifft!(_x_fft, dims) x
155+
@test _x_fft x
156+
end
157+
# test OOP plans
158+
for P in (plan_ifft(similar(x_complexf), dims),
159+
(_inv(plan_fft(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
160+
@test eltype(P) <: Complex
161+
@test fftdims(P) == dims
162+
TestUtils.test_plan(P, x_fft, x)
163+
if test_adjoint
164+
TestUtils.test_plan_adjoint(P, x_fft)
165+
end
166+
end
167+
# test IIP plans
168+
if test_inplace
169+
for P in (plan_ifft!(similar(x_complexf), dims),
170+
(_inv(plan_fft!(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
171+
@test eltype(P) <: Complex
172+
@test fftdims(P) == dims
173+
TestUtils.test_plan(P, x_fft, x; inplace_plan=true)
174+
end
175+
end
176+
end
177+
end
178+
end
179+
180+
function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false)
181+
@testset "correctness of rfft, brfft, irfft" begin
182+
for test_case in TEST_CASES
183+
_x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft)
184+
x = convert(ArrayType, _x) # dummy array that will be passed to plans
185+
x_real = float.(x) # for testing mutating real FFTs
186+
x_fft = convert(ArrayType, _x_fft)
187+
x_rfft = collect(selectdim(x_fft, first(dims), 1:(size(x_fft, first(dims)) ÷ 2 + 1)))
188+
189+
if !(eltype(x) <: Real)
190+
continue
191+
end
192+
193+
# RFFT
194+
@test rfft(x, dims) x_rfft
195+
for P in (plan_rfft(similar(x_real), dims),
196+
(_inv(plan_irfft(similar(x_rfft), size(x, first(dims)), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
197+
@test eltype(P) <: Real
198+
@test fftdims(P) == dims
199+
TestUtils.test_plan(P, x_real, x_rfft; copy_input=copy_input)
200+
if test_adjoint
201+
TestUtils.test_plan_adjoint(P, x_real; real_plan=true, copy_input=copy_input)
202+
end
203+
end
204+
205+
# BRFFT
206+
x_scaled = prod(size(x, d) for d in dims) .* x
207+
@test brfft(x_rfft, size(x, first(dims)), dims) x_scaled
208+
for P in (plan_brfft(similar(x_rfft), size(x, first(dims)), dims),)
209+
@test eltype(P) <: Complex
210+
@test fftdims(P) == dims
211+
TestUtils.test_plan(P, x_rfft, x_scaled; copy_input=copy_input)
212+
if test_adjoint
213+
TestUtils.test_plan_adjoint(P, x_rfft; real_plan=true, copy_input=copy_input)
214+
end
215+
end
216+
217+
# IRFFT
218+
@test irfft(x_rfft, size(x, first(dims)), dims) x
219+
for P in (plan_irfft(similar(x_rfft), size(x, first(dims)), dims),
220+
(_inv(plan_rfft(similar(x_real), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
221+
@test eltype(P) <: Complex
222+
@test fftdims(P) == dims
223+
TestUtils.test_plan(P, x_rfft, x; copy_input=copy_input)
224+
if test_adjoint
225+
TestUtils.test_plan_adjoint(P, x_rfft; real_plan=true, copy_input=copy_input)
226+
end
227+
end
228+
end
229+
end
230+
end
231+
232+
end

src/AbstractFFTs.jl

+2
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ export fft, ifft, bfft, fft!, ifft!, bfft!,
66
fftdims, fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq
77

88
include("definitions.jl")
9+
include("TestUtils.jl")
910

1011
if !isdefined(Base, :get_extension)
1112
include("../ext/AbstractFFTsChainRulesCoreExt.jl")
13+
include("../ext/AbstractFFTsTestExt.jl")
1214
end
1315

1416
end # module

src/TestUtils.jl

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
module TestUtils
2+
3+
"""
4+
TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_adjoint=true)
5+
6+
Run tests to verify correctness of FFT, BFFT, and IFFT functionality using a particular backend plan implementation.
7+
The backend implementation is assumed to be loaded prior to calling this function.
8+
9+
# Arguments
10+
11+
- `ArrayType`: determines the `AbstractArray` implementation for
12+
which the correctness tests are run. Arrays are constructed via
13+
`convert(ArrayType, ...)`.
14+
- `test_inplace=true`: whether to test in-place plans.
15+
- `test_adjoint=true`: whether to test [plan adjoints](api.md#Base.adjoint).
16+
"""
17+
function test_complex_ffts end
18+
19+
"""
20+
TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false)
21+
22+
Run tests to verify correctness of RFFT, BRFFT, and IRFFT functionality using a particular backend plan implementation.
23+
The backend implementation is assumed to be loaded prior to calling this function.
24+
25+
# Arguments
26+
27+
- `ArrayType`: determines the `AbstractArray` implementation for
28+
which the correctness tests are run. Arrays are constructed via
29+
`convert(ArrayType, ...)`.
30+
- `test_adjoint=true`: whether to test [plan adjoints](api.md#Base.adjoint).
31+
- `copy_input=false`: whether to copy the input before applying the plan in tests, to accomodate for
32+
[input-mutating behaviour of real FFTW plans](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101).
33+
"""
34+
function test_real_ffts end
35+
36+
# Always copy input before application due to FFTW real plans possibly mutating input (AbstractFFTs.jl#101)
37+
"""
38+
TestUtils.test_plan(P::Plan, x::AbstractArray, x_transformed::AbstractArray;
39+
inplace_plan=false, copy_input=false)
40+
41+
Test basic properties of a plan `P` given an input array `x` and expected output `x_transformed`.
42+
43+
Because [real FFTW plans may mutate their input in some cases](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101),
44+
we allow specifying `copy_input=true` to allow for this behaviour in tests by copying the input before applying the plan.
45+
"""
46+
function test_plan end
47+
48+
"""
49+
TestUtils.test_plan_adjoint(P::Plan, x::AbstractArray; real_plan=false, copy_input=false)
50+
51+
Test basic properties of the [adjoint](api.md#Base.adjoint) `P'` of a particular plan given an input array `x`,
52+
including its accuracy via the dot test.
53+
54+
Real-to-complex and complex-to-real plans require a slightly modified dot test, in which case `real_plan=true` should be provided.
55+
The plan is assumed out-of-place, as adjoints are not yet supported for in-place plans.
56+
Because [real FFTW plans may mutate their input in some cases](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101),
57+
we allow specifying `copy_input=true` to allow for this behaviour in tests by copying the input before applying the plan.
58+
"""
59+
function test_plan_adjoint end
60+
61+
if isdefined(Base, :get_extension) && isdefined(Base.Experimental, :register_error_hint)
62+
function __init__()
63+
# Better error message if users forget to load Test
64+
Base.Experimental.register_error_hint(MethodError) do io, exc, _, _
65+
if any(f -> (f === exc.f), (test_real_ffts, test_complex_ffts, test_plan, test_plan_adjoint)) &&
66+
(Base.get_extension(AbstractFFTs, :AbstractFFTsTestExt) === nothing)
67+
print(io, "\nDid you forget to load Test?")
68+
end
69+
end
70+
end
71+
end
72+
73+
end

src/definitions.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ end
665665
666666
Return a plan that performs the adjoint operation of the original plan.
667667
668-
!!! note
668+
!!! warning
669669
Adjoint plans do not currently support `LinearAlgebra.mul!`. Further, as a new addition to `AbstractFFTs`,
670670
coverage of `Base.adjoint` in downstream implementations may be limited.
671671
"""
@@ -676,6 +676,7 @@ Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale)
676676

677677
size(p::AdjointPlan) = output_size(p.p)
678678
output_size(p::AdjointPlan) = size(p.p)
679+
fftdims(p::AdjointPlan) = fftdims(p.p)
679680

680681
Base.:*(p::AdjointPlan, x::AbstractArray) = adjoint_mul(p.p, x)
681682

0 commit comments

Comments
 (0)