Skip to content

Add ComplexZernike #98

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 72 additions & 2 deletions src/ModalInterlace.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
abstract type AbstractModalInterlace{T} <: AbstractBandedBlockBandedMatrix{T} end

axes(Z::AbstractModalInterlace) = blockedrange.(oneto.(Z.MN))

"""
ModalInterlace
"""
struct ModalInterlace{T, MMNN<:Tuple} <: AbstractBandedBlockBandedMatrix{T}
struct ModalInterlace{T, MMNN<:Tuple} <: AbstractModalInterlace{T}
ops
MN::MMNN
bandwidths::NTuple{2,Int}
Expand All @@ -10,7 +14,7 @@ end
ModalInterlace{T}(ops, MN::NTuple{2,Integer}, bandwidths::NTuple{2,Int}) where T = ModalInterlace{T,typeof(MN)}(ops, MN, bandwidths)
ModalInterlace(ops::AbstractVector{<:AbstractMatrix{T}}, MN::NTuple{2,Integer}, bandwidths::NTuple{2,Int}) where T = ModalInterlace{T}(ops, MN, bandwidths)

axes(Z::ModalInterlace) = blockedrange.(oneto.(Z.MN))


blockbandwidths(R::ModalInterlace) = R.bandwidths
subblockbandwidths(::ModalInterlace) = (0,0)
Expand Down Expand Up @@ -56,3 +60,69 @@ end

# act like lazy array
Base.BroadcastStyle(::Type{<:ModalInterlace{<:Any,NTuple{2,InfiniteCardinal{0}}}}) = LazyArrayStyle{2}()




"""
ShiftModalInterlace

for operators that shift the mode
"""
struct ShiftModalInterlace{T, MMNN<:Tuple} <: AbstractModalInterlace{T}
ops
MN::MMNN
bandwidths::NTuple{2,Int}
subbandwidth::Int
end

ShiftModalInterlace{T}(ops, MN::NTuple{2,Integer}, bandwidths::NTuple{2,Int}, subbandwidth::Int) where T = ShiftModalInterlace{T,typeof(MN)}(ops, MN, bandwidths, subbandwidth)
ShiftModalInterlace(ops::AbstractVector{<:AbstractMatrix{T}}, MN::NTuple{2,Integer}, bandwidths::NTuple{2,Int}, subbandwidth::Int) where T = ShiftModalInterlace{T}(ops, MN, bandwidths, subbandwidth)



blockbandwidths(R::ShiftModalInterlace) = R.bandwidths
subblockbandwidths(R::ShiftModalInterlace) = (-R.subbandwidth,R.subbandwidth)


function Base.view(R::ShiftModalInterlace{T}, KJ::Block{2}) where T
K,J = KJ.n
dat = Matrix{T}(undef,1,J)
l,u = blockbandwidths(R)
λ = R.subbandwidth
if isodd(J-K) && -l ≤ J - K ≤ u
sh = (J-K-1)÷2
if iseven(K)
k = K÷2+1
dat[1,1] = R.ops[1][k,k+sh]
end
for m in range(2-iseven(K); step=2, length=J÷2-max(0,sh))
k = K÷2-m÷2+isodd(K)
dat[1,m] = dat[1,m+1] = R.ops[m+1][k,k+sh]
end
else
fill!(dat, zero(T))
end
_BandedMatrix(dat, K, 0, 0)
end

getindex(R::ShiftModalInterlace, k::Integer, j::Integer) = R[findblockindex.(axes(R),(k,j))...]

struct ShiftModalInterlaceLayout <: AbstractBandedBlockBandedLayout end
struct LazyShiftModalInterlaceLayout <: AbstractLazyBandedBlockBandedLayout end

MemoryLayout(::Type{<:ShiftModalInterlace}) = ShiftModalInterlaceLayout()
MemoryLayout(::Type{<:ShiftModalInterlace{<:Any,NTuple{2,InfiniteCardinal{0}}}}) = LazyShiftModalInterlaceLayout()
sublayout(::Union{ShiftModalInterlaceLayout,LazyShiftModalInterlaceLayout}, ::Type{<:NTuple{2,BlockSlice{<:BlockOneTo}}}) = ShiftModalInterlaceLayout()


