From ad718161b1c3ad116bcbb41bd96304f45e7d7330 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Wed, 8 Jun 2022 23:34:17 -0700 Subject: [PATCH 01/26] Implement AdjointPlans --- src/definitions.jl | 66 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/src/definitions.jl b/src/definitions.jl index ac9a4ba5..e140233f 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -12,6 +12,7 @@ eltype(::Type{<:Plan{T}}) where {T} = T # size(p) should return the size of the input array for p size(p::Plan, d) = size(p)[d] +output_size(p::Plan, d) = output_size(p)[d] ndims(p::Plan) = length(size(p)) length(p::Plan) = prod(size(p))::Int @@ -255,6 +256,7 @@ ScaledPlan(p::Plan{T}, scale::Number) where {T} = ScaledPlan{T}(p, scale) ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α) size(p::ScaledPlan) = size(p.p) +output_size(p::ScaledPlan) = output_size(p.p) fftdims(p::ScaledPlan) = fftdims(p.p) @@ -576,3 +578,67 @@ Pre-plan an optimized real-input unnormalized transform, similar to the same as for [`brfft`](@ref). """ plan_brfft + +############################################################################## + +struct NoProjectionStyle end +struct RealProjectionStyle end +struct RealInverseProjectionStyle end +const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle} + +function irfft_dim end + +output_size(p::Plan) = _output_size(p, ProjectionStyle(p)) +_output_size(p::Plan, ::NoProjectionStyle) = size(p) +_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), region(p)) +_output_size(p::Plan, ::RealInverseProjectionStyle) = brfft_output_size(size(p), irfft_dim(p), region(p)) + +mutable struct AdjointPlan{T,P} <: Plan{T} + p::P + pinv::Plan + AdjointPlan{T,P}(p) where {T,P} = new(p) +end + +Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{T, typeof(p)}(p) +Base.adjoint(p::AdjointPlan{T}) where {T} = p.p +# always have AdjointPlan inside ScaledPlan. +Base.adjoint(p::ScaledPlan{T}) where {T} = ScaledPlan{T}(p.p', p.scale) + +size(p::AdjointPlan) = output_size(p.p) +output_size(p::AdjointPlan) = size(p.p) + +Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p)) + +function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T} + dims = region(p.p) + N = normalization(T, size(p.p), dims) + return 1/N * (p.p \ x) +end + +function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T} + dims = region(p.p) + N = normalization(T, size(p.p), dims) + halfdim = first(dims) + d = size(p.p, halfdim) + n = output_size(p.p, halfdim) + scale = reshape( + [(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n], + ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))) + ) + return 1/N * (p.p \ (x ./ scale)) +end + +function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T} + dims = region(p.p) + N = normalization(real(T), output_size(p.p), dims) + halfdim = first(dims) + n = size(p.p, halfdim) + d = output_size(p.p, halfdim) + scale = reshape( + [(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n], + ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))) + ) + return 1/N * scale .* (p.p \ x) +end + +plan_inv(p::AdjointPlan) = AdjointPlan(plan_inv(p.p)) From c91ad5017712768e04d46c8c0f6ad2c480ff6596 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Wed, 8 Jun 2022 23:35:10 -0700 Subject: [PATCH 02/26] Implement chain rules for FFT plans --- src/chainrules.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/chainrules.jl b/src/chainrules.jl index 97d4d229..610bd59d 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -150,3 +150,20 @@ function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims) end return y, ifftshift_pullback end + +# plans +function ChainRulesCore.frule((_, _, Δx), ::typeof(*), P::Plan, x::AbstractArray) + y = P * x + Δy = P * Δx + return y, Δy +end +function ChainRulesCore.rrule(::typeof(*), P::Plan, x::AbstractArray) + y = P * x + project_x = ChainRulesCore.ProjectTo(x) + Pt = P' + function mul_plan_pullback(ȳ) + x̄ = project_x(Pt * ȳ) + return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), x̄ + end + return y, mul_plan_pullback +end From 061eef9dbf840c85845fc4204f0d3d41d06c1af3 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Wed, 8 Jun 2022 23:35:55 -0700 Subject: [PATCH 03/26] Test plan adjoints and AD rules --- test/runtests.jl | 93 ++++++++++++++++++++++++++++++++++++++++++++++- test/testplans.jl | 37 +++++++++++-------- 2 files changed, 114 insertions(+), 16 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 623d6256..f06a270e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using AbstractFFTs using AbstractFFTs: Plan using ChainRulesTestUtils +using ChainRulesCore: NoTangent using LinearAlgebra using Random @@ -197,6 +198,79 @@ end @test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10 end +@testset "output size" begin + @testset "complex fft output size" begin + for x in (randn(3), randn(3, 4), randn(3, 4, 5)) + N = ndims(x) + y = randn(size(x)) + for dims in unique((1, 1:N, N)) + P = plan_fft(x, dims) + @test AbstractFFTs.output_size(P) == size(x) + @test AbstractFFTs.output_size(P') == size(x) + Pinv = plan_ifft(x) + @test AbstractFFTs.output_size(Pinv) == size(x) + @test AbstractFFTs.output_size(Pinv') == size(x) + end + end + end + @testset "real fft output size" begin + for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) # test odd and even lengths + N = ndims(x) + for dims in unique((1, 1:N, N)) + P = plan_rfft(x, dims) + Px_sz = size(P * x) + @test AbstractFFTs.output_size(P) == Px_sz + @test AbstractFFTs.output_size(P') == size(x) + y = randn(Px_sz) .+ randn(Px_sz) * im + Pinv = plan_irfft(y, size(x)[first(dims)], dims) + @test AbstractFFTs.output_size(Pinv) == size(Pinv * y) + @test AbstractFFTs.output_size(Pinv') == size(y) + end + end + end +end + +@testset "adjoint" begin + @testset "complex fft adjoint" begin + for x in (randn(3), randn(3, 4), randn(3, 4, 5)) + N = ndims(x) + y = randn(size(x)) + for dims in unique((1, 1:N, N)) + P = plan_fft(x, dims) + @test (P')' * x == P * x # test adjoint of adjoint + @test size(P') == AbstractFFTs.output_size(P) # test size of adjoint + @test dot(y, P * x) ≈ dot(P' * y, x) # test validity of adjoint + @test_broken dot(y, P \ x) ≈ dot(P' \ y, x) + Pinv = plan_ifft(y) + @test (Pinv')' * y == Pinv * y + @test size(Pinv') == AbstractFFTs.output_size(Pinv) + @test dot(x, Pinv * y) ≈ dot(Pinv' * x, y) + @test_broken dot(x, Pinv \ y) ≈ dot(Pinv' \ x, y) + end + end + end + @testset "real fft adjoint" begin + for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) # test odd and even lengths + N = ndims(x) + for dims in unique((1, 1:N, N)) + P = plan_rfft(x, dims) + y_real = randn(size(P * x)) + y_imag = randn(size(P * x)) + y = y_real .+ y_imag .* im + @test (P')' * x == P * x + @test size(P') == AbstractFFTs.output_size(P) + @test dot(y_real, real.(P * x)) + dot(y_imag, imag.(P * x)) ≈ dot(P' * y, x) + @test_broken dot(y_real, real.(P \ x)) + dot(y_imag, imag.(P \ x)) ≈ dot(P' * y, x) + Pinv = plan_irfft(y, size(x)[first(dims)], dims) + @test (Pinv')' * y == Pinv * y + @test size(Pinv') == AbstractFFTs.output_size(Pinv) + @test dot(x, Pinv * y) ≈ dot(y_real, real.(Pinv' * x)) + dot(y_imag, imag.(Pinv' * x)) + @test_broken dot(x, Pinv \ y) ≈ dot(y_real, real.(Pinv' \ x)) + dot(y_imag, imag.(Pinv' \ x)) + end + end + end +end + @testset "ChainRules" begin @testset "shift functions" begin for x in (randn(3), randn(3, 4), randn(3, 4, 5)) @@ -218,20 +292,31 @@ end end @testset "fft" begin - for x in (randn(3), randn(3, 4), randn(3, 4, 5)) + for x in (randn(2), randn(2, 3), randn(3, 4, 5)) N = ndims(x) complex_x = complex.(x) for dims in unique((1, 1:N, N)) + # fft, ifft, bfft for f in (fft, ifft, bfft) test_frule(f, x, dims) test_rrule(f, x, dims) test_frule(f, complex_x, dims) test_rrule(f, complex_x, dims) end + for pf in (plan_fft, plan_ifft, plan_bfft) + test_frule(*, pf(x, dims) ⊢ NoTangent(), x) + test_rrule(*, pf(x, dims) ⊢ NoTangent(), x) + test_frule(*, pf(complex_x, dims) ⊢ NoTangent(), complex_x) + test_rrule(*, pf(complex_x, dims) ⊢ NoTangent(), complex_x) + end + # rfft test_frule(rfft, x, dims) test_rrule(rfft, x, dims) + test_frule(*, plan_rfft(x, dims) ⊢ NoTangent(), x) + test_rrule(*, plan_rfft(x, dims) ⊢ NoTangent(), x) + # irfft, brfft for f in (irfft, brfft) for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2) test_frule(f, x, d, dims) @@ -240,6 +325,12 @@ end test_rrule(f, complex_x, d, dims) end end + for pf in (plan_irfft, plan_brfft) + for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2) + test_frule(*, pf(complex_x, d, dims) ⊢ NoTangent(), complex_x) + test_rrule(*, pf(complex_x, d, dims) ⊢ NoTangent(), complex_x) + end + end end end end diff --git a/test/testplans.jl b/test/testplans.jl index 7abecfeb..fa07729a 100644 --- a/test/testplans.jl +++ b/test/testplans.jl @@ -1,18 +1,18 @@ -mutable struct TestPlan{T,N} <: Plan{T} - region +mutable struct TestPlan{T,N,G} <: Plan{T} + region::G sz::NTuple{N,Int} pinv::Plan{T} - function TestPlan{T}(region, sz::NTuple{N,Int}) where {T,N} - return new{T,N}(region, sz) + function TestPlan{T}(region::G, sz::NTuple{N,Int}) where {T,N,G} + return new{T,N,G}(region, sz) end end -mutable struct InverseTestPlan{T,N} <: Plan{T} - region +mutable struct InverseTestPlan{T,N,G} <: Plan{T} + region::G sz::NTuple{N,Int} pinv::Plan{T} - function InverseTestPlan{T}(region, sz::NTuple{N,Int}) where {T,N} - return new{T,N}(region, sz) + function InverseTestPlan{T}(region::G, sz::NTuple{N,Int}) where {T,N,G} + return new{T,N,G}(region, sz) end end @@ -21,6 +21,9 @@ Base.ndims(::TestPlan{T,N}) where {T,N} = N Base.size(p::InverseTestPlan) = p.sz Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N +AbstractFFTs.ProjectionStyle(::TestPlan) = AbstractFFTs.NoProjectionStyle() +AbstractFFTs.ProjectionStyle(::InverseTestPlan) = AbstractFFTs.NoProjectionStyle() + function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T} return TestPlan{T}(region, size(x)) end @@ -89,24 +92,28 @@ end Base.:*(p::TestPlan, x::AbstractArray) = mul!(similar(x, complex(float(eltype(x)))), p, x) Base.:*(p::InverseTestPlan, x::AbstractArray) = mul!(similar(x, complex(float(eltype(x)))), p, x) -mutable struct TestRPlan{T,N} <: Plan{T} - region +mutable struct TestRPlan{T,N,G} <: Plan{T} + region::G sz::NTuple{N,Int} pinv::Plan{T} - TestRPlan{T}(region, sz::NTuple{N,Int}) where {T,N} = new{T,N}(region, sz) + TestRPlan{T}(region::G, sz::NTuple{N,Int}) where {T,N,G} = new{T,N,G}(region, sz) end -mutable struct InverseTestRPlan{T,N} <: Plan{T} +mutable struct InverseTestRPlan{T,N,G} <: Plan{T} d::Int - region + region::G sz::NTuple{N,Int} pinv::Plan{T} - function InverseTestRPlan{T}(d::Int, region, sz::NTuple{N,Int}) where {T,N} + function InverseTestRPlan{T}(d::Int, region::G, sz::NTuple{N,Int}) where {T,N,G} sz[first(region)::Int] == d ÷ 2 + 1 || error("incompatible dimensions") - return new{T,N}(d, region, sz) + return new{T,N,G}(d, region, sz) end end +AbstractFFTs.ProjectionStyle(::TestRPlan) = AbstractFFTs.RealProjectionStyle() +AbstractFFTs.ProjectionStyle(::InverseTestRPlan) = AbstractFFTs.RealInverseProjectionStyle() +AbstractFFTs.irfft_dim(p::InverseTestRPlan) = p.d + function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T} return TestRPlan{T}(region, size(x)) end From 497ff4db2aae5f0acd43b099c7d721ee26f3b911 Mon Sep 17 00:00:00 2001 From: gaurav-arya Date: Thu, 9 Jun 2022 09:35:17 -0700 Subject: [PATCH 04/26] Apply suggestions from adjoint plan code review Co-authored-by: David Widmann --- src/definitions.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/definitions.jl b/src/definitions.jl index e140233f..380ac9ca 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -593,16 +593,16 @@ _output_size(p::Plan, ::NoProjectionStyle) = size(p) _output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), region(p)) _output_size(p::Plan, ::RealInverseProjectionStyle) = brfft_output_size(size(p), irfft_dim(p), region(p)) -mutable struct AdjointPlan{T,P} <: Plan{T} +mutable struct AdjointPlan{T,P<:Plan} <: Plan{T} p::P pinv::Plan AdjointPlan{T,P}(p) where {T,P} = new(p) end Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{T, typeof(p)}(p) -Base.adjoint(p::AdjointPlan{T}) where {T} = p.p +Base.adjoint(p::AdjointPlan) = p.p # always have AdjointPlan inside ScaledPlan. -Base.adjoint(p::ScaledPlan{T}) where {T} = ScaledPlan{T}(p.p', p.scale) +Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale) size(p::AdjointPlan) = output_size(p.p) output_size(p::AdjointPlan) = size(p.p) @@ -612,7 +612,7 @@ Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p)) function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T} dims = region(p.p) N = normalization(T, size(p.p), dims) - return 1/N * (p.p \ x) + return (p.p \ x) / N end function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T} @@ -622,10 +622,10 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where d = size(p.p, halfdim) n = output_size(p.p, halfdim) scale = reshape( - [(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n], - ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))) + [(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n], + ntuple(i -> i == halfdim ? n : 1, Val(ndims(x))) ) - return 1/N * (p.p \ (x ./ scale)) + return p.p \ (x ./ scale) end function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T} @@ -636,9 +636,9 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) d = output_size(p.p, halfdim) scale = reshape( [(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n], - ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))) + ntuple(i -> i == halfdim ? n : 1, Val(ndims(x))) ) - return 1/N * scale .* (p.p \ x) + return scale ./ N .* (p.p \ x) end -plan_inv(p::AdjointPlan) = AdjointPlan(plan_inv(p.p)) +plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p)) From 5d5c06c61a23e6e0bef48434a600ed5e9eaf90d9 Mon Sep 17 00:00:00 2001 From: gaurav-arya Date: Thu, 9 Jun 2022 09:39:55 -0700 Subject: [PATCH 05/26] Include irrft_dim in RealInverseProjectionStyle Co-authored-by: David Widmann --- src/definitions.jl | 6 ++++-- test/testplans.jl | 3 +-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/definitions.jl b/src/definitions.jl index 380ac9ca..0abe30db 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -583,7 +583,9 @@ plan_brfft struct NoProjectionStyle end struct RealProjectionStyle end -struct RealInverseProjectionStyle end +struct RealInverseProjectionStyle + dim::Int +end const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle} function irfft_dim end @@ -591,7 +593,7 @@ function irfft_dim end output_size(p::Plan) = _output_size(p, ProjectionStyle(p)) _output_size(p::Plan, ::NoProjectionStyle) = size(p) _output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), region(p)) -_output_size(p::Plan, ::RealInverseProjectionStyle) = brfft_output_size(size(p), irfft_dim(p), region(p)) +_output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, region(p)) mutable struct AdjointPlan{T,P<:Plan} <: Plan{T} p::P diff --git a/test/testplans.jl b/test/testplans.jl index fa07729a..27f5f00c 100644 --- a/test/testplans.jl +++ b/test/testplans.jl @@ -111,8 +111,7 @@ mutable struct InverseTestRPlan{T,N,G} <: Plan{T} end AbstractFFTs.ProjectionStyle(::TestRPlan) = AbstractFFTs.RealProjectionStyle() -AbstractFFTs.ProjectionStyle(::InverseTestRPlan) = AbstractFFTs.RealInverseProjectionStyle() -AbstractFFTs.irfft_dim(p::InverseTestRPlan) = p.d +AbstractFFTs.ProjectionStyle(p::InverseTestRPlan) = AbstractFFTs.RealInverseProjectionStyle(p.d) function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T} return TestRPlan{T}(region, size(x)) From ef84edfb10587c2deca69ecee3bbac101f882b80 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Thu, 30 Jun 2022 20:13:32 -0400 Subject: [PATCH 06/26] update to new fftdims interface --- src/definitions.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/definitions.jl b/src/definitions.jl index 0abe30db..497007e6 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -592,8 +592,8 @@ function irfft_dim end output_size(p::Plan) = _output_size(p, ProjectionStyle(p)) _output_size(p::Plan, ::NoProjectionStyle) = size(p) -_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), region(p)) -_output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, region(p)) +_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), fftdims(p)) +_output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, fftdims(p)) mutable struct AdjointPlan{T,P<:Plan} <: Plan{T} p::P @@ -612,13 +612,13 @@ output_size(p::AdjointPlan) = size(p.p) Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p)) function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T} - dims = region(p.p) + dims = fftdims(p.p) N = normalization(T, size(p.p), dims) return (p.p \ x) / N end function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T} - dims = region(p.p) + dims = fftdims(p.p) N = normalization(T, size(p.p), dims) halfdim = first(dims) d = size(p.p, halfdim) @@ -631,7 +631,7 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where end function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T} - dims = region(p.p) + dims = fftdims(p.p) N = normalization(real(T), output_size(p.p), dims) halfdim = first(dims) n = size(p.p, halfdim) From d7ff39433c1dd3ab3be266f9a4993818d23a1571 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Thu, 30 Jun 2022 21:35:00 -0400 Subject: [PATCH 07/26] fix broken tests --- test/runtests.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index f06a270e..93c14625 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -240,12 +240,12 @@ end @test (P')' * x == P * x # test adjoint of adjoint @test size(P') == AbstractFFTs.output_size(P) # test size of adjoint @test dot(y, P * x) ≈ dot(P' * y, x) # test validity of adjoint - @test_broken dot(y, P \ x) ≈ dot(P' \ y, x) + @test dot(y, P \ x) ≈ dot(P' \ y, x) Pinv = plan_ifft(y) @test (Pinv')' * y == Pinv * y @test size(Pinv') == AbstractFFTs.output_size(Pinv) @test dot(x, Pinv * y) ≈ dot(Pinv' * x, y) - @test_broken dot(x, Pinv \ y) ≈ dot(Pinv' \ x, y) + @test dot(x, Pinv \ y) ≈ dot(Pinv' \ x, y) end end end @@ -260,12 +260,12 @@ end @test (P')' * x == P * x @test size(P') == AbstractFFTs.output_size(P) @test dot(y_real, real.(P * x)) + dot(y_imag, imag.(P * x)) ≈ dot(P' * y, x) - @test_broken dot(y_real, real.(P \ x)) + dot(y_imag, imag.(P \ x)) ≈ dot(P' * y, x) + @test dot(y_real, real.(P' \ x)) + dot(y_imag, imag.(P' \ x)) ≈ dot(P \ y, x) Pinv = plan_irfft(y, size(x)[first(dims)], dims) @test (Pinv')' * y == Pinv * y @test size(Pinv') == AbstractFFTs.output_size(Pinv) @test dot(x, Pinv * y) ≈ dot(y_real, real.(Pinv' * x)) + dot(y_imag, imag.(Pinv' * x)) - @test_broken dot(x, Pinv \ y) ≈ dot(y_real, real.(Pinv' \ x)) + dot(y_imag, imag.(Pinv' \ x)) + @test dot(x, Pinv' \ y) ≈ dot(y_real, real.(Pinv \ x)) + dot(y_imag, imag.(Pinv \ x)) end end end From aa8e5759d8a07598e7a7716d13566ff5b295b119 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Fri, 1 Jul 2022 00:54:43 -0400 Subject: [PATCH 08/26] Explicitly don't support mul! for adjoint plans --- src/definitions.jl | 3 +++ test/runtests.jl | 1 + 2 files changed, 4 insertions(+) diff --git a/src/definitions.jl b/src/definitions.jl index 497007e6..3cf9d1d8 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -644,3 +644,6 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) end plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p)) +function LinearAlgebra.mul!(y::AbstractArray, p::AdjointPlan, x::AbstractArray) + throw(MethodError(LinearAlgebra.mul!, "mul! is not supported for adjoint plans")) +end diff --git a/test/runtests.jl b/test/runtests.jl index 93c14625..a2fcf427 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -246,6 +246,7 @@ end @test size(Pinv') == AbstractFFTs.output_size(Pinv) @test dot(x, Pinv * y) ≈ dot(Pinv' * x, y) @test dot(x, Pinv \ y) ≈ dot(Pinv' \ x, y) + @test_throws MethodError mul!(x, P', y) end end end From 9d998863064d94e8b619687918665a50c46b6378 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Fri, 1 Jul 2022 01:05:04 -0400 Subject: [PATCH 09/26] Document adjoint plans --- src/definitions.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/definitions.jl b/src/definitions.jl index 3cf9d1d8..d812e477 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -601,6 +601,14 @@ mutable struct AdjointPlan{T,P<:Plan} <: Plan{T} AdjointPlan{T,P}(p) where {T,P} = new(p) end +""" + Base.adjoint(p::Plan) + +Form the adjoint operator of an FFT plan. Returns a plan `p'` which performs the adjoint operation +the original plan. Note that this differs from the corresponding backwards plan in the case of real +FFTs due to the halving of one of the dimensions of the FFT output, as described in [`rfft`](@ref). +Adjoint plans do not currently support `mul!`. +""" Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{T, typeof(p)}(p) Base.adjoint(p::AdjointPlan) = p.p # always have AdjointPlan inside ScaledPlan. From ac7c78c3c05504ea860c733bd75d9270c4602166 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Fri, 1 Jul 2022 01:23:17 -0400 Subject: [PATCH 10/26] remove incorrectly thrown error --- src/definitions.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/definitions.jl b/src/definitions.jl index d812e477..66ebb2e0 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -652,6 +652,3 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) end plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p)) -function LinearAlgebra.mul!(y::AbstractArray, p::AdjointPlan, x::AbstractArray) - throw(MethodError(LinearAlgebra.mul!, "mul! is not supported for adjoint plans")) -end From 8474141abffdb1117c7222d0b273e73b30b78595 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Thu, 14 Jul 2022 11:18:50 -0700 Subject: [PATCH 11/26] Update adjoint plan docs --- README.md | 23 +------------------- docs/Project.toml | 1 + docs/src/implementations.md | 42 +++++++++++++++++++++++++------------ 3 files changed, 31 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index 5b33c59c..df63a20d 100644 --- a/README.md +++ b/README.md @@ -16,25 +16,4 @@ This allows multiple FFT packages to co-exist with the same underlying `fft(x)` ## Developer information -To define a new FFT implementation in your own module, you should - -* Define a new subtype (e.g. `MyPlan`) of `AbstractFFTs.Plan{T}` for FFTs and related transforms on arrays of `T`. - This must have a `pinv::Plan` field, initially undefined when a `MyPlan` is created, that is used for caching the - inverse plan. - -* Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of - `x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`). - -* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` (or `A_mul_B!(y, p::MyPlan, x)` on Julia prior to - 0.7.0-DEV.3204) that computes the transform `p` of `x` and stores the result in `y`. - -* Define a method of `*(p::MyPlan, x)`, which can simply call your `mul!` (or `A_mul_B!`) method. - This is not defined generically in this package due to subtleties that arise for in-place and real-input FFTs. - -* If the inverse transform is implemented, you should also define `plan_inv(p::MyPlan)`, which should construct the - inverse plan to `p`, and `plan_bfft(x, region; kws...)` for an unnormalized inverse ("backwards") transform of `x`. - -* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs. - -The normalization convention for your FFT should be that it computes yₖ = ∑ⱼ xⱼ exp(-2πi jk/n) for a transform of -length n, and the "backwards" (unnormalized inverse) transform computes the same thing but with exp(+2πi jk/n). +To define a new FFT implementation in your own module, see [defining a new implementation](https://juliamath.github.io/AbstractFFTs.jl/stable/implementations/#Defining-a-new-implementation). diff --git a/docs/Project.toml b/docs/Project.toml index ed025f5a..4ca9eda1 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" [compat] diff --git a/docs/src/implementations.md b/docs/src/implementations.md index 632a6026..09044cd9 100644 --- a/docs/src/implementations.md +++ b/docs/src/implementations.md @@ -11,16 +11,32 @@ The following packages extend the functionality provided by AbstractFFTs: ## Defining a new implementation -Implementations should implement `LinearAlgebra.mul!(Y, plan, X)` (or -`A_mul_B!(y, p::MyPlan, x)` on Julia prior to 0.7.0-DEV.3204) so as to support -pre-allocated output arrays. -We don't define `*` in terms of `mul!` generically here, however, because -of subtleties for in-place and real FFT plans. - -To support `inv`, `\`, and `ldiv!(y, plan, x)`, we require `Plan` subtypes -to have a `pinv::Plan` field, which caches the inverse plan, and which should be -initially undefined. -They should also implement `plan_inv(p)` to construct the inverse of a plan `p`. - -Implementations only need to provide the unnormalized backwards FFT, -similar to FFTW, and we do the scaling generically to get the inverse FFT. +To define a new FFT implementation in your own module, you should + +* Define a new subtype (e.g. `MyPlan`) of `AbstractFFTs.Plan{T}` for FFTs and related transforms on arrays of `T`. + This must have a `pinv::Plan` field, initially undefined when a `MyPlan` is created, that is used for caching the + inverse plan. + +* Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of + `x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`). + +* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` (or `A_mul_B!(y, p::MyPlan, x)` on Julia prior to + 0.7.0-DEV.3204) that computes the transform `p` of `x` and stores the result in `y`. + +* Define a method of `*(p::MyPlan, x)`, which can simply call your `mul!` (or `A_mul_B!`) method. + This is not defined generically in this package due to subtleties that arise for in-place and real-input FFTs. + +* If the inverse transform is implemented, you should also define `plan_inv(p::MyPlan)`, which should construct the + inverse plan to `p`, and `plan_bfft(x, region; kws...)` for an unnormalized inverse ("backwards") transform of `x`. + Implementations only need to provide the unnormalized backwards FFT, similar to FFTW, and we do the scaling generically + to get the inverse FFT. + +* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs. + +* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can take values: + * `AbstractFFTs.NoProjectionStyle()`, + * `AbstractFFTs.RealProjectionStyle()`, for plans which halve one of the output's dimensions analogously to [`rfft`](@ref), + * `AbstractFFTs.RealInverseProjectionStyle(d::Integer)`, for plans which expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension. + +The normalization convention for your FFT should be that it computes yₖ = ∑ⱼ xⱼ exp(-2πi jk/n) for a transform of +length n, and the "backwards" (unnormalized inverse) transform computes the same thing but with exp(+2πi jk/n). From 769c090a901be5a41db1a3e2495fa96d6506038c Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Thu, 14 Jul 2022 11:43:25 -0700 Subject: [PATCH 12/26] Update adjoint docs --- docs/src/api.md | 1 + src/definitions.jl | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index 5d8316b2..bb3b8492 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -20,6 +20,7 @@ AbstractFFTs.plan_rfft AbstractFFTs.plan_brfft AbstractFFTs.plan_irfft AbstractFFTs.fftdims +Base.adjoint AbstractFFTs.fftshift AbstractFFTs.fftshift! AbstractFFTs.ifftshift diff --git a/src/definitions.jl b/src/definitions.jl index 66ebb2e0..5de78d4e 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -602,7 +602,7 @@ mutable struct AdjointPlan{T,P<:Plan} <: Plan{T} end """ - Base.adjoint(p::Plan) + adjoint(p::Plan) Form the adjoint operator of an FFT plan. Returns a plan `p'` which performs the adjoint operation the original plan. Note that this differs from the corresponding backwards plan in the case of real From 3ed83dfd7be5ce4e7b0f746679092ca43056b223 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Thu, 14 Jul 2022 12:16:47 -0700 Subject: [PATCH 13/26] Fix typos --- docs/src/implementations.md | 2 +- src/definitions.jl | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/src/implementations.md b/docs/src/implementations.md index 09044cd9..04901be5 100644 --- a/docs/src/implementations.md +++ b/docs/src/implementations.md @@ -36,7 +36,7 @@ To define a new FFT implementation in your own module, you should * To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can take values: * `AbstractFFTs.NoProjectionStyle()`, * `AbstractFFTs.RealProjectionStyle()`, for plans which halve one of the output's dimensions analogously to [`rfft`](@ref), - * `AbstractFFTs.RealInverseProjectionStyle(d::Integer)`, for plans which expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension. + * `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans which expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension. The normalization convention for your FFT should be that it computes yₖ = ∑ⱼ xⱼ exp(-2πi jk/n) for a transform of length n, and the "backwards" (unnormalized inverse) transform computes the same thing but with exp(+2πi jk/n). diff --git a/src/definitions.jl b/src/definitions.jl index 5de78d4e..d4778410 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -588,8 +588,6 @@ struct RealInverseProjectionStyle end const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle} -function irfft_dim end - output_size(p::Plan) = _output_size(p, ProjectionStyle(p)) _output_size(p::Plan, ::NoProjectionStyle) = size(p) _output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), fftdims(p)) From 552d49f93b3de8a877fec8ed0e19331e759c6836 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Thu, 14 Jul 2022 12:25:23 -0700 Subject: [PATCH 14/26] tweak adjoint doc string --- src/definitions.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/definitions.jl b/src/definitions.jl index d4778410..a92a0be0 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -600,12 +600,15 @@ mutable struct AdjointPlan{T,P<:Plan} <: Plan{T} end """ + p' adjoint(p::Plan) -Form the adjoint operator of an FFT plan. Returns a plan `p'` which performs the adjoint operation +Form the adjoint operator of an FFT plan. Returns a plan which performs the adjoint operation the original plan. Note that this differs from the corresponding backwards plan in the case of real FFTs due to the halving of one of the dimensions of the FFT output, as described in [`rfft`](@ref). -Adjoint plans do not currently support `mul!`. + +!!! note + Adjoint plans do not currently support `LinearAlgebra.mul!`. """ Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{T, typeof(p)}(p) Base.adjoint(p::AdjointPlan) = p.p From 1e9ece2e5ed92ef857ffe66d57233e19c8d31010 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Thu, 14 Jul 2022 17:21:12 -0700 Subject: [PATCH 15/26] Tweaks to adjoint description --- src/definitions.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/definitions.jl b/src/definitions.jl index a92a0be0..a0e7f781 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -608,7 +608,8 @@ the original plan. Note that this differs from the corresponding backwards plan FFTs due to the halving of one of the dimensions of the FFT output, as described in [`rfft`](@ref). !!! note - Adjoint plans do not currently support `LinearAlgebra.mul!`. + Adjoint plans do not currently support `LinearAlgebra.mul!`. Further, as a new addition to `AbstractFFTs`, + coverage of `Base.adjoint` in downstream implementations may be limited. """ Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{T, typeof(p)}(p) Base.adjoint(p::AdjointPlan) = p.p From 8ddfa9750afc9fe98ba8de90d01e241cd7ecf99b Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sat, 16 Jul 2022 01:36:16 -0400 Subject: [PATCH 16/26] Immutable AdjointPlan --- src/definitions.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/definitions.jl b/src/definitions.jl index a0e7f781..dfdfc00b 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -593,9 +593,8 @@ _output_size(p::Plan, ::NoProjectionStyle) = size(p) _output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), fftdims(p)) _output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, fftdims(p)) -mutable struct AdjointPlan{T,P<:Plan} <: Plan{T} +struct AdjointPlan{T,P<:Plan} <: Plan{T} p::P - pinv::Plan AdjointPlan{T,P}(p) where {T,P} = new(p) end @@ -653,4 +652,4 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) return scale ./ N .* (p.p \ x) end -plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p)) +inv(p::AdjointPlan) = adjoint(inv(p.p)) From 87758c83508ee80e9102dfb57b3737894c4f5c8d Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sat, 6 Aug 2022 18:13:24 -0400 Subject: [PATCH 17/26] Add rules and tests for ScaledPlan --- Project.toml | 3 ++- src/chainrules.jl | 20 ++++++++++++++++++++ test/runtests.jl | 39 ++++++++++++++++++++++++++------------- 3 files changed, 48 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index a639c5dd..1f3355fb 100644 --- a/Project.toml +++ b/Project.toml @@ -12,9 +12,10 @@ julia = "^1.0" [extras] ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["ChainRulesTestUtils", "Random", "Test", "Unitful"] +test = ["ChainRulesTestUtils", "FiniteDifferences", "Random", "Test", "Unitful"] diff --git a/src/chainrules.jl b/src/chainrules.jl index 610bd59d..80c51142 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -167,3 +167,23 @@ function ChainRulesCore.rrule(::typeof(*), P::Plan, x::AbstractArray) end return y, mul_plan_pullback end + +function ChainRulesCore.frule((_, ΔP, Δx), ::typeof(*), P::ScaledPlan, x::AbstractArray) + y = P * x + Δy = P * Δx + ΔP.scale / P.scale * y + return y, Δy +end +function ChainRulesCore.rrule(::typeof(*), P::ScaledPlan, x::AbstractArray) + y = P * x + project_x = ChainRulesCore.ProjectTo(x) + project_scale = ChainRulesCore.ProjectTo(P.scale) + Pt = P' + scale = P.scale + function mul_plan_pullback(ȳ) + x̄ = ChainRulesCore.@thunk(project_x(Pt * ȳ)) + scale_tangent = ChainRulesCore.@thunk(project_scale(sum(conj(y) .* ȳ) / conj(scale))) + plan_tangent = ChainRulesCore.Tangent{typeof(P)}(;p=ChainRulesCore.NoTangent(), scale=scale_tangent) + return ChainRulesCore.NoTangent(), plan_tangent, x̄ + end + return y, mul_plan_pullback +end diff --git a/test/runtests.jl b/test/runtests.jl index a2fcf427..edbb1b0b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,9 +1,10 @@ # This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license using AbstractFFTs -using AbstractFFTs: Plan +using AbstractFFTs: Plan, ScaledPlan using ChainRulesTestUtils -using ChainRulesCore: NoTangent +using FiniteDifferences +import ChainRulesCore using LinearAlgebra using Random @@ -293,9 +294,21 @@ end end @testset "fft" begin - for x in (randn(2), randn(2, 3), randn(3, 4, 5)) - N = ndims(x) - complex_x = complex.(x) + # Overloads to allow ChainRulesTestUtils to test rules w.r.t. ScaledPlan's. See https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/256 + InnerPlan = Union{TestPlan, InverseTestPlan, TestRPlan, InverseTestRPlan} + function FiniteDifferences.to_vec(x::InnerPlan) + function FFTPlan_from_vec(x_vec::Vector) + return x + end + return Bool[], FFTPlan_from_vec + end + ChainRulesTestUtils.test_approx(::ChainRulesCore.AbstractZero, x::InnerPlan, msg=""; kwargs...) = true + ChainRulesTestUtils.rand_tangent(::AbstractRNG, x::InnerPlan) = ChainRulesCore.NoTangent() + + for x_shape in ((2,), (2, 3), (3, 4, 5)) + N = length(x_shape) + x = randn(x_shape) + complex_x = x + randn(x_shape) * im for dims in unique((1, 1:N, N)) # fft, ifft, bfft for f in (fft, ifft, bfft) @@ -305,17 +318,17 @@ end test_rrule(f, complex_x, dims) end for pf in (plan_fft, plan_ifft, plan_bfft) - test_frule(*, pf(x, dims) ⊢ NoTangent(), x) - test_rrule(*, pf(x, dims) ⊢ NoTangent(), x) - test_frule(*, pf(complex_x, dims) ⊢ NoTangent(), complex_x) - test_rrule(*, pf(complex_x, dims) ⊢ NoTangent(), complex_x) + test_frule(*, pf(x, dims), x) + test_rrule(*, pf(x, dims), x) + test_frule(*, pf(complex_x, dims), complex_x) + test_rrule(*, pf(complex_x, dims), complex_x) end # rfft test_frule(rfft, x, dims) test_rrule(rfft, x, dims) - test_frule(*, plan_rfft(x, dims) ⊢ NoTangent(), x) - test_rrule(*, plan_rfft(x, dims) ⊢ NoTangent(), x) + test_frule(*, plan_rfft(x, dims), x) + test_rrule(*, plan_rfft(x, dims), x) # irfft, brfft for f in (irfft, brfft) @@ -328,8 +341,8 @@ end end for pf in (plan_irfft, plan_brfft) for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2) - test_frule(*, pf(complex_x, d, dims) ⊢ NoTangent(), complex_x) - test_rrule(*, pf(complex_x, d, dims) ⊢ NoTangent(), complex_x) + test_frule(*, pf(complex_x, d, dims), complex_x) + test_rrule(*, pf(complex_x, d, dims), complex_x) end end end From 09b8b38ab9a63989a202254f8653774afe216106 Mon Sep 17 00:00:00 2001 From: gaurav-arya Date: Tue, 16 Aug 2022 17:49:23 -0400 Subject: [PATCH 18/26] Apply suggestions from code review Co-authored-by: David Widmann --- docs/src/implementations.md | 7 +++---- src/chainrules.jl | 9 ++++----- test/runtests.jl | 8 +++----- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/docs/src/implementations.md b/docs/src/implementations.md index 04901be5..a8bf6db0 100644 --- a/docs/src/implementations.md +++ b/docs/src/implementations.md @@ -20,10 +20,9 @@ To define a new FFT implementation in your own module, you should * Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of `x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`). -* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` (or `A_mul_B!(y, p::MyPlan, x)` on Julia prior to - 0.7.0-DEV.3204) that computes the transform `p` of `x` and stores the result in `y`. +* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` that computes the transform `p` of `x` and stores the result in `y`. -* Define a method of `*(p::MyPlan, x)`, which can simply call your `mul!` (or `A_mul_B!`) method. +* Define a method of `*(p::MyPlan, x)`, which can simply call your `mul!` method. This is not defined generically in this package due to subtleties that arise for in-place and real-input FFTs. * If the inverse transform is implemented, you should also define `plan_inv(p::MyPlan)`, which should construct the @@ -33,7 +32,7 @@ To define a new FFT implementation in your own module, you should * You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs. -* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can take values: +* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse-mode differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can take values: * `AbstractFFTs.NoProjectionStyle()`, * `AbstractFFTs.RealProjectionStyle()`, for plans which halve one of the output's dimensions analogously to [`rfft`](@ref), * `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans which expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension. diff --git a/src/chainrules.jl b/src/chainrules.jl index 80c51142..9e7e46b1 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -170,20 +170,19 @@ end function ChainRulesCore.frule((_, ΔP, Δx), ::typeof(*), P::ScaledPlan, x::AbstractArray) y = P * x - Δy = P * Δx + ΔP.scale / P.scale * y + Δy = P * Δx .+ (ΔP.scale / P.scale) .* y return y, Δy end function ChainRulesCore.rrule(::typeof(*), P::ScaledPlan, x::AbstractArray) y = P * x project_x = ChainRulesCore.ProjectTo(x) - project_scale = ChainRulesCore.ProjectTo(P.scale) Pt = P' scale = P.scale - function mul_plan_pullback(ȳ) + function mul_scaledplan_pullback(ȳ) x̄ = ChainRulesCore.@thunk(project_x(Pt * ȳ)) - scale_tangent = ChainRulesCore.@thunk(project_scale(sum(conj(y) .* ȳ) / conj(scale))) + scale_tangent = ChainRulesCore.@thunk(dot(y, ȳ) / conj(scale)) plan_tangent = ChainRulesCore.Tangent{typeof(P)}(;p=ChainRulesCore.NoTangent(), scale=scale_tangent) return ChainRulesCore.NoTangent(), plan_tangent, x̄ end - return y, mul_plan_pullback + return y, mul_scaledplan_pullback end diff --git a/test/runtests.jl b/test/runtests.jl index edbb1b0b..7c191c20 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -206,7 +206,7 @@ end y = randn(size(x)) for dims in unique((1, 1:N, N)) P = plan_fft(x, dims) - @test AbstractFFTs.output_size(P) == size(x) + @test @inferred(AbstractFFTs.output_size(P)) == size(x) @test AbstractFFTs.output_size(P') == size(x) Pinv = plan_ifft(x) @test AbstractFFTs.output_size(Pinv) == size(x) @@ -222,7 +222,7 @@ end Px_sz = size(P * x) @test AbstractFFTs.output_size(P) == Px_sz @test AbstractFFTs.output_size(P') == size(x) - y = randn(Px_sz) .+ randn(Px_sz) * im + y = randn(Complex{Float64}, Px_sz) Pinv = plan_irfft(y, size(x)[first(dims)], dims) @test AbstractFFTs.output_size(Pinv) == size(Pinv * y) @test AbstractFFTs.output_size(Pinv') == size(y) @@ -256,9 +256,7 @@ end N = ndims(x) for dims in unique((1, 1:N, N)) P = plan_rfft(x, dims) - y_real = randn(size(P * x)) - y_imag = randn(size(P * x)) - y = y_real .+ y_imag .* im + y = randn(Complex{Float64}, size(P * x)) @test (P')' * x == P * x @test size(P') == AbstractFFTs.output_size(P) @test dot(y_real, real.(P * x)) + dot(y_imag, imag.(P * x)) ≈ dot(P' * y, x) From d967aa224690e482164410888dafea021574f78a Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Tue, 16 Aug 2022 18:13:11 -0400 Subject: [PATCH 19/26] More tweaks to address code review --- src/chainrules.jl | 5 +++-- test/runtests.jl | 10 +++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 9e7e46b1..905ae00d 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -175,12 +175,13 @@ function ChainRulesCore.frule((_, ΔP, Δx), ::typeof(*), P::ScaledPlan, x::Abst end function ChainRulesCore.rrule(::typeof(*), P::ScaledPlan, x::AbstractArray) y = P * x - project_x = ChainRulesCore.ProjectTo(x) Pt = P' scale = P.scale + project_x = ChainRulesCore.ProjectTo(x) + project_scale = ChainRulesCore.ProjectTo(scale) function mul_scaledplan_pullback(ȳ) x̄ = ChainRulesCore.@thunk(project_x(Pt * ȳ)) - scale_tangent = ChainRulesCore.@thunk(dot(y, ȳ) / conj(scale)) + scale_tangent = ChainRulesCore.@thunk(project_scale(dot(y, ȳ) / conj(scale))) plan_tangent = ChainRulesCore.Tangent{typeof(P)}(;p=ChainRulesCore.NoTangent(), scale=scale_tangent) return ChainRulesCore.NoTangent(), plan_tangent, x̄ end diff --git a/test/runtests.jl b/test/runtests.jl index 7c191c20..fa081c2a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -238,7 +238,7 @@ end y = randn(size(x)) for dims in unique((1, 1:N, N)) P = plan_fft(x, dims) - @test (P')' * x == P * x # test adjoint of adjoint + @test (P')' === P # test adjoint of adjoint @test size(P') == AbstractFFTs.output_size(P) # test size of adjoint @test dot(y, P * x) ≈ dot(P' * y, x) # test validity of adjoint @test dot(y, P \ x) ≈ dot(P' \ y, x) @@ -259,13 +259,13 @@ end y = randn(Complex{Float64}, size(P * x)) @test (P')' * x == P * x @test size(P') == AbstractFFTs.output_size(P) - @test dot(y_real, real.(P * x)) + dot(y_imag, imag.(P * x)) ≈ dot(P' * y, x) - @test dot(y_real, real.(P' \ x)) + dot(y_imag, imag.(P' \ x)) ≈ dot(P \ y, x) + @test dot(real.(y), real.(P * x)) + dot(imag.(y), imag.(P * x)) ≈ dot(P' * y, x) + @test dot(real.(y), real.(P' \ x)) + dot(imag.(y), imag.(P' \ x)) ≈ dot(P \ y, x) Pinv = plan_irfft(y, size(x)[first(dims)], dims) @test (Pinv')' * y == Pinv * y @test size(Pinv') == AbstractFFTs.output_size(Pinv) - @test dot(x, Pinv * y) ≈ dot(y_real, real.(Pinv' * x)) + dot(y_imag, imag.(Pinv' * x)) - @test dot(x, Pinv' \ y) ≈ dot(y_real, real.(Pinv \ x)) + dot(y_imag, imag.(Pinv \ x)) + @test dot(x, Pinv * y) ≈ dot(real.(y), real.(Pinv' * x)) + dot(imag.(y), imag.(Pinv' * x)) + @test dot(x, Pinv' \ y) ≈ dot(real.(y), real.(Pinv \ x)) + dot(imag.(y), imag.(Pinv \ x)) end end end From 2a423e2897c934cf553168eeef7c55db062275d2 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Tue, 16 Aug 2022 19:03:00 -0400 Subject: [PATCH 20/26] Restrict to T<:Real for rfft adjoint --- src/definitions.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/definitions.jl b/src/definitions.jl index dfdfc00b..0a243013 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -626,7 +626,7 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T return (p.p \ x) / N end -function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T} +function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T<:Real} dims = fftdims(p.p) N = normalization(T, size(p.p), dims) halfdim = first(dims) From eedba14590638c1af3ac1cd76f7098a00ebf8484 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Tue, 16 Aug 2022 19:03:24 -0400 Subject: [PATCH 21/26] Get type T correct for test irfft --- test/testplans.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/testplans.jl b/test/testplans.jl index 27f5f00c..d9b336b3 100644 --- a/test/testplans.jl +++ b/test/testplans.jl @@ -95,11 +95,11 @@ Base.:*(p::InverseTestPlan, x::AbstractArray) = mul!(similar(x, complex(float(el mutable struct TestRPlan{T,N,G} <: Plan{T} region::G sz::NTuple{N,Int} - pinv::Plan{T} + pinv::Plan{Complex{T}} TestRPlan{T}(region::G, sz::NTuple{N,Int}) where {T,N,G} = new{T,N,G}(region, sz) end -mutable struct InverseTestRPlan{T,N,G} <: Plan{T} +mutable struct InverseTestRPlan{T,N,G} <: Plan{Complex{T}} d::Int region::G sz::NTuple{N,Int} @@ -113,10 +113,10 @@ end AbstractFFTs.ProjectionStyle(::TestRPlan) = AbstractFFTs.RealProjectionStyle() AbstractFFTs.ProjectionStyle(p::InverseTestRPlan) = AbstractFFTs.RealInverseProjectionStyle(p.d) -function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T} +function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T<:Real} return TestRPlan{T}(region, size(x)) end -function AbstractFFTs.plan_brfft(x::AbstractArray{T}, d, region; kwargs...) where {T} +function AbstractFFTs.plan_brfft(x::AbstractArray{Complex{T}}, d, region; kwargs...) where {T} return InverseTestRPlan{T}(d, region, size(x)) end function AbstractFFTs.plan_inv(p::TestRPlan{T,N}) where {T,N} From 25bb86bfb9aeed9b79b7079b69e539fc885025e1 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Tue, 16 Aug 2022 19:10:30 -0400 Subject: [PATCH 22/26] Test complex input when appropriate for adjoint tests --- test/runtests.jl | 63 +++++++++++++++++++++++++++--------------------- 1 file changed, 35 insertions(+), 28 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index fa081c2a..10b3c275 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -201,16 +201,19 @@ end @testset "output size" begin @testset "complex fft output size" begin - for x in (randn(3), randn(3, 4), randn(3, 4, 5)) - N = ndims(x) - y = randn(size(x)) - for dims in unique((1, 1:N, N)) - P = plan_fft(x, dims) - @test @inferred(AbstractFFTs.output_size(P)) == size(x) - @test AbstractFFTs.output_size(P') == size(x) - Pinv = plan_ifft(x) - @test AbstractFFTs.output_size(Pinv) == size(x) - @test AbstractFFTs.output_size(Pinv') == size(x) + for x_shape in ((3,), (3, 4), (3, 4, 5)) + N = length(x_shape) + real_x = randn(x_shape) + complex_x = randn(ComplexF64, x_shape) + for x in (real_x, complex_x) + for dims in unique((1, 1:N, N)) + P = plan_fft(x, dims) + @test @inferred(AbstractFFTs.output_size(P)) == size(x) + @test AbstractFFTs.output_size(P') == size(x) + Pinv = plan_ifft(x) + @test AbstractFFTs.output_size(Pinv) == size(x) + @test AbstractFFTs.output_size(Pinv') == size(x) + end end end end @@ -222,7 +225,7 @@ end Px_sz = size(P * x) @test AbstractFFTs.output_size(P) == Px_sz @test AbstractFFTs.output_size(P') == size(x) - y = randn(Complex{Float64}, Px_sz) + y = randn(ComplexF64, Px_sz) Pinv = plan_irfft(y, size(x)[first(dims)], dims) @test AbstractFFTs.output_size(Pinv) == size(Pinv * y) @test AbstractFFTs.output_size(Pinv') == size(y) @@ -233,21 +236,25 @@ end @testset "adjoint" begin @testset "complex fft adjoint" begin - for x in (randn(3), randn(3, 4), randn(3, 4, 5)) - N = ndims(x) - y = randn(size(x)) - for dims in unique((1, 1:N, N)) - P = plan_fft(x, dims) - @test (P')' === P # test adjoint of adjoint - @test size(P') == AbstractFFTs.output_size(P) # test size of adjoint - @test dot(y, P * x) ≈ dot(P' * y, x) # test validity of adjoint - @test dot(y, P \ x) ≈ dot(P' \ y, x) - Pinv = plan_ifft(y) - @test (Pinv')' * y == Pinv * y - @test size(Pinv') == AbstractFFTs.output_size(Pinv) - @test dot(x, Pinv * y) ≈ dot(Pinv' * x, y) - @test dot(x, Pinv \ y) ≈ dot(Pinv' \ x, y) - @test_throws MethodError mul!(x, P', y) + for x_shape in ((3,), (3, 4), (3, 4, 5)) + N = length(x_shape) + real_x = randn(x_shape) + complex_x = randn(ComplexF64, x_shape) + y = randn(ComplexF64, x_shape) + for x in (real_x, complex_x) + for dims in unique((1, 1:N, N)) + P = plan_fft(x, dims) + @test (P')' === P # test adjoint of adjoint + @test size(P') == AbstractFFTs.output_size(P) # test size of adjoint + @test dot(y, P * x) ≈ dot(P' * y, x) # test validity of adjoint + @test dot(y, P \ x) ≈ dot(P' \ y, x) + Pinv = plan_ifft(y) + @test (Pinv')' * y == Pinv * y + @test size(Pinv') == AbstractFFTs.output_size(Pinv) + @test dot(x, Pinv * y) ≈ dot(Pinv' * x, y) + @test dot(x, Pinv \ y) ≈ dot(Pinv' \ x, y) + @test_throws MethodError mul!(x, P', y) + end end end end @@ -256,7 +263,7 @@ end N = ndims(x) for dims in unique((1, 1:N, N)) P = plan_rfft(x, dims) - y = randn(Complex{Float64}, size(P * x)) + y = randn(ComplexF64, size(P * x)) @test (P')' * x == P * x @test size(P') == AbstractFFTs.output_size(P) @test dot(real.(y), real.(P * x)) + dot(imag.(y), imag.(P * x)) ≈ dot(P' * y, x) @@ -306,7 +313,7 @@ end for x_shape in ((2,), (2, 3), (3, 4, 5)) N = length(x_shape) x = randn(x_shape) - complex_x = x + randn(x_shape) * im + complex_x = randn(ComplexF64, x_shape) for dims in unique((1, 1:N, N)) # fft, ifft, bfft for f in (fft, ifft, bfft) From fe3b06a49cb58cfd9a56af0cba7f2466233d69dd Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sun, 28 Aug 2022 13:31:20 -0400 Subject: [PATCH 23/26] Add plan_inv implementation for adjoint plan and test it --- src/definitions.jl | 2 ++ test/runtests.jl | 9 +++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/definitions.jl b/src/definitions.jl index 4e7982d0..209eb485 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -654,4 +654,6 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) return scale ./ N .* (p.p \ x) end +# Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only). +plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p)) inv(p::AdjointPlan) = adjoint(inv(p.p)) diff --git a/test/runtests.jl b/test/runtests.jl index 6c9cb333..34906fc1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -263,12 +263,14 @@ end @test (P')' === P # test adjoint of adjoint @test size(P') == AbstractFFTs.output_size(P) # test size of adjoint @test dot(y, P * x) ≈ dot(P' * y, x) # test validity of adjoint - @test dot(y, P \ x) ≈ dot(P' \ y, x) + @test dot(y, P \ x) ≈ dot(P' \ y, x) # test inv of adjoint + @test dot(y, P \ x) ≈ dot(AbstractFFTs.plan_inv(P') * y, x) # test plan_inv of adjoint Pinv = plan_ifft(y) @test (Pinv')' * y == Pinv * y @test size(Pinv') == AbstractFFTs.output_size(Pinv) @test dot(x, Pinv * y) ≈ dot(Pinv' * x, y) @test dot(x, Pinv \ y) ≈ dot(Pinv' \ x, y) + @test dot(x, Pinv \ y) ≈ dot(AbstractFFTs.plan_inv(Pinv') * x, y) @test_throws MethodError mul!(x, P', y) end end @@ -281,14 +283,17 @@ end P = plan_rfft(x, dims) y = randn(ComplexF64, size(P * x)) @test (P')' * x == P * x - @test size(P') == AbstractFFTs.output_size(P) + @test size(P') == AbstractFFTs.output_size(P) @test dot(real.(y), real.(P * x)) + dot(imag.(y), imag.(P * x)) ≈ dot(P' * y, x) @test dot(real.(y), real.(P' \ x)) + dot(imag.(y), imag.(P' \ x)) ≈ dot(P \ y, x) + @test dot(real.(y), real.(AbstractFFTs.plan_inv(P') * x)) + + dot(imag.(y), imag.(AbstractFFTs.plan_inv(P') * x)) ≈ dot(P \ y, x) Pinv = plan_irfft(y, size(x)[first(dims)], dims) @test (Pinv')' * y == Pinv * y @test size(Pinv') == AbstractFFTs.output_size(Pinv) @test dot(x, Pinv * y) ≈ dot(real.(y), real.(Pinv' * x)) + dot(imag.(y), imag.(Pinv' * x)) @test dot(x, Pinv' \ y) ≈ dot(real.(y), real.(Pinv \ x)) + dot(imag.(y), imag.(Pinv \ x)) + @test dot(x, AbstractFFTs.plan_inv(Pinv') * y) ≈ dot(real.(y), real.(Pinv \ x)) + dot(imag.(y), imag.(Pinv \ x)) end end end From 403ce47e79ee8d1b46c433a7a008d8732e9350fe Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 5 Jul 2023 01:12:16 +0200 Subject: [PATCH 24/26] Apply suggestions from code review Co-authored-by: Seth Axen --- docs/src/implementations.md | 10 +++++----- src/definitions.jl | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/src/implementations.md b/docs/src/implementations.md index a8bf6db0..7367fd4c 100644 --- a/docs/src/implementations.md +++ b/docs/src/implementations.md @@ -32,10 +32,10 @@ To define a new FFT implementation in your own module, you should * You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs. -* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse-mode differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can take values: +* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse-mode differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can return: * `AbstractFFTs.NoProjectionStyle()`, - * `AbstractFFTs.RealProjectionStyle()`, for plans which halve one of the output's dimensions analogously to [`rfft`](@ref), - * `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans which expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension. + * `AbstractFFTs.RealProjectionStyle()`, for plans that halve one of the output's dimensions analogously to [`rfft`](@ref), + * `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans that expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension. -The normalization convention for your FFT should be that it computes yₖ = ∑ⱼ xⱼ exp(-2πi jk/n) for a transform of -length n, and the "backwards" (unnormalized inverse) transform computes the same thing but with exp(+2πi jk/n). +The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of +length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``. diff --git a/src/definitions.jl b/src/definitions.jl index 9ce8f0d3..04d518c2 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -601,10 +601,10 @@ struct AdjointPlan{T,P<:Plan} <: Plan{T} end """ - p' + (p::Plan)' adjoint(p::Plan) -Form the adjoint operator of an FFT plan. Returns a plan which performs the adjoint operation +Form the adjoint operator of an FFT plan. Returns a plan that performs the adjoint operation of the original plan. Note that this differs from the corresponding backwards plan in the case of real FFTs due to the halving of one of the dimensions of the FFT output, as described in [`rfft`](@ref). From e137ae33d8493d8ac3fced4bef645c0d0a4782a8 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 5 Jul 2023 13:54:58 +0200 Subject: [PATCH 25/26] Apply suggestions from code review --- src/definitions.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/definitions.jl b/src/definitions.jl index 04d518c2..4ec176eb 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -638,7 +638,7 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where [(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n], ntuple(i -> i == halfdim ? n : 1, Val(ndims(x))) ) - return p.p \ (x ./ scale) + return p.p \ (x ./ convert(typeof(x), scale)) end function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T} @@ -651,7 +651,7 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) [(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n], ntuple(i -> i == halfdim ? n : 1, Val(ndims(x))) ) - return scale ./ N .* (p.p \ x) + return (convert(typeof(x), scale) ./ N) .* (p.p \ x) end # Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only). From e601347e64e932b2177841d14bd22b3a78496003 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 5 Jul 2023 15:08:01 +0200 Subject: [PATCH 26/26] Test in-place plans --- ext/AbstractFFTsChainRulesCoreExt.jl | 14 +++++++++++- test/runtests.jl | 33 ++++++++++++++++++++++++---- test/testplans.jl | 22 +++++++++++++++++++ 3 files changed, 64 insertions(+), 5 deletions(-) diff --git a/ext/AbstractFFTsChainRulesCoreExt.jl b/ext/AbstractFFTsChainRulesCoreExt.jl index aa19724b..5ab5d2ee 100644 --- a/ext/AbstractFFTsChainRulesCoreExt.jl +++ b/ext/AbstractFFTsChainRulesCoreExt.jl @@ -161,12 +161,18 @@ end # plans function ChainRulesCore.frule((_, _, Δx), ::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray) - y = P * x + y = P * x + if Base.mightalias(y, x) + throw(ArgumentError("differentiation rules are not supported for in-place plans")) + end Δy = P * Δx return y, Δy end function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray) y = P * x + if Base.mightalias(y, x) + throw(ArgumentError("differentiation rules are not supported for in-place plans")) + end project_x = ChainRulesCore.ProjectTo(x) Pt = P' function mul_plan_pullback(ȳ) @@ -178,11 +184,17 @@ end function ChainRulesCore.frule((_, ΔP, Δx), ::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray) y = P * x + if Base.mightalias(y, x) + throw(ArgumentError("differentiation rules are not supported for in-place plans")) + end Δy = P * Δx .+ (ΔP.scale / P.scale) .* y return y, Δy end function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray) y = P * x + if Base.mightalias(y, x) + throw(ArgumentError("differentiation rules are not supported for in-place plans")) + end Pt = P' scale = P.scale project_x = ChainRulesCore.ProjectTo(x) diff --git a/test/runtests.jl b/test/runtests.jl index e7c255cd..c5f0659b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -68,6 +68,13 @@ end @test fftdims(P) == dims end + # in-place plan + P = plan_fft!(x, dims) + @test eltype(P) === ComplexF64 + xc64 = ComplexF64.(x) + @test P * xc64 ≈ fftw_fft + @test xc64 ≈ fftw_fft + fftw_bfft = complex.(size(x, dims) .* x) @test AbstractFFTs.bfft(y, dims) ≈ fftw_bfft P = plan_bfft(x, dims) @@ -75,6 +82,13 @@ end @test P \ (P * y) ≈ y @test fftdims(P) == dims + # in-place plan + P = plan_bfft!(x, dims) + @test eltype(P) === ComplexF64 + yc64 = ComplexF64.(y) + @test P * yc64 ≈ fftw_bfft + @test yc64 ≈ fftw_bfft + fftw_ifft = complex.(x) @test AbstractFFTs.ifft(y, dims) ≈ fftw_ifft # test plan_ifft and also inv and plan_inv of plan_fft, which should all give @@ -86,6 +100,13 @@ end @test fftdims(P) == dims end + # in-place plan + P = plan_ifft!(x, dims) + @test eltype(P) === ComplexF64 + yc64 = ComplexF64.(y) + @test P * yc64 ≈ fftw_ifft + @test yc64 ≈ fftw_ifft + # real FFT fftw_rfft = fftw_fft[ (Colon() for _ in 1:(ndims(fftw_fft) - 1))..., @@ -361,7 +382,8 @@ end for x_shape in ((2,), (2, 3), (3, 4, 5)) N = length(x_shape) x = randn(x_shape) - complex_x = randn(ComplexF64, x_shape) + complex_x = randn(ComplexF64, x_shape) + Δ = (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), ChainRulesTestUtils.rand_tangent(complex_x)) for dims in unique((1, 1:N, N)) # fft, ifft, bfft for f in (fft, ifft, bfft) @@ -370,11 +392,14 @@ end test_frule(f, complex_x, dims) test_rrule(f, complex_x, dims) end - for pf in (plan_fft, plan_ifft, plan_bfft) + for (pf, pf!) in ((plan_fft, plan_fft!), (plan_ifft, plan_ifft!), (plan_bfft, plan_bfft!)) test_frule(*, pf(x, dims), x) test_rrule(*, pf(x, dims), x) test_frule(*, pf(complex_x, dims), complex_x) test_rrule(*, pf(complex_x, dims), complex_x) + + @test_throws ArgumentError ChainRulesCore.frule(Δ, *, pf!(complex_x, dims), complex_x) + @test_throws ArgumentError ChainRulesCore.rrule(*, pf!(complex_x, dims), complex_x) end # rfft @@ -392,10 +417,10 @@ end test_rrule(f, complex_x, d, dims) end end - for pf in (plan_irfft, plan_brfft) + for pf in (plan_irfft, plan_brfft) for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2) test_frule(*, pf(complex_x, d, dims), complex_x) - test_rrule(*, pf(complex_x, d, dims), complex_x) + test_rrule(*, pf(complex_x, d, dims), complex_x) end end end diff --git a/test/testplans.jl b/test/testplans.jl index d9b336b3..09b3f671 100644 --- a/test/testplans.jl +++ b/test/testplans.jl @@ -232,3 +232,25 @@ function Base.:*(p::InverseTestRPlan, x::AbstractArray) return y end + +# In-place plans +# (simple wrapper of out-of-place plans that does not support inverses) +struct InplaceTestPlan{T,P<:Plan{T}} <: Plan{T} + plan::P +end + +Base.size(p::InplaceTestPlan) = size(p.plan) +Base.ndims(p::InplaceTestPlan) = ndims(p.plan) +AbstractFFTs.ProjectionStyle(p::InplaceTestPlan) = AbstractFFTs.ProjectionStyle(p.plan) + +function AbstractFFTs.plan_fft!(x::AbstractArray, region; kwargs...) + return InplaceTestPlan(plan_fft(x, region; kwargs...)) +end +function AbstractFFTs.plan_bfft!(x::AbstractArray, region; kwargs...) + return InplaceTestPlan(plan_bfft(x, region; kwargs...)) +end + +function LinearAlgebra.mul!(y::AbstractArray, p::InplaceTestPlan, x::AbstractArray) + return mul!(y, p.plan, x) +end +Base.:*(p::InplaceTestPlan, x::AbstractArray) = copyto!(x, p.plan * x)