Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Put tests of FFT backends into TestUtils submodule #78

Merged
merged 19 commits into from
Jul 29, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ version = "1.4.0"
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[extensions]
AbstractFFTsChainRulesCoreExt = "ChainRulesCore"
AbstractFFTsTestExt = "Test"

[compat]
ChainRulesCore = "1"
Expand Down
14 changes: 14 additions & 0 deletions docs/src/implementations.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,17 @@ To define a new FFT implementation in your own module, you should

The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of
length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``.

## Testing implementations

`AbstractFFTs.jl` provides a `TestUtils` module to help with testing downstream implementations.
The following functions test that all FFT functionality has been correctly implemented:
```@docs
AbstractFFTs.TestUtils.test_complex_ffts
AbstractFFTs.TestUtils.test_real_ffts
```
`TestUtils` also exposes lower level functions for generically testing particular plans:
```@docs
AbstractFFTs.TestUtils.test_plan
AbstractFFTs.TestUtils.test_plan_adjoint
```
226 changes: 226 additions & 0 deletions ext/AbstractFFTsTestExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license

module AbstractFFTsTestExt

using AbstractFFTs
using AbstractFFTs: TestUtils
using AbstractFFTs.LinearAlgebra
using Test

# Ground truth x_fft computed using FFTW library
const TEST_CASES = (
(; x = collect(1:7), dims = 1,
x_fft = [28.0 + 0.0im,
-3.5 + 7.267824888003178im,
-3.5 + 2.7911568610884143im,
-3.5 + 0.7988521603655248im,
-3.5 - 0.7988521603655248im,
-3.5 - 2.7911568610884143im,
-3.5 - 7.267824888003178im]),
(; x = collect(1:8), dims = 1,
x_fft = [36.0 + 0.0im,
-4.0 + 9.65685424949238im,
-4.0 + 4.0im,
-4.0 + 1.6568542494923806im,
-4.0 + 0.0im,
-4.0 - 1.6568542494923806im,
-4.0 - 4.0im,
-4.0 - 9.65685424949238im]),
(; x = collect(reshape(1:8, 2, 4)), dims = 2,
x_fft = [16.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im;
20.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im]),
(; x = collect(reshape(1:9, 3, 3)), dims = 2,
x_fft = [12.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im;
15.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im;
18.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im]),
(; x = collect(reshape(1:8, 2, 2, 2)), dims = 1:2,
x_fft = cat([10.0 + 0.0im -4.0 + 0.0im; -2.0 + 0.0im 0.0 + 0.0im],
[26.0 + 0.0im -4.0 + 0.0im; -2.0 + 0.0im 0.0 + 0.0im],
dims=3)),
(; x = collect(1:7) + im * collect(8:14), dims = 1,
x_fft = [28.0 + 77.0im,
-10.76782488800318 + 3.767824888003175im,
-6.291156861088416 - 0.7088431389115883im,
-4.298852160365525 - 2.7011478396344746im,
-2.7011478396344764 - 4.298852160365524im,
-0.7088431389115866 - 6.291156861088417im,
3.767824888003177 - 10.76782488800318im]),
(; x = collect(reshape(1:8, 2, 2, 2)) + im * reshape(9:16, 2, 2, 2), dims = 1:2,
x_fft = cat([10.0 + 42.0im -4.0 - 4.0im; -2.0 - 2.0im 0.0 + 0.0im],
[26.0 + 58.0im -4.0 - 4.0im; -2.0 - 2.0im 0.0 + 0.0im],
dims=3)),
)


function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transformed::AbstractArray; inplace_plan=false, copy_input=false)
_copy = copy_input ? copy : identity
if !inplace_plan
@test P * _copy(x) ≈ x_transformed
@test P \ (P * _copy(x)) ≈ x
_x_out = similar(P * _copy(x))
@test mul!(_x_out, P, _copy(x)) ≈ x_transformed
@test _x_out ≈ x_transformed
else
_x = copy(x)
@test P * _copy(_x) ≈ x_transformed
@test _x ≈ x_transformed
@test P \ _copy(_x) ≈ x
@test _x ≈ x
end
end