function sub_materialize(::ShiftModalInterlaceLayout, V::AbstractMatrix{T}) where T
kr,jr = parentindices(V)
KR,JR = kr.block,jr.block
M,N = Int(last(KR)), Int(last(JR))
R = parent(V)
ShiftModalInterlace{T}([R.ops[m][1:(M-m+2)÷2,1:(N-m+2)÷2] for m=1:min(N,M)], (M,N), R.bandwidths)
end

# act like lazy array
Base.BroadcastStyle(::Type{<:ShiftModalInterlace{<:Any,NTuple{2,InfiniteCardinal{0}}}}) = LazyArrayStyle{2}()
6 changes: 3 additions & 3 deletions src/MultivariateOrthogonalPolynomials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ClassicalOrthogonalPolynomials, FastTransforms, BlockBandedMatrices, Block
LazyArrays, SpecialFunctions, LinearAlgebra, BandedMatrices, LazyBandedMatrices, ArrayLayouts,
HarmonicOrthogonalPolynomials

import Base: axes, in, ==, *, ^, \, copy, OneTo, getindex, size, oneto
import Base: axes, in, ==, *, ^, \, copy, OneTo, getindex, size, oneto, conj
import DomainSets: boundary

import QuasiArrays: LazyQuasiMatrix, LazyQuasiArrayStyle
Expand All @@ -26,8 +26,8 @@ export MultivariateOrthogonalPolynomial, BivariateOrthogonalPolynomial,
UnitTriangle, UnitDisk,
JacobiTriangle, TriangleWeight, WeightedTriangle,
DunklXuDisk, DunklXuDiskWeight, WeightedDunklXuDisk,
Zernike, ZernikeWeight, zerniker, zernikez,
PartialDerivative, Laplacian, AbsLaplacianPower, AngularMomentum,
Zernike, ComplexZernike, ZernikeWeight, zerniker, zernikez,
PartialDerivative, ComplexDerivative, Laplacian, AbsLaplacianPower, AngularMomentum,
RadialCoordinate, Weighted, Block

include("ModalInterlace.jl")
Expand Down
148 changes: 114 additions & 34 deletions src/disk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,36 +101,62 @@ function getindex(w::ZernikeWeight, xy::StaticVector{2})
end


abstract type AbstractZernike{T} <: BivariateOrthogonalPolynomial{T} end

"""
Zernike(a, b)

is a quasi-matrix orthogonal `r^(2a) * (1-r^2)^b`
"""
struct Zernike{T} <: BivariateOrthogonalPolynomial{T}
struct Zernike{T} <: AbstractZernike{T}
a::T
b::T
Zernike{T}(a::T, b::T) where T = new{T}(a, b)
end
Zernike{T}(a, b) where T = Zernike{T}(convert(T,a), convert(T,b))
Zernike(a::T, b::V) where {T,V} = Zernike{float(promote_type(T,V))}(a, b)
Zernike{T}(b) where T = Zernike{T}(zero(b), b)
Zernike{T}() where T = Zernike{T}(zero(T))


"""
ComplexZernike(a, b)

is a complex-valued quasi-matrix orthogonal `r^(2a) * (1-r^2)^b`
"""
struct ComplexZernike{T} <: AbstractZernike{T}
a::T
b::T
ComplexZernike{T}(a::T, b::T) where T = new{T}(a, b)
end

for Zer in (:Zernike, :ComplexZernike)
@eval begin
$Zer{T}(a, b) where T = $Zer{T}(convert(T,a), convert(T,b))
$Zer(a::T, b::V) where {T,V} = $Zer{float(promote_type(T,V))}(a, b)
$Zer{T}(b) where T = $Zer{T}(zero(b), b)
$Zer{T}() where T = $Zer{T}(zero(T))

$Zer() = $Zer{Float64}()

==(w::$Zer, v::$Zer) = w.a == v.a && w.b == v.b
end
end

"""
Zernike(b)

is a quasi-matrix orthogonal `(1-r^2)^b`
"""
Zernike(b) = Zernike(zero(b), b)
Zernike() = Zernike{Float64}()

axes(P::Zernike{T}) where T = (Inclusion(UnitDisk{T}()),blockedrange(oneto(∞)))
"""
ComplexZernike(b)

==(w::Zernike, v::Zernike) = w.a == v.a && w.b == v.b
is a complex quasi-matrix orthogonal `(1-r^2)^b`
"""
ComplexZernike(b) = ComplexZernike(zero(b), b)

copy(A::Zernike) = A
axes(P::AbstractZernike{T}) where T = (Inclusion(UnitDisk{T}()),blockedrange(oneto(∞)))
copy(A::AbstractZernike) = A

