|
| 1 | +# This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license |
| 2 | + |
| 3 | +module AbstractFFTsTestExt |
| 4 | + |
| 5 | +using AbstractFFTs |
| 6 | +using AbstractFFTs: TestUtils |
| 7 | +using AbstractFFTs.LinearAlgebra |
| 8 | +using Test |
| 9 | + |
| 10 | +# Ground truth x_fft computed using FFTW library |
| 11 | +const TEST_CASES = ( |
| 12 | + (; x = collect(1:7), dims = 1, |
| 13 | + x_fft = [28.0 + 0.0im, |
| 14 | + -3.5 + 7.267824888003178im, |
| 15 | + -3.5 + 2.7911568610884143im, |
| 16 | + -3.5 + 0.7988521603655248im, |
| 17 | + -3.5 - 0.7988521603655248im, |
| 18 | + -3.5 - 2.7911568610884143im, |
| 19 | + -3.5 - 7.267824888003178im]), |
| 20 | + (; x = collect(1:8), dims = 1, |
| 21 | + x_fft = [36.0 + 0.0im, |
| 22 | + -4.0 + 9.65685424949238im, |
| 23 | + -4.0 + 4.0im, |
| 24 | + -4.0 + 1.6568542494923806im, |
| 25 | + -4.0 + 0.0im, |
| 26 | + -4.0 - 1.6568542494923806im, |
| 27 | + -4.0 - 4.0im, |
| 28 | + -4.0 - 9.65685424949238im]), |
| 29 | + (; x = collect(reshape(1:8, 2, 4)), dims = 2, |
| 30 | + x_fft = [16.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im; |
| 31 | + 20.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im]), |
| 32 | + (; x = collect(reshape(1:9, 3, 3)), dims = 2, |
| 33 | + x_fft = [12.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im; |
| 34 | + 15.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im; |
| 35 | + 18.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im]), |
| 36 | + (; x = collect(reshape(1:8, 2, 2, 2)), dims = 1:2, |
| 37 | + x_fft = cat([10.0 + 0.0im -4.0 + 0.0im; -2.0 + 0.0im 0.0 + 0.0im], |
| 38 | + [26.0 + 0.0im -4.0 + 0.0im; -2.0 + 0.0im 0.0 + 0.0im], |
| 39 | + dims=3)), |
| 40 | + (; x = collect(1:7) + im * collect(8:14), dims = 1, |
| 41 | + x_fft = [28.0 + 77.0im, |
| 42 | + -10.76782488800318 + 3.767824888003175im, |
| 43 | + -6.291156861088416 - 0.7088431389115883im, |
| 44 | + -4.298852160365525 - 2.7011478396344746im, |
| 45 | + -2.7011478396344764 - 4.298852160365524im, |
| 46 | + -0.7088431389115866 - 6.291156861088417im, |
| 47 | + 3.767824888003177 - 10.76782488800318im]), |
| 48 | + (; x = collect(reshape(1:8, 2, 2, 2)) + im * reshape(9:16, 2, 2, 2), dims = 1:2, |
| 49 | + x_fft = cat([10.0 + 42.0im -4.0 - 4.0im; -2.0 - 2.0im 0.0 + 0.0im], |
| 50 | + [26.0 + 58.0im -4.0 - 4.0im; -2.0 - 2.0im 0.0 + 0.0im], |
| 51 | + dims=3)), |
| 52 | + ) |
| 53 | + |
| 54 | + |
| 55 | +function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transformed::AbstractArray; inplace_plan=false, copy_input=false) |
| 56 | + _copy = copy_input ? copy : identity |
| 57 | + if !inplace_plan |
| 58 | + @test P * _copy(x) ≈ x_transformed |
| 59 | + @test P \ (P * _copy(x)) ≈ x |
| 60 | + _x_out = similar(P * _copy(x)) |
| 61 | + @test mul!(_x_out, P, _copy(x)) ≈ x_transformed |
| 62 | + @test _x_out ≈ x_transformed |
| 63 | + else |
| 64 | + _x = copy(x) |
| 65 | + @test P * _copy(_x) ≈ x_transformed |
| 66 | + @test _x ≈ x_transformed |
| 67 | + @test P \ _copy(_x) ≈ x |
| 68 | + @test _x ≈ x |
| 69 | + end |
| 70 | +end |
| 71 | + |
| 72 | +function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; real_plan=false, copy_input=false) |
| 73 | + _copy = copy_input ? copy : identity |
| 74 | + y = rand(eltype(P * _copy(x)), size(P * _copy(x))) |
| 75 | + # test basic properties |
| 76 | + @test_skip eltype(P') === typeof(y) # (AbstractFFTs.jl#110) |
| 77 | + @test (P')' === P # test adjoint of adjoint |
| 78 | + @test size(P') == AbstractFFTs.output_size(P) # test size of adjoint |
| 79 | + # test correctness of adjoint and its inverse via the dot test |
| 80 | + if !real_plan |
| 81 | + @test dot(y, P * _copy(x)) ≈ dot(P' * _copy(y), x) |
| 82 | + @test dot(y, P \ _copy(x)) ≈ dot(P' \ _copy(y), x) |
| 83 | + else |
| 84 | + _component_dot(x, y) = dot(real.(x), real.(y)) + dot(imag.(x), imag.(y)) |
| 85 | + @test _component_dot(y, P * _copy(x)) ≈ _component_dot(P' * _copy(y), x) |
| 86 | + @test _component_dot(x, P \ _copy(y)) ≈ _component_dot(P' \ _copy(x), y) |
| 87 | + end |
| 88 | + @test_throws MethodError mul!(x, P', y) |
| 89 | +end |
| 90 | + |
| 91 | +function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_adjoint=true) |
| 92 | + @testset "correctness of fft, bfft, ifft" begin |
| 93 | + for test_case in TEST_CASES |
| 94 | + _x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft) |
| 95 | + x = convert(ArrayType, _x) # dummy array that will be passed to plans |
| 96 | + x_complexf = convert(ArrayType, complex.(float.(x))) # for testing mutating complex FFTs |
| 97 | + x_fft = convert(ArrayType, _x_fft) |
| 98 | + |
| 99 | + # FFT |
| 100 | + @test fft(x, dims) ≈ x_fft |
| 101 | + if test_inplace |
| 102 | + _x_complexf = copy(x_complexf) |
| 103 | + @test fft!(_x_complexf, dims) ≈ x_fft |
| 104 | + @test _x_complexf ≈ x_fft |
| 105 | + end |
| 106 | + # test OOP plans, checking plan_fft and also inv and plan_inv of plan_ifft, |
| 107 | + # which should give functionally identical plans |
| 108 | + for P in (plan_fft(similar(x_complexf), dims), |
| 109 | + (_inv(plan_ifft(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...) |
| 110 | + @test eltype(P) <: Complex |
| 111 | + @test fftdims(P) == dims |
| 112 | + TestUtils.test_plan(P, x_complexf, x_fft) |
| 113 | + if test_adjoint |
| 114 | + @test fftdims(P') == fftdims(P) |
| 115 | + TestUtils.test_plan_adjoint(P, x_complexf) |
| 116 | + end |
| 117 | + end |
| 118 | + if test_inplace |
| 119 | + # test IIP plans |
| 120 | + for P in (plan_fft!(similar(x_complexf), dims), |
| 121 | + (_inv(plan_ifft!(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...) |
| 122 | + TestUtils.test_plan(P, x_complexf, x_fft; inplace_plan=true) |
| 123 | + end |
| 124 | + end |
| 125 | + |
| 126 | + # BFFT |
| 127 | + x_scaled = prod(size(x, d) for d in dims) .* x |
| 128 | + @test bfft(x_fft, dims) ≈ x_scaled |
| 129 | + if test_inplace |
| 130 | + _x_fft = copy(x_fft) |
| 131 | + @test bfft!(_x_fft, dims) ≈ x_scaled |
| 132 | + @test _x_fft ≈ x_scaled |
| 133 | + end |
| 134 | + # test OOP plans. Just 1 plan to test, but we use a for loop for consistent style |
| 135 | + for P in (plan_bfft(similar(x_fft), dims),) |
| 136 | + @test eltype(P) <: Complex |
| 137 | + @test fftdims(P) == dims |
| 138 | + TestUtils.test_plan(P, x_fft, x_scaled) |
| 139 | + if test_adjoint |
| 140 | + TestUtils.test_plan_adjoint(P, x_fft) |
| 141 | + end |
| 142 | + end |
| 143 | + # test IIP plans |
| 144 | + for P in (plan_bfft!(similar(x_fft), dims),) |
| 145 | + @test eltype(P) <: Complex |
| 146 | + @test fftdims(P) == dims |
| 147 | + TestUtils.test_plan(P, x_fft, x_scaled; inplace_plan=true) |
| 148 | + end |
| 149 | + |
| 150 | + # IFFT |
| 151 | + @test ifft(x_fft, dims) ≈ x |
| 152 | + if test_inplace |
| 153 | + _x_fft = copy(x_fft) |
| 154 | + @test ifft!(_x_fft, dims) ≈ x |
| 155 | + @test _x_fft ≈ x |
| 156 | + end |
| 157 | + # test OOP plans |
| 158 | + for P in (plan_ifft(similar(x_complexf), dims), |
| 159 | + (_inv(plan_fft(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...) |
| 160 | + @test eltype(P) <: Complex |
| 161 | + @test fftdims(P) == dims |
| 162 | + TestUtils.test_plan(P, x_fft, x) |
| 163 | + if test_adjoint |
| 164 | + TestUtils.test_plan_adjoint(P, x_fft) |
| 165 | + end |
| 166 | + end |
| 167 | + # test IIP plans |
| 168 | + if test_inplace |
| 169 | + for P in (plan_ifft!(similar(x_complexf), dims), |
| 170 | + (_inv(plan_fft!(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...) |
| 171 | + @test eltype(P) <: Complex |
| 172 | + @test fftdims(P) == dims |
| 173 | + TestUtils.test_plan(P, x_fft, x; inplace_plan=true) |
| 174 | + end |
| 175 | + end |
| 176 | + end |
| 177 | + end |
| 178 | +end |
| 179 | + |
| 180 | +function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false) |
| 181 | + @testset "correctness of rfft, brfft, irfft" begin |
| 182 | + for test_case in TEST_CASES |
| 183 | + _x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft) |
| 184 | + x = convert(ArrayType, _x) # dummy array that will be passed to plans |
| 185 | + x_real = float.(x) # for testing mutating real FFTs |
| 186 | + x_fft = convert(ArrayType, _x_fft) |
| 187 | + x_rfft = collect(selectdim(x_fft, first(dims), 1:(size(x_fft, first(dims)) ÷ 2 + 1))) |
| 188 | + |
| 189 | + if !(eltype(x) <: Real) |
| 190 | + continue |
| 191 | + end |
| 192 | + |
| 193 | + # RFFT |
| 194 | + @test rfft(x, dims) ≈ x_rfft |
| 195 | + for P in (plan_rfft(similar(x_real), dims), |
| 196 | + (_inv(plan_irfft(similar(x_rfft), size(x, first(dims)), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...) |
| 197 | + @test eltype(P) <: Real |
| 198 | + @test fftdims(P) == dims |
| 199 | + TestUtils.test_plan(P, x_real, x_rfft; copy_input=copy_input) |
| 200 | + if test_adjoint |
| 201 | + TestUtils.test_plan_adjoint(P, x_real; real_plan=true, copy_input=copy_input) |
| 202 | + end |
| 203 | + end |
| 204 | + |
| 205 | + # BRFFT |
| 206 | + x_scaled = prod(size(x, d) for d in dims) .* x |
| 207 | + @test brfft(x_rfft, size(x, first(dims)), dims) ≈ x_scaled |
| 208 | + for P in (plan_brfft(similar(x_rfft), size(x, first(dims)), dims),) |
| 209 | + @test eltype(P) <: Complex |
| 210 | + @test fftdims(P) == dims |
| 211 | + TestUtils.test_plan(P, x_rfft, x_scaled; copy_input=copy_input) |
| 212 | + if test_adjoint |
| 213 | + TestUtils.test_plan_adjoint(P, x_rfft; real_plan=true, copy_input=copy_input) |
| 214 | + end |
| 215 | + end |
| 216 | + |
| 217 | + # IRFFT |
| 218 | + @test irfft(x_rfft, size(x, first(dims)), dims) ≈ x |
| 219 | + for P in (plan_irfft(similar(x_rfft), size(x, first(dims)), dims), |
| 220 | + (_inv(plan_rfft(similar(x_real), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...) |
| 221 | + @test eltype(P) <: Complex |
| 222 | + @test fftdims(P) == dims |
| 223 | + TestUtils.test_plan(P, x_rfft, x; copy_input=copy_input) |
| 224 | + if test_adjoint |
| 225 | + TestUtils.test_plan_adjoint(P, x_rfft; real_plan=true, copy_input=copy_input) |
| 226 | + end |
| 227 | + end |
| 228 | + end |
| 229 | + end |
| 230 | +end |
| 231 | + |
| 232 | +end |
0 commit comments