function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; real_plan=false, copy_input=false)
_copy = copy_input ? copy : identity
y = rand(eltype(P * _copy(x)), size(P * _copy(x)))
# test basic properties
@test_skip eltype(P') === typeof(y) # (AbstractFFTs.jl#110)
@test (P')' === P # test adjoint of adjoint
@test size(P') == AbstractFFTs.output_size(P) # test size of adjoint
# test correctness of adjoint and its inverse via the dot test
if !real_plan
@test dot(y, P * _copy(x)) ≈ dot(P' * _copy(y), x)
@test dot(y, P \ _copy(x)) ≈ dot(P' \ _copy(y), x)
else
_component_dot(x, y) = dot(real.(x), real.(y)) + dot(imag.(x), imag.(y))
@test _component_dot(y, P * _copy(x)) ≈ _component_dot(P' * _copy(y), x)
@test _component_dot(x, P \ _copy(y)) ≈ _component_dot(P' \ _copy(x), y)
end
@test_throws MethodError mul!(x, P', y)
end

function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_adjoint=true)
@testset "correctness of fft, bfft, ifft" begin
for test_case in TEST_CASES
_x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft)
x = convert(ArrayType, _x) # dummy array that will be passed to plans
x_complexf = convert(ArrayType, complex.(float.(x))) # for testing mutating complex FFTs
x_fft = convert(ArrayType, _x_fft)

# FFT
@test fft(x, dims) ≈ x_fft
if test_inplace
_x_complexf = copy(x_complexf)
@test fft!(_x_complexf, dims) ≈ x_fft
@test _x_complexf ≈ x_fft
end
# test OOP plans, checking plan_fft and also inv and plan_inv of plan_ifft,
# which should give functionally identical plans
for P in (plan_fft(similar(x_complexf), dims),
(_inv(plan_ifft(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_complexf, x_fft)
if test_adjoint
@test fftdims(P') == fftdims(P)
TestUtils.test_plan_adjoint(P, x_complexf)
end
end
if test_inplace
# test IIP plans
for P in (plan_fft!(similar(x_complexf), dims),
(_inv(plan_ifft!(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
TestUtils.test_plan(P, x_complexf, x_fft; inplace_plan=true)
end
end

# BFFT
x_scaled = prod(size(x, d) for d in dims) .* x
@test bfft(x_fft, dims) ≈ x_scaled
if test_inplace
_x_fft = copy(x_fft)
@test bfft!(_x_fft, dims) ≈ x_scaled
@test _x_fft ≈ x_scaled
end
# test OOP plans. Just 1 plan to test, but we use a for loop for consistent style
for P in (plan_bfft(similar(x_fft), dims),)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_fft, x_scaled)
if test_adjoint
TestUtils.test_plan_adjoint(P, x_fft)
end
end
# test IIP plans
for P in (plan_bfft!(similar(x_fft), dims),)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_fft, x_scaled; inplace_plan=true)
end

# IFFT
@test ifft(x_fft, dims) ≈ x
if test_inplace
_x_fft = copy(x_fft)
@test ifft!(_x_fft, dims) ≈ x
@test _x_fft ≈ x
end
# test OOP plans
for P in (plan_ifft(similar(x_complexf), dims),
(_inv(plan_fft(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_fft, x)
if test_adjoint
TestUtils.test_plan_adjoint(P, x_fft)
end
end
# test IIP plans
if test_inplace
for P in (plan_ifft!(similar(x_complexf), dims),
(_inv(plan_fft!(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_fft, x; inplace_plan=true)
end
end
end
end
end

function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false)
@testset "correctness of rfft, brfft, irfft" begin
for test_case in TEST_CASES
_x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft)
x = convert(ArrayType, _x) # dummy array that will be passed to plans
x_real = float.(x) # for testing mutating real FFTs
x_fft = convert(ArrayType, _x_fft)
x_rfft = selectdim(x_fft, first(dims), 1:(size(x_fft, first(dims)) ÷ 2 + 1))

if !(eltype(x) <: Real)
continue
end

# RFFT
@test rfft(x, dims) ≈ x_rfft
for P in (plan_rfft(similar(x_real), dims),
(_inv(plan_irfft(similar(x_rfft), size(x, first(dims)), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
@test eltype(P) <: Real
@test fftdims(P) == dims
TestUtils.test_plan(P, x_real, x_rfft; copy_input=copy_input)
if test_adjoint
TestUtils.test_plan_adjoint(P, x_real; real_plan=true, copy_input=copy_input)
end
end

# BRFFT
x_scaled = prod(size(x, d) for d in dims) .* x
@test brfft(x_rfft, size(x, first(dims)), dims) ≈ x_scaled
for P in (plan_brfft(similar(x_rfft), size(x, first(dims)), dims),)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_rfft, x_scaled; copy_input=copy_input)
end

# IRFFT
@test irfft(x_rfft, size(x, first(dims)), dims) ≈ x
for P in (plan_irfft(similar(x_rfft), size(x, first(dims)), dims),
(_inv(plan_rfft(similar(x_real), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_rfft, x; copy_input=copy_input)
end
end
end
end

end
2 changes: 2 additions & 0 deletions src/AbstractFFTs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ export fft, ifft, bfft, fft!, ifft!, bfft!,
fftdims, fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq

include("definitions.jl")
include("TestUtils.jl")

if !isdefined(Base, :get_extension)
include("../ext/AbstractFFTsChainRulesCoreExt.jl")
include("../ext/AbstractFFTsTestExt.jl")
end

end # module
72 changes: 72 additions & 0 deletions src/TestUtils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
module TestUtils

"""
TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_adjoint=true)

Run tests to verify correctness of FFT, BFFT, and IFFT functionality using a particular backend plan implementation.
The backend implementation is assumed to be loaded prior to calling this function.

# Arguments

- `ArrayType`: determines the `AbstractArray` implementation for
which the correctness tests are run. Arrays are constructed via
`convert(ArrayType, ...)`.
- `test_inplace=true`: whether to test in-place plans.
- `test_adjoint=true`: whether to test [plan adjoints](api.md#Base.adjoint).
"""
function test_complex_ffts end

"""
TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false)

Run tests to verify correctness of RFFT, BRFFT, and IRFFT functionality using a particular backend plan implementation.
The backend implementation is assumed to be loaded prior to calling this function.

# Arguments

- `ArrayType`: determines the `AbstractArray` implementation for
which the correctness tests are run. Arrays are constructed via
`convert(ArrayType, ...)`.
- `test_adjoint=true`: whether to test [plan adjoints](api.md#Base.adjoint).
- `copy_input=false`: whether to copy the input before applying the plan in tests, to accomodate for
[input-mutating behaviour of real FFTW plans](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101).
"""
function test_real_ffts end

# Always copy input before application due to FFTW real plans possibly mutating input (AbstractFFTs.jl#101)
"""
TestUtils.test_plan(P::Plan, x::AbstractArray, x_transformed::AbstractArray;
inplace_plan=false, copy_input=false)

Test basic properties of a plan `P` given an input array `x` and expected output `x_transformed`.

Because [real FFTW plans may mutate their input in some cases](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101),
we allow specifying `copy_input=true` to allow for this behaviour in tests by copying the input before applying the plan.
"""
function test_plan end

"""
TestUtils.test_plan_adjoint(P::Plan, x::AbstractArray; real_plan=false, copy_input=false)

Test basic properties of the [adjoint](api.md#Base.adjoint) `P'` of a particular plan given an input array `x`,
including its accuracy via the dot test.

Real-to-complex and complex-to-real plans require a slightly modified dot test, in which case `real_plan=true` should be provided.
The plan is assumed out-of-place, as adjoints are not yet supported for in-place plans.
Because [real FFTW plans may mutate their input in some cases](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101),
we allow specifying `copy_input=true` to allow for this behaviour in tests by copying the input before applying the plan.
"""
function test_plan_adjoint end

function __init__()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can wrap it inside a if isdefined(Base, :get_extension).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outside of the whole __init__ function I meant, not in its function body.

if isdefined(Base, :get_extension) && isdefined(Base.Experimental, :register_error_hint)
# Better error message if users forget to load Test
Base.Experimental.register_error_hint(MethodError) do io, exc, _, _
if (exc.f === test_real_fft || exc.f === test_complex_fft) && Base.get_extension(AbstractFFTs, :AbstractFFTsTestExt) === nothing
print(io, "\nDid you forget to load Test?")
end
end
end
end

end
3 changes: 2 additions & 1 deletion src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ Form the adjoint operator of an FFT plan. Returns a plan that performs the adjoi
the original plan. Note that this differs from the corresponding backwards plan in the case of real
FFTs due to the halving of one of the dimensions of the FFT output, as described in [`rfft`](@ref).

!!! note
!!! warning
Adjoint plans do not currently support `LinearAlgebra.mul!`. Further, as a new addition to `AbstractFFTs`,
coverage of `Base.adjoint` in downstream implementations may be limited.
"""
Expand All @@ -619,6 +619,7 @@ Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale)

size(p::AdjointPlan) = output_size(p.p)
output_size(p::AdjointPlan) = size(p.p)
fftdims(p::AdjointPlan) = fftdims(p.p)

Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p))

Expand Down
Loading