Skip to content

Commit 780f206

Browse files
committed
Implement in-place test plans
1 parent 3004ff7 commit 780f206

File tree

1 file changed

+34
-22
lines changed

1 file changed

+34
-22
lines changed

test/TestPlans.jl

+34-22
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,25 @@ import AbstractFFTs
44
import LinearAlgebra.mul!
55
using AbstractFFTs: Plan
66

7-
mutable struct TestPlan{T,N} <: Plan{T}
7+
mutable struct TestPlan{T,N,inplace} <: Plan{T}
88
region
99
sz::NTuple{N,Int}
1010
pinv::Plan{T}
11-
function TestPlan{T}(region, sz::NTuple{N,Int}) where {T,N}
12-
return new{T,N}(region, sz)
11+
function TestPlan{T,inplace}(region, sz::NTuple{N,Int}) where {T,N,inplace}
12+
return new{T,N,inplace}(region, sz)
1313
end
1414
end
15+
TestPlan{T}(region, sz) where {T} = TestPlan{T,false}(region, sz)
1516

16-
mutable struct InverseTestPlan{T,N} <: Plan{T}
17+
mutable struct InverseTestPlan{T,N,inplace} <: Plan{T}
1718
region
1819
sz::NTuple{N,Int}
1920
pinv::Plan{T}
20-
function InverseTestPlan{T}(region, sz::NTuple{N,Int}) where {T,N}
21-
return new{T,N}(region, sz)
21+
function InverseTestPlan{T,inplace}(region, sz::NTuple{N,Int}) where {T,N,inplace}
22+
return new{T,N,inplace}(region, sz)
2223
end
2324
end
25+
InverseTestPlan{T}(region, sz) where {T} = InverseTestPlan{T,false}(region, sz)
2426

2527
Base.size(p::TestPlan) = p.sz
2628
Base.ndims(::TestPlan{T,N}) where {T,N} = N
@@ -34,18 +36,25 @@ function AbstractFFTs.plan_bfft(x::AbstractArray{T}, region; kwargs...) where {T
3436
return InverseTestPlan{T}(region, size(x))
3537
end
3638

37-
function AbstractFFTs.plan_inv(p::TestPlan{T}) where {T}
38-
unscaled_pinv = InverseTestPlan{T}(p.region, p.sz)
39-
N = AbstractFFTs.normalization(T, p.sz, p.region)
40-
unscaled_pinv.pinv = AbstractFFTs.ScaledPlan(p, N)
41-
pinv = AbstractFFTs.ScaledPlan(unscaled_pinv, N)
39+
function AbstractFFTs.plan_fft!(x::AbstractArray{T}, region; kwargs...) where {T}
40+
return TestPlan{T,true}(region, size(x))
41+
end
42+
function AbstractFFTs.plan_bfft!(x::AbstractArray{T}, region; kwargs...) where {T}
43+
return InverseTestPlan{T,true}(region, size(x))
44+
end
45+
46+
function AbstractFFTs.plan_inv(p::TestPlan{T,N,inplace}) where {T,N,inplace}
47+
unscaled_pinv = InverseTestPlan{T,inplace}(p.region, p.sz)
48+
_N = AbstractFFTs.normalization(T, p.sz, p.region)
49+
unscaled_pinv.pinv = AbstractFFTs.ScaledPlan(p, _N)
50+
pinv = AbstractFFTs.ScaledPlan(unscaled_pinv, _N)
4251
return pinv
4352
end
44-
function AbstractFFTs.plan_inv(pinv::InverseTestPlan{T}) where {T}
45-
unscaled_p = TestPlan{T}(pinv.region, pinv.sz)
46-
N = AbstractFFTs.normalization(T, pinv.sz, pinv.region)
47-
unscaled_p.pinv = AbstractFFTs.ScaledPlan(pinv, N)
48-
p = AbstractFFTs.ScaledPlan(unscaled_p, N)
53+
function AbstractFFTs.plan_inv(pinv::InverseTestPlan{T,N,inplace}) where {T,N,inplace}
54+
unscaled_p = TestPlan{T,inplace}(pinv.region, pinv.sz)
55+
_N = AbstractFFTs.normalization(T, pinv.sz, pinv.region)
56+
unscaled_p.pinv = AbstractFFTs.ScaledPlan(pinv, _N)
57+
p = AbstractFFTs.ScaledPlan(unscaled_p, _N)
4958
return p
5059
end
5160

@@ -80,20 +89,23 @@ function dft!(
8089
end
8190

8291
function mul!(
83-
y::AbstractArray{<:Complex,N}, p::TestPlan, x::AbstractArray{<:Union{Complex,Real},N}
84-
) where {N}
92+
y::AbstractArray{<:Complex,N}, p::TestPlan{T,N,false}, x::AbstractArray{<:Union{Complex,Real},N}
93+
) where {T,N}
8594
size(y) == size(p) == size(x) || throw(DimensionMismatch())
8695
dft!(y, x, p.region, -1)
8796
end
8897
function mul!(
89-
y::AbstractArray{<:Complex,N}, p::InverseTestPlan, x::AbstractArray{<:Union{Complex,Real},N}
90-
) where {N}
98+
y::AbstractArray{<:Complex,N}, p::InverseTestPlan{T,N,false}, x::AbstractArray{<:Union{Complex,Real},N}
99+
) where {T,N}
91100
size(y) == size(p) == size(x) || throw(DimensionMismatch())
92101
dft!(y, x, p.region, 1)
93102
end
94103

95-
Base.:*(p::TestPlan, x::AbstractArray) = mul!(similar(x, complex(float(eltype(x)))), p, x)
96-
Base.:*(p::InverseTestPlan, x::AbstractArray) = mul!(similar(x, complex(float(eltype(x)))), p, x)
104+
Base.:*(p::TestPlan{T,N,false}, x::AbstractArray) where {T,N} = mul!(similar(x, complex(float(eltype(x)))), p, x)
105+
Base.:*(p::InverseTestPlan{T,N,false}, x::AbstractArray) where {T,N} = mul!(similar(x, complex(float(eltype(x)))), p, x)
106+
107+
Base.:*(p::TestPlan{T,N,true}, x::AbstractArray) where {T,N} = copy!(x, dft!(similar(x), x, p.region, -1))
108+
Base.:*(p::InverseTestPlan{T,N,true}, x::AbstractArray) where {T,N} = copy!(x, dft!(similar(x), x, p.region, 1))
97109

98110
mutable struct TestRPlan{T,N} <: Plan{T}
99111
region

0 commit comments

Comments
 (0)