orthogonalityweight(Z::Zernike) = ZernikeWeight(Z.a, Z.b)
orthogonalityweight(Z::AbstractZernike) = ZernikeWeight(Z.a, Z.b)

zerniker(ℓ, m, a, b, r::T) where T = sqrt(convert(T,2)^(m+a+b+2-iszero(m))/π) * r^m * normalizedjacobip((ℓ-m) ÷ 2, b, m+a, 2r^2-1)
zerniker(ℓ, m, b, r) = zerniker(ℓ, m, zero(b), b, r)
Expand All @@ -142,9 +168,19 @@ function zernikez(ℓ, ms, a, b, rθ::RadialCoordinate{T}) where T
zerniker(ℓ, m, a, b, r) * (signbit(ms) ? sin(m*θ) : cos(m*θ))
end

zernikez(ℓ, ms, a, b, xy::StaticVector{2}) = zernikez(ℓ, ms, a, b, RadialCoordinate(xy))
zernikez(ℓ, ms, b, xy::StaticVector{2}) = zernikez(ℓ, ms, zero(b), b, xy)
zernikez(ℓ, ms, xy::StaticVector{2,T}) where T = zernikez(ℓ, ms, zero(T), xy)
function complexzernikez(ℓ, ms, a, b, rθ::RadialCoordinate{T}) where T
r,θ = rθ.r,rθ.θ
m = abs(ms)
zerniker(ℓ, m, a, b, r) * exp(im*m*θ)
end

for func in (:zernikez, :complexzernikez)
@eval begin
$func(ℓ, ms, a, b, xy::StaticVector{2}) = $func(ℓ, ms, a, b, RadialCoordinate(xy))
$func(ℓ, ms, b, xy::StaticVector{2}) = $func(ℓ, ms, zero(b), b, xy)
$func(ℓ, ms, xy::StaticVector{2,T}) where T = $func(ℓ, ms, zero(T), xy)
end
end

function getindex(Z::Zernike{T}, rθ::RadialCoordinate, B::BlockIndex{1}) where T
ℓ = Int(block(B))-1
Expand All @@ -153,10 +189,17 @@ function getindex(Z::Zernike{T}, rθ::RadialCoordinate, B::BlockIndex{1}) where
zernikez(ℓ, (isodd(k+ℓ) ? 1 : -1) * m, Z.a, Z.b, rθ)
end

function getindex(Z::ComplexZernike{T}, rθ::RadialCoordinate, B::BlockIndex{1}) where T
ℓ = Int(block(B))-1
k = blockindex(B)
m = iseven(ℓ) ? k-isodd(k) : k-iseven(k)
complexzernikez(ℓ, (isodd(k+ℓ) ? 1 : -1) * m, Z.a, Z.b, rθ)
end


getindex(Z::Zernike, xy::StaticVector{2}, B::BlockIndex{1}) = Z[RadialCoordinate(xy), B]
getindex(Z::Zernike, xy::StaticVector{2}, B::Block{1}) = [Z[xy, B[j]] for j=1:Int(B)]
getindex(Z::Zernike, xy::StaticVector{2}, JR::BlockOneTo) = mortar([Z[xy,Block(J)] for J = 1:Int(JR[end])])
getindex(Z::AbstractZernike, xy::StaticVector{2}, B::BlockIndex{1}) = Z[RadialCoordinate(xy), B]
getindex(Z::AbstractZernike, xy::StaticVector{2}, B::Block{1}) = [Z[xy, B[j]] for j=1:Int(B)]
getindex(Z::AbstractZernike, xy::StaticVector{2}, JR::BlockOneTo) = mortar([Z[xy,Block(J)] for J = 1:Int(JR[end])])



Expand Down Expand Up @@ -220,7 +263,11 @@ end

factorize(S::FiniteZernike{T}) where T = TransformFactorization(grid(S), ZernikeTransform{T}(blocksize(S,2), parent(S).a, parent(S).b))

# gives the entries for the Laplacian times (1-r^2) * Zernike(1)
"""
WeightedZernikeLaplacianDiag{T}()

gives the entries for the Laplacian times (1-r^2) * Zernike(1)
"""
struct WeightedZernikeLaplacianDiag{T} <: AbstractBlockVector{T} end

axes(::WeightedZernikeLaplacianDiag) = (blockedrange(oneto(∞)),)
Expand All @@ -243,7 +290,7 @@ end

getindex(W::WeightedZernikeLaplacianDiag, k::Integer) = W[findblockindex(axes(W,1),k)]

@simplify function *(Δ::Laplacian, WZ::Weighted{<:Any,<:Zernike})
@simplify function *(Δ::Laplacian, WZ::Weighted{<:Any,<:AbstractZernike})
@assert WZ.P.a == 0 && WZ.P.b == 1
WZ.P * Diagonal(WeightedZernikeLaplacianDiag{eltype(eltype(WZ))}())
end
Expand Down Expand Up @@ -288,22 +335,55 @@ end
# 2 dimensional special case, again without the 2^(2*β) factor
fractionalcfs2d(l::Integer, m::Integer, β) = fractionalcfs(l,m,β,2)

function \(A::Zernike{T}, B::Zernike{V}) where {T,V}
TV = promote_type(T,V)
A.a == B.a && A.b == B.b && return Eye{TV}(∞)
@assert A.a == 0 && A.b == 1
@assert B.a == 0 && B.b == 0
ModalInterlace{TV}((Normalized.(Jacobi{TV}.(1,0:∞)) .\ Normalized.(Jacobi{TV}.(0,0:∞))) ./ sqrt(convert(TV, 2)), (ℵ₀,ℵ₀), (0,2))
for Zer in (:Zernike, :ComplexZernike)
@eval begin
function \(A::$Zer{T}, B::$Zer{V}) where {T,V}
TV = promote_type(T,V)
A.a == B.a && A.b == B.b && return Eye{TV}(∞)
@assert A.a == 0 && A.b == 1
@assert B.a == 0 && B.b == 0
ModalInterlace{TV}((Normalized.(Jacobi{TV}.(1,0:∞)) .\ Normalized.(Jacobi{TV}.(0,0:∞))) ./ sqrt(convert(TV, 2)), (ℵ₀,ℵ₀), (0,2))
end

function \(A::$Zer{T}, B::Weighted{V,$Zer{V}}) where {T,V}
TV = promote_type(T,V)
A.a == B.P.a == A.b == B.P.b == 0 && return Eye{TV}(∞)
if A.a == A.b == 0
@assert B.P.a == 0 && B.P.b == 1
ModalInterlace{TV}((Normalized.(Jacobi{TV}.(0, 0:∞)) .\ HalfWeighted{:a}.(Normalized.(Jacobi{TV}.(1, 0:∞)))) ./ sqrt(convert(TV, 2)), (ℵ₀,ℵ₀), (2,0))
else
Z = $Zer{TV}()
(A \ Z) * (Z \ B)
end
end
end
end

function \(A::Zernike{T}, B::Weighted{V,Zernike{V}}) where {T,V}
TV = promote_type(T,V)
A.a == B.P.a == A.b == B.P.b == 0 && return Eye{TV}(∞)
if A.a == A.b == 0
@assert B.P.a == 0 && B.P.b == 1
ModalInterlace{TV}((Normalized.(Jacobi{TV}.(0, 0:∞)) .\ HalfWeighted{:a}.(Normalized.(Jacobi{TV}.(1, 0:∞)))) ./ sqrt(convert(TV, 2)), (ℵ₀,ℵ₀), (2,0))
else
Z = Zernike{TV}()
(A \ Z) * (Z \ B)

#########
# ComplexDerivative
# is the complex differential (∂ˣ - im*∂ʸ)/2
#########


for Der in (:ComplexDerivative, :ConjComplexDerivative)
@eval begin
struct $Der{T,Ax<:Inclusion} <: LazyQuasiMatrix{Complex{T}}
axis::Ax
end

$Der{T}(axis::Inclusion) where {k,T} = $Der{T,typeof(axis)}(axis)
$Der{T}(domain) where {k,T} = $Der{T}(Inclusion(domain))
$Der(axis) where k = $Der{eltype(eltype(axis))}(axis)

axes(D::$Der) = (D.axis, D.axis)
==(a::$Der, b::$Der) where k = a.axis == b.axis
copy(D::$Der) where k = $Der(copy(D.axis))

^(D::$Der, k::Integer) = ApplyQuasiArray(^, D, k)
end
end
end

conj(D::ComplexDerivative{T}) where T = ConjComplexDerivative{T}(D.axis)
conj(D::ConjComplexDerivative{T}) where T = ComplexDerivative{T}(D.axis)

Loading