Skip to content

Commit 5c23f4b

Browse files
vpuri3gaurav-arya
andauthored
Rename ProjectionStyle -> AdjointStyle, _mul -> AdjointMul, and improve docs (#109)
* make ProjectionStyle abstract type so we can subtype in downstream packages. add a few lines of docs * Rename ProjectionStyle -> AdjointStyle, _mul -> AdjointMul, improve docs * Clarify normalization * Clarify documentation, rename _output_size -> output_size * Remove unnecessary def * Remove confusing commas * Tweak docstring wording * Reposition and improve size/output_size docstrings * Note that size needs to be implemented in docs --------- Co-authored-by: Gaurav Arya <[email protected]>
1 parent 1cc9ca0 commit 5c23f4b

File tree

4 files changed

+128
-42
lines changed

4 files changed

+128
-42
lines changed

docs/src/api.md

+18-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Public Interface
22

3+
## FFT and FFT planning functions
4+
35
```@docs
46
AbstractFFTs.fft
57
AbstractFFTs.fft!
@@ -20,11 +22,26 @@ AbstractFFTs.plan_rfft
2022
AbstractFFTs.plan_brfft
2123
AbstractFFTs.plan_irfft
2224
AbstractFFTs.fftdims
23-
Base.adjoint
2425
AbstractFFTs.fftshift
2526
AbstractFFTs.fftshift!
2627
AbstractFFTs.ifftshift
2728
AbstractFFTs.ifftshift!
2829
AbstractFFTs.fftfreq
2930
AbstractFFTs.rfftfreq
31+
Base.size
32+
```
33+
34+
## Adjoint functionality
35+
36+
The following API is supported by plans that support adjoint functionality.
37+
It is also relevant to implementers of FFT plans that wish to support adjoints.
38+
```@docs
39+
Base.adjoint
40+
AbstractFFTs.AdjointStyle
41+
AbstractFFTs.output_size
42+
AbstractFFTs.adjoint_mul
43+
AbstractFFTs.FFTAdjointStyle
44+
AbstractFFTs.RFFTAdjointStyle
45+
AbstractFFTs.IRFFTAdjointStyle
46+
AbstractFFTs.UnitaryAdjointStyle
3047
```

docs/src/implementations.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ To define a new FFT implementation in your own module, you should
1818
inverse plan.
1919

2020
* Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of
21-
`x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`).
21+
`x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)`
22+
(which defaults to `p.region`), and the input size `size(x)` should be accessible via `size(p::MyPlan)`.
2223

2324
* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` that computes the transform `p` of `x` and stores the result in `y`.
2425

@@ -32,10 +33,9 @@ To define a new FFT implementation in your own module, you should
3233

3334
* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs.
3435

35-
* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse-mode differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can return:
36-
* `AbstractFFTs.NoProjectionStyle()`,
37-
* `AbstractFFTs.RealProjectionStyle()`, for plans that halve one of the output's dimensions analogously to [`rfft`](@ref),
38-
* `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans that expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension.
36+
* To support adjoints in a new plan, define the trait [`AbstractFFTs.AdjointStyle`](@ref).
37+
`AbstractFFTs` implements the following adjoint styles: [`AbstractFFTs.FFTAdjointStyle`](@ref), [`AbstractFFTs.RFFTAdjointStyle`](@ref), [`AbstractFFTs.IRFFTAdjointStyle`](@ref), and [`AbstractFFTs.UnitaryAdjointStyle`](@ref).
38+
To define a new adjoint style, define the methods [`AbstractFFTs.adjoint_mul`](@ref) and [`AbstractFFTs.output_size`](@ref).
3939

4040
The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of
4141
length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``.

src/definitions.jl

+100-31
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@ abstract type Plan{T} end
1010

1111
eltype(::Type{<:Plan{T}}) where {T} = T
1212

13-
# size(p) should return the size of the input array for p
14-
size(p::Plan, d) = size(p)[d]
15-
output_size(p::Plan, d) = output_size(p)[d]
13+
"""
14+
size(p::Plan, [dim])
15+
16+
Return the size of the input of a plan `p`, optionally at a specified dimenion `dim`.
17+
"""
18+
size(p::Plan, dim) = size(p)[dim]
1619
ndims(p::Plan) = length(size(p))
1720
length(p::Plan) = prod(size(p))::Int
1821

@@ -583,17 +586,73 @@ plan_brfft
583586

584587
##############################################################################
585588

