Skip to content

Commit 3537f76

Browse files
committed
Apply code review suggestions and refactor TestUtils
1 parent 98fdcde commit 3537f76

File tree

6 files changed

+87
-86
lines changed

6 files changed

+87
-86
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1313

1414
[extensions]
1515
AbstractFFTsChainRulesCoreExt = "ChainRulesCore"
16-
AbstractFFTsTestUtilsExt = "Test"
16+
AbstractFFTsTestExt = "Test"
1717

1818
[compat]
1919
ChainRulesCore = "1"

docs/src/implementations.md

+7-3
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,13 @@ length ``n``, and the "backwards" (unnormalized inverse) transform computes the
4343
## Testing implementations
4444

4545
`AbstractFFTs.jl` provides a `TestUtils` module to help with testing downstream implementations.
46-
46+
The following functions test that all FFT functionality has been correctly implemented:
47+
```@docs
48+
AbstractFFTs.TestUtils.test_complex_ffts
49+
AbstractFFTs.TestUtils.test_real_ffts
50+
```
51+
`TestUtils` also exposes lower level functions for generically testing particular plans:
4752
```@docs
48-
AbstractFFTs.TestUtils.test_complex_fft
49-
AbstractFFTs.TestUtils.test_real_fft
53+
AbstractFFTs.TestUtils.test_plan
5054
AbstractFFTs.TestUtils.test_plan_adjoint
5155
```

ext/AbstractFFTsTestUtilsExt.jl ext/AbstractFFTsTestExt.jl

+46-65
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
# This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license
22

3-
module AbstractFFTsTestUtilsExt
3+
module AbstractFFTsTestExt
44

55
using AbstractFFTs
66
using AbstractFFTs: TestUtils
77
using AbstractFFTs.LinearAlgebra
88
using Test
99

