Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move more packages to weak dependencies #693

Closed
wants to merge 3 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Part I: RecursiveArrayTools as weak dep.
kellertuer committed Dec 14, 2023
commit 302a0a59d95bb50ad28b93e814d6b96208a5752e
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -16,7 +16,6 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
MatrixEquations = "99c1a7ee-ab34-5fd5-8076-27c950a045f4"
Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
@@ -31,13 +30,15 @@ DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[extensions]
ManifoldsBoundaryValueDiffEqExt = "BoundaryValueDiffEq"
ManifoldsNLsolveExt = "NLsolve"
ManifoldsOrdinaryDiffEqDiffEqCallbacksExt = ["DiffEqCallbacks", "OrdinaryDiffEq"]
ManifoldsOrdinaryDiffEqDiffEqCallbacksRecursiveArrayToolsExt = ["DiffEqCallbacks", "OrdinaryDiffEq", "RecursiveArrayTools"]
ManifoldsOrdinaryDiffEqExt = "OrdinaryDiffEq"
ManifoldsRecursiveArrayToolsExt = ["RecursiveArrayTools"]
ManifoldsRecipesBaseExt = ["Colors", "RecipesBase"]
ManifoldsTestExt = "Test"

Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@ if isdefined(Base, :get_extension)
transition_map_diff!
import Manifolds: solve_chart_exp_ode, solve_chart_parallel_transport_ode
using ManifoldsBase

using RecursiveArrayTools: ArrayPartition
using DiffEqCallbacks
using OrdinaryDiffEq: OrdinaryDiffEq, SciMLBase, Rodas5, AutoVern9, ODEProblem, solve
else
@@ -27,6 +27,7 @@ else
using ..ManifoldsBase

using ..DiffEqCallbacks
using ..RecursiveArrayTools: ArrayPartition
using ..OrdinaryDiffEq: OrdinaryDiffEq, SciMLBase, Rodas5, AutoVern9, ODEProblem, solve
end

91 changes: 91 additions & 0 deletions ext/ManifoldsRecursiveArrayToolsExt/FiberBundleRATExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@

@inline function allocate_result(M::FiberBundle, f::TF) where {TF}
p = allocate_result(M.manifold, f)
X = allocate_result(Fiber(M.manifold, p, M.type), f)
return ArrayPartition(p, X)
end

function get_vector(M::FiberBundle, p, X, B::AbstractBasis)
n = manifold_dimension(M.manifold)
xp1, xp2 = submanifold_components(M, p)
F = Fiber(M.manifold, xp1, M.type)
return ArrayPartition(
get_vector(M.manifold, xp1, X[1:n], B),
get_vector(F, xp2, X[(n + 1):end], B),
)
end
function get_vector(
M::FiberBundle,
p,
X,
B::CachedBasis{𝔽,<:AbstractBasis{𝔽},<:FiberBundleBasisData},
) where {𝔽}
n = manifold_dimension(M.manifold)
xp1, xp2 = submanifold_components(M, p)
F = Fiber(M.manifold, xp1, M.type)
return ArrayPartition(
get_vector(M.manifold, xp1, X[1:n], B.data.base_basis),
get_vector(F, xp2, X[(n + 1):end], B.data.fiber_basis),
)
end

function get_vectors(
M::FiberBundle,
p::ArrayPartition,
B::CachedBasis{𝔽,<:AbstractBasis{𝔽},<:FiberBundleBasisData},
) where {𝔽}
xp1, xp2 = submanifold_components(M, p)
zero_m = zero_vector(M.manifold, xp1)
F = Fiber(M.manifold, xp1, M.type)
zero_f = zero_vector(F, xp1)
vs = typeof(ArrayPartition(zero_m, zero_f))[]
for bv in get_vectors(M.manifold, xp1, B.data.base_basis)
push!(vs, ArrayPartition(bv, zero_f))
end
for bv in get_vectors(F, xp2, B.data.fiber_basis)
push!(vs, ArrayPartition(zero_m, bv))
end
return vs
end


"""
getindex(p::ArrayPartition, M::FiberBundle, s::Symbol)
p[M::FiberBundle, s]

Access the element(s) at index `s` of a point `p` on a [`FiberBundle`](@ref) `M` by
using the symbols `:point` and `:vector` or `:fiber` for the base and vector or fiber
component, respectively.
"""
@inline function getindex(p::ArrayPartition, M::FiberBundle, s::Symbol)
(s === :point) && return p.x[1]
(s === :vector || s === :fiber) && return p.x[2]
return throw(DomainError(s, "unknown component $s on $M."))
end

"""
setindex!(p::ArrayPartition, val, M::FiberBundle, s::Symbol)
p[M::VectorBundle, s] = val

Set the element(s) at index `s` of a point `p` on a [`FiberBundle`](@ref) `M` to `val` by
using the symbols `:point` and `:fiber` or `:vector` for the base and fiber or vector
component, respectively.

!!! note

The *content* of element of `p` is replaced, not the element itself.
"""
@inline function setindex!(x::ArrayPartition, val, M::FiberBundle, s::Symbol)
if s === :point
return copyto!(x.x[1], val)
elseif s === :vector || s === :fiber
return copyto!(x.x[2], val)
else
throw(DomainError(s, "unknown component $s on $M."))
end
end
@inline function view(x::ArrayPartition, M::FiberBundle, s::Symbol)
(s === :point) && return x.x[1]
(s === :vector || s === :fiber) && return x.x[2]
throw(DomainError(s, "unknown component $s on $M."))
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
module ManifoldsRecursiveArrayToolsExt