586-
struct NoProjectionStyle end
587-
struct RealProjectionStyle end
588-
struct RealInverseProjectionStyle
589+
"""
590+
AbstractFFTs.AdjointStyle(::Plan)
591+
592+
Return the adjoint style of a plan, enabling automatic computation of adjoint plans via
593+
[`Base.adjoint`](@ref). Instructions for supporting adjoint styles are provided in the
594+
[implementation instructions](implementations.md#Defining-a-new-implementation).
595+
"""
596+
abstract type AdjointStyle end
597+
598+
"""
599+
FFTAdjointStyle()
600+
601+
Adjoint style for complex to complex discrete Fourier transforms that normalize
602+
the output analogously to [`fft`](@ref).
603+
604+
Since the Fourier transform is unitary up to a scaling, the adjoint simply applies
605+
the transform's inverse with an appropriate scaling.
606+
"""
607+
struct FFTAdjointStyle <: AdjointStyle end
608+
609+
"""
610+
RFFTAdjointStyle()
611+
612+
Adjoint style for real to complex discrete Fourier transforms that halve one of
613+
the output's dimensions and normalize the output analogously to [`rfft`](@ref).
614+
615+
Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's
616+
inverse, but with appropriate scaling and additional logic to handle the fact that the
617+
output is projected to exploit its conjugate symmetry (see [`rfft`](@ref)).
618+
"""
619+
struct RFFTAdjointStyle <: AdjointStyle end
620+
621+
"""
622+
IRFFTAdjointStyle(d::Dim)
623+
624+
Adjoint style for complex to real discrete Fourier transforms that expect an input
625+
with a halved dimension and normalize the output analogously to [`irfft`](@ref),
626+
where `d` is the original length of the dimension.
627+
628+
Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's
629+
inverse, but with appropriate scaling and additional logic to handle the fact that the
630+
input is projected to exploit its conjugate symmetry (see [`irfft`](@ref)).
631+
"""
632+
struct IRFFTAdjointStyle <: AdjointStyle
589633
dim::Int
590634
end
591-
const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle}
592635

593-
output_size(p::Plan) = _output_size(p, ProjectionStyle(p))
594-
_output_size(p::Plan, ::NoProjectionStyle) = size(p)
595-
_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), fftdims(p))
596-
_output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
636+
"""
637+
UnitaryAdjointStyle()
638+
639+
Adjoint style for unitary transforms, whose adjoint equals their inverse.
640+
"""
641+
struct UnitaryAdjointStyle <: AdjointStyle end
642+
643+
"""
644+
output_size(p::Plan, [dim])
645+
646+
Return the size of the output of a plan `p`, optionally at a specified dimension `dim`.
647+
648+
Implementations of a new adjoint style `AS <: AbstractFFTs.AdjointStyle` should define `output_size(::Plan, ::AS)`.
649+
"""
650+
output_size(p::Plan) = output_size(p, AdjointStyle(p))
651+
output_size(p::Plan, dim) = output_size(p)[dim]
652+
output_size(p::Plan, ::FFTAdjointStyle) = size(p)
653+
output_size(p::Plan, ::RFFTAdjointStyle) = rfft_output_size(size(p), fftdims(p))
654+
output_size(p::Plan, s::IRFFTAdjointStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
655+
output_size(p::Plan, ::UnitaryAdjointStyle) = size(p)
597656

598657
struct AdjointPlan{T,P<:Plan} <: Plan{T}
599658
p::P
@@ -604,9 +663,7 @@ end
604663
(p::Plan)'
605664
adjoint(p::Plan)
606665
607-
Form the adjoint operator of an FFT plan. Returns a plan that performs the adjoint operation of
608-
the original plan. Note that this differs from the corresponding backwards plan in the case of real
609-
FFTs due to the halving of one of the dimensions of the FFT output, as described in [`rfft`](@ref).
666+
Return a plan that performs the adjoint operation of the original plan.
610667
611668
!!! note
612669
Adjoint plans do not currently support `LinearAlgebra.mul!`. Further, as a new addition to `AbstractFFTs`,
@@ -620,40 +677,52 @@ Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale)
620677
size(p::AdjointPlan) = output_size(p.p)
621678
output_size(p::AdjointPlan) = size(p.p)
622679