10-
# Ground truth _x_fft computed using FFTW library
10+
# Ground truth x_fft computed using FFTW library
1111
const TEST_CASES = (
1212
(; x = collect(1:7), dims = 1,
1313
x_fft = [28.0 + 0.0im,
@@ -51,29 +51,47 @@ const TEST_CASES = (
5151
dims=3)),
5252
)
5353

54-
function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; real_plan=false)
55-
y = rand(eltype(P * x), size(P * x))
54+
55+
function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transformed::AbstractArray; inplace_plan=false, copy_input=false)
56+
_copy = copy_input ? copy : identity
57+
if !inplace_plan
58+
@test P * _copy(x) x_transformed
59+
@test P \ (P * _copy(x)) x
60+
_x_out = similar(P * _copy(x))
61+
@test mul!(_x_out, P, _copy(x)) x_transformed
62+
@test _x_out x_transformed
63+
else
64+
_x = copy(x)
65+
@test P * _copy(_x) x_transformed
66+
@test _x x_transformed
67+
@test P \ _copy(_x) x
68+
@test _x x
69+
end
70+
end
71+
72+
function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; real_plan=false, copy_input=false)
73+
_copy = copy_input ? copy : identity
74+
y = rand(eltype(P * _copy(x)), size(P * _copy(x)))
5675
# test basic properties
57-
@test_broken eltype(P') === typeof(y) # (AbstractFFTs.jl#110)
58-
@test fftdims(P') == fftdims(P)
76+
@test_skip eltype(P') === typeof(y) # (AbstractFFTs.jl#110)
5977
@test (P')' === P # test adjoint of adjoint
6078
@test size(P') == AbstractFFTs.output_size(P) # test size of adjoint
6179
# test correctness of adjoint and its inverse via the dot test
6280
if !real_plan
63-
@test dot(y, P * x) dot(P' * y, x)
64-
@test dot(y, P \ x) dot(P' \ y, x)
81+
@test dot(y, P * _copy(x)) dot(P' * _copy(y), x)
82+
@test dot(y, P \ _copy(x)) dot(P' \ _copy(y), x)
6583
else
6684
_component_dot(x, y) = dot(real.(x), real.(y)) + dot(imag.(x), imag.(y))
67-
@test _component_dot(y, P * copy(x)) _component_dot(P' * copy(y), x)
68-
@test _component_dot(x, P \ copy(y)) _component_dot(P' \ copy(x), y)
85+
@test _component_dot(y, P * _copy(x)) _component_dot(P' * _copy(y), x)
86+
@test _component_dot(x, P \ _copy(y)) _component_dot(P' \ _copy(x), y)
6987
end
7088
@test_throws MethodError mul!(x, P', y)
7189
end
7290

73-
function TestUtils.test_complex_fft(ArrayType=Array; test_inplace=true, test_adjoint=true)
91+
function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_adjoint=true)
7492
@testset "correctness of fft, bfft, ifft" begin
7593
for test_case in TEST_CASES
76-
_x, dims, _x_fft = test_case.x, test_case.dims, test_case.x_fft
94+
_x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft)
7795
x = convert(ArrayType, _x) # dummy array that will be passed to plans
7896
x_complexf = convert(ArrayType, complex.(float.(x))) # for testing mutating complex FFTs
7997
x_fft = convert(ArrayType, _x_fft)
@@ -90,25 +108,16 @@ function TestUtils.test_complex_fft(ArrayType=Array; test_inplace=true, test_adj
90108
for P in (plan_fft(similar(x_complexf), dims), inv(plan_ifft(similar(x_complexf), dims)))
91109
@test eltype(P) <: Complex
92110
@test fftdims(P) == dims
93-
@test P * x x_fft
94-
@test P \ (P * x) x
95-
_x_out = similar(x_fft)
96-
@test mul!(_x_out, P, x_complexf) x_fft
97-
@test _x_out x_fft
111+
TestUtils.test_plan(P, x_complexf, x_fft)
98112
if test_adjoint
113+
@test fftdims(P') == fftdims(P)
99114
TestUtils.test_plan_adjoint(P, x_complexf)
100115
end
101116
end
102117
if test_inplace
103118
# test IIP plans
104119
for P in (plan_fft!(similar(x_complexf), dims), inv(plan_ifft!(similar(x_complexf), dims)))
105-
@test eltype(P) <: Complex
106-
@test fftdims(P) == dims
107-
_x_complexf = copy(x_complexf)
108-
@test P * _x_complexf x_fft
109-
@test _x_complexf x_fft
110-
@test P \ _x_complexf x
111-
@test _x_complexf x
120+
TestUtils.test_plan(P, x_complexf, x_fft; inplace_plan=true)
112121
end
113122
end
114123

@@ -124,24 +133,16 @@ function TestUtils.test_complex_fft(ArrayType=Array; test_inplace=true, test_adj
124133
for P in (plan_bfft(similar(x_fft), dims),)
125134
@test eltype(P) <: Complex
126135
@test fftdims(P) == dims
127-
@test P * x_fft x_scaled
128-
@test P \ (P * x_fft) x_fft
129-
_x_complexf = similar(x_complexf)
130-
@test mul!(_x_complexf, P, x_fft) x_scaled
131-
@test _x_complexf x_scaled
136+
TestUtils.test_plan(P, x_fft, x_scaled)
132137
if test_adjoint
133-
TestUtils.test_plan_adjoint(P, x_complexf)
138+
TestUtils.test_plan_adjoint(P, x_fft)
134139
end
135140
end
136141
# test IIP plans
137142
for P in (plan_bfft!(similar(x_fft), dims),)
138143
@test eltype(P) <: Complex
139144
@test fftdims(P) == dims
140-
_x_fft = copy(x_fft)
141-
@test P * _x_fft x_scaled
142-
@test _x_fft x_scaled
143-
@test P \ _x_fft x_fft
144-
@test _x_fft x_fft
145+
TestUtils.test_plan(P, x_fft, x_scaled; inplace_plan=true)
145146
end
146147

147148
# IFFT
@@ -155,35 +156,27 @@ function TestUtils.test_complex_fft(ArrayType=Array; test_inplace=true, test_adj
155156
for P in (plan_ifft(similar(x_complexf), dims), inv(plan_fft(similar(x_complexf), dims)))
156157
@test eltype(P) <: Complex
157158
@test fftdims(P) == dims
158-
@test P * x_fft x
159-
@test P \ (P * x_fft) x_fft
160-
_x_complexf = similar(x_complexf)
161-
@test mul!(_x_complexf, P, x_fft) x
162-
@test _x_complexf x
159+
TestUtils.test_plan(P, x_fft, x)
163160
if test_adjoint
164-
TestUtils.test_plan_adjoint(P, x_complexf)
161+
TestUtils.test_plan_adjoint(P, x_fft)
165162
end
166163
end
167164
# test IIP plans
168165
if test_inplace
169166
for P in (plan_ifft!(similar(x_complexf), dims), inv(plan_fft!(similar(x_complexf), dims)))
170167
@test eltype(P) <: Complex
171168
@test fftdims(P) == dims
172-
_x_fft = copy(x_fft)
173-
@test P * _x_fft x
174-
@test _x_fft x
175-
@test P \ _x_fft x_fft
176-
@test _x_fft x_fft
169+
TestUtils.test_plan(P, x_fft, x; inplace_plan=true)
177170
end
178171
end
179172
end
180173
end
181174
end
182175

183-
function TestUtils.test_real_fft(ArrayType=Array; test_inplace=true, test_adjoint=true)
176+
function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false)
184177
@testset "correctness of rfft, brfft, irfft" begin
185178
for test_case in TEST_CASES
186-
_x, dims, _x_fft = test_case.x, test_case.dims, test_case.x_fft
179+
_x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft)
187180
x = convert(ArrayType, _x) # dummy array that will be passed to plans
188181
x_real = float.(x) # for testing mutating real FFTs
189182
x_fft = convert(ArrayType, _x_fft)
@@ -198,14 +191,9 @@ function TestUtils.test_real_fft(ArrayType=Array; test_inplace=true, test_adjoin
198191
for P in (plan_rfft(similar(x_real), dims), inv(plan_irfft(similar(x_rfft), size(x, first(dims)), dims)))
199192
@test eltype(P) <: Real
200193
@test fftdims(P) == dims
201-
# Always copy input before application due to FFTW real plans possibly mutating input (AbstractFFTs.jl#101)
202-
@test P * copy(x) x_rfft
203-
@test P \ (P * copy(x)) x
204-
_x_rfft = similar(x_rfft)
205-
@test mul!(_x_rfft, P, copy(x_real)) x_rfft
206-
@test _x_rfft x_rfft
194+
TestUtils.test_plan(P, x_real, x_rfft; copy_input)
207195
if test_adjoint
208-
TestUtils.test_plan_adjoint(P, x_real; real_plan=true)
196+
TestUtils.test_plan_adjoint(P, x_real; real_plan=true, copy_input)
209197
end
210198
end
211199

@@ -215,25 +203,18 @@ function TestUtils.test_real_fft(ArrayType=Array; test_inplace=true, test_adjoin
215203
for P in (plan_brfft(similar(x_rfft), size(x, first(dims)), dims),)
216204
@test eltype(P) <: Complex
217205
@test fftdims(P) == dims
218-
@test P * copy(x_rfft) x_scaled
219-
@test P \ (P * copy(x_rfft)) x_rfft
220-
_x_scaled = similar(x_real)
221-
@test mul!(_x_scaled, P, copy(x_rfft)) x_scaled
222-
@test _x_scaled x_scaled
206+
TestUtils.test_plan(P, x_rfft, x_scaled; copy_input)
223207
end
224208

225209
# IRFFT
226210
@test irfft(x_rfft, size(x, first(dims)), dims) x
227211
for P in (plan_irfft(similar(x_rfft), size(x, first(dims)), dims), inv(plan_rfft(similar(x_real), dims)))
228212
@test eltype(P) <: Complex
229213
@test fftdims(P) == dims
230-
@test P * copy(x_rfft) x
231-
@test P \ (P * copy(x_rfft)) x_rfft
232-
_x_real = similar(x_real)
233-
@test mul!(_x_real, P, copy(x_rfft)) x_real
214+
TestUtils.test_plan(P, x_rfft, x; copy_input)
234215
end
235216
end
236217
end
237218
end
238219

239-
end
220+
end

src/AbstractFFTs.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ include("TestUtils.jl")
1010

1111
if !isdefined(Base, :get_extension)
1212
include("../ext/AbstractFFTsChainRulesCoreExt.jl")
13-
include("../ext/AbstractFFTsTestUtilsExt.jl")
13+
include("../ext/AbstractFFTsTestExt.jl")
1414
end
1515

1616
end # module

src/TestUtils.jl

+30-14
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
module TestUtils
22

33
"""
4-
TestUtils.test_complex_fft(ArrayType=Array; test_real=true, test_inplace=true)
4+
TestUtils.test_complex_ffts(ArrayType=Array; test_adjoint=true, test_inplace=true)
55
6-
Run tests to verify correctness of FFT/BFFT/IFFT functionality using a particular backend plan implementation.
6+
Run tests to verify correctness of FFT, BFFT, and IFFT functionality using a particular backend plan implementation.
77
The backend implementation is assumed to be loaded prior to calling this function.
88
99
# Arguments
@@ -14,43 +14,59 @@ The backend implementation is assumed to be loaded prior to calling this functio
1414
- `test_inplace=true`: whether to test in-place plans.
1515
- `test_adjoint=true`: whether to test [plan adjoints](api.md#Base.adjoint).
1616
"""
17-
function test_complex_fft end
17+
function test_complex_ffts end
1818

1919
"""
20-
TestUtils.test_real_fft(ArrayType=Array; test_real=true, test_inplace=true)
20+
TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false)
2121
22-
Run tests to verify correctness of RFFT/BRFFT/IRFFT functionality using a particular backend plan implementation.
22+
Run tests to verify correctness of RFFT, BRFFT, and IRFFT functionality using a particular backend plan implementation.
2323
The backend implementation is assumed to be loaded prior to calling this function.
2424
2525
# Arguments
2626
2727
- `ArrayType`: determines the `AbstractArray` implementation for
2828
which the correctness tests are run. Arrays are constructed via
2929
`convert(ArrayType, ...)`.
30-
- `test_inplace=true`: whether to test in-place plans.
3130
- `test_adjoint=true`: whether to test [plan adjoints](api.md#Base.adjoint).
31+
- `copy_input=false`: whether to copy the input before applying the plan in tests, to accomodate for
32+
[input-mutating behaviour of real FFTW plans](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101).
3233
"""
33-
function test_real_fft end
34+
function test_real_ffts end
3435

36+
# Always copy input before application due to FFTW real plans possibly mutating input (AbstractFFTs.jl#101)
3537
"""
36-
TestUtils.test_plan_adjoint(P::Plan, x::AbstractArray; real_plan=false)
38+
TestUtils.test_plan(P::Plan, x::AbstractArray, x_transformed::AbstractArray;
39+
inplace_plan=false, copy_input=false)
40+
41+
Test basic properties of a plan `P` given an input array `x` and expected output `x_transformed`.
3742
38-
Test basic properties of the adjoint `P'` of a particular plan given an input array to the plan `x`,
39-
including its accuracy via the dot test. Real-to-complex and complex-to-real plans require
40-
a slightly modified dot test, in which case `real_plan=true` should be provided.
43+
Because [real FFTW plans may mutate their input in some cases](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101),
44+
we allow specifying `copy_input=true` to allow for this behaviour in tests by copying the input before applying the plan.
45+
"""
46+
function test_plan end
4147

48+
"""
49+
TestUtils.test_plan_adjoint(P::Plan, x::AbstractArray; real_plan=false, copy_input=false)
50+
51+
Test basic properties of the [adjoint](api.md#Base.adjoint) `P'` of a particular plan given an input array `x`,
52+
including its accuracy via the dot test.
53+
54+
Real-to-complex and complex-to-real plans require a slightly modified dot test, in which case `real_plan=true` should be provided.
55+
The plan is assumed out-of-place, as adjoints are not yet supported for in-place plans.
56+
Because [real FFTW plans may mutate their input in some cases](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101),
57+
we allow specifying `copy_input=true` to allow for this behaviour in tests by copying the input before applying the plan.
4258
"""
4359
function test_plan_adjoint end
4460

4561
function __init__()
46-
if isdefined(Base, :Experimental)
62+
if isdefined(Base, :get_extension) && isdefined(Base.Experimental, :register_error_hint)
4763
# Better error message if users forget to load Test
4864
Base.Experimental.register_error_hint(MethodError) do io, exc, _, _
49-
if exc.f in (test_real_fft, test_complex_fft)
65+
if (exc.f === test_real_fft || exc.f === test_complex_fft) && Base.get_extension(AbstractFFTs, :AbstractFFTsTestExt) === nothing
5066
print(io, "\nDid you forget to load Test?")
5167
end
5268
end
5369
end
5470
end
5571

56-
end
72+
end

test/runtests.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ Random.seed!(1234)
1313
include("TestPlans.jl")
1414

1515
# Run interface tests for TestPlans
16-
AbstractFFTs.TestUtils.test_complex_fft(Array)
17-
AbstractFFTs.TestUtils.test_real_fft(Array)
16+
AbstractFFTs.TestUtils.test_complex_ffts(Array)
17+
AbstractFFTs.TestUtils.test_real_ffts(Array)
1818

1919
@testset "rfft sizes" begin
2020
A = rand(11, 10)

0 commit comments

Comments
 (0)