Skip to content

Commit f3575aa

Browse files
committed
Rename ProjectionStyle -> AdjointStyle, _mul -> AdjointMul, improve docs
1 parent d53f57d commit f3575aa

File tree

4 files changed

+69
-40
lines changed

4 files changed

+69
-40
lines changed

docs/src/api.md

+4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ AbstractFFTs.plan_brfft
2121
AbstractFFTs.plan_irfft
2222
AbstractFFTs.fftdims
2323
Base.adjoint
24+
AbstractFFTs.FFTAdjointStyle
25+
AbstractFFTs.RFFTAdjointStyle
26+
AbstractFFTs.IRFFTAdjointStyle
27+
AbstractFFTs.UnitaryAdjointStyle
2428
AbstractFFTs.fftshift
2529
AbstractFFTs.fftshift!
2630
AbstractFFTs.ifftshift

docs/src/implementations.md

+5-4
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ To define a new FFT implementation in your own module, you should
3232

3333
* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs.
3434

35-
* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse-mode differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can return:
36-
* `AbstractFFTs.NoProjectionStyle()`,
37-
* `AbstractFFTs.RealProjectionStyle()`, for plans that halve one of the output's dimensions analogously to [`rfft`](@ref),
38-
* `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans that expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension.
35+
* We offer an experimental `AdjointStyle` trait to enable automatic computation of adjoint plans via [`Base.adjoint`](@ref).
36+
To support adjoints in a new plan, define the trait `AbstractFFTs.AdjointStyle(::MyPlan)`. This should return a subtype of `AS <: AbstractFFTs.AdjointStyle` supporting `AbstractFFTs.adjoint_mul(::Plan, ::AbstractArray, ::AS)` and
37+
`AbstractFFTs._output_size(::Plan, ::AS)`.
38+
39+
`AbstractFFTs` pre-implements the following adjoint styles: [`AbstractFFTs.FFTAdjointStyle`](@ref), [`AbstractFFTs.RFFTAdjointStyle`](@ref), [`AbstractFFTs.IRFFTAdjointStyle`](@ref), and [`AbstractFFTs.UnitaryAdjointStyle`](@ref).
3940

4041
The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of
4142
length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``.

src/definitions.jl

+55-31
Original file line numberDiff line numberDiff line change
@@ -583,35 +583,57 @@ plan_brfft
583583

584584
##############################################################################
585585

586-
abstract type ProjectionStyle end
586+
abstract type AdjointStyle end
587587

588588
"""
589-
NoProjectionStyle()
589+
FFTAdjointStyle()
590590
591-
Projection style for complex to complex discrete Fourier transform
591+
Projection style for complex to complex discrete Fourier transforms.
592+
593+
Since the Fourier transform is unitary up to a scaling, the adjoint simply applies
594+
the transform's inverse with an appropriate scaling.
592595
"""
593-
struct NoProjectionStyle <: ProjectionStyle end
596+
struct FFTAdjointStyle <: AdjointStyle end
594597

595598
"""
596-
RealProjectionStyle()
599+
RFFTAdjointStyle()
597600
598-
Projection style for complex to real discrete Fourier transform
601+
Projection style for real to complex discrete Fourier transforms, for plans that
602+
halve one of the output's dimensions analogously to [`rfft`](@ref).
603+
604+
Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's
605+
inverse, but with additional logic to handle the fact that the output is projected
606+
to exploit its conjugate symmetry (see [`rfft`](@ref)).
599607
"""
600-
struct RealProjectionStyle <: ProjectionStyle end
608+
struct RFFTAdjointStyle <: AdjointStyle end
601609

602610
"""
603-
RealInverseProjectionStyle()
611+
IRFFTAdjointStyle(d::Dim)
604612
605-
Projection style for inverse of complex to real discrete Fourier transform
613+
Projection style for complex to real discrete Fourier transforms, for plans that
614+
expect an input with a halved dimension analogously to [`irfft`](@ref), where `d`
615+
is the original length of the dimension.
616+
617+
Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's
618+
inverse, but with additional logic to handle the fact that the input is projected
619+
to exploit its conjugate symmetry (see [`irfft`](@ref)).
606620
"""
607-
struct RealInverseProjectionStyle <: ProjectionStyle
621+
struct IRFFTAdjointStyle <: AdjointStyle
608622
dim::Int
609623
end
610624

611-
output_size(p::Plan) = _output_size(p, ProjectionStyle(p))
612-
_output_size(p::Plan, ::NoProjectionStyle) = size(p)
613-
_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), fftdims(p))
614-
_output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
625+
"""
626+
UnitaryAdjointStyle()
627+
628+
Projection style for unitary transforms, whose adjoint equals their inverse.
629+
"""
630+
struct UnitaryAdjointStyle <: AdjointStyle end
631+
632+
output_size(p::Plan) = _output_size(p, AdjointStyle(p))
633+
_output_size(p::Plan, ::FFTAdjointStyle) = size(p)
634+
_output_size(p::Plan, ::RFFTAdjointStyle) = rfft_output_size(size(p), fftdims(p))
635+
_output_size(p::Plan, s::IRFFTAdjointStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
636+
_output_size(p::Plan, ::UnitaryAdjointStyle) = size(p)
615637

616638
struct AdjointPlan{T,P<:Plan} <: Plan{T}
617639
p::P
@@ -638,40 +660,42 @@ Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale)
638660
size(p::AdjointPlan) = output_size(p.p)
639661
output_size(p::AdjointPlan) = size(p.p)
640662

641-
Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p))
663+
Base.:*(p::AdjointPlan, x::AbstractArray) = adjoint_mul(p.p, x, AdjointStyle(p.p))
642664

643-
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T}
644-
dims = fftdims(p.p)
645-
N = normalization(T, size(p.p), dims)
646-
return (p.p \ x) / N
665+
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T}
666+
dims = fftdims(p)
667+
N = normalization(T, size(p), dims)
668+
return (p \ x) / N
647669
end
648670

649-
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T<:Real}
650-
dims = fftdims(p.p)
651-
N = normalization(T, size(p.p), dims)
671+
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real}
672+
dims = fftdims(p)
673+
N = normalization(T, size(p), dims)
652674
halfdim = first(dims)
653-
d = size(p.p, halfdim)
654-
n = output_size(p.p, halfdim)
675+
d = size(p, halfdim)
676+
n = output_size(p, halfdim)
655677
scale = reshape(
656678
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n],
657679
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
658680
)
659-
return p.p \ (x ./ convert(typeof(x), scale))
681+
return p \ (x ./ convert(typeof(x), scale))
660682
end
661683

662-
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T}
663-
dims = fftdims(p.p)
664-
N = normalization(real(T), output_size(p.p), dims)
684+
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T}
685+
dims = fftdims(p)
686+
N = normalization(real(T), output_size(p), dims)
665687
halfdim = first(dims)
666-
n = size(p.p, halfdim)
667-
d = output_size(p.p, halfdim)
688+
n = size(p, halfdim)
689+
d = output_size(p, halfdim)
668690
scale = reshape(
669691
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
670692
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
671693
)
672-
return (convert(typeof(x), scale) ./ N) .* (p.p \ x)
694+
return (convert(typeof(x), scale) ./ N) .* (p \ x)
673695
end
674696

697+
adjoint_mul(p::Plan, x::AbstractArray, ::UnitaryAdjointStyle) = p \ x
698+
675699
# Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only).
676700
plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p))
677701
inv(p::AdjointPlan) = adjoint(inv(p.p))

test/testplans.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ Base.ndims(::TestPlan{T,N}) where {T,N} = N
2121
Base.size(p::InverseTestPlan) = p.sz
2222
Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N
2323

24-
AbstractFFTs.ProjectionStyle(::TestPlan) = AbstractFFTs.NoProjectionStyle()
25-
AbstractFFTs.ProjectionStyle(::InverseTestPlan) = AbstractFFTs.NoProjectionStyle()
24+
AbstractFFTs.AdjointStyle(::TestPlan) = AbstractFFTs.FFTAdjointStyle()
25+
AbstractFFTs.AdjointStyle(::InverseTestPlan) = AbstractFFTs.FFTAdjointStyle()
2626

2727
function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T}
2828
return TestPlan{T}(region, size(x))
@@ -110,8 +110,8 @@ mutable struct InverseTestRPlan{T,N,G} <: Plan{Complex{T}}
110110
end
111111
end
112112

113-
AbstractFFTs.ProjectionStyle(::TestRPlan) = AbstractFFTs.RealProjectionStyle()
114-
AbstractFFTs.ProjectionStyle(p::InverseTestRPlan) = AbstractFFTs.RealInverseProjectionStyle(p.d)
113+
AbstractFFTs.AdjointStyle(::TestRPlan) = AbstractFFTs.RFFTAdjointStyle()
114+
AbstractFFTs.AdjointStyle(p::InverseTestRPlan) = AbstractFFTs.IRFFTAdjointStyle(p.d)
115115

116116
function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T<:Real}
117117
return TestRPlan{T}(region, size(x))
@@ -241,7 +241,7 @@ end
241241

242242
Base.size(p::InplaceTestPlan) = size(p.plan)
243243
Base.ndims(p::InplaceTestPlan) = ndims(p.plan)
244-
AbstractFFTs.ProjectionStyle(p::InplaceTestPlan) = AbstractFFTs.ProjectionStyle(p.plan)
244+
AbstractFFTs.AdjointStyle(p::InplaceTestPlan) = AbstractFFTs.AdjointStyle(p.plan)
245245

246246
function AbstractFFTs.plan_fft!(x::AbstractArray, region; kwargs...)
247247
return InplaceTestPlan(plan_fft(x, region; kwargs...))

0 commit comments

Comments
 (0)