623-
Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p))
680+
Base.:*(p::AdjointPlan, x::AbstractArray) = adjoint_mul(p.p, x)
681+
682+
"""
683+
adjoint_mul(p::Plan, x::AbstractArray)
684+
685+
Multiply an array `x` by the adjoint of a plan `p`. This is equivalent to `p' * x`.
686+
687+
Implementations of a new adjoint style `AS <: AbstractFFTs.AdjointStyle` should define
688+
`adjoint_mul(::Plan, ::AbstractArray, ::AS)`.
689+
"""
690+
adjoint_mul(p::Plan, x::AbstractArray) = adjoint_mul(p, x, AdjointStyle(p))
624691

625-
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T}
626-
dims = fftdims(p.p)
627-
N = normalization(T, size(p.p), dims)
628-
return (p.p \ x) / N
692+
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T}
693+
dims = fftdims(p)
694+
N = normalization(T, size(p), dims)
695+
return (p \ x) / N
629696
end
630697

631-
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T<:Real}
632-
dims = fftdims(p.p)
633-
N = normalization(T, size(p.p), dims)
698+
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real}
699+
dims = fftdims(p)
700+
N = normalization(T, size(p), dims)
634701
halfdim = first(dims)
635-
d = size(p.p, halfdim)
636-
n = output_size(p.p, halfdim)
702+
d = size(p, halfdim)
703+
n = output_size(p, halfdim)
637704
scale = reshape(
638705
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n],
639706
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
640707
)
641-
return p.p \ (x ./ convert(typeof(x), scale))
708+
return p \ (x ./ convert(typeof(x), scale))
642709
end
643710

644-
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T}
645-
dims = fftdims(p.p)
646-
N = normalization(real(T), output_size(p.p), dims)
711+
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T}
712+
dims = fftdims(p)
713+
N = normalization(real(T), output_size(p), dims)
647714
halfdim = first(dims)
648-
n = size(p.p, halfdim)
649-
d = output_size(p.p, halfdim)
715+
n = size(p, halfdim)
716+
d = output_size(p, halfdim)
650717
scale = reshape(
651718
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
652719
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
653720
)
654-
return (convert(typeof(x), scale) ./ N) .* (p.p \ x)
721+
return (convert(typeof(x), scale) ./ N) .* (p \ x)
655722
end
656723

724+
adjoint_mul(p::Plan, x::AbstractArray, ::UnitaryAdjointStyle) = p \ x
725+
657726
# Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only).
658727
plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p))
659728
inv(p::AdjointPlan) = adjoint(inv(p.p))

test/testplans.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ Base.ndims(::TestPlan{T,N}) where {T,N} = N
2121
Base.size(p::InverseTestPlan) = p.sz
2222
Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N
2323

24-
AbstractFFTs.ProjectionStyle(::TestPlan) = AbstractFFTs.NoProjectionStyle()
25-
AbstractFFTs.ProjectionStyle(::InverseTestPlan) = AbstractFFTs.NoProjectionStyle()
24+
AbstractFFTs.AdjointStyle(::TestPlan) = AbstractFFTs.FFTAdjointStyle()
25+
AbstractFFTs.AdjointStyle(::InverseTestPlan) = AbstractFFTs.FFTAdjointStyle()
2626

2727
function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T}
2828
return TestPlan{T}(region, size(x))
@@ -110,8 +110,8 @@ mutable struct InverseTestRPlan{T,N,G} <: Plan{Complex{T}}
110110
end
111111
end
112112

113-
AbstractFFTs.ProjectionStyle(::TestRPlan) = AbstractFFTs.RealProjectionStyle()
114-
AbstractFFTs.ProjectionStyle(p::InverseTestRPlan) = AbstractFFTs.RealInverseProjectionStyle(p.d)
113+
AbstractFFTs.AdjointStyle(::TestRPlan) = AbstractFFTs.RFFTAdjointStyle()
114+
AbstractFFTs.AdjointStyle(p::InverseTestRPlan) = AbstractFFTs.IRFFTAdjointStyle(p.d)
115115

116116
function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T<:Real}
117117
return TestRPlan{T}(region, size(x))
@@ -241,7 +241,7 @@ end
241241

242242
Base.size(p::InplaceTestPlan) = size(p.plan)
243243
Base.ndims(p::InplaceTestPlan) = ndims(p.plan)
244-
AbstractFFTs.ProjectionStyle(p::InplaceTestPlan) = AbstractFFTs.ProjectionStyle(p.plan)
244+
AbstractFFTs.AdjointStyle(p::InplaceTestPlan) = AbstractFFTs.AdjointStyle(p.plan)
245245

246246
function AbstractFFTs.plan_fft!(x::AbstractArray, region; kwargs...)
247247
return InplaceTestPlan(plan_fft(x, region; kwargs...))

0 commit comments

Comments
 (0)