if isdefined(Base, :get_extension)
using Base: @propagate_inbounds
using Manifolds
using Manifolds: submanifold_components
using RecursiveArrayTools: ArrayPartition
import Base: getindex, setindex!, view
import Manifolds:
ProductFVectorDistribution,
adjoint_Jacobi_field,
allocate,
allocate_result,
apply,
apply!,
apply_diff,
apply_diff_group!,
compose,
_compose,
exp,
exp_lie,
get_coordinates,
get_vector,
get_vectors,
hat,
identity_element,
inverse_apply,
inverse_apply_diff,
inverse_translate,
inverse_translate_diff,
isapprox,
jacobi_field,
lie_bracket,
log,
optimal_alignment,
project,
rand,
_rand!,
translate,
translate_diff,
_vector_transport_direction,
_vector_transport_to,
vee,

else
# imports need to be relative for Requires.jl-based workflows:
# https://github.com/JuliaArrays/ArrayInterface.jl/pull/387
using ..Manifolds
using ..RecursiveArrayTools
import Base: getindex, setindex!, view
import Manifolds:
ProductFVectorDistribution,
adjoint_Jacobi_field,
allocate,
allocate_result,
apply,
apply!,
apply_diff,
apply_diff_group!,
_compose,
exp,
exp_lie,
get_vector,
get_vectors,
identity_element,
inverse_apply,
inverse_apply_diff,
inverse_translate,
inverse_translate_diff,
isapprox,
jacobi_field,
log,
optimal_alignment,
rand,
_rand!,
translate,
translate_diff,
_vector_transport_direction,
_vector_transport_to
end

function allocate(
::PowerManifoldNestedReplacing,
x::AbstractArray{<:ArrayPartition{T,<:NTuple{N,SArray}}},
) where {T,N}
return similar(x)
end

include("FiberBundleRATExt.jl")
include("ProductGroupRATExt.jl")
include("ProductManifoldRATExt.jl")
include("rotation_translation_actionRATExt.jl")
include("semidirect_product_groupRATExt.jl")
include("VectorBundleRATExt.jl")
end
131 changes: 131 additions & 0 deletions ext/ManifoldsRecursiveArrayToolsExt/ProductGroupRATExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
function _compose(M::ProductManifold, p::ArrayPartition, q::ArrayPartition)
return ArrayPartition(
map(
compose,
M.manifolds,
submanifold_components(M, p),
submanifold_components(M, q),
)...,
)
end

function exp(M::ProductGroup, p::Identity{ProductOperation}, X::ArrayPartition)
return ArrayPartition(
map(
exp,
M.manifold.manifolds,
submanifold_components(M, p),
submanifold_components(M, X),
)...,
)
end

function exp_lie(G::ProductGroup, X)
M = G.manifold
return ArrayPartition(map(exp_lie, M.manifolds, submanifold_components(G, X))...)
end

Base.@propagate_inbounds function Base.getindex(
p::ArrayPartition,
M::ProductGroup,
i::Union{Integer,Colon,AbstractVector,Val},
)
return getindex(p, base_manifold(M), i)
end

function identity_element(G::ProductGroup)
M = G.manifold
return ArrayPartition(map(identity_element, M.manifolds))
end

function inverse_translate(G::ProductGroup, p, q, conv::ActionDirectionAndSide)
M = G.manifold
return ArrayPartition(
map(
inverse_translate,
M.manifolds,
submanifold_components(G, p),
submanifold_components(G, q),
repeated(conv),
)...,
)
end

function inverse_translate_diff(G::ProductGroup, p, q, X, conv::ActionDirectionAndSide)
M = G.manifold
return ArrayPartition(
map(
inverse_translate_diff,
M.manifolds,
submanifold_components(G, p),
submanifold_components(G, q),
submanifold_components(G, X),
repeated(conv),
)...,
)
end

# these isapprox methods are here just to reduce time-to-first-isapprox
function isapprox(G::ProductGroup, p::ArrayPartition, q::ArrayPartition; kwargs...)
return isapprox(G.manifold, p, q; kwargs...)
end
function isapprox(
G::ProductGroup,
p::ArrayPartition,
X::ArrayPartition,
Y::ArrayPartition;
kwargs...,
)
return isapprox(G.manifold, p, X, Y; kwargs...)
end

function Base.log(M::ProductGroup, p::Identity{ProductOperation}, q::ArrayPartition)
return ArrayPartition(
map(
log,
M.manifold.manifolds,
submanifold_components(M, p),
submanifold_components(M, q),
)...,
)
end

Base.@propagate_inbounds function Base.setindex!(
q::ArrayPartition,
p,
M::ProductGroup,
i::Union{Integer,Colon,AbstractVector,Val},
)
return setindex!(q, p, base_manifold(M), i)
end

function translate(
M::ProductGroup,
p::ArrayPartition,
q::ArrayPartition,
conv::ActionDirectionAndSide,
)
return ArrayPartition(
map(
translate,
M.manifold.manifolds,
submanifold_components(M, p),
submanifold_components(M, q),
repeated(conv),
)...,
)
end

function translate_diff(G::ProductGroup, p, q, X, conv::ActionDirectionAndSide)
M = G.manifold
return ArrayPartition(
map(
translate_diff,
M.manifolds,
submanifold_components(G, p),
submanifold_components(G, q),
submanifold_components(G, X),
repeated(conv),
)...,
)
end
Loading