Skip to content

Commit

Permalink
mostly convert to retraction_t
Browse files Browse the repository at this point in the history
  • Loading branch information
mateuszbaran committed Jan 25, 2025
1 parent 6d8555d commit 255623f
Show file tree
Hide file tree
Showing 23 changed files with 72 additions and 40 deletions.
2 changes: 1 addition & 1 deletion ext/ManifoldsTestExt/tests_general.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ function test_manifold(
Test.@test isapprox(
M,
p,
retract(M, p, X, 0, retr_method);
retract_t(M, p, X, 0, retr_method);
atol=epsx * retraction_atol_multiplier,
rtol=retraction_atol_multiplier == 0 ?
sqrt(epsx) * retraction_rtol_multiplier : 0,
Expand Down
4 changes: 4 additions & 0 deletions src/Manifolds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ import ManifoldsBase:
retract_qr!,
retract_sasaki!,
retract_softmax!,
retract_t,
retract_t!,
riemann_tensor,
riemann_tensor!,
sectional_curvature,
Expand Down Expand Up @@ -920,6 +922,8 @@ export ×,
representation_size,
retract,
retract!,
retract_t,
retract_t!,
riemannian_gradient,
riemannian_gradient!,
riemannian_Hessian,
Expand Down
2 changes: 1 addition & 1 deletion src/manifolds/Elliptope.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ compute a projection based retraction by projecting ``q+Y`` back onto the manifo
"""
retract(::Elliptope, ::Any, ::Any, ::ProjectionRetraction)

function retract_project!(M::Elliptope, r, q, Y, t::Number)
function ManifoldsBase.retract_project_t!(M::Elliptope, r, q, Y, t::Number)
r .= q .+ t .* Y
project!(M, r, r)
return r
Expand Down
12 changes: 6 additions & 6 deletions src/manifolds/FixedRankMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ Compute the in-place variant of the [`OrthographicInverseRetraction`](@ref).
inverse_retract_orthographic!(M::AbstractManifold, X, p, q)

## Layer II
function _retract!(
function ManifoldsBase._retract_t!(
M::AbstractManifold,
q,
p,
Expand All @@ -173,16 +173,16 @@ function _retract!(
::OrthographicRetraction;
kwargs...,
)
return retract_orthographic!(M, q, p, X, t; kwargs...)
return retract_orthographic_t!(M, q, p, X, t; kwargs...)
end

## Layer III
"""
retract_orthographic!(M::AbstractManifold, q, p, X, t::Number)
retract_orthographic!(M::AbstractManifold, q, p, X)
Compute the in-place variant of the [`OrthographicRetraction`](@ref).
"""
retract_orthographic!(M::AbstractManifold, q, p, X, t::Number)
retract_orthographic!(M::AbstractManifold, q, p, X)

# \|---

Expand Down Expand Up @@ -655,7 +655,7 @@ For more details, see [AbsilOseledets:2014](@cite).
"""
retract(::FixedRankMatrices, ::Any, ::Any, ::OrthographicRetraction)

function retract_orthographic!(
function retract_orthographic_t!(
M::FixedRankMatrices,
q::SVDMPoint,
p::SVDMPoint,
Expand Down Expand Up @@ -689,7 +689,7 @@ singular values and ``U`` and ``V`` are shortened accordingly.
"""
retract(::FixedRankMatrices, ::Any, ::Any, ::PolarRetraction)

function retract_polar!(
function ManifoldsBase.retract_polar_t!(
M::FixedRankMatrices,
q::SVDMPoint,
p::SVDMPoint,
Expand Down
2 changes: 1 addition & 1 deletion src/manifolds/FlagOrthogonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ as the first order approximation to the exponential map. Similar to QR retractio
"""
retract(M::Flag, p::OrthogonalPoint, X::OrthogonalTangentVector, ::QRRetraction)

function retract_qr!(
function ManifoldsBase.retract_qr_t!(
::Flag,
q::OrthogonalPoint{<:AbstractMatrix{T}},
p::OrthogonalPoint,
Expand Down
2 changes: 1 addition & 1 deletion src/manifolds/FlagStiefel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ where $\cdot^{\mathrm{H}}$ denotes the complex conjugate transposed or Hermitian
"""
retract(::Flag, ::Any, ::Any, ::PolarRetraction)

function retract_polar!(::Flag, q, p, X, t::Number)
function ManifoldsBase.retract_polar_t!(::Flag, q, p, X, t::Number)
s = svd(p .+ t .* X)
return mul!(q, s.U, s.Vt)
end
4 changes: 2 additions & 2 deletions src/manifolds/GeneralUnitaryMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ This is also the default retraction on these manifolds.
"""
retract(::GeneralUnitaryMatrices, ::Any, ::Any, ::QRRetraction)

function retract_qr!(
function ManifoldsBase.retract_qr_t!(
::GeneralUnitaryMatrices,
q::AbstractArray{T},
p,
Expand All @@ -1042,7 +1042,7 @@ function retract_qr!(
D = Diagonal(sign.(d .+ convert(T, 0.5)))
return copyto!(q, qr_decomp.Q * D)
end
function retract_polar!(M::GeneralUnitaryMatrices, q, p, X, t::Number)
function ManifoldsBase.retract_polar_t!(M::GeneralUnitaryMatrices, q, p, X, t::Number)
A = p + p * (t * X)
return project!(M, q, A; check_det=false)
end
Expand Down
4 changes: 2 additions & 2 deletions src/manifolds/GeneralizedGrassmann.jl
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,12 @@ Compute the SVD-based retraction [`PolarRetraction`](@extref `ManifoldsBase.Pola
"""
retract(::GeneralizedGrassmann, ::Any, ::Any, ::PolarRetraction)

function retract_polar!(M::GeneralizedGrassmann, q, p, X, t::Number)
function ManifoldsBase.retract_polar_t!(M::GeneralizedGrassmann, q, p, X, t::Number)
q .= p .+ t .* X
project!(M, q, q)
return q
end
function retract_project!(M::GeneralizedGrassmann, q, p, X, t::Number)
function ManifoldsBase.retract_project_t!(M::GeneralizedGrassmann, q, p, X, t::Number)
q .= p .+ t .* X
project!(M, q, q)
return q
Expand Down
4 changes: 2 additions & 2 deletions src/manifolds/GeneralizedStiefel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,12 +243,12 @@ retract(::GeneralizedStiefel, ::Any...)

default_retraction_method(::GeneralizedStiefel) = ProjectionRetraction()

function retract_polar!(M::GeneralizedStiefel, q, p, X, t::Number)
function ManifoldsBase.retract_polar_t!(M::GeneralizedStiefel, q, p, X, t::Number)
q .= p .+ t .* X
project!(M, q, q)
return q
end
function retract_project!(M::GeneralizedStiefel, q, p, X, t::Number)
function ManifoldsBase.retract_project_t!(M::GeneralizedStiefel, q, p, X, t::Number)
q .= p .+ t .* X
project!(M, q, q)
return q
Expand Down
4 changes: 2 additions & 2 deletions src/manifolds/GrassmannStiefel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ where ``⋅^{\mathrm{H}}`` denotes the complex conjugate transposed or Hermitian
"""
retract(::Grassmann, ::Any, ::Any, ::PolarRetraction)

function retract_polar!(M::Grassmann, q, p, X, t::Number)
function ManifoldsBase.retract_polar_t!(M::Grassmann, q, p, X, t::Number)
q .= p .+ t .* X
project!(M, q, q)
return q
Expand All @@ -340,7 +340,7 @@ D = \operatorname{diag}\left( \operatorname{sgn}\left(R_{ii}+\frac{1}{2}\right)_
"""
retract(::Grassmann, ::Any, ::Any, ::QRRetraction)

function retract_qr!(::Grassmann, q, p, X, t::Number)
function ManifoldsBase.retract_qr_t!(::Grassmann, q, p, X, t::Number)
q .= p .+ t .* X
qrfac = qr(q)
d = diag(qrfac.R)
Expand Down
2 changes: 1 addition & 1 deletion src/manifolds/Hyperrectangle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ Return the array dimensions required to represent an element on the
"""
representation_size(M::Hyperrectangle) = size(M.lb)

function retract_project!(M::Hyperrectangle, r, q, Y, t::Number)
function ManifoldsBase.retract_project_t!(M::Hyperrectangle, r, q, Y, t::Number)
r .= q .+ t .* Y
project!(M, r, r)
return r
Expand Down
8 changes: 7 additions & 1 deletion src/manifolds/MultinomialDoublyStochastic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,13 @@ refers to the elementwise exponentiation.
"""
retract(::MultinomialDoubleStochastic, ::Any, ::Any, ::ProjectionRetraction)

function retract_project!(M::MultinomialDoubleStochastic, q, p, X, t::Number)
function ManifoldsBase.retract_project_t!(
M::MultinomialDoubleStochastic,
q,
p,
X,
t::Number,
)
return project!(M, q, p .* exp.(t .* X ./ p))
end

Expand Down
2 changes: 1 addition & 1 deletion src/manifolds/MultinomialSymmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ refers to the elementwise exponentiation.
"""
retract(::MultinomialSymmetric, ::Any, ::Any, ::ProjectionRetraction)

function retract_project!(M::MultinomialSymmetric, q, p, X, t::Number)
function ManifoldsBase.retract_project_t!(M::MultinomialSymmetric, q, p, X, t::Number)
return project!(M, q, p .* exp.(t .* X ./ p))
end

Expand Down
2 changes: 1 addition & 1 deletion src/manifolds/ProbabilitySimplex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ where multiplication, exponentiation and division are meant elementwise.
"""
retract(::ProbabilitySimplex, ::Any, ::Any, ::SoftmaxRetraction)

function retract_softmax!(::ProbabilitySimplex, q, p, X, t::Number)
function ManifoldsBase.retract_softmax_t!(::ProbabilitySimplex, q, p, X, t::Number)
s = zero(eltype(q))
@inbounds for i in eachindex(q, p, X)
q[i] = p[i] * exp(t * X[i])
Expand Down
6 changes: 3 additions & 3 deletions src/manifolds/ProjectiveSpace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -506,15 +506,15 @@ retract(
::Union{ProjectionRetraction,PolarRetraction,QRRetraction},
)

function retract_polar!(M::AbstractProjectiveSpace, q, p, X, t::Number)
function ManifoldsBase.retract_polar_t!(M::AbstractProjectiveSpace, q, p, X, t::Number)
q .= p .+ t .* X
return project!(M, q, q)
end
function retract_project!(M::AbstractProjectiveSpace, q, p, X, t::Number)
function ManifoldsBase.retract_project_t!(M::AbstractProjectiveSpace, q, p, X, t::Number)
q .= p .+ t .* X
return project!(M, q, q)
end
function retract_qr!(M::AbstractProjectiveSpace, q, p, X, t::Number)
function ManifoldsBase.retract_qr_t!(M::AbstractProjectiveSpace, q, p, X, t::Number)
q .= p .+ t .* X
return project!(M, q, q)
end
Expand Down
4 changes: 3 additions & 1 deletion src/manifolds/Spectrahedron.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ compute a projection based retraction by projecting ``q+Y`` back onto the manifo
"""
retract(::Spectrahedron, ::Any, ::Any, ::ProjectionRetraction)

retract_project!(M::Spectrahedron, r, q, Y, t::Number) = project!(M, r, q .+ t .* Y)
function ManifoldsBase.retract_project_t!(M::Spectrahedron, r, q, Y, t::Number)
return project!(M, r, q .+ t .* Y)
end

@doc raw"""
representation_size(M::Spectrahedron)
Expand Down
2 changes: 1 addition & 1 deletion src/manifolds/Sphere.jl
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ Compute the retraction that is based on projection, i.e.
"""
retract(::AbstractSphere, ::Any, ::Any, ::ProjectionRetraction)

function retract_project!(M::AbstractSphere, q, p, X, t::Number)
function ManifoldsBase.retract_project_t!(M::AbstractSphere, q, p, X, t::Number)
q .= p .+ t .* X
return project!(M, q, q)
end
Expand Down
13 changes: 10 additions & 3 deletions src/manifolds/Stiefel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,14 @@ retract(::Stiefel, ::Any, ::Any, ::QRRetraction)
_qrfac_to_q(qrfac) = Matrix(qrfac.Q)
_qrfac_to_q(qrfac::StaticArrays.QR) = qrfac.Q

function retract_pade!(::Stiefel, q, p, X, t::Number, ::PadeRetraction{m}) where {m}
function ManifoldsBase.retract_pade_t!(
::Stiefel,
q,
p,
X,
t::Number,
::PadeRetraction{m},
) where {m}
tX = t * X
Pp = I - 1 // 2 * p * p'
WpX = Pp * tX * p' - p * tX' * Pp
Expand All @@ -493,12 +500,12 @@ function retract_pade!(::Stiefel, q, p, X, t::Number, ::PadeRetraction{m}) where
end
return copyto!(q, (qm \ pm) * p)
end
function retract_polar!(::Stiefel, q, p, X, t::Number)
function ManifoldsBase.retract_polar_t!(::Stiefel, q, p, X, t::Number)
q .= p .+ t .* X
s = svd(q)
return mul!(q, s.U, s.Vt)
end
function retract_qr!(::Stiefel, q, p, X, t::Number)
function ManifoldsBase.retract_qr_t!(::Stiefel, q, p, X, t::Number)
q .= p .+ t .* X
qrfac = qr(q)
d = diag(qrfac.R)
Expand Down
2 changes: 1 addition & 1 deletion src/manifolds/StiefelEuclideanMetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ The retraction is computed by projecting the exponential map in the embedding to
"""
retract(::Stiefel, ::Any, ::Any, ::ProjectionRetraction)

function retract_project!(M::Stiefel, q, p, X, t::Number)
function ManifoldsBase.retract_project_t!(M::Stiefel, q, p, X, t::Number)
q .= p .+ t .* X
project!(M, q, q)
return q
Expand Down
2 changes: 1 addition & 1 deletion src/manifolds/SymplecticStiefel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ This expression is computed inplace of `q`.
"""
retract(::SymplecticStiefel, p, X, ::CayleyRetraction)

function retract_cayley!(M::SymplecticStiefel, q, p, X, t::Number)
function ManifoldsBase.retract_cayley_t!(M::SymplecticStiefel, q, p, X, t::Number)
tX = t * X
# Define intermediate matrices for later use:
A = symplectic_inverse_times(M, p, tX)
Expand Down
2 changes: 1 addition & 1 deletion src/manifolds/Tucker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ retraction produces a boundary point, which is outside the manifold.
"""
retract(::Tucker, ::Any, ::Any, ::PolarRetraction)

function retract_polar!(
function ManifoldsBase.retract_polar_t!(
::Tucker,
q::TuckerPoint,
p::TuckerPoint{T,D},
Expand Down
23 changes: 18 additions & 5 deletions src/manifolds/VectorBundle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ function project!(B::VectorBundle, Y, p, X)
return Y
end

function _retract!(M::VectorBundle, q, p, X, t::Number, ::FiberBundleProductRetraction)
function _retract_t!(M::VectorBundle, q, p, X, t::Number, ::FiberBundleProductRetraction)
return retract_product!(M, q, p, X, t)
end

Expand All @@ -249,16 +249,22 @@ which by default allocates and calls `retract_product!`.
"""
retract(::VectorBundle, p, q, t::Number, ::FiberBundleProductRetraction)

function _retract(M::VectorBundle, p, X, t::Number, ::FiberBundleProductRetraction)
function ManifoldsBase._retract_t(
M::VectorBundle,
p,
X,
t::Number,
::FiberBundleProductRetraction,
)
return retract_product(M, p, X, t)
end

function retract_product(M::VectorBundle, p, X, t::Number)
function retract_product_t(M::VectorBundle, p, X, t::Number)
q = allocate_result(M, retract, p, X)
return retract_product!(M, q, p, X, t)
end

function retract_product!(B::VectorBundle, q, p, X, t::Number)
function retract_product_t!(B::VectorBundle, q, p, X, t::Number)
tX = t * X
xp, Xp = submanifold_components(B.manifold, p)
xq, Xq = submanifold_components(B.manifold, q)
Expand All @@ -277,7 +283,14 @@ function retract_product!(B::VectorBundle, q, p, X, t::Number)
return q
end

function retract_sasaki!(B::TangentBundle, q, p, X, t::Number, m::SasakiRetraction)
function ManifoldsBase.retract_sasaki_t!(
B::TangentBundle,
q,
p,
X,
t::Number,
m::SasakiRetraction,
)
tX = t * X
xp, Xp = submanifold_components(B.manifold, p)
xq, Xq = submanifold_components(B.manifold, q)
Expand Down
4 changes: 2 additions & 2 deletions src/statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ function Statistics.mean!(
inverse_retract!(M, vtmp, yold, x[j], inverse_retraction)
v .+= α[j] .* (vtmp .- v)
end
retract!(M, y, yold, v, 0.5, retraction)
retract_t!(M, y, yold, v, 0.5, retraction)
isapprox(M, y, yold; kwargs...) && break
end
return y
Expand Down Expand Up @@ -283,7 +283,7 @@ function Statistics.mean!(
s += w[j]
t = w[j] / s
inverse_retract!(M, v, q, x[j], inverse_retraction)
retract!(M, ytmp, q, v, t, retraction)
retract_t!(M, ytmp, q, v, t, retraction)
copyto!(q, ytmp)
end
return q
Expand Down

0 comments on commit 255623f

Please sign in to comment.