diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2e3a0a6108..0b40047c36 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,12 +7,15 @@ on: jobs: test: - name: Julia ${{ matrix.julia-version }} - ${{ matrix.os }} + name: Julia ${{ matrix.julia-version }} - ${{ matrix.group }} - ${{ matrix.os }} runs-on: ${{ matrix.os }} strategy: matrix: - julia-version: ["1.5", "1.6", "~1.7.0-0"] + julia-version: ["1.6", "1.7"] os: [ubuntu-latest, macOS-latest] + group: + - 'test_manifolds' + - 'test_lie_groups' steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 @@ -23,8 +26,9 @@ jobs: - uses: julia-actions/julia-runtest@latest env: PYTHON: "" + MANIFOLDS_TEST_GROUP: ${{ matrix.group }} - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v1 + - uses: codecov/codecov-action@v2 with: fail_ci_if_error: false - if: ${{ matrix.julia-version == '1.6' && matrix.os =='ubuntu-latest' }} + if: ${{ matrix.julia-version == '1.7' && matrix.os =='ubuntu-latest' }} diff --git a/.github/workflows/clear_preview.yml b/.github/workflows/clear_preview.yml index 129d85c1a9..22d90435ef 100644 --- a/.github/workflows/clear_preview.yml +++ b/.github/workflows/clear_preview.yml @@ -1,4 +1,4 @@ -name: Documentation Preview Cleanup +name: Doc Preview Cleanup on: pull_request: @@ -12,17 +12,15 @@ jobs: uses: actions/checkout@v2 with: ref: gh-pages - - - name: Delete preview and history + - name: Delete preview and history + push changes run: | - git config user.name "Documenter.jl" - git config user.email "documenter@juliadocs.github.io" - git rm -rf "previews/PR$PRNUM" - git commit -m "delete preview" - git branch gh-pages-new $(echo "delete history" | git commit-tree HEAD^{tree}) + if [ -d "previews/PR$PRNUM" ]; then + git config user.name "Documenter.jl" + git config user.email "documenter@juliadocs.github.io" + git rm -rf "previews/PR$PRNUM" + git commit -m "delete preview" + git branch gh-pages-new $(echo "delete history" | git commit-tree HEAD^{tree}) + git push --force origin gh-pages-new:gh-pages + fi env: - PRNUM: ${{ github.event.number }} - - - name: Push changes - run: | - git push --force origin gh-pages-new:gh-pages \ No newline at end of file + PRNUM: ${{ github.event.number }} \ No newline at end of file diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 2cd8d5552c..3168849500 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -4,12 +4,15 @@ on: jobs: test: - name: Julia nightly - ${{ matrix.os }} + name: Julia nightly - ${{ matrix.group }} - ${{ matrix.os }} runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: os: [ubuntu-latest, macOS-latest, windows-latest] + group: + - 'test_manifolds' + - 'test_lie_groups' steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 @@ -20,3 +23,4 @@ jobs: - uses: julia-actions/julia-runtest@latest env: PYTHON: "" + MANIFOLDS_TEST_GROUP: ${{ matrix.group }} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index fe7d18f329..0bd92c98bb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -38,7 +38,7 @@ Even providing a single new method is a good contribution. A main contribution you can provide is another manifold that is not yet included in the package. -A manifold is a concrete type of [`AbstractManifold`](https://juliamanifolds.github.io/Manifolds.jl/latest/interface.html#ManifoldsBase.AbstractManifold) from [`ManifoldsBase.jl`](https://juliamanifolds.github.io/Manifolds.jl/latest/interface.html). +A manifold is a concrete type of [`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) from [`ManifoldsBase.jl`](https://juliamanifolds.github.io/Manifolds.jl/latest/interface.html). This package also provides the main set of functions a manifold can/should implement. Don't worry if you can only implement some of the functions. If the application you have in mind only requires a subset of these functions, implement those. @@ -49,7 +49,7 @@ See for example [exp!](https://juliamanifolds.github.io/Manifolds.jl/latest/inte The non-mutating one (e.g. `exp`) always falls back to use the mutating one, so in most cases it should suffice to implement the mutating one (e.g. `exp!`). -Note that since the first argument is _always_ the [`AbstractManifold`](https://juliamanifolds.github.io/Manifolds.jl/latest/interface.html#ManifoldsBase.AbstractManifold), the mutated argument is always the second one in the signature. +Note that since the first argument is _always_ the [`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold), the mutated argument is always the second one in the signature. In the example we have `exp(M, p, X)` for the exponential map and `exp!(M, q, X, p)` for the mutating one, which stores the result in `q`. On the other hand, the user will most likely look for the documentation of the non-mutating version, so we recommend adding the docstring for the non-mutating one, where all different signatures should be collected in one string when reasonable. diff --git a/Project.toml b/Project.toml index 6a96b49843..24a748bbf7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Manifolds" uuid = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" authors = ["Seth Axen ", "Mateusz Baran ", "Ronny Bergmann ", "Antoine Levitt "] -version = "0.7.8" +version = "0.8.0" [deps] Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" @@ -27,11 +27,10 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Colors = "0.12" Distributions = "0.22.6, 0.23, 0.24, 0.25" Einsum = "0.4" -FiniteDiff = "2" Graphs = "1.4" HybridArrays = "0.4" Kronecker = "0.4, 0.5" -ManifoldsBase = "0.12.13" +ManifoldsBase = "0.13" Plots = "1" RecipesBase = "1.1" RecursiveArrayTools = "2" @@ -40,14 +39,12 @@ SimpleWeightedGraphs = "1.2" SpecialFunctions = "0.8, 0.9, 0.10, 1.0, 2" StaticArrays = "1.0" StatsBase = "0.32, 0.33" -julia = "1.5" +julia = "1.6" [extras] Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78" -FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Gtk = "4c0ca9eb-093a-5379-98c5-f87ac0bbbf44" ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19" ImageMagick = "6218d12a-5da1-5696-b52f-db25d2ecc6d1" @@ -58,10 +55,8 @@ PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee" QuartzImageIO = "dca85d43-d64c-5e67-8c65-017450d5d020" Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" VisualRegressionTests = "34922c18-7c2a-561c-bac1-01e79b2c4c92" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "Colors", "DoubleFloats", "FiniteDiff", "FiniteDifferences", "ForwardDiff", "Gtk", "ImageIO", "ImageMagick", "OrdinaryDiffEq", "NLsolve", "Plots", "PyPlot", "Quaternions", "QuartzImageIO", "RecipesBase", "ReverseDiff", "Zygote"] +test = ["Test", "Colors", "DoubleFloats", "FiniteDifferences", "Gtk", "ImageIO", "ImageMagick", "OrdinaryDiffEq", "NLsolve", "Plots", "PyPlot", "Quaternions", "QuartzImageIO", "RecipesBase"] diff --git a/README.md b/README.md index 2bc0b8552b..962148ba00 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@
- Manifolds.jl + Manifolds.jl Logo with text
| **Documentation** | **Source** | **Citation** | @@ -21,8 +21,8 @@ To install the package just type ] add Manifolds ``` -Then you can directly start, for example to stop half way from the north pole on the [`Sphere`](https://juliamanifolds.github.io/Manifolds.jl/stable/manifolds/sphere.html) to a point on the equator, you can generate the [`shortest_geodesic`](https://juliamanifolds.github.io/Manifolds.jl/stable/interface.html#ManifoldsBase.shortest_geodesic-Tuple{AbstractManifold,Any,Any}). -It internally employs [`exp`](https://juliamanifolds.github.io/Manifolds.jl/latest/interface.html#Base.exp-Tuple{AbstractManifold,Any,Any}) and [`log`](https://juliamanifolds.github.io/Manifolds.jl/latest/interface.html#Base.log-Tuple{AbstractManifold,Any,Any}). +Then you can directly start, for example to stop half way from the north pole on the [`Sphere`](https://juliamanifolds.github.io/Manifolds.jl/stable/manifolds/sphere.html) to a point on the equator, you can generate the [`shortest_geodesic`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions.html#ManifoldsBase.shortest_geodesic-Tuple{AbstractManifold,%20Any,%20Any}). +It internally employs [`exp`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions.html#Base.exp-Tuple{AbstractManifold,%20Any,%20Any}) and [`log`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions.html#Base.log-Tuple{AbstractManifold,%20Any,%20Any}). ```julia using Manifolds diff --git a/REQUIRE b/REQUIRE deleted file mode 100644 index 05b5ab4c7d..0000000000 --- a/REQUIRE +++ /dev/null @@ -1 +0,0 @@ -julia 1.0 diff --git a/assets/retraction_illustration.jl b/assets/retraction_illustration.jl index c7d9f76bdb..0ff804b8a6 100644 --- a/assets/retraction_illustration.jl +++ b/assets/retraction_illustration.jl @@ -95,7 +95,7 @@ push!( ) push!( tp, - raw"\node[label ={[label distance=.05cm]below:{\color{gray}$q=\exp_pX$}}] at (axis cs:" * + raw"\node[label ={[label distance=.05cm]left:{\color{gray}$q=\exp_pX$}}] at (axis cs:" * "$(real(qE)),$(imag(qE))) {};", ) diff --git a/docs/Project.toml b/docs/Project.toml index b6d9d80706..5103b08964 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -2,7 +2,6 @@ Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" HybridArrays = "1baab800-613f-4b0a-84e4-9cd3431bfbb9" Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" @@ -10,16 +9,15 @@ ManifoldsBase = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -Documenter = "0.24, 0.25, 0.26, 0.27" +Documenter = "0.27" FiniteDifferences = "0.12" Graphs = "1.4" HybridArrays = "0.4" -ManifoldsBase = "0.12.9" +ManifoldsBase = "0.13" Plots = "1" PyPlot = "2.9" StaticArrays = "1.0" diff --git a/docs/make.jl b/docs/make.jl index 1ca4b187a2..9b819a2fdd 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,6 +1,6 @@ using Plots, RecipesBase, Manifolds, ManifoldsBase, Documenter, PyPlot # required for loading the manifold tests functios -using Test, ForwardDiff, ReverseDiff, FiniteDifferences +using Test, FiniteDifferences ENV["GKSwstype"] = "100" generated_path = joinpath(@__DIR__, "src", "misc") @@ -26,13 +26,11 @@ end makedocs( # for development, we disable prettyurls format=Documenter.HTML(prettyurls=false, assets=["assets/favicon.ico"]), - modules=[Manifolds, ManifoldsBase], + modules=[Manifolds], authors="Seth Axen, Mateusz Baran, Ronny Bergmann, and contributors.", sitename="Manifolds.jl", pages=[ "Home" => "index.md", - "ManifoldsBase.jl" => "interface.md", - "Examples" => ["How to implement a Manifold" => "examples/manifold.md"], "Manifolds" => [ "Basic manifolds" => [ "Centered matrices" => "manifolds/centeredmatrices.md", diff --git a/docs/src/assets/images/projection_illustration.png b/docs/src/assets/images/projection_illustration.png index 7bea7fce03..9ddd73601a 100644 Binary files a/docs/src/assets/images/projection_illustration.png and b/docs/src/assets/images/projection_illustration.png differ diff --git a/docs/src/assets/images/projection_illustration_600.png b/docs/src/assets/images/projection_illustration_600.png index 49ace7e661..80d4612cf4 100644 Binary files a/docs/src/assets/images/projection_illustration_600.png and b/docs/src/assets/images/projection_illustration_600.png differ diff --git a/docs/src/assets/images/retraction_illustration.png b/docs/src/assets/images/retraction_illustration.png index 2c985e81c9..a3d42d874d 100644 Binary files a/docs/src/assets/images/retraction_illustration.png and b/docs/src/assets/images/retraction_illustration.png differ diff --git a/docs/src/assets/images/retraction_illustration_600.png b/docs/src/assets/images/retraction_illustration_600.png index 6b51d1b3c9..e379593a46 100644 Binary files a/docs/src/assets/images/retraction_illustration_600.png and b/docs/src/assets/images/retraction_illustration_600.png differ diff --git a/docs/src/examples/manifold.md b/docs/src/examples/manifold.md deleted file mode 100644 index d153399bf9..0000000000 --- a/docs/src/examples/manifold.md +++ /dev/null @@ -1,217 +0,0 @@ -# [How to implement your own manifold](@id manifold-tutorial) - -```@meta -CurrentModule = ManifoldsBase -DocTestSetup = quote - using Manifolds -end -``` - -This tutorial demonstrates how to easily set your own manifold up within `Manifolds.jl`. - -## Introduction - -If you looked around a little and saw the [interface](../interface.md), the amount of functions and possibilities, it might seem that a manifold might take some time to implement. -This tutorial demonstrates that you can get your first own manifold quite fast and you only have to implement the functions you actually need. For this tutorial it would be helpful if you take a look at our [notation](../misc/notation.md). -This tutorial assumes that you heard of the exponential map, tangent vectors and the dimension of a manifold. If not, please read for example [[do Carmo, 1992](#doCarmo1992)], -Chapter 3, first. - -In general you need just a datatype (`struct`) that inherits from [`AbstractManifold`](@ref) to define a manifold. No function is _per se_ required to be implemented. -However, it is a good idea to provide functions that might be useful to others, for example [`check_point`](@ref) and [`check_vector`](@ref), as we do in this tutorial. - -We start with two technical preliminaries. If you want to start directly, you can [skip](@ref manifold-tutorial-task) this paragraph and revisit it for two of the implementation details. - -After that, we will - -* [model](@ref manifold-tutorial-task) the manifold -* [implement](@ref manifold-tutorial-checks) two tests, so that points and tangent vectors can be checked for validity, for example also within [`ValidationManifold`](@ref), -* [implement](@ref manifold-tutorial-fn) two functions, the exponential map and the manifold dimension. - -## [Technical preliminaries](@id manifold-tutorial-prel) - -There are only two small technical things we need to explain at this point. -First of all our [`AbstractManifold`](@ref)`{𝔽}` has a parameter `𝔽`. -This parameter indicates the [`number_system`](@ref) the manifold is based on, for example `ℝ` for real manifolds. It is important primarily for defining bases of tangent spaces. -See [`SymmetricMatrices`](@ref Main.Manifolds.SymmetricMatrices) as an example of defining both a real-valued and a complex-valued symmetric manifolds using one type. - -Second, a main design decision of `Manifolds.jl` is that most functions are implemented as mutating functions, i.e. as in-place-computations. There usually exists a non-mutating version that falls back to allocating memory and calling the mutating one. This means you only have to implement the mutating version, _unless_ there is a good reason to provide a special case for the non-mutating one, i.e. because in that case you know a far better performing implementation. - -Let's look at an example. The exponential map $\exp_p\colon T_p\mathcal M \to \mathcal M$ that maps a tangent vector $X\in T_p\mathcal M$ from the tangent space at $p\in \mathcal M$ to the manifold. -The function [`exp`](@ref exp(M::AbstractManifold, p, X)) has to know the manifold `M`, the point `p` and the tangent vector `X` as input, so to compute the resulting point `q` you need to call - -```julia -q = exp(M, p, X) -``` - -If you already have allocated memory for the variable that should store the result, it is better to perform the computations directly in that memory and avoid reallocations. For example - -```julia -q = similar(p) -exp!(M, q, p, X) -``` - -calls [`exp!`](@ref exp!(M::AbstractManifold, q, p, X)), which modifies its input `q` and returns the resulting point in there. -Actually these two lines are (almost) the default implementation for [`exp`](@ref exp(M::AbstractManifold, p, X)). [`allocate_result`](@ref) that is actually used there just calls `similar` for simple `Array`s. -Note that for a unified interface, the manifold `M` is _always_ the first parameter, and the variable the result will be stored to in the mutating variants is _always_ the second parameter. - -Long story short: if possible, implement the mutating version [`exp!`](@ref exp!(M::AbstractManifold, q, p, X)), you get the [`exp`](@ref exp(M::AbstractManifold, p, X)) for free. -Many functions that build upon basic functions employ the mutating variant, too, to avoid reallocations. - -## [Startup](@id manifold-tutorial-startup) - -As a start, let's load `ManifoldsBase.jl` and import the functions we consider throughout this tutorial. -For implementing a manifold, loading the interface should suffice for quite some time. - -```@example manifold-tutorial -using ManifoldsBase, LinearAlgebra, Test -import ManifoldsBase: check_point, check_vector, manifold_dimension, exp! -``` - -## [Goal](@id manifold-tutorial-task) - -As an example, let's implement the sphere, but with a radius $r$. -Since this radius is a property inherent to the manifold, it will become a field of the manifold. -The second information, we want to store is the dimension of the sphere, for example whether it's the 1-sphere, i.e. the circle, represented by vectors $p\in\mathbb R^2$ or the 2-sphere in $\mathbb R^3$. -Since the latter might be something we want to [dispatch](https://en.wikipedia.org/wiki/Multiple_dispatch) on, we model it as a parameter of the type. - -In general the `struct` of a manifold should provide information about the manifold, which are inherent to the manifold or has to be available without a specific point or tangent vector present. -This is -- most prominently -- a way to determine the manifold dimension. - -For our example we define - -```@example manifold-tutorial -""" - MySphere{N} <: AbstractManifold{ℝ} - -Define an `n`-sphere of radius `r`. Construct by `MySphere(radius,n)` -""" -struct MySphere{N} <: AbstractManifold{ManifoldsBase.ℝ} where {N} - radius::Float64 -end -MySphere(radius, n) = MySphere{n}(radius) -Base.show(io::IO, M::MySphere{n}) where {n} = print(io, "MySphere($(M.radius),$n)") -nothing #hide -``` - -Here, the last line just provides a nicer print of a variable of that type -Now we can already initialize our manifold that we will use later, the $2$-sphere of radius $1.5$. - -```@example manifold-tutorial -S = MySphere(1.5, 2) -``` - -## [Checking points and tangents](@id manifold-tutorial-checks) - -If we have now a point, represented as an array, we would first like to check, that it is a valid point on the manifold. -For this one can use the easy interface [`is_point`](@ref is_point(M::AbstractManifold, p; kwargs...)). This internally uses [`check_point`](@ref check_point(M, p; kwargs...)). -This is what we want to implement. -We have to return the error if `p` is not on `M` and `nothing` otherwise. - -We have to check two things: that a point `p` is a vector with `N+1` entries and its norm is the desired radius. -To spare a few lines, we can use [short-circuit evaluation](https://docs.julialang.org/en/v1/manual/control-flow/#Short-Circuit-Evaluation-1) instead of `if` statements. -If something has to only hold up to precision, we can pass that down, too using the `kwargs...`. - -```@example manifold-tutorial -function check_point(M::MySphere{N}, p; kwargs...) where {N} - (size(p)) == (N+1,) || return DomainError(size(p),"The size of $p is not $((N+1,)).") - if !isapprox(norm(p), M.radius; kwargs...) - return DomainError(norm(p), "The norm of $p is not $(M.radius).") - end - return nothing -end -nothing #hide -``` - -Similarly, we can verify, whether a tangent vector `X` is valid. It has to fulfill the same size requirements and it has to be orthogonal to `p`. We can again use the `kwargs`, but also provide a way to check `p`, too. - -```@example manifold-tutorial -function check_vector(M::MySphere, p, X; kwargs...) - size(X) != size(p) && return DomainError(size(X), "The size of $X is not $(size(p)).") - if !isapprox(dot(p,X), 0.0; kwargs...) - return DomainError(dot(p,X), "The tangent $X is not orthogonal to $p.") - end - return nothing -end -nothing #hide -``` - -to test points we can now use - -```@example manifold-tutorial -is_point(S, [1.0,0.0,0.0]) # norm 1, so not on S, returns false -@test_throws DomainError is_point(S, [1.5,0.0], true) # only on R^2, throws an error. -p = [1.5,0.0,0.0] -X = [0.0,1.0,0.0] -# The following two tests return true -[ is_point(S, p); is_vector(S,p,X) ] -``` - -## [Functions on the manifold](@id manifold-tutorial-fn) - -For the [`manifold_dimension`](@ref manifold_dimension(M::AbstractManifold)) we have to just return the `N` parameter - -```@example manifold-tutorial -manifold_dimension(::MySphere{N}) where {N} = N -manifold_dimension(S) -``` - -Note that we can even omit the variable name in the first line since we do not have to access any field or use the variable otherwise. - -To implement the exponential map, we have to implement the formula for great arcs, given a start point `p` and a direction `X` on the $n$-sphere of radius $r$ the formula reads - -````math -\exp_p X = \cos(\frac{1}{r}\lVert X \rVert)p + \sin(\frac{1}{r}\lVert X \rVert)\frac{r}{\lVert X \rVert}X. -```` - -Note that with this choice we for example implicitly assume a certain metric. This is completely fine. We only have to think about specifying a metric explicitly, when we have (at least) two different metrics on the same manifold. - -An implementation of the mutation version, see the [technical note](@ref manifold-tutorial-prel), reads - -```@example manifold-tutorial -function exp!(M::MySphere{N}, q, p, X) where {N} - nX = norm(X) - if nX == 0 - q .= p - else - q .= cos(nX/M.radius)*p + M.radius*sin(nX/M.radius) .* (X./nX) - end - return q -end -nothing #hide -``` - -A first easy check can be done taking `p` from above and any vector `X` of length `1.5Ο€` from its tangent space. The resulting point is opposite of `p`, i.e. `-p` - -```@example manifold-tutorial -q = exp(S,p, [0.0,1.5Ο€,0.0]) -[isapprox(p,-q); is_point(S,q)] -``` - -## [Conclusion](@id manifold-tutorial-outlook) - -You can now just continue implementing further functions from the [interface](../interface.md), -but with just [`exp!`](@ref exp!(M::AbstractManifold, q, p, X)) you for example already have - -* [`geodesic`](@ref geodesic(M::AbstractManifold, p, X)) the (not necessarily shortest) geodesic emanating from `p` in direction `X`. -* the [`ExponentialRetraction`](@ref), that the [`retract`](@ref retract(M::AbstractManifold, p, X)) function uses by default. - -For the [`shortest_geodesic`](@ref shortest_geodesic(M::AbstractManifold, p, q)) the implementation of a logarithm [`log`](@ref ManifoldsBase.log(M::AbstractManifold, p, q)), again better a [`log!`](@ref log!(M::AbstractManifold, X, p, q)) is necessary. - -Sometimes a default implementation is provided; for example if you implemented [`inner`](@ref inner(M::AbstractManifold, p, X, Y)), the [`norm`](@ref norm(M, p, X)) is defined. You should overwrite it, if you can provide a more efficient version. For a start the default should suffice. -With [`log!`](@ref log!(M::AbstractManifold, X, p, q)) and [`inner`](@ref inner(M::AbstractManifold, p, X, Y)) you get the [`distance`](@ref distance(M::AbstractManifold, p, q)), and so. - -In summary with just these few functions you can already explore the first things on your own manifold. Whenever a function from `Manifolds.jl` requires another function to be specifically implemented, you get a reasonable error message. - -## Literature - -```@raw html -
    -
  • - [doCarmo, 1992] - M. P. do Carmo, - Riemannian Geometry, - BirkhΓ€user Boston, 1992, - ISBN: 0-8176-3490-8. -
  • -
-``` diff --git a/docs/src/features/differentiation.md b/docs/src/features/differentiation.md index 393ccc3f06..e0163d7cf3 100644 --- a/docs/src/features/differentiation.md +++ b/docs/src/features/differentiation.md @@ -10,13 +10,7 @@ Pages = ["differentiation/differentiation.jl"] Order = [:type, :function, :constant] ``` -### ForwardDiff.jl - -```@autodocs -Modules = [Manifolds] -Pages = ["differentiation/forward_diff.jl"] -Order = [:type, :function, :constant] -``` +Further differentiation backends and features are available in [`ManifoldDiff.jl`](https://github.com/JuliaManifolds/ManifoldDiff.jl). ### FiniteDifferenes.jl diff --git a/docs/src/features/utilities.md b/docs/src/features/utilities.md index b6bba006a3..b6d4c7d9d9 100644 --- a/docs/src/features/utilities.md +++ b/docs/src/features/utilities.md @@ -1,6 +1,6 @@ # Ease of notation -The following terms introduce a nicer notation for some operations, for example using the ∈ operator, $p ∈ \mathcal M$, to determine whether $p$ is a point on the [`AbstractManifold`](@ref) $\mathcal M$. +The following terms introduce a nicer notation for some operations, for example using the ∈ operator, $p ∈ \mathcal M$, to determine whether $p$ is a point on the [`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) $\mathcal M$. ````@docs in diff --git a/docs/src/index.md b/docs/src/index.md index 50511b1288..122a4d1c8f 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -3,7 +3,7 @@ The package __Manifolds__ aims to provide a library of manifolds to be used within your project. The implemented manifolds are accompanied by their mathematical formulae. -The manifolds are implemented using the interface for manifolds given in [`ManifoldsBase.jl`](interface.md). +The manifolds are implemented using the interface for manifolds given in [`ManifoldsBase.jl`](https://juliamanifolds.github.io/ManifoldsBase.jl/). You can use that interface to implement your own software on manifolds, such that all manifolds based on that interface can be used within your code. @@ -17,7 +17,7 @@ To install the package just type ] add Manifolds ``` -Then you can directly start, for example to stop half way from the north pole on the [`Sphere`](@ref) to a point on the the equator, you can generate the [`shortest_geodesic`](@ref). +Then you can directly start, for example to stop half way from the north pole on the [`Sphere`](@ref) to a point on the the equator, you can generate the [`shortest_geodesic`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions.html#ManifoldsBase.shortest_geodesic-Tuple{AbstractManifold,%20Any,%20Any}). It internally employs [`log`](@ref log(::Sphere,::Any,::Any)) and [`exp`](@ref exp(::Sphere,::Any,::Any)). ```@example diff --git a/docs/src/interface.md b/docs/src/interface.md deleted file mode 100644 index e170520306..0000000000 --- a/docs/src/interface.md +++ /dev/null @@ -1,346 +0,0 @@ -# `ManifoldsBase.jl` – an interface for manifolds - -The interface for a manifold is provided in the lightweight package [ManifoldsBase.jl](https://github.com/JuliaManifolds/ManifoldsBase.jl). -You can easily implement your algorithms and even your own manifolds just using the interface. -All manifolds from the package here are also based on this interface, so any project based on the interface can benefit from all manifolds, as soon as a certain manifold provides implementations of the functions a project requires. - -```@contents -Pages = ["interface.md"] -Depth = 2 -``` - -Additionally the [`AbstractDecoratorManifold`](@ref) is provided as well as the [`ValidationManifold`](@ref) as a specific example of such a decorator. - -## [Types and functions](@id interface-types-and-functions) - -The following functions are currently available from the interface. -If a manifold that you implement for your own package fits this interface, we happily look forward to a [Pull Request](https://github.com/JuliaManifolds/Manifolds.jl/compare) to add it here. - -We would like to highlight a few of the types and functions in the next two sections before listing the remaining types and functions alphabetically. - -### The Manifold Type - -Besides the most central type, that of an [`AbstractManifold`](@ref) accompanied by [`AbstractManifoldPoint`](@ref) to represent points thereon, note that the point type is meant in a lazy fashion. -This is mean as follows: if you implement a new manifold and your points are represented by matrices, vectors or arrays, then it is best to not restrict types of the points `p` in functions, such that the methods work for example for other array representation types as well. -You should subtype your new points on a manifold, if the structure you use is more structured, see for example [`FixedRankMatrices`](@ref). -Another reason is, if you want to distinguish (and hence dispatch on) different representation of points on the manifold. -For an example, see the [Hyperbolic](@ref HyperbolicSpace) manifold, which has different models to be represented. - -```@autodocs -Modules = [Manifolds, ManifoldsBase] -Pages = ["maintypes.jl"] -Order = [:type, :function] -``` - -### The exponential and the logarithmic map, and geodesics - -Geodesics are the generalizations of a straight line to manifolds, i.e. their intrinsic acceleration is zero. -Together with geodesics one also obtains the exponential map and its inverse, the logarithmic map. -Informally speaking, the exponential map takes a vector (think of a direction and a length) at one point and returns another point, -which lies towards this direction at distance of the specified length. The logarithmic map does the inverse, i.e. given two points, it tells which vector β€œpoints towards” the other point. - -```@autodocs -Modules = [Manifolds, ManifoldsBase] -Pages = ["exp_log_geo.jl"] -Order = [:function] -``` - -### Retractions and inverse Retractions - -The exponential and logarithmic map might be too expensive to evaluate or not be available in a very stable numerical way. Retractions provide a possibly cheap, fast and stable alternative. - -The following figure compares the exponential map [`exp`](@ref)`(M, p, X)` on the [`Circle`](@ref)`(β„‚)` (or [`Sphere`](@ref)`(1)` embedded in $ℝ^2$ with one possible retraction, the one based on projections. Note especially that ``\mathrm{dist}(p,q)=\lVert X\rVert_p`` while this is not the case for ``q'``. - -![A comparson of the exponential map and a retraction on the Circle.](assets/images/retraction_illustration_600.png) - -```@autodocs -Modules = [Manifolds, ManifoldsBase] -Pages = ["retractions.jl"] -Order = [:function] -``` - -To distinguish different types of retractions, the last argument of the (inverse) retraction -specifies a type. The following ones are available. - -```@autodocs -Modules = [Manifolds, ManifoldsBase] -Pages = ["retractions.jl"] -Order = [:type] -``` - -### Projections - -A manifold might be embedded in some space. -Often this is implicitly assumed, for example the complex [`Circle`](@ref) is embedded in the complex plane. -Letβ€˜s keep the circle in mind in the following as a simple example. -For the general case see of explicitly stating an embedding and/or distinguising several, different embeddings, see [Embedded Manifolds](@ref EmbeddedmanifoldSec) below. - -To make this a little more concrete, letβ€˜s assume we have a manifold ``\mathcal M`` which is embedded in some manifold ``\mathcal N`` and the image ``i(\mathcal M)`` of the embedding function ``i`` is a closed set (with respect to the topology on ``\mathcal N``). Then we can do two kinds of projections. - -To make this concrete in an example for the Circle ``\mathcal M=\mathcal C := \{ p ∈ β„‚Β | |p| = 1\}`` -the embedding can be chosen to be the manifold ``N = β„‚`` and due to our representation of ``\mathcal C`` as complex numbers already, we have ``i(p) = p`` the identity as the embedding function. - -1. Given a point ``p∈\mathcal N`` we can look for the closest point on the manifold ``\mathcal M`` formally as - -```math - \operatorname*{arg\,min}_{q\in \mathcal M} d_{\mathcal N}(i(q),p) -``` - -And this resulting ``q`` we call the projection of ``p`` onto the manifold ``\mathcal M``. - -2. Given a point ``p∈\mathcal M`` and a vector in ``X\inT_{i(p)}\mathcal N`` in the embedding we can similarly look for the closest point to ``Y∈ T_p\mathcal M`` using the pushforward ``\mathrm{d}i_p`` of the embedding. - -```math - \operatorname*{arg\,min}_{Y\in T_p\mathcal M} \lVert \mathrm{d}i(p)[Y] - X \rVert_{i(p)} -``` - -And we call the resulting ``Y`` the projection of ``X`` onto the tangent space ``T_p\mathcal M`` at ``p``. - -Letβ€˜s look at the little more concrete example of the complex Circle again. -Here, the closest point of ``p ∈ β„‚`` is just the projection onto the circle, or in other words ``q = \frac{p}{\lvert p \rvert}``. A tangent space ``T_p\mathcal C`` in the embedding is the line orthogonal to a point ``p∈\mathcal C`` through the origin. -This can be better visualized by looking at ``p+T_p\mathcal C`` which is actually the line tangent to ``p``. Note that this shift does not change the resulting projection relative to the origin of the tangent space. - -Here the projection can be computed as the classical projection onto the line, i.e. ``Y = X - ⟨X,p⟩X``. - -this is illustrated in the following figure - -![An example illustrating the two kinds of projections on the Circle.](assets/images/projection_illustration_600.png) - -```@autodocs -Modules = [Manifolds, ManifoldsBase] -Pages = ["projections.jl"] -Order = [:function] -``` - -### Remaining functions - -```@autodocs -Modules = [Manifolds, ManifoldsBase] -Pages = ["ManifoldsBase.jl"] -Order = [:type, :function] -``` - -## Number systems - -```@autodocs -Modules = [Manifolds, ManifoldsBase] -Pages = ["numbers.jl"] -Order = [:type, :function] -``` - -## Allocation - -Non-mutating functions in `ManifoldsBase.jl` are typically implemented using mutating variants. -Allocation of new points is performed using a custom mechanism that relies on the following functions: - -* [`allocate`](@ref) that allocates a new point or vector similar to the given one. - This function behaves like `similar` for simple representations of points and vectors (for example `Array{Float64}`). - For more complex types, such as nested representations of [`PowerManifold`](@ref) (see [`NestedPowerRepresentation`](@ref)), [`FVector`](@ref) types, checked types like [`ValidationMPoint`](@ref) and more it operates differently. - While `similar` only concerns itself with the higher level of nested structures, `allocate` maps itself through all levels of nesting until a simple array of numbers is reached and then calls `similar`. - The difference can be most easily seen in the following example: - -```julia -julia> x = similar([[1.0], [2.0]]) -2-element Array{Array{Float64,1},1}: - #undef - #undef - -julia> y = Manifolds.allocate([[1.0], [2.0]]) -2-element Array{Array{Float64,1},1}: - [6.90031725726027e-310] - [6.9003678131654e-310] - -julia> x[1] -ERROR: UndefRefError: access to undefined reference -Stacktrace: - [1] getindex(::Array{Array{Float64,1},1}, ::Int64) at ./array.jl:744 - [2] top-level scope at REPL[12]:1 - -julia> y[1] -1-element Array{Float64,1}: - 6.90031725726027e-310 -``` - -* [`allocate_result`](@ref) allocates a result of a particular function (for example [`exp`](@ref), [`flat`](@ref), etc.) on a particular manifold with particular arguments. - It takes into account the possibility that different arguments may have different numeric [`number_eltype`](@ref) types thorough the [`ManifoldsBase.allocate_result_type`](@ref) function. - -## Bases - -The following functions and types provide support for bases of the tangent space of different manifolds. -Moreover, bases of the cotangent space are also supported, though this description focuses on the tangent space. -An orthonormal basis of the tangent space $T_p \mathcal M$ of (real) dimension $n$ has a real-coefficient basis $e_1, e_2, …, e_n$ if $\mathrm{Re}(g_p(e_i, e_j)) = Ξ΄_{ij}$ for each $i,j ∈ \{1, 2, …, n\}$ where $g_p$ is the Riemannian metric at point $p$. -A vector $X$ from the tangent space $T_p \mathcal M$ can be expressed in Einstein notation as a sum $X = X^i e_i$, where (real) coefficients $X^i$ are calculated as $X^i = \mathrm{Re}(g_p(X, e_i))$. - -Bases are closely related to [atlases](@ref atlases_and_charts). - -The main types are: - -* [`DefaultOrthonormalBasis`](@ref), which is designed to work when no special properties of the tangent space basis are required. - It is designed to make [`get_coordinates`](@ref) and [`get_vector`](@ref) fast. -* [`DiagonalizingOrthonormalBasis`](@ref), which diagonalizes the curvature tensor and makes the curvature in the selected direction equal to 0. -* [`ProjectedOrthonormalBasis`](@ref), which projects a basis of the ambient space and orthonormalizes projections to obtain a basis in a generic way. -* [`CachedBasis`](@ref), which stores (explicitly or implicitly) a precomputed basis at a certain point. - -The main functions are: - -* [`get_basis`](@ref) precomputes a basis at a certain point. -* [`get_coordinates`](@ref) returns coordinates of a tangent vector. -* [`get_vector`](@ref) returns a vector for the specified coordinates. -* [`get_vectors`](@ref) returns a vector of basis vectors. Calling it should be avoided for high-dimensional manifolds. - -Coordinates of a vector in a basis can be stored in an [`FVector`](@ref) to explicitly indicate which basis they are expressed in. -It is useful to avoid potential ambiguities. - -```@autodocs -Modules = [ManifoldsBase,Manifolds] -Pages = ["bases.jl"] -Order = [:type, :function] -``` - -```@autodocs -Modules = [ManifoldsBase,Manifolds] -Pages = ["vector_spaces.jl"] -Order = [:type, :function] -``` - -## Vector transport - -There are three main functions for vector transport: -* [`vector_transport_along`](@ref) -* [`vector_transport_direction`](@ref) -* [`vector_transport_to`](@ref) - -Different types of vector transport are implemented using subtypes of [`AbstractVectorTransportMethod`](@ref): -* [`ParallelTransport`](@ref) -* [`PoleLadderTransport`](@ref) -* [`ProjectionTransport`](@ref) -* [`SchildsLadderTransport`](@ref) - -```@autodocs -Modules = [ManifoldsBase,Manifolds] -Pages = ["vector_transport.jl"] -Order = [:type, :function] -``` - -## A Decorator for manifolds - -A decorator manifold extends the functionality of a [`AbstractManifold`](@ref) in a semi-transparent way. -It internally stores the [`AbstractManifold`](@ref) it extends and by default for functions defined in the [`ManifoldsBase`](interface.md) it acts transparently in the sense that it passes all functions through to the base except those that it actually affects. -For example, because the [`ValidationManifold`](@ref) affects nearly all functions, it overwrites nearly all functions, except a few like [`manifold_dimension`](@ref). -On the other hand, the [`MetricManifold`](@ref) only affects functions that involve metrics, especially [`exp`](@ref) and [`log`](@ref) but not the [`manifold_dimension`](@ref). -Contrary to the previous decorator, the [`MetricManifold`](@ref) does not overwrite functions. -The decorator sets functions like [`exp`](@ref) and [`log`](@ref) to be implemented anew but required to be implemented when specifying a new metric. -An exception is not issued if a metric is additionally set to be the default metric (see [`is_default_metric`](@ref), since this makes all functions act transparently. -this last case assumes that the newly specified metric type is actually the one already implemented on a manifold initially. - -By default, i.e. for a plain new decorator, all functions are transparent, i.e. passed down to the manifold the [`AbstractDecoratorManifold`](@ref) decorates. -To implement a method for a decorator that behaves differently from the method of the same function for the internal manifold, two steps are required. -Let's assume the function is called `f(M, arg1, arg2)`, and our decorator manifold `DM` of type `OurDecoratorManifold` decorates `M`. -Then - -1. set `decorator_transparent_dispatch(f, M::OurDecoratorManifold, args...) = Val(:intransparent)` -2. implement `f(DM::OurDecoratorManifold, arg1, arg2)` - -This makes it possible to extend a manifold or all manifolds with a feature or replace a feature of the original manifold. - -The [`MetricManifold`](@ref) is the best example of the second case, since the default metric indicates for which metric the manifold was originally implemented, such that those functions are just passed through. -This can best be seen in the [`SymmetricPositiveDefinite`](@ref) manifold with its [`LinearAffineMetric`](@ref). - -A final technical note – if several manifolds have similar transparency rules concerning functions from the interface, the last parameter `T` of the [`AbstractDecoratorManifold`](@ref)`{𝔽,T<:`[`AbstractDecoratorType`](@ref)`}` can be used to dispatch on different transparency schemes. - -```@autodocs -Modules = [Manifolds, ManifoldsBase] -Pages = ["DecoratorManifold.jl"] -Order = [:type, :macro, :function] -``` - -## Abstract Power Manifold - -```@autodocs -Modules = [Manifolds, ManifoldsBase] -Pages = ["src/PowerManifold.jl"] -Order = [:macro, :type, :function] -``` - -## ValidationManifold - -[`ValidationManifold`](@ref) is a simple decorator using the [`AbstractDecoratorManifold`](@ref) that β€œdecorates” a manifold with tests that all involved points and vectors are valid for the wrapped manifold. -For example involved input and output paratemers are checked before and after running a function, repectively. -This is done by calling [`is_point`](@ref) or [`is_vector`](@ref) whenever applicable. - -```@autodocs -Modules = [Manifolds, ManifoldsBase] -Pages = ["ValidationManifold.jl"] -Order = [:macro, :type, :function] -``` - -## [EmbeddedManifold](@id EmbeddedmanifoldSec) - -Some manifolds can easily be defined by using a certain embedding. -For example the [`Sphere`](@ref)`(n)` is embedded in [`Euclidean`](@ref)`(n+1)`. -Similar to the metric and [`MetricManifold`](@ref), an embedding is often implicitly assumed. -We introduce the embedded manifolds hence as an [`AbstractDecoratorManifold`](@ref). - -This decorator enables to use such an embedding in an transparent way. -Different types of embeddings can be distinguished using the [`AbstractEmbeddingType`](@ref), -which is an [`AbstractDecoratorType`](@ref). - -### Isometric Embeddings - -For isometric embeddings the type [`AbstractIsometricEmbeddingType`](@ref) can be used to avoid reimplementing the metric. -See [`Sphere`](@ref) or [`Hyperbolic`](@ref) for example. -Here, the exponential map, the logarithmic map, the retraction and its inverse -are set to `:intransparent`, i.e. they have to be implemented. - -Furthermore, the [`TransparentIsometricEmbedding`](@ref) type even states that the exponential -and logarithmic maps as well as retractions and vector transports of the embedding can be -used for the embedded manifold as well. -See [`SymmetricMatrices`](@ref) for an example. - -In both cases of course [`check_point`](@ref) and [`check_vector`](@ref) have to be implemented. - -### Further Embeddings - -A first embedding can also just be given implementing [`embed!`](@ref) ann [`project!`](@ref) -for a manifold. This is considered to be the most usual or default embedding. - -If you have two different embeddings for your manifold, a second one can be specified using -the [`EmbeddedManifold`](@ref), a type that β€œcouples” two manifolds, more precisely a -manifold and its embedding, to define embedding and projection functions between these -two manifolds. - -### Types - -```@autodocs -Modules = [ManifoldsBase, Manifolds] -Pages = ["EmbeddedManifold.jl"] -Order = [:type] -``` - -### Functions - -```@autodocs -Modules = [ManifoldsBase, Manifolds] -Pages = ["EmbeddedManifold.jl"] -Order = [:function] -``` - -## DefaultManifold - -[`DefaultManifold`](@ref ManifoldsBase.DefaultManifold) is a simplified version of [`Euclidean`](@ref) and demonstrates a basic interface implementation. -It can be used to perform simple tests. -Since when using `Manifolds.jl` the [`Euclidean`](@ref) is available, the `DefaultManifold` itself is not exported. - -```@docs -ManifoldsBase.DefaultManifold -``` - -## Error Messages -especially to collect and display errors on [`AbstractPowerManifold`](@ref)s the following -component and collection error messages are available. - -```@autodocs -Modules = [ManifoldsBase] -Pages = ["errors.jl"] -Order = [:type] -``` diff --git a/docs/src/manifolds/connection.md b/docs/src/manifolds/connection.md index 798e06395b..76fb5186e7 100644 --- a/docs/src/manifolds/connection.md +++ b/docs/src/manifolds/connection.md @@ -4,11 +4,11 @@ A connection manifold always consists of a [topological manifold](https://en.wik However, often there is an implicitly assumed (default) connection, like the [`LeviCivitaConnection`](@ref) connection on a Riemannian manifold. It is not necessary to use this decorator if you implement just one (or the first) connection. -If you later introduce a second, the old (first) connection can be used with the (non [`AbstractConnectionManifold`](@ref)) [`AbstractManifold`](@ref), i.e. without an explicitly stated connection. +If you later introduce a second, the old (first) connection can be used without an explicitly stated connection. This manifold decorator serves two purposes: -1. to implement different connections (e.g. in closed form) for one [`AbstractManifold`](@ref) +1. to implement different connections (e.g. in closed form) for one `AbstractManifold` 2. to provide a way to compute geodesics on manifolds, where this [`AbstractAffineConnection`](@ref) does not yield a closed formula. An example of usage can be found in Cartan-Schouten connections, see [`AbstractCartanSchoutenConnection`](@ref). @@ -36,4 +36,4 @@ Order = [:function] ## [Charts and bases of vector spaces](@id connections_charts) -All connection-related functions take a basis of a vector space as one of the arguments. This is needed because generally there is no way to define these functions without referencing a basis. In some cases there is no need to be explicit about this basis, and then for example a [`DefaultOrthonormalBasis`](@ref) object can be used. In cases where being explicit about these bases is needed, for example when using multiple charts, a basis can be specified, for example using [`induced_basis`](@ref Main.Manifolds.induced_basis). +All connection-related functions take a basis of a vector space as one of the arguments. This is needed because generally there is no way to define these functions without referencing a basis. In some cases there is no need to be explicit about this basis, and then for example a [`DefaultOrthonormalBasis`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/bases.html#ManifoldsBase.DefaultOrthonormalBasis) object can be used. In cases where being explicit about these bases is needed, for example when using multiple charts, a basis can be specified, for example using [`induced_basis`](@ref Main.Manifolds.induced_basis). diff --git a/docs/src/manifolds/essentialmanifold.md b/docs/src/manifolds/essentialmanifold.md index 94abf44f6f..7a0ef023a2 100644 --- a/docs/src/manifolds/essentialmanifold.md +++ b/docs/src/manifolds/essentialmanifold.md @@ -1,5 +1,5 @@ # Essential Manifold -The essential manifold is modeled as an [`AbstractPowerManifold`](@ref) of the $3\times3$ [`Rotations`](@ref) and uses [`NestedPowerRepresentation`](@ref). +The essential manifold is modeled as an [`AbstractPowerManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.AbstractPowerManifold) of the $3\times3$ [`Rotations`](@ref) and uses [`NestedPowerRepresentation`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.NestedPowerRepresentation). ```@autodocs Modules = [Manifolds] diff --git a/docs/src/manifolds/euclidean.md b/docs/src/manifolds/euclidean.md index ccfb38f0f7..1833b0eea3 100644 --- a/docs/src/manifolds/euclidean.md +++ b/docs/src/manifolds/euclidean.md @@ -1,7 +1,7 @@ # Euclidean space The Euclidean space $ℝ^n$ is a simple model space, since it has curvature constantly zero everywhere; hence, nearly all operations simplify. -The easiest way to generate an Euclidean space is to use a field, i.e. [`AbstractNumbers`](@ref), e.g. to create the $ℝ^n$ or $ℝ^{n\times n}$ you can simply type `M = ℝ^n` or `ℝ^(n,n)`, respectively. +The easiest way to generate an Euclidean space is to use a field, i.e. [`AbstractNumbers`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#number-system), e.g. to create the $ℝ^n$ or $ℝ^{n\times n}$ you can simply type `M = ℝ^n` or `ℝ^(n,n)`, respectively. ```@autodocs Modules = [Manifolds] diff --git a/docs/src/manifolds/graph.md b/docs/src/manifolds/graph.md index acda4169f0..f744fdfb27 100644 --- a/docs/src/manifolds/graph.md +++ b/docs/src/manifolds/graph.md @@ -1,6 +1,6 @@ # Graph manifold -For a given graph $G(V,E)$ implemented using [`Graphs.jl`](https://juliagraphs.github.io/Graphs.jl/latest/), the [`GraphManifold`](@ref) models a [`PowerManifold`](@ref) either on the nodes or edges of the graph, depending on the [`GraphManifoldType`](@ref). +For a given graph $G(V,E)$ implemented using [`Graphs.jl`](https://juliagraphs.github.io/Graphs.jl/latest/), the [`GraphManifold`](@ref) models a [`PowerManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.PowerManifold) either on the nodes or edges of the graph, depending on the [`GraphManifoldType`](@ref). i.e., it's either a $\mathcal M^{\lvert V \rvert}$ for the case of a vertex manifold or a $\mathcal M^{\lvert E \rvert}$ for the case of a edge manifold. ## Example @@ -20,7 +20,7 @@ add_edge!(G, 2, 3) N = GraphManifold(G, M, VertexManifold()) ``` -It supports all [`AbstractPowerManifold`](@ref) operations (it is based on [`NestedPowerRepresentation`](@ref)) and furthermore it is possible to compute a graph logarithm: +It supports all [`AbstractPowerManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.AbstractPowerManifold) operations (it is based on [`NestedPowerRepresentation`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.NestedPowerRepresentation)) and furthermore it is possible to compute a graph logarithm: ```@setup graph-1 using Manifolds diff --git a/docs/src/manifolds/group.md b/docs/src/manifolds/group.md index 40a7c2c5df..7fed20a0e2 100644 --- a/docs/src/manifolds/group.md +++ b/docs/src/manifolds/group.md @@ -1,6 +1,6 @@ # Group manifolds and actions -Lie groups, groups that are [`AbstractManifold`](@ref)s with a smooth binary group operation [`AbstractGroupOperation`](@ref), are implemented as subtypes of [`AbstractGroupManifold`](@ref) or by decorating an existing manifold with a group operation using [`GroupManifold`](@ref). +Lie groups, groups that are Riemannian manifolds with a smooth binary group operation [`AbstractGroupOperation`](@ref), are implemented as [`AbstractDecoratorManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/decorator.html#ManifoldsBase.AbstractDecoratorManifold) and specifying the group operation using the [`IsGroupManifold`](@ref) or by decorating an existing manifold with a group operation using [`GroupManifold`](@ref). The common addition and multiplication group operations of [`AdditionOperation`](@ref) and [`MultiplicationOperation`](@ref) are provided, though their behavior may be customized for a specific group. @@ -33,6 +33,37 @@ Pages = ["groups/group.jl"] Order = [:type, :function] ``` +### GroupManifold + +As a concrete wrapper for manifolds (e.g. when the manifold per se is a group manifold but another group structure should be implemented), there is the [`GroupManifold`](@ref) + +```@autodocs +Modules = [Manifolds] +Pages = ["groups/GroupManifold.jl"] +Order = [:type, :function] +``` + +### Generic Operations + +For groups based on an addition operation or a group operation, several default implementations are provided. + +#### Addition Operation + +```@autodocs +Modules = [Manifolds] +Pages = ["groups/addition_operation.jl"] +Order = [:type, :function] +``` + +#### Multiplication Operation + +```@autodocs +Modules = [Manifolds] +Pages = ["groups/multiplication_operation.jl"] +Order = [:type, :function] +``` + + ### Product group ```@autodocs diff --git a/docs/src/manifolds/metric.md b/docs/src/manifolds/metric.md index dd34512be0..3ee47b5d15 100644 --- a/docs/src/manifolds/metric.md +++ b/docs/src/manifolds/metric.md @@ -5,11 +5,11 @@ A Riemannian manifold always consists of a [topological manifold](https://en.wik However, often there is an implicitly assumed (default) metric, like the usual inner product on [`Euclidean`](@ref) space. This decorator takes this into account. It is not necessary to use this decorator if you implement just one (or the first) metric. -If you later introduce a second, the old (first) metric can be used with the (non [`MetricManifold`](@ref)) [`AbstractManifold`](@ref), i.e. without an explicitly stated metric. +If you later introduce a second, the old (first) metric can be used with the (non [`MetricManifold`](@ref)) `AbstractManifold`, i.e. without an explicitly stated metric. This manifold decorator serves two purposes: -1. to implement different metrics (e.g. in closed form) for one [`AbstractManifold`](@ref) +1. to implement different metrics (e.g. in closed form) for one `AbstractManifold` 2. to provide a way to compute geodesics on manifolds, where this [`AbstractMetric`](@ref) does not yield closed formula. ```@contents @@ -17,7 +17,7 @@ Pages = ["metric.md"] Depth = 2 ``` -Note that a metric manifold is an [`AbstractConnectionManifold`](@ref) with the [`LeviCivitaConnection`](@ref) of the metric $g$, and thus a large part of metric manifold's functionality relies on this. +Note that a metric manifold is has a [`IsConnectionManifold`](@ref) trait referring to the [`LeviCivitaConnection`](@ref) of the metric $g$, and thus a large part of metric manifold's functionality relies on this. Let's first look at the provided types. diff --git a/docs/src/manifolds/multinomial.md b/docs/src/manifolds/multinomial.md index 3618944109..aa7f9e1d9c 100644 --- a/docs/src/manifolds/multinomial.md +++ b/docs/src/manifolds/multinomial.md @@ -8,7 +8,7 @@ Order = [:type] ## Functions -Most functions are directly implemented for an [`AbstractPowerManifold`](@ref) with [`ArrayPowerRepresentation`](@ref) except the following special cases: +Most functions are directly implemented for an [`AbstractPowerManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.AbstractPowerManifold) with [`ArrayPowerRepresentation`](@ref) except the following special cases: ```@autodocs Modules = [Manifolds] diff --git a/docs/src/manifolds/oblique.md b/docs/src/manifolds/oblique.md index 90b188889c..36f36233a4 100644 --- a/docs/src/manifolds/oblique.md +++ b/docs/src/manifolds/oblique.md @@ -1,6 +1,6 @@ # Oblique manifold -The oblique manifold $\mathcal{OB}(n,m)$ is modeled as an [`AbstractPowerManifold`](@ref) of the (real-valued) [`Sphere`](@ref) and uses [`ArrayPowerRepresentation`](@ref). +The oblique manifold $\mathcal{OB}(n,m)$ is modeled as an [`AbstractPowerManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.AbstractPowerManifold) of the (real-valued) [`Sphere`](@ref) and uses [`ArrayPowerRepresentation`](@ref). Points on the torus are hence matrices, $x ∈ ℝ^{n,m}$. ```@autodocs @@ -11,7 +11,7 @@ Order = [:type] ## Functions -Most functions are directly implemented for an [`AbstractPowerManifold`](@ref) with [`ArrayPowerRepresentation`](@ref) except the following special cases: +Most functions are directly implemented for an [`AbstractPowerManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.AbstractPowerManifold) with [`ArrayPowerRepresentation`](@ref) except the following special cases: ```@autodocs Modules = [Manifolds] diff --git a/docs/src/manifolds/positivenumbers.md b/docs/src/manifolds/positivenumbers.md index 01754003da..eecf0f23f5 100644 --- a/docs/src/manifolds/positivenumbers.md +++ b/docs/src/manifolds/positivenumbers.md @@ -1,6 +1,6 @@ # Positive Numbers -The manifold [`PositiveNumbers`](@ref) represents positive numbers with hyperbolic geometry. Additionally, there are also short forms for its corresponding [`PowerManifold`](@ref)s, i.e. [`PositiveVectors`](@ref), [`PositiveMatrices`](@ref), and [`PositiveArrays`](@ref). +The manifold [`PositiveNumbers`](@ref) represents positive numbers with hyperbolic geometry. Additionally, there are also short forms for its corresponding `PowerManifold`s, i.e. [`PositiveVectors`](@ref), [`PositiveMatrices`](@ref), and [`PositiveArrays`](@ref). ```@autodocs Modules = [Manifolds] diff --git a/docs/src/manifolds/power.md b/docs/src/manifolds/power.md index 6d3ce06f37..3cbb83869f 100644 --- a/docs/src/manifolds/power.md +++ b/docs/src/manifolds/power.md @@ -1,14 +1,14 @@ # Power manifold -A power manifold is based on a [`AbstractManifold`](@ref) $\mathcal M$ to build a $\mathcal M^{n_1 \times n_2 \times \cdots \times n_m}$. +A power manifold is based on a [`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) $\mathcal M$ to build a $\mathcal M^{n_1 \times n_2 \times \cdots \times n_m}$. In the case where $m=1$ we can represent a manifold-valued vector of data of length $n_1$, for example a time series. The case where $m=2$ is useful for representing manifold-valued matrices of data of size $n_1 \times n_2$, for example certain types of images. There are three available representations for points and vectors on a power manifold: * [`ArrayPowerRepresentation`](@ref) (the default one), very efficient but only applicable when points on the underlying manifold are represented using plain `AbstractArray`s. -* [`NestedPowerRepresentation`](@ref), applicable to any manifold. It assumes that points on the underlying manifold are represented using mutable data types. -* [`NestedReplacingPowerRepresentation`](@ref), applicable to any manifold. It does not mutate points on the underlying manifold, replacing them instead when appropriate. +* [`NestedPowerRepresentation`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.NestedPowerRepresentation), applicable to any manifold. It assumes that points on the underlying manifold are represented using mutable data types. +* [`NestedReplacingPowerRepresentation`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.NestedReplacingPowerRepresentation), applicable to any manifold. It does not mutate points on the underlying manifold, replacing them instead when appropriate. Below are some examples of usage of these representations. @@ -46,7 +46,7 @@ using HybridArrays, StaticArrays q = HybridArray{Tuple{3,StaticArrays.Dynamic()},Float64,2}(p) ``` -which is still a valid point on `M` and [`PowerManifold`](@ref) works with these, too. +which is still a valid point on `M` and [`PowerManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.PowerManifold) works with these, too. An advantage of this representation is that it is quite efficient, especially when a `HybridArray` (from the [HybridArrays.jl](https://github.com/mateuszbaran/HybridArrays.jl) package) is used to represent a point on the power manifold. A disadvantage is not being able to easily identify parts of the multidimensional array that correspond to a single point on the base manifold. @@ -54,7 +54,7 @@ Another problem is, that accessing a single point is ` p[:, 1]` which might be u ### `NestedPowerRepresentation` -For the [`NestedPowerRepresentation`](@ref) we can now do +For the [`NestedPowerRepresentation`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.NestedPowerRepresentation) we can now do ```@example 2 using Manifolds @@ -66,7 +66,7 @@ p = [ [1.0, 0.0, 0.0], ] ``` -which is again a valid point so [`is_point`](@ref)`(M, p)` here also yields true. +which is again a valid point so `is_point(M, p)` here also yields true. A disadvantage might be that with nested arrays one loses a little bit of performance. The data however is nicely encapsulated. Accessing the first data item is just `p[1]`. @@ -87,7 +87,7 @@ get_component(M, p, 4) ### `NestedReplacingPowerRepresentation` -The final representation is the [`NestedReplacingPowerRepresentation`](@ref). It is similar to the [`NestedPowerRepresentation`](@ref) but it does not perform mutating operations on the points on the underlying manifold. The example below uses this representation to store points on a power manifold of the [`SpecialEuclidean`](@ref) group in-line in an `Vector` for improved efficiency. When having a mixture of both, i.e. an array structure that is nested (like [Β΄NestedPowerRepresentation](@ref)) in the sense that the elements of the main vector are immutable, then changing the elements can not be done in a mutating way and hence [`NestedReplacingPowerRepresentation`](@ref) has to be used. +The final representation is the [`NestedReplacingPowerRepresentation`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.NestedReplacingPowerRepresentation). It is similar to the [`NestedPowerRepresentation`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.NestedPowerRepresentation) but it does not perform mutating operations on the points on the underlying manifold. The example below uses this representation to store points on a power manifold of the [`SpecialEuclidean`](@ref) group in-line in an `Vector` for improved efficiency. When having a mixture of both, i.e. an array structure that is nested (like [Β΄NestedPowerRepresentation](@ref)) in the sense that the elements of the main vector are immutable, then changing the elements can not be done in a mutating way and hence [`NestedReplacingPowerRepresentation`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.NestedReplacingPowerRepresentation) has to be used. ```@example 4 using Manifolds, StaticArrays diff --git a/docs/src/manifolds/torus.md b/docs/src/manifolds/torus.md index d1011c9282..447092618f 100644 --- a/docs/src/manifolds/torus.md +++ b/docs/src/manifolds/torus.md @@ -1,6 +1,6 @@ # Torus -The torus $𝕋^d β‰… [-Ο€,Ο€)^d$ is modeled as an [`AbstractPowerManifold`](@ref) of the (real-valued) [`Circle`](@ref) and uses [`ArrayPowerRepresentation`](@ref). +The torus $𝕋^d β‰… [-Ο€,Ο€)^d$ is modeled as an [`AbstractPowerManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.AbstractPowerManifold) of the (real-valued) [`Circle`](@ref) and uses [`ArrayPowerRepresentation`](@ref). Points on the torus are hence row vectors, $x ∈ ℝ^{d}$. ## Example @@ -17,7 +17,7 @@ X = log(M, p, q) ## Types and functions -Most functions are directly implemented for an [`AbstractPowerManifold`](@ref) with [`ArrayPowerRepresentation`](@ref) except the following special cases: +Most functions are directly implemented for an [`AbstractPowerManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.AbstractPowerManifold) with [`ArrayPowerRepresentation`](@ref) except the following special cases: ```@autodocs Modules = [Manifolds] diff --git a/docs/src/manifolds/vector_bundle.md b/docs/src/manifolds/vector_bundle.md index 608501cc5b..98d1a9e976 100644 --- a/docs/src/manifolds/vector_bundle.md +++ b/docs/src/manifolds/vector_bundle.md @@ -18,7 +18,7 @@ This is also considered a manifold. ## FVector -For cases where confusion between different types of vectors is possible, the type [`FVector`](@ref) can be used to express which type of vector space the vector belongs to. +For cases where confusion between different types of vectors is possible, the type [`FVector`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.FVector) can be used to express which type of vector space the vector belongs to. It is used for example in musical isomorphisms (the [`flat`](@ref) and [`sharp`](@ref) functions) that are used to go from a tangent space to cotangent space and vice versa. ## Example diff --git a/docs/src/misc/internals.md b/docs/src/misc/internals.md index 1fbeef87a1..3b0673f9de 100644 --- a/docs/src/misc/internals.md +++ b/docs/src/misc/internals.md @@ -13,7 +13,6 @@ Manifolds.mul!_safe Manifolds.nzsign Manifolds.realify Manifolds.realify! -Manifolds.size_to_tuple Manifolds.select_from_tuple Manifolds.unrealify! Manifolds.usinc diff --git a/docs/src/misc/notation.md b/docs/src/misc/notation.md index 88e1a810d0..a38c97811a 100644 --- a/docs/src/misc/notation.md +++ b/docs/src/misc/notation.md @@ -12,7 +12,7 @@ Within the documented functions, the utf8 symbols are used whenever possible, as | ``\tau_p`` | action map by group element ``p`` | ``\mathrm{L}_p``, ``\mathrm{R}_p`` | either left or right | | ``\operatorname{Ad}_p(X)`` | adjoint action of element ``p`` of a Lie group on the element ``X`` of the corresponding Lie algebra | | | | ``\times`` | Cartesian product of two manifolds |Β | see [`ProductManifold`](@ref) | -| ``^{\wedge}`` | (n-ary) Cartesian power of a manifold |Β | see [`PowerManifold`](@ref) | +| ``^{\wedge}`` | (n-ary) Cartesian power of a manifold |Β | see [`PowerManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.PowerManifold) | | ``a`` | coordinates of a point in a chart | | see [`get_parameters`](@ref) | | ``\frac{\mathrm{D}}{\mathrm{d}t}`` | covariant derivative of a vector field ``X(t)`` | | | | ``T^*_p \mathcal M`` | the cotangent space at ``p`` | | | diff --git a/src/Manifolds.jl b/src/Manifolds.jl index dc739392c3..c31558b4c2 100644 --- a/src/Manifolds.jl +++ b/src/Manifolds.jl @@ -1,9 +1,17 @@ module Manifolds import ManifoldsBase: + @trait_function, _access_nested, + _get_basis, + _injectivity_radius, + _inverse_retract, + _inverse_retract!, _read, + _retract, + _retract!, _write, + active_traits, allocate, allocate_coordinates, allocate_result, @@ -12,58 +20,119 @@ import ManifoldsBase: array_value, base_manifold, check_point, - check_point__transparent, + check_size, check_vector, copy, copyto!, + default_inverse_retraction_method, + default_retraction_method, + default_vector_transport_method, decorated_manifold, - decorator_transparent_dispatch, - default_decorator_dispatch, distance, dual_basis, embed, embed!, exp, exp!, - exp!__intransparent, get_basis, + get_basis_default, + get_basis_diagonalizing, + get_basis_orthogonal, + get_basis_orthonormal, + get_basis_vee, get_component, get_coordinates, get_coordinates!, + get_coordinates_diagonalizing, + get_coordinates_diagonalizing!, + get_coordinates_orthogonal, + get_coordinates_orthonormal, + get_coordinates_orthogonal!, + get_coordinates_orthonormal!, + get_coordinates_vee!, get_embedding, get_iterator, get_vector, get_vector!, + get_vector_diagonalizing, + get_vector_diagonalizing!, + get_vector_orthogonal, + get_vector_orthonormal, + get_vector_orthogonal!, + get_vector_orthonormal!, get_vectors, gram_schmidt, hat, hat!, injectivity_radius, + _injectivity_radius, + injectivity_radius_exp, inner, - inner__intransparent, + isapprox, is_point, is_vector, inverse_retract, inverse_retract!, + _inverse_retract, + _inverse_retract!, + inverse_retract_caley!, + inverse_retract_embedded!, + inverse_retract_nlsolve!, + inverse_retract_pade!, + inverse_retract_polar!, + inverse_retract_project!, + inverse_retract_qr!, + inverse_retract_softmax!, log, log!, manifold_dimension, mid_point, mid_point!, + norm, number_eltype, number_of_coordinates, + parallel_transport_along, + parallel_transport_along!, + parallel_transport_direction, + parallel_transport_direction!, + parallel_transport_to, + parallel_transport_to!, + parent_trait, power_dimensions, project, project!, representation_size, retract, retract!, + retract_caley!, + retract_exp_ode!, + retract_pade!, + retract_polar!, + retract_project!, + retract_qr!, + retract_softmax!, set_component!, vector_space_dimension, + vector_transport_along, # just specified in Euclidean - the next 5 as well + vector_transport_along_diff, + vector_transport_along_project, + vector_transport_along!, + vector_transport_along_diff!, + vector_transport_along_project!, vector_transport_direction, + vector_transport_direction_diff, vector_transport_direction!, + vector_transport_direction_diff!, vector_transport_to, + vector_transport_to_diff, + vector_transport_to_project, vector_transport_to!, + vector_transport_to_diff!, + vector_transport_to_project!, # some overwrite layer 2 + _vector_transport_direction, + _vector_transport_direction!, + _vector_transport_to, + _vector_transport_to!, vee, vee!, zero_vector, @@ -77,7 +146,6 @@ import Base: identity, in, inv, - isapprox, isempty, length, ndims, @@ -98,10 +166,7 @@ using ManifoldsBase: ℍ, AbstractBasis, AbstractDecoratorManifold, - AbstractDecoratorType, - AbstractEmbeddedManifold, AbstractInverseRetractionMethod, - AbstractIsometricEmbeddingType, AbstractManifold, AbstractManifoldPoint, AbstractNumbers, @@ -110,19 +175,20 @@ using ManifoldsBase: AbstractPowerManifold, AbstractPowerRepresentation, AbstractRetractionMethod, + AbstractTrait, AbstractVectorTransportMethod, AbstractLinearVectorTransportMethod, ApproximateInverseRetraction, ApproximateRetraction, CachedBasis, + CayleyRetraction, + CayleyInverseRetraction, + ComplexNumbers, ComponentManifoldError, CompositeManifoldError, CotangentSpaceType, CoTFVector, DefaultBasis, - DefaultEmbeddingType, - DefaultIsometricEmbeddingType, - DefaultManifold, DefaultOrthogonalBasis, DefaultOrthonormalBasis, DefaultOrDiagonalizingBasis, @@ -130,15 +196,23 @@ using ManifoldsBase: DiagonalizingOrthonormalBasis, DifferentiatedRetractionVectorTransport, EmbeddedManifold, + EmptyTrait, ExponentialRetraction, FVector, - InversePowerRetraction, + IsIsometricEmbeddedManifold, + IsEmbeddedManifold, + IsEmbeddedSubmanifold, + IsExplicitDecorator, LogarithmicInverseRetraction, ManifoldsBase, NestedPowerRepresentation, NestedReplacingPowerRepresentation, - NLsolveInverseRetraction, + TraitList, + NLSolveInverseRetraction, + ODEExponentialRetraction, OutOfInjectivityRadiusError, + PadeRetraction, + PadeInverseRetraction, ParallelTransport, PolarInverseRetraction, PolarRetraction, @@ -146,45 +220,40 @@ using ManifoldsBase: PowerManifold, PowerManifoldNested, PowerManifoldNestedReplacing, - PowerRetraction, - PowerVectorTransport, ProjectedOrthonormalBasis, ProjectionInverseRetraction, ProjectionRetraction, ProjectionTransport, + QuaternionNumbers, QRInverseRetraction, QRRetraction, + RealNumbers, ScaledVectorTransport, SchildsLadderTransport, + SoftmaxRetraction, + SoftmaxInverseRetraction, TangentSpaceType, TCoTSpaceType, TFVector, - TransparentIsometricEmbedding, TVector, ValidationManifold, ValidationMPoint, ValidationTVector, VectorSpaceType, VeeOrthogonalBasis, - @decorator_transparent_fallback, - @decorator_transparent_function, - @decorator_transparent_signature, @invoke_maker, _euclidean_basis_vector, - _extract_val, combine_allocation_promotion_functions, default_inverse_retraction_method, geodesic, - is_decorator_transparent, - is_default_decorator, - manifold_function_not_implemented_message, + merge_traits, + next_trait, number_system, real_dimension, rep_size_to_colons, shortest_geodesic, size_to_tuple, - vector_transport_along, - vector_transport_along! + trait using Markdown: @doc_str using Random using RecipesBase @@ -285,8 +354,11 @@ include("manifolds/EssentialManifold.jl") # # Group Manifolds +include("groups/GroupManifold.jl") # a) generics +include("groups/addition_operation.jl") +include("groups/multiplication_operation.jl") include("groups/connections.jl") include("groups/metric.jl") include("groups/group_action.jl") @@ -311,8 +383,8 @@ include("groups/special_euclidean.jl") Base.in(p, M::AbstractManifold; kwargs...) p ∈ M -Check, whether a point `p` is a valid point (i.e. in) a [`AbstractManifold`](@ref) `M`. -This method employs [`is_point`](@ref) deactivating the error throwing option. +Check, whether a point `p` is a valid point (i.e. in) a [`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) `M`. +This method employs [`is_point`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions.html#ManifoldsBase.is_point) deactivating the error throwing option. """ Base.in(p, M::AbstractManifold; kwargs...) = is_point(M, p, false; kwargs...) @@ -321,29 +393,19 @@ Base.in(p, M::AbstractManifold; kwargs...) = is_point(M, p, false; kwargs...) X ∈ TangentSpaceAtPoint(M,p) Check whether `X` is a tangent vector from (in) the tangent space $T_p\mathcal M$, i.e. -the [`TangentSpaceAtPoint`](@ref) at `p` on the [`AbstractManifold`](@ref) `M`. -This method uses [`is_vector`](@ref) deactivating the error throw option. +the [`TangentSpaceAtPoint`](@ref) at `p` on the [`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) `M`. +This method uses [`is_vector`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions.html#ManifoldsBase.is_vector) deactivating the error throw option. """ function Base.in(X, TpM::TangentSpaceAtPoint; kwargs...) return is_vector(base_manifold(TpM), TpM.point, X, false; kwargs...) end function __init__() - @require FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" begin - using .FiniteDiff - include("differentiation/finite_diff.jl") - end - @require FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" begin using .FiniteDifferences include("differentiation/finite_differences.jl") end - @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin - using .ForwardDiff - include("differentiation/forward_diff.jl") - end - @require OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" begin using .OrdinaryDiffEq: ODEProblem, AutoVern9, Rodas5, solve include("differentiation/ode.jl") @@ -354,26 +416,12 @@ function __init__() include("nlsolve.jl") end - @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin - using .ReverseDiff: ReverseDiff - include("differentiation/reverse_diff.jl") - end - @require Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" begin using .Test: Test include("tests/tests_general.jl") export test_manifold include("tests/tests_group.jl") export test_group, test_action - @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin - include("tests/tests_forwarddiff.jl") - export test_forwarddiff - end - - @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin - include("tests/tests_reversediff.jl") - export test_reversediff - end end @require Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" begin @@ -390,11 +438,6 @@ function __init__() end end - @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin - using .Zygote: Zygote - include("differentiation/zygote.jl") - end - return nothing end @@ -446,8 +489,9 @@ export HyperboloidTVector, export AbstractNumbers, ℝ, β„‚, ℍ # decorator manifolds -export AbstractDecoratorManifold, MetricDecoratorType -export AbstractGroupDecoratorType, DefaultGroupDecoratorType, TransparentGroupDecoratorType +export AbstractDecoratorManifold +export IsIsometricEmbeddedManifold, IsEmbeddedManifold, IsEmbeddedSubmanifold +export IsDefaultMetric, IsDefaultConnection, IsMetricManifold, IsConnectionManifold export ValidationManifold, ValidationMPoint, ValidationTVector, ValidationCoTVector export CotangentBundle, CotangentSpaceAtPoint, CotangentBundleFibers, CotangentSpace, FVector @@ -465,8 +509,7 @@ export VectorBundleFibers export AbstractVectorTransportMethod, DifferentiatedRetractionVectorTransport, ParallelTransport, ProjectedPointDistribution export PoleLadderTransport, SchildsLadderTransport -export PowerVectorTransport, ProductVectorTransport -export AbstractEmbeddedManifold +export ProductVectorTransport export AbstractAffineConnection, AbstractConnectionManifold, ConnectionManifold, LeviCivitaConnection export AbstractCartanSchoutenConnection, @@ -487,8 +530,6 @@ export AbstractMetric, CanonicalMetric, MetricManifold export AbstractAtlas, RetractionAtlas -export AbstractEmbeddingType, AbstractIsometricEmbeddingType -export DefaultEmbeddingType, DefaultIsometricEmbeddingType, TransparentIsometricEmbedding export AbstractVectorTransportMethod, ParallelTransport, ProjectionTransport export AbstractRetractionMethod, CayleyRetraction, @@ -570,14 +611,12 @@ export Γ—, inverse_retract, inverse_retract!, isapprox, - is_decorator_transparent, - is_default_decorator, + is_default_connection, is_default_metric, - is_group_decorator, + is_group_manifold, is_identity, is_point, is_vector, - isapprox, kurtosis, local_metric, local_metric_jacobian, @@ -639,7 +678,6 @@ export Γ—, # Lie group types & functions export AbstractGroupAction, AbstractGroupOperation, - AbstractGroupManifold, ActionDirection, AdditionOperation, CircleGroup, @@ -663,6 +701,10 @@ export AbstractGroupAction, SpecialOrthogonal, TranslationGroup, TranslationAction +export AbstractInvarianceTrait +export IsMetricManifold, IsConnectionManifold +export IsGroupManifold, + HasLeftInvariantMetric, HasRightInvariantMetric, HasBiinvariantMetric export adjoint_action, adjoint_action!, affine_matrix, @@ -678,19 +720,26 @@ export adjoint_action, direction, exp_lie, exp_lie!, - g_manifold, + group_manifold, geodesic, get_coordinates_lie, get_coordinates_lie!, + get_coordinates_orthogonal, + get_coordinates_orthonormal, + get_coordinates_orthogonal!, + get_coordinates_orthonormal!, + get_vector_diagonalizing!, get_vector_lie, get_vector_lie!, + get_vector_orthogonal, + get_vector_orthonormal, + get_coordinates_vee!, has_biinvariant_metric, has_invariant_metric, identity_element, identity_element!, inv, inv!, - invariant_metric_dispatch, inverse_apply, inverse_apply!, inverse_apply_diff, @@ -724,14 +773,7 @@ export AbstractBasis, export OutOfInjectivityRadiusError export get_basis, get_coordinates, get_coordinates!, get_vector, get_vector!, get_vectors, number_system -# differentiation -export AbstractDiffBackend, - AbstractRiemannianDiffBackend, - ExplicitEmbeddedBackend, - FiniteDifferencesBackend, - TangentDiffBackend, - RiemannianProjectionBackend -export default_differential_backend, set_default_differential_backend! + # atlases and charts export get_point, get_point!, get_parameters, get_parameters! diff --git a/src/atlases.jl b/src/atlases.jl index 9255905aad..a524bb6356 100644 --- a/src/atlases.jl +++ b/src/atlases.jl @@ -3,7 +3,7 @@ AbstractAtlas{𝔽} An abstract class for atlases whith charts that have values in the vector space `𝔽ⁿ` -for some value of `n`. `𝔽` is a number system determined by an [`AbstractNumbers`](@ref) +for some value of `n`. `𝔽` is a number system determined by an [`AbstractNumbers`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#number-system) object. """ abstract type AbstractAtlas{𝔽} end @@ -28,8 +28,8 @@ In short: The coordinates with respect to a basis are used together with a retra # See also -[`AbstractAtlas`](@ref), [`AbstractInverseRetractionMethod`](@ref), -[`AbstractRetractionMethod`](@ref), [`AbstractBasis`](@ref) +[`AbstractAtlas`](@ref), [`AbstractInverseRetractionMethod`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.AbstractInverseRetractionMethod), +[`AbstractRetractionMethod`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.AbstractRetractionMethod), [`AbstractBasis`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/bases.html#ManifoldsBase.AbstractBasis) """ struct RetractionAtlas{ 𝔽, @@ -90,6 +90,12 @@ function allocate_result(M::AbstractManifold, f::typeof(get_parameters), p) T = allocate_result_type(M, f, (p,)) return allocate(p, T, manifold_dimension(M)) end +# disambiguation +@invoke_maker 1 AbstractManifold allocate_result( + M::AbstractDecoratorManifold, + f::typeof(get_parameters), + p, +) function get_parameters!(M::AbstractManifold, a, A::RetractionAtlas, i, p) return get_coordinates!(M, a, i, inverse_retract(M, i, p, A.invretr), A.basis) @@ -122,6 +128,12 @@ function allocate_result(M::AbstractManifold, f::typeof(get_point), a) T = allocate_result_type(M, f, (a,)) return allocate(a, T, representation_size(M)...) end +# disambiguation +@invoke_maker 1 AbstractManifold allocate_result( + M::AbstractDecoratorManifold, + f::typeof(get_point), + a, +) function get_point(M::AbstractManifold, A::RetractionAtlas, i, a) return retract(M, i, get_vector(M, i, a, A.basis), A.retr) @@ -209,7 +221,7 @@ chart (`A`, `i`). # See also -[`VectorSpaceType`](@ref), [`AbstractAtlas`](@ref) +[`VectorSpaceType`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/bases.html#ManifoldsBase.VectorSpaceType), [`AbstractAtlas`](@ref) """ induced_basis(M::AbstractManifold, A::AbstractAtlas, i, VST::VectorSpaceType) @@ -265,7 +277,7 @@ and the set ``\{X_1,\ldots,X_n\}`` is the chart-induced basis of ``T_p\mathcal M # See also -[`VectorSpaceType`](@ref), [`AbstractBasis`](@ref) +[`VectorSpaceType`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/bases.html#ManifoldsBase.VectorSpaceType), [`AbstractBasis`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/bases.html#ManifoldsBase.AbstractBasis) """ struct InducedBasis{𝔽,VST<:VectorSpaceType,TA<:AbstractAtlas,TI} <: AbstractBasis{𝔽,VST} vs::VST @@ -303,6 +315,32 @@ function dual_basis( return induced_basis(M, B.A, B.i, TangentSpace) end +function ManifoldsBase._get_coordinates(M::AbstractManifold, p, X, B::InducedBasis) + return get_coordinates_induced_basis(M, p, X, B) +end +function get_coordinates_induced_basis(M::AbstractManifold, p, X, B::InducedBasis) + Y = allocate_result(M, get_coordinates, p, X, B) + return get_coordinates_induced_basis!(M, Y, p, X, B) +end + +function ManifoldsBase._get_coordinates!(M::AbstractManifold, Y, p, X, B::InducedBasis) + return get_coordinates_induced_basis!(M, Y, p, X, B) +end +function get_coordinates_induced_basis! end + +function ManifoldsBase._get_vector(M::AbstractManifold, p, c, B::InducedBasis) + return get_vector_induced_basis(M, p, c, B) +end +function get_vector_induced_basis(M::AbstractManifold, p, c, B::InducedBasis) + Y = allocate_result(M, get_vector, p, c) + return get_vector!(M, Y, p, c, B) +end + +function ManifoldsBase._get_vector!(M::AbstractManifold, Y, p, c, B::InducedBasis) + return get_vector_induced_basis!(M, Y, p, c, B) +end +function get_vector_induced_basis! end + """ local_metric(M::AbstractManifold, p, B::InducedBasis) diff --git a/src/cotangent_space.jl b/src/cotangent_space.jl index db0d7b1b1c..7cdef27ec3 100644 --- a/src/cotangent_space.jl +++ b/src/cotangent_space.jl @@ -19,18 +19,13 @@ function (ΞΎ::RieszRepresenterCotangentVector)(Y) return inner(ΞΎ.manifold, ΞΎ.p, ΞΎ.X, Y) end -@decorator_transparent_signature flat!( - M::AbstractDecoratorManifold, - ΞΎ::CoTFVector, - p, - X::TFVector, -) +@trait_function flat!(M::AbstractDecoratorManifold, ΞΎ::CoTFVector, p, X::TFVector) @doc raw""" flat(M::AbstractManifold, p, X) Compute the flat isomorphism (one of the musical isomorphisms) of tangent vector `X` -from the vector space of type `M` at point `p` from the underlying [`AbstractManifold`](@ref). +from the vector space of type `M` at point `p` from the underlying `AbstractManifold`. The function can be used for example to transform vectors from the tangent bundle to vectors from the cotangent bundle @@ -41,6 +36,8 @@ function flat(M::AbstractManifold, p, X::TFVector{<:Any,<:AbstractBasis}) return CoTFVector(X.data, dual_basis(M, p, X.basis)) end +is_metric_function(::typeof(flat)) = true + function flat!(::AbstractManifold, ΞΎ::RieszRepresenterCotangentVector, p, X) # TODO: maybe assert that ΞΎ.p is equal to p? Allowing for varying p in ΞΎ leads to # issues with power manifold. @@ -144,7 +141,7 @@ end sharp(M::AbstractManifold, p, ΞΎ) Compute the sharp isomorphism (one of the musical isomorphisms) of vector `ΞΎ` -from the vector space `M` at point `p` from the underlying [`AbstractManifold`](@ref). +from the vector space `M` at point `p` from the underlying `AbstractManifold`. The function can be used for example to transform vectors from the cotangent bundle to vectors from the tangent bundle @@ -152,58 +149,18 @@ $β™― : T^{*}\mathcal M β†’ T\mathcal M$ """ sharp(::AbstractManifold, p, ΞΎ) -@decorator_transparent_signature sharp( - M::AbstractDecoratorManifold, - X::TFVector, - p, - ΞΎ::CoTFVector, -) +@trait_function sharp(M::AbstractDecoratorManifold, X::TFVector, p, ΞΎ::CoTFVector) sharp(::AbstractManifold, p, ΞΎ::RieszRepresenterCotangentVector) = ΞΎ.X function sharp(M::AbstractManifold, p, X::CoTFVector{<:Any,<:AbstractBasis}) return TFVector(X.data, dual_basis(M, p, X.basis)) end -@decorator_transparent_signature sharp!( - M::AbstractDecoratorManifold, - X::TFVector, - p, - ΞΎ::CoTFVector, -) +is_metric_function(::typeof(sharp)) = true + +@trait_function sharp!(M::AbstractDecoratorManifold, X::TFVector, p, ΞΎ::CoTFVector) function sharp!(::AbstractManifold, X, p, ΞΎ::RieszRepresenterCotangentVector) copyto!(X, ΞΎ.X) return X end - -# -# Introduce transparency for connection manfiolds -# (a) new functions & other parents -for f in [flat, sharp] - eval( - quote - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractConnectionManifold, - args..., - ) - return Val(:parent) - end - end, - ) -end - -# (b) changes / intransparencies. -for f in [flat!, sharp!] - eval( - quote - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractConnectionManifold, - args..., - ) - return Val(:intransparent) - end - end, - ) -end diff --git a/src/differentiation/differentiation.jl b/src/differentiation/differentiation.jl index 4e06db64d2..b9760ec530 100644 --- a/src/differentiation/differentiation.jl +++ b/src/differentiation/differentiation.jl @@ -115,62 +115,3 @@ function set_default_differential_backend!(backend::AbstractDiffBackend) _current_default_differential_backend.backend = backend return backend end - -@doc raw""" - ODEExponentialRetraction{T<:AbstractRetractionMethod, B<:AbstractBasis} <: AbstractRetractionMethod - -Approximate the exponential map on the manifold by evaluating the ODE descripting the geodesic at 1, -assuming the default connection of the given manifold by solving the ordinary differential -equation - -```math -\frac{d^2}{dt^2} p^k + Ξ“^k_{ij} \frac{d}{dt} p_i \frac{d}{dt} p_j = 0, -``` - -where ``Ξ“^k_{ij}`` are the Christoffel symbols of the second kind, and -the Einstein summation convention is assumed. - -See [`solve_exp_ode`](@ref) for further details. - -# Constructor - - ODEExponentialRetraction( - r::AbstractRetractionMethod, - b::AbstractBasis=DefaultOrthogonalBasis(), - ) - -Generate the retraction with a retraction to use internally (for some approaches) -and a basis for the tangent space(s). -""" -struct ODEExponentialRetraction{T<:AbstractRetractionMethod,B<:AbstractBasis} <: - AbstractRetractionMethod - retraction::T - basis::B -end -function ODEExponentialRetraction(r::T) where {T<:AbstractRetractionMethod} - return ODEExponentialRetraction(r, DefaultOrthonormalBasis()) -end -function ODEExponentialRetraction(::T, b::CachedBasis) where {T<:AbstractRetractionMethod} - return throw( - DomainError( - b, - "Cached Bases are currently not supported, since the basis has to be implemented in a surrounding of the start point as well.", - ), - ) -end -function ODEExponentialRetraction(r::ExponentialRetraction, ::AbstractBasis) - return throw( - DomainError( - r, - "You can not use the exponential map as an inner method to solve the ode for the exponential map.", - ), - ) -end -function ODEExponentialRetraction(r::ExponentialRetraction, ::CachedBasis) - return throw( - DomainError( - r, - "Neither the exponential map nor a Cached Basis can be used with this retraction type.", - ), - ) -end diff --git a/src/differentiation/finite_diff.jl b/src/differentiation/finite_diff.jl deleted file mode 100644 index bd5403d60a..0000000000 --- a/src/differentiation/finite_diff.jl +++ /dev/null @@ -1,38 +0,0 @@ - -""" - FiniteDiffBackend <: AbstractDiffBackend - -A type to specify / use differentiation backend based on FiniteDiff package. - -# Constructor - FiniteDiffBackend(method::Val{Symbol} = Val{:central}) -""" -struct FiniteDiffBackend{TM<:Val} <: AbstractDiffBackend - method::TM -end - -FiniteDiffBackend() = FiniteDiffBackend(Val(:central)) - -function _derivative(f, p, ::FiniteDiffBackend{Method}) where {Method} - return FiniteDiff.finite_difference_derivative(f, p, Method) -end - -function _gradient(f, p, ::FiniteDiffBackend{Method}) where {Method} - return FiniteDiff.finite_difference_gradient(f, p, Method) -end - -function _gradient!(f, X, p, ::FiniteDiffBackend{Method}) where {Method} - return FiniteDiff.finite_difference_gradient!(X, f, p, Method) -end - -function _jacobian(f, p, ::FiniteDiffBackend{Method}) where {Method} - return FiniteDiff.finite_difference_jacobian(f, p, Method) -end - -function _jacobian!(f, X, p, ::FiniteDiffBackend{Method}) where {Method} - return FiniteDiff.finite_difference_jacobian!(X, f, p, Method) -end - -if default_differential_backend() === NoneDiffBackend() - set_default_differential_backend!(FiniteDiffBackend()) -end diff --git a/src/differentiation/forward_diff.jl b/src/differentiation/forward_diff.jl deleted file mode 100644 index c6a0f0762d..0000000000 --- a/src/differentiation/forward_diff.jl +++ /dev/null @@ -1,31 +0,0 @@ - -""" - ForwardDiffBackend <: AbstractDiffBackend - -Differentiation backend based on the ForwardDiff.jl package. -""" -struct ForwardDiffBackend <: AbstractDiffBackend end - -function Manifolds._derivative(f, p, ::ForwardDiffBackend) - return ForwardDiff.derivative(f, p) -end - -function _derivative!(f, X, t, ::ForwardDiffBackend) - return ForwardDiff.derivative!(X, f, t) -end - -function _gradient(f, p, ::ForwardDiffBackend) - return ForwardDiff.gradient(f, p) -end - -function _gradient!(f, X, t, ::ForwardDiffBackend) - return ForwardDiff.gradient!(X, f, t) -end - -function _jacobian(f, p, ::ForwardDiffBackend) - return ForwardDiff.jacobian(f, p) -end - -if default_differential_backend() === NoneDiffBackend() - set_default_differential_backend!(ForwardDiffBackend()) -end diff --git a/src/differentiation/ode.jl b/src/differentiation/ode.jl index c9c52ffae5..73c5dc1225 100644 --- a/src/differentiation/ode.jl +++ b/src/differentiation/ode.jl @@ -33,3 +33,15 @@ function solve_exp_ode( q = sol.u[1][(n + 1):(2 * n)] return q end +# also define exp! for metric manifold anew in this case +function exp!( + ::TraitList{IsMetricManifold}, + M::AbstractDecoratorManifold, + q, + p, + X; + kwargs..., +) + copyto!(M, q, solve_exp_ode(M, p, X; kwargs...)) + return q +end diff --git a/src/differentiation/reverse_diff.jl b/src/differentiation/reverse_diff.jl deleted file mode 100644 index 4ec0081c01..0000000000 --- a/src/differentiation/reverse_diff.jl +++ /dev/null @@ -1,13 +0,0 @@ -struct ReverseDiffBackend <: AbstractDiffBackend end - -function Manifolds._gradient(f, p, ::ReverseDiffBackend) - return ReverseDiff.gradient(f, p) -end - -function Manifolds._gradient!(f, X, p, ::ReverseDiffBackend) - return ReverseDiff.gradient!(X, f, p) -end - -if default_differential_backend() === NoneDiffBackend() - set_default_differential_backend!(ReverseDiffBackend()) -end diff --git a/src/differentiation/riemannian_diff.jl b/src/differentiation/riemannian_diff.jl index 47ba4caae9..2fcdf94a6e 100644 --- a/src/differentiation/riemannian_diff.jl +++ b/src/differentiation/riemannian_diff.jl @@ -46,11 +46,11 @@ end @doc raw""" TangentDiffBackend <: AbstractRiemannianDiffBackend -A backend that uses a tangent space and a basis therein to derive an +A backend that uses tangent spaces and bases therein to derive an intrinsic differentiation scheme. -Since it works in a tangent space, methods might require a retraction and an -inverse retraction as well as a basis. +Since it works in tangent spaces at argument and function value, methods might require a +retraction and an inverse retraction as well as a basis. In the tangent space itself, this backend then employs an (Euclidean) [`AbstractDiffBackend`](@ref) @@ -63,64 +63,70 @@ where `diff_backend` is an [`AbstractDiffBackend`](@ref) to be used on the tange With the keyword arguments -* `retraction` an [`AbstractRetractionMethod`](@ref) ([`ExponentialRetraction`](@ref) by default) -* `inverse_retraction` an [`AbstractInverseRetractionMethod`](@ref) ([`LogarithmicInverseRetraction`](@ref) by default) -* `basis` an [`AbstractBasis`](@ref) ([`DefaultOrthogonalBasis`](@ref) by default) +* `retraction` an [AbstractRetractionMethod](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html) (`ExponentialRetraction` by default) +* `inverse_retraction` an [AbstractInverseRetractionMethod](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html) `LogarithmicInverseRetraction` by default) +* `basis_arg` an [AbstractBasis](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/bases.html) (`DefaultOrthogonalBasis` by default) +* `basis_val` an [AbstractBasis](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/bases.html) (`DefaultOrthogonalBasis` by default) """ struct TangentDiffBackend{ TAD<:AbstractDiffBackend, TR<:AbstractRetractionMethod, TIR<:AbstractInverseRetractionMethod, - TB<:AbstractBasis, + TBarg<:AbstractBasis, + TBval<:AbstractBasis, } <: AbstractRiemannianDiffBackend diff_backend::TAD retraction::TR inverse_retraction::TIR - basis::TB + basis_arg::TBarg + basis_val::TBval end function TangentDiffBackend( diff_backend::TAD; retraction::TR=ExponentialRetraction(), inverse_retraction::TIR=LogarithmicInverseRetraction(), - basis::TB=DefaultOrthonormalBasis(), + basis_arg::TBarg=DefaultOrthonormalBasis(), + basis_val::TBval=DefaultOrthonormalBasis(), ) where { TAD<:AbstractDiffBackend, TR<:AbstractRetractionMethod, TIR<:AbstractInverseRetractionMethod, - TB<:AbstractBasis, + TBarg<:AbstractBasis, + TBval<:AbstractBasis, } - return TangentDiffBackend{TAD,TR,TIR,TB}( + return TangentDiffBackend{TAD,TR,TIR,TBarg,TBval}( diff_backend, retraction, inverse_retraction, - basis, + basis_arg, + basis_val, ) end function differential(M::AbstractManifold, f, t::Real, backend::TangentDiffBackend) p = f(t) - onb_coords = _derivative(zero(number_eltype(p)), backend.diff_backend) do h + onb_coords = Manifolds._derivative(zero(number_eltype(p)), backend.diff_backend) do h return get_coordinates( M, p, inverse_retract(M, p, f(t + h), backend.inverse_retraction), - backend.basis, + backend.basis_val, ) end - return get_vector(M, p, onb_coords, backend.basis) + return get_vector(M, p, onb_coords, backend.basis_val) end function differential!(M::AbstractManifold, f, X, t::Real, backend::TangentDiffBackend) p = f(t) - onb_coords = _derivative(zero(number_eltype(p)), backend.diff_backend) do h + onb_coords = Manifolds._derivative(zero(number_eltype(p)), backend.diff_backend) do h return get_coordinates( M, p, inverse_retract(M, p, f(t + h), backend.inverse_retraction), - backend.basis, + backend.basis_val, ) end - return get_vector!(M, X, p, onb_coords, backend.basis) + return get_vector!(M, X, p, onb_coords, backend.basis_val) end @doc raw""" @@ -133,7 +139,7 @@ This method uses the internal `backend.diff_backend` (Euclidean) on the function ``` which is given on the tangent space. In detail, the gradient can be written in -terms of the `backend.basis`. We illustrate it here for an [`AbstractOrthonormalBasis`](@ref), +terms of the `backend.basis_arg`. We illustrate it here for an [AbstractOrthonormalBasis](https://juliamanifolds.github.io/Manifolds.jl/stable/interface.html#ManifoldsBase.AbstractOrthonormalBasis), since that simplifies notations: ```math @@ -152,19 +158,19 @@ writing ``p=\exp_p(0)`` we see that this is a finite difference of ``f\circ\exp_ a function on the tangent space, so we can also use other (Euclidean) backends """ function gradient(M::AbstractManifold, f, p, backend::TangentDiffBackend) - X = get_coordinates(M, p, zero_vector(M, p), backend.basis) - onb_coords = _gradient(X, backend.diff_backend) do Y - return f(retract(M, p, get_vector(M, p, Y, backend.basis), backend.retraction)) + X = get_coordinates(M, p, zero_vector(M, p), backend.basis_arg) + onb_coords = Manifolds._gradient(X, backend.diff_backend) do Y + return f(retract(M, p, get_vector(M, p, Y, backend.basis_arg), backend.retraction)) end - return get_vector(M, p, onb_coords, backend.basis) + return get_vector(M, p, onb_coords, backend.basis_arg) end function gradient!(M::AbstractManifold, f, X, p, backend::TangentDiffBackend) - X2 = get_coordinates(M, p, zero_vector(M, p), backend.basis) - onb_coords = _gradient(X2, backend.diff_backend) do Y - return f(retract(M, p, get_vector(M, p, Y, backend.basis), backend.retraction)) + X2 = get_coordinates(M, p, zero_vector(M, p), backend.basis_arg) + onb_coords = Manifolds._gradient(X2, backend.diff_backend) do Y + return f(retract(M, p, get_vector(M, p, Y, backend.basis_arg), backend.retraction)) end - return get_vector!(M, X, p, onb_coords, backend.basis) + return get_vector!(M, X, p, onb_coords, backend.basis_arg) end @doc raw""" @@ -200,13 +206,45 @@ struct RiemannianProjectionBackend{TADBackend<:AbstractDiffBackend} <: end function gradient(M::AbstractManifold, f, p, backend::RiemannianProjectionBackend) - amb_grad = _gradient(f, p, backend.diff_backend) + amb_grad = Manifolds._gradient(f, p, backend.diff_backend) return change_representer(M, EuclideanMetric(), p, project(M, p, amb_grad)) end function gradient!(M::AbstractManifold, f, X, p, backend::RiemannianProjectionBackend) amb_grad = embed(M, p, X) - _gradient!(f, amb_grad, p, backend.diff_backend) + Manifolds._gradient!(f, amb_grad, p, backend.diff_backend) project!(M, X, p, amb_grad) return change_representer!(M, X, EuclideanMetric(), p, X) end + +function jacobian( + M_dom::AbstractManifold, + M_codom::AbstractManifold, + f, + p, + backend::TangentDiffBackend, +) + X = get_coordinates(M_dom, p, zero_vector(M_dom, p), backend.basis_arg) + q = f(p) + onb_coords = Manifolds._jacobian(X, backend.diff_backend) do Y + return get_coordinates( + M_codom, + q, + inverse_retract( + M_codom, + q, + f( + retract( + M_dom, + p, + get_vector(M_dom, p, Y, backend.basis_arg), + backend.retraction, + ), + ), + backend.inverse_retraction, + ), + backend.basis_val, + ) + end + return onb_coords +end diff --git a/src/differentiation/zygote.jl b/src/differentiation/zygote.jl deleted file mode 100644 index 883e027e30..0000000000 --- a/src/differentiation/zygote.jl +++ /dev/null @@ -1,13 +0,0 @@ -struct ZygoteDiffBackend <: AbstractDiffBackend end - -function Manifolds._gradient(f, p, ::ZygoteDiffBackend) - return Zygote.gradient(f, p)[1] -end - -function Manifolds._gradient!(f, X, p, ::ZygoteDiffBackend) - return copyto!(X, Zygote.gradient(f, p)[1]) -end - -if default_differential_backend() === NoneDiffBackend() - set_default_differential_backend!(ZygoteDiffBackend()) -end diff --git a/src/distributions.jl b/src/distributions.jl index 4dd99c7f04..4cfc4ea9dd 100644 --- a/src/distributions.jl +++ b/src/distributions.jl @@ -40,7 +40,7 @@ struct MPointvariate <: VariateForm end MPointSupport(M::AbstractManifold) Value support for manifold-valued distributions (values from given -[`AbstractManifold`](@ref) `M`). +[`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) `M`). """ struct MPointSupport{TM<:AbstractManifold} <: ValueSupport manifold::TM @@ -82,7 +82,7 @@ Optionally a random number generator `rng` to be used can be specified. An optio Usually a uniform distribution should be expected for compact manifolds and a Gaussian-like distribution for non-compact manifolds and tangent vectors, although it is - not guaranteed. The distribution may change between releases. + not guaranteed. The distribution may change between releases. `rand` methods for specific manifolds may take additional keyword arguments. diff --git a/src/groups/GroupManifold.jl b/src/groups/GroupManifold.jl new file mode 100644 index 0000000000..4f381c9fb3 --- /dev/null +++ b/src/groups/GroupManifold.jl @@ -0,0 +1,91 @@ +""" + GroupManifold{𝔽,M<:AbstractManifold{𝔽},O<:AbstractGroupOperation} <: AbstractDecoratorManifold{𝔽} + +Decorator for a smooth manifold that equips the manifold with a group operation, thus making +it a Lie group. See [`IsGroupManifold`](@ref) for more details. + +Group manifolds by default forward metric-related operations to the wrapped manifold. + + +# Constructor + + GroupManifold(manifold, op) +""" +struct GroupManifold{𝔽,M<:AbstractManifold{𝔽},O<:AbstractGroupOperation} <: + AbstractDecoratorManifold{𝔽} + manifold::M + op::O +end + +@inline function active_traits(f, M::GroupManifold, args...) + return merge_traits( + IsGroupManifold(M.op), + active_traits(f, M.manifold, args...), + IsExplicitDecorator(), + ) +end + +decorated_manifold(G::GroupManifold) = G.manifold + +(op::AbstractGroupOperation)(M::AbstractManifold) = GroupManifold(M, op) +function (::Type{T})(M::AbstractManifold) where {T<:AbstractGroupOperation} + return GroupManifold(M, T()) +end + +function inverse_retract( + ::TraitList{<:IsGroupManifold}, + G::GroupManifold, + p, + q, + method::GroupLogarithmicInverseRetraction, +) + conv = direction(method) + pinvq = inverse_translate(G, p, q, conv) + Xβ‚‘ = log_lie(G, pinvq) + return translate_diff(G, p, Identity(G), Xβ‚‘, conv) +end + +function inverse_retract!( + ::TraitList{<:IsGroupManifold}, + G::GroupManifold, + X, + p, + q, + method::GroupLogarithmicInverseRetraction, +) + conv = direction(method) + pinvq = inverse_translate(G, p, q, conv) + Xβ‚‘ = log_lie(G, pinvq) + return translate_diff!(G, X, p, Identity(G), Xβ‚‘, conv) +end + +function is_point( + ::TraitList{<:IsGroupManifold}, + G::GroupManifold, + e::Identity, + te=false; + kwargs..., +) + ie = is_identity(G, e; kwargs...) + (te && !ie) && throw(DomainError(e, "The provided identity is not a point on $G.")) + return ie +end + +function is_vector( + t::TraitList{<:IsGroupManifold}, + G::GroupManifold, + e::Identity, + X, + te=false, + cbp=true; + kwargs..., +) + if cbp + ie = is_identity(G, e; kwargs...) + (te && !ie) && throw(DomainError(e, "The provided identity is not a point on $G.")) + (!te) && return ie + end + return is_vector(G.manifold, identity_element(G), X, te, false; kwargs...) +end + +Base.show(io::IO, G::GroupManifold) = print(io, "GroupManifold($(G.manifold), $(G.op))") diff --git a/src/groups/addition_operation.jl b/src/groups/addition_operation.jl new file mode 100644 index 0000000000..90e8e92ce0 --- /dev/null +++ b/src/groups/addition_operation.jl @@ -0,0 +1,156 @@ + +""" + AdditionOperation <: AbstractGroupOperation + +Group operation that consists of simple addition. +""" +struct AdditionOperation <: AbstractGroupOperation end + +Base.:+(e::Identity{AdditionOperation}) = e +Base.:+(e::Identity{AdditionOperation}, ::Identity{AdditionOperation}) = e +Base.:+(::Identity{AdditionOperation}, p) = p +Base.:+(p, ::Identity{AdditionOperation}) = p + +Base.:-(e::Identity{AdditionOperation}) = e +Base.:-(e::Identity{AdditionOperation}, ::Identity{AdditionOperation}) = e +Base.:-(::Identity{AdditionOperation}, p) = -p +Base.:-(p, ::Identity{AdditionOperation}) = p + +Base.:*(e::Identity{AdditionOperation}, p) = e +Base.:*(p, e::Identity{AdditionOperation}) = e +Base.:*(e::Identity{AdditionOperation}, ::Identity{AdditionOperation}) = e + +const AdditionGroupTrait = TraitList{<:IsGroupManifold{AdditionOperation}} + +adjoint_action(::AdditionGroupTrait, G::AbstractDecoratorManifold, p, X) = X + +function adjoint_action!(::AdditionGroupTrait, G::AbstractDecoratorManifold, Y, p, X) + return copyto!(G, Y, p, X) +end + +identity_element(::AdditionGroupTrait, G::AbstractDecoratorManifold, p::Number) = zero(p) + +function identity_element!(::AdditionGroupTrait, G::AbstractDecoratorManifold, p) where {𝔽} + return fill!(p, zero(eltype(p))) +end + +Base.inv(::AdditionGroupTrait, G::AbstractDecoratorManifold, p) = -p +Base.inv(::AdditionGroupTrait, G::AbstractDecoratorManifold, e::Identity) = e + +inv!(::AdditionGroupTrait, G::AbstractDecoratorManifold, q, p) = copyto!(G, q, -p) +function inv!( + ::AdditionGroupTrait, + G::AbstractDecoratorManifold, + q, + ::Identity{AdditionOperation}, +) + return identity_element!(G, q) +end +function inv!( + ::AdditionGroupTrait, + G::AbstractDecoratorManifold, + q::Identity{AdditionOperation}, + e::Identity{AdditionOperation}, +) + return q +end + +function is_identity(::AdditionGroupTrait, G::AbstractDecoratorManifold, q; kwargs...) + return isapprox(G, q, zero(q); kwargs...) +end +# resolve ambiguities +function is_identity( + ::AdditionGroupTrait, + G::AbstractDecoratorManifold, + ::Identity{AdditionOperation}; + kwargs..., +) + return true +end +function is_identity( + ::AdditionGroupTrait, + G::AbstractDecoratorManifold, + ::Identity; + kwargs..., +) + return false +end + +compose(::AdditionGroupTrait, G::AbstractDecoratorManifold, p, q) = p + q + +function compose!(::AdditionGroupTrait, G::AbstractDecoratorManifold, x, p, q) + x .= p .+ q + return x +end + +function translate_diff( + ::AdditionGroupTrait, + G::AbstractDecoratorManifold, + p, + q, + X, + ::ActionDirection, +) + return X +end + +function translate_diff!( + ::AdditionGroupTrait, + G::AbstractDecoratorManifold, + Y, + p, + q, + X, + ::ActionDirection, +) + return copyto!(G, Y, p, X) +end + +function inverse_translate_diff( + ::AdditionGroupTrait, + G::AbstractDecoratorManifold, + p, + q, + X, + ::ActionDirection, +) + return X +end + +function inverse_translate_diff!( + ::AdditionGroupTrait, + G::AbstractDecoratorManifold, + Y, + p, + q, + X, + ::ActionDirection, +) + return copyto!(G, Y, p, X) +end + +exp_lie(::AdditionGroupTrait, G::AbstractDecoratorManifold, X) = X + +exp_lie!(::AdditionGroupTrait, G::AbstractDecoratorManifold, q, X) = copyto!(G, q, X) + +log_lie(::AdditionGroupTrait, G::AbstractDecoratorManifold, q) = q +function log_lie( + ::AdditionGroupTrait, + G::AbstractDecoratorManifold, + ::Identity{AdditionOperation}, +) + return zero_vector(G, identity_element(G)) +end +log_lie!(::AdditionGroupTrait, G::AbstractDecoratorManifold, X, q) = copyto!(G, X, q) +function log_lie!( + ::AdditionGroupTrait, + G::AbstractDecoratorManifold, + X, + ::Identity{AdditionOperation}, +) + return zero_vector!(G, X, identity_element(G)) +end + +lie_bracket(::AdditionGroupTrait, G::AbstractDecoratorManifold, X, Y) = zero(X) + +lie_bracket!(::AdditionGroupTrait, G::AbstractDecoratorManifold, Z, X, Y) = fill!(Z, 0) diff --git a/src/groups/circle_group.jl b/src/groups/circle_group.jl index 13dcfce570..6d499d8226 100644 --- a/src/groups/circle_group.jl +++ b/src/groups/circle_group.jl @@ -8,40 +8,58 @@ const CircleGroup = GroupManifold{β„‚,Circle{β„‚},MultiplicationOperation} CircleGroup() = GroupManifold(Circle{β„‚}(), MultiplicationOperation()) -Base.show(io::IO, ::CircleGroup) = print(io, "CircleGroup()") - -invariant_metric_dispatch(::CircleGroup, ::ActionDirection) = Val(true) +@inline function active_traits(f, M::CircleGroup, args...) + return merge_traits( + IsGroupManifold(M.op), + IsDefaultMetric(EuclideanMetric()), + HasBiinvariantMetric(), + active_traits(f, M.manifold, args...), + IsExplicitDecorator(), + ) +end -default_metric_dispatch(::MetricManifold{β„‚,CircleGroup,EuclideanMetric}) = Val(true) +Base.show(io::IO, ::CircleGroup) = print(io, "CircleGroup()") adjoint_action(::CircleGroup, p, X) = X adjoint_action!(::CircleGroup, Y, p, X) = copyto!(Y, X) -function _compose(G::CircleGroup, p::AbstractVector, q::AbstractVector) - return map(compose, repeated(G), p, q) +function compose( + ::MultiplicationGroupTrait, + G::CircleGroup, + p::AbstractArray{<:Any,0}, + q::AbstractArray{<:Any,0}, +) + return map((pp, qq) -> compose(G, pp, qq), p, q) end -_compose!(G::CircleGroup, x, p, q) = copyto!(x, compose(G, p, q)) +function compose!( + ::MultiplicationGroupTrait, + G::CircleGroup, + x, + p::AbstractArray{<:Any,0}, + q::AbstractArray{<:Any,0}, +) + return copyto!(x, compose(G, p, q)) +end identity_element(G::CircleGroup) = 1.0 identity_element(::CircleGroup, p::Number) = one(p) -identity_element(::CircleGroup, p::AbstractArray) = map(i -> one(eltype(p)), p) -Base.inv(G::CircleGroup, p::AbstractVector) = map(inv, repeated(G), p) +Base.inv(G::CircleGroup, p::AbstractArray{<:Any,0}) = map(pp -> inv(G, pp), p) function inverse_translate( ::CircleGroup, - p::AbstractVector, - q::AbstractVector, + p::AbstractArray{<:Any,0}, + q::AbstractArray{<:Any,0}, ::LeftAction, ) return map(/, q, p) end function inverse_translate( ::CircleGroup, - p::AbstractVector, - q::AbstractVector, + p::AbstractArray{<:Any,0}, + q::AbstractArray{<:Any,0}, ::RightAction, ) return map(/, q, p) @@ -76,16 +94,15 @@ end exp_lie!(G::CircleGroup, q, X) = (q .= exp_lie(G, X)) -function log_lie(::CircleGroup, q) +function _log_lie(::CircleGroup, q) return map(q) do z cosΞΈ, sinΞΈ = reim(z) ΞΈ = atan(sinΞΈ, cosΞΈ) return ΞΈ * im end end -log_lie(::CircleGroup, e::Identity{MultiplicationOperation}) = 0.0 * im -_log_lie!(G::CircleGroup, X, q) = (X .= log_lie(G, q)) +_log_lie!(G::CircleGroup, X, q) = (X .= _log_lie(G, q)) function number_of_coordinates(G::CircleGroup, B::AbstractBasis) return number_of_coordinates(base_manifold(G), B) @@ -101,39 +118,61 @@ const RealCircleGroup = GroupManifold{ℝ,Circle{ℝ},AdditionOperation} RealCircleGroup() = GroupManifold(Circle{ℝ}(), AdditionOperation()) -Base.show(io::IO, ::RealCircleGroup) = print(io, "RealCircleGroup()") +@inline function active_traits(f, M::RealCircleGroup, args...) + return merge_traits( + IsGroupManifold(M.op), + IsDefaultMetric(EuclideanMetric()), + HasBiinvariantMetric(), + active_traits(f, M.manifold, args...), + IsExplicitDecorator(), + ) +end -invariant_metric_dispatch(::RealCircleGroup, ::ActionDirection) = Val(true) +Base.show(io::IO, ::RealCircleGroup) = print(io, "RealCircleGroup()") -default_metric_dispatch(::MetricManifold{ℝ,RealCircleGroup,EuclideanMetric}) = Val(true) +is_default_metric(::RealCircleGroup, ::EuclideanMetric) = true -_compose(::RealCircleGroup, p, q) = sym_rem(p + q) -function _compose(G::RealCircleGroup, p::AbstractVector, q::AbstractVector) - return map(compose, repeated(G), p, q) +# Lazy overwrite since this is a rare case of nonmutating foo. +compose(::RealCircleGroup, p, q) = sym_rem(p + q) +compose(::RealCircleGroup, ::Identity{AdditionOperation}, q) = sym_rem(q) +compose(::RealCircleGroup, p, ::Identity{AdditionOperation}) = sym_rem(p) +function compose( + ::RealCircleGroup, + e::Identity{AdditionOperation}, + ::Identity{AdditionOperation}, +) + return e end -function _compose!(::RealCircleGroup, x, p, q) - x .= sym_rem.(p .+ q) - return x +compose!(::RealCircleGroup, x, p, q) = copyto!(x, sym_rem(p + q)) +compose!(::RealCircleGroup, x, ::Identity{AdditionOperation}, q) = copyto!(x, sym_rem(q)) +compose!(::RealCircleGroup, x, p, ::Identity{AdditionOperation}) = copyto!(x, sym_rem(p)) +function compose!( + ::RealCircleGroup, + ::Identity{AdditionOperation}, + e::Identity{AdditionOperation}, + ::Identity{AdditionOperation}, +) + return e end identity_element(G::RealCircleGroup) = 0.0 identity_element(::RealCircleGroup, p::AbstractArray) = map(i -> zero(eltype(p)), p) -Base.inv(G::RealCircleGroup, p::AbstractVector) = map(inv, repeated(G), p) +Base.inv(G::RealCircleGroup, p::AbstractArray{<:Any,0}) = map(pp -> inv(G, pp), p) function inverse_translate( ::RealCircleGroup, - p::AbstractVector, - q::AbstractVector, + p::AbstractArray{<:Any,0}, + q::AbstractArray{<:Any,0}, ::LeftAction, ) return map((x, y) -> sym_rem(x - y), q, p) end function inverse_translate( ::RealCircleGroup, - p::AbstractVector, - q::AbstractVector, + p::AbstractArray{<:Any,0}, + q::AbstractArray{<:Any,0}, ::RightAction, ) return map((x, y) -> sym_rem(x - y), q, p) diff --git a/src/groups/connections.jl b/src/groups/connections.jl index 8d36f44756..552fd65842 100644 --- a/src/groups/connections.jl +++ b/src/groups/connections.jl @@ -47,7 +47,7 @@ const CartanSchoutenPlusGroup{𝔽,M} = ConnectionManifold{𝔽,M,CartanSchouten const CartanSchoutenZeroGroup{𝔽,M} = ConnectionManifold{𝔽,M,CartanSchoutenZero} """ - exp!(M::ConnectionManifold{𝔽,<:AbstractGroupManifold{𝔽},<:AbstractCartanSchoutenConnection}, q, p, X) where {𝔽} + exp!(M::ConnectionManifold{𝔽,<:AbstractDecoratorManifold{𝔽},<:AbstractCartanSchoutenConnection}, q, p, X) where {𝔽} Compute the exponential map on the [`ConnectionManifold`](@ref) `M` with a Cartan-Schouten connection. See Sections 5.3.2 and 5.3.3 of [^Pennec2020] for details. @@ -59,7 +59,11 @@ connection. See Sections 5.3.2 and 5.3.3 of [^Pennec2020] for details. > doi: 10.1016/B978-0-12-814725-2.00012-1. """ function exp!( - M::ConnectionManifold{𝔽,<:AbstractGroupManifold{𝔽},<:AbstractCartanSchoutenConnection}, + M::ConnectionManifold{ + 𝔽, + <:AbstractDecoratorManifold{𝔽}, + <:AbstractCartanSchoutenConnection, + }, q, p, X, @@ -69,7 +73,7 @@ function exp!( end """ - log!(M::ConnectionManifold{𝔽,<:AbstractGroupManifold{𝔽},<:AbstractCartanSchoutenConnection}, Y, p, q) where {𝔽} + log!(M::ConnectionManifold{𝔽,<:AbstractDecoratorManifold{𝔽},<:AbstractCartanSchoutenConnection}, Y, p, q) where {𝔽} Compute the logarithmic map on the [`ConnectionManifold`](@ref) `M` with a Cartan-Schouten connection. See Sections 5.3.2 and 5.3.3 of [^Pennec2020] for details. @@ -81,7 +85,11 @@ connection. See Sections 5.3.2 and 5.3.3 of [^Pennec2020] for details. > doi: 10.1016/B978-0-12-814725-2.00012-1. """ function log!( - M::ConnectionManifold{𝔽,<:AbstractGroupManifold{𝔽},<:AbstractCartanSchoutenConnection}, + M::ConnectionManifold{ + 𝔽, + <:AbstractDecoratorManifold{𝔽}, + <:AbstractCartanSchoutenConnection, + }, Y, p, q, @@ -92,7 +100,7 @@ function log!( end """ - vector_transport_to(M::CartanSchoutenMinusGroup, p, X, q, ::ParallelTransport) + parallel_transport_to(M::CartanSchoutenMinusGroup, p, X, q) Transport tangent vector `X` at point `p` on the group manifold `M` with the [`CartanSchoutenMinus`](@ref) connection to point `q`. See [^Pennec2020] for details. @@ -103,14 +111,16 @@ Transport tangent vector `X` at point `p` on the group manifold `M` with the > Analysis, X. Pennec, S. Sommer, and T. Fletcher, Eds. Academic Press, 2020, pp. 169–229. > doi: 10.1016/B978-0-12-814725-2.00012-1. """ -vector_transport_to(M::CartanSchoutenMinusGroup, p, X, q, ::ParallelTransport) +function parallel_transport_to(M::CartanSchoutenMinusGroup, p, X, q) + return inverse_translate_diff(M.manifold, q, p, X, LeftAction()) +end -function vector_transport_to!(M::CartanSchoutenMinusGroup, Y, p, X, q, ::ParallelTransport) +function parallel_transport_to!(M::CartanSchoutenMinusGroup, Y, p, X, q) return inverse_translate_diff!(M.manifold, Y, q, p, X, LeftAction()) end """ - vector_transport_to(M::CartanSchoutenPlusGroup, p, X, q, ::ParallelTransport) + vector_transport_to(M::CartanSchoutenPlusGroup, p, X, q) Transport tangent vector `X` at point `p` on the group manifold `M` with the [`CartanSchoutenPlus`](@ref) connection to point `q`. See [^Pennec2020] for details. @@ -121,14 +131,14 @@ Transport tangent vector `X` at point `p` on the group manifold `M` with the > Analysis, X. Pennec, S. Sommer, and T. Fletcher, Eds. Academic Press, 2020, pp. 169–229. > doi: 10.1016/B978-0-12-814725-2.00012-1. """ -vector_transport_to(M::CartanSchoutenPlusGroup, p, X, q, ::ParallelTransport) +parallel_transport_to(M::CartanSchoutenPlusGroup, p, X, q) -function vector_transport_to!(M::CartanSchoutenPlusGroup, Y, p, X, q, ::ParallelTransport) +function parallel_transport_to!(M::CartanSchoutenPlusGroup, Y, p, X, q) return inverse_translate_diff!(M.manifold, Y, q, p, X, RightAction()) end """ - vector_transport_direction(M::CartanSchoutenZeroGroup, ::Identity, X, d, ::ParallelTransport) + parallel_transport_direction(M::CartanSchoutenZeroGroup, ::Identity, X, d) Transport tangent vector `X` at identity on the group manifold with the [`CartanSchoutenZero`](@ref) connection in the direction `d`. See [^Pennec2020] for details. @@ -139,44 +149,30 @@ Transport tangent vector `X` at identity on the group manifold with the > Analysis, X. Pennec, S. Sommer, and T. Fletcher, Eds. Academic Press, 2020, pp. 169–229. > doi: 10.1016/B978-0-12-814725-2.00012-1. """ -vector_transport_direction( - M::CartanSchoutenZeroGroup, - Y, - ::Identity, - X, - d, - ::ParallelTransport, -) +function parallel_transport_direction(M::CartanSchoutenZeroGroup, p::Identity, X, d) + dexp_half = exp_lie(M.manifold, d / 2) + Y = translate_diff(M.manifold, dexp_half, p, X, RightAction()) + return translate_diff(M.manifold, dexp_half, p, Y, LeftAction()) +end -function vector_transport_direction!( - M::CartanSchoutenZeroGroup, - Y, - p::Identity, - X, - d, - ::ParallelTransport, -) +function parallel_transport_direction!(M::CartanSchoutenZeroGroup, Y, p::Identity, X, d) dexp_half = exp_lie(M.manifold, d / 2) translate_diff!(M.manifold, Y, dexp_half, p, X, RightAction()) return translate_diff!(M.manifold, Y, dexp_half, p, Y, LeftAction()) end """ - vector_transport_to(M::CartanSchoutenZeroGroup, ::Identity, X, q, m::ParallelTransport) + parallel_transport_to(M::CartanSchoutenZeroGroup, p::Identity, X, q) Transport vector `X` at identity of group `M` equipped with the [`CartanSchoutenZero`](@ref) connection to point `q` using parallel transport. """ -vector_transport_to(::CartanSchoutenZeroGroup, ::Identity, X, q, ::ParallelTransport) +function parallel_transport_to(M::CartanSchoutenZeroGroup, p::Identity, X, q) + d = log_lie(M.manifold, q) + return parallel_transport_direction(M, p, X, d) +end -function vector_transport_to!( - M::CartanSchoutenZeroGroup, - Y, - p::Identity, - X, - q, - m::ParallelTransport, -) +function parallel_transport_to!(M::CartanSchoutenZeroGroup, Y, p::Identity, X, q) d = log_lie(M.manifold, q) - return vector_transport_direction!(M, Y, p, X, d, m) + return parallel_transport_direction!(M, Y, p, X, d) end diff --git a/src/groups/general_linear.jl b/src/groups/general_linear.jl index d9b8b92906..b685c8081d 100644 --- a/src/groups/general_linear.jl +++ b/src/groups/general_linear.jl @@ -1,6 +1,6 @@ @doc raw""" GeneralLinear{n,𝔽} <: - AbstractGroupManifold{𝔽,MultiplicationOperation,DefaultEmbeddingType} + AbstractDecoratorManifold{𝔽} The general linear group, that is, the group of all invertible matrices in ``𝔽^{nΓ—n}``. @@ -16,8 +16,16 @@ vector in the Lie algebra, and ``βŸ¨β‹…,β‹…βŸ©_\mathrm{F}`` denotes the Frobeniu By default, tangent vectors ``X_p`` are represented with their corresponding Lie algebra vectors ``X_e = p^{-1}X_p``. """ -struct GeneralLinear{n,𝔽} <: - AbstractGroupManifold{𝔽,MultiplicationOperation,DefaultGroupDecoratorType} end +struct GeneralLinear{n,𝔽} <: AbstractDecoratorManifold{𝔽} end + +function active_traits(f, ::GeneralLinear, args...) + return merge_traits( + IsGroupManifold(MultiplicationOperation()), + IsEmbeddedManifold(), + HasLeftInvariantMetric(), + IsDefaultMetric(EuclideanMetric()), + ) +end GeneralLinear(n, 𝔽::AbstractNumbers=ℝ) = GeneralLinear{n,𝔽}() @@ -26,8 +34,6 @@ function allocation_promotion_function(::GeneralLinear{n,β„‚}, f, ::Tuple) where end function check_point(G::GeneralLinear, p; kwargs...) - mpv = check_point(decorated_manifold(G), p; kwargs...) - mpv === nothing || return mpv detp = det(p) if iszero(detp) return DomainError( @@ -38,27 +44,16 @@ function check_point(G::GeneralLinear, p; kwargs...) return nothing end check_point(::GeneralLinear, ::Identity{MultiplicationOperation}) = nothing -function check_point( - G::GeneralLinear, - e::Identity{O}; - kwargs..., -) where {O<:AbstractGroupOperation} - return invoke(check_point, Tuple{AbstractGroupManifold,typeof(e)}, G, e; kwargs...) -end function check_vector(G::GeneralLinear, p, X; kwargs...) - mpv = check_vector(decorated_manifold(G), p, X; kwargs...) - mpv === nothing || return mpv return nothing end -decorated_manifold(::GeneralLinear{n,𝔽}) where {n,𝔽} = Euclidean(n, n; field=𝔽) - -default_metric_dispatch(::GeneralLinear, ::EuclideanMetric) = Val(true) -default_metric_dispatch(::GeneralLinear, ::LeftInvariantMetric{EuclideanMetric}) = Val(true) - distance(G::GeneralLinear, p, q) = norm(G, p, log(G, p, q)) +embed(::GeneralLinear, p) = p +embed(::GeneralLinear, p, X) = X + @doc raw""" exp(G::GeneralLinear, p, X) @@ -131,6 +126,8 @@ function get_coordinates!( return copyto!(Xⁱ, X) end +get_embedding(::GeneralLinear{n,𝔽}) where {n,𝔽} = Euclidean(n, n; field=𝔽) + function get_vector( ::GeneralLinear{n,ℝ}, p, @@ -156,15 +153,8 @@ function exp_lie!(::GeneralLinear{1}, q, X) end exp_lie!(::GeneralLinear{2}, q, X) = copyto!(q, exp(SizedMatrix{2,2}(X))) -function _log_lie!(::GeneralLinear{1}, X, p) - X[1] = log(p[1]) - return X -end - inner(::GeneralLinear, p, X, Y) = dot(X, Y) -invariant_metric_dispatch(::GeneralLinear, ::LeftAction) = Val(true) - inverse_translate_diff(::GeneralLinear, p, q, X, ::LeftAction) = X inverse_translate_diff(::GeneralLinear, p, q, X, ::RightAction) = p * X / p @@ -190,7 +180,7 @@ The algorithm proceeds in two stages. First, the point ``r = p^{-1} q`` is proje nearest element (under the Frobenius norm) of the direct product subgroup ``\mathrm{O}(n) Γ— S^+``, whose logarithmic map is exactly computed using the matrix logarithm. This initial tangent vector is then refined using the -[`NLsolveInverseRetraction`](@ref). +[`NLSolveInverseRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.NLSolveInverseRetraction). For `GeneralLinear(n, β„‚)`, the logarithmic map is instead computed on the realified supergroup `GeneralLinear(2n)` and the resulting tangent vector is then complexified. @@ -211,7 +201,7 @@ function log!(G::GeneralLinear{n,𝔽}, X, p, q) where {n,𝔽} pinvqα΅£ = realify(pinvq, 𝔽) Xα΅£ = realify(X, 𝔽) log_safe!(Xα΅£, _project_Un_S⁺(pinvqα΅£)) - inverse_retraction = NLsolveInverseRetraction(ExponentialRetraction(), Xα΅£) + inverse_retraction = NLSolveInverseRetraction(ExponentialRetraction(), Xα΅£) inverse_retract!(Gα΅£, Xα΅£, Identity(G), pinvqα΅£, inverse_retraction) unrealify!(X, Xα΅£, 𝔽, n) end @@ -224,10 +214,19 @@ function log!(::GeneralLinear{1}, X, p, q) return X end -manifold_dimension(G::GeneralLinear) = manifold_dimension(decorated_manifold(G)) +function _log_lie!(::GeneralLinear{1}, X, p) + X[1] = log(p[1]) + return X +end + +manifold_dimension(G::GeneralLinear) = manifold_dimension(get_embedding(G)) LinearAlgebra.norm(::GeneralLinear, p, X) = norm(X) +parallel_transport_to(::GeneralLinear, p, X, q) = X + +parallel_transport_to!(::GeneralLinear, Y, p, X, q) = copyto!(Y, X) + project(::GeneralLinear, p) = p project(::GeneralLinear, p, X) = X @@ -242,7 +241,3 @@ translate_diff(::GeneralLinear, p, q, X, ::RightAction) = p \ X * p function translate_diff!(G::GeneralLinear, Y, p, q, X, conv::ActionDirection) return copyto!(Y, translate_diff(G, p, q, X, conv)) end - -vector_transport_to(::GeneralLinear, p, X, q, ::ParallelTransport) = X - -vector_transport_to!(::GeneralLinear, Y, p, X, q, ::ParallelTransport) = copyto!(Y, X) diff --git a/src/groups/group.jl b/src/groups/group.jl index 75b6dc6770..90366df7bb 100644 --- a/src/groups/group.jl +++ b/src/groups/group.jl @@ -5,130 +5,101 @@ Abstract type for smooth binary operations $∘$ on elements of a Lie group $\ma ```math ∘ : \mathcal{G} Γ— \mathcal{G} β†’ \mathcal{G} ``` -An operation can be either defined for a specific [`AbstractGroupManifold`](@ref) over +An operation can be either defined for a specific group manifold over number system `𝔽` or in general, by defining for an operation `Op` the following methods: - identity_element!(::AbstractGroupManifold{𝔽,Op}, q, q) - inv!(::AbstractGroupManifold{𝔽,Op}, q, p) - _compose!(::AbstractGroupManifold{𝔽,Op}, x, p, q) + identity_element!(::AbstractDecoratorManifold, q, q) + inv!(::AbstractDecoratorManifold, q, p) + _compose!(::AbstractDecoratorManifold, x, p, q) Note that a manifold is connected with an operation by wrapping it with a decorator, -[`AbstractGroupManifold`](@ref). In typical cases the concrete wrapper -[`GroupManifold`](@ref) can be used. +[`AbstractDecoratorManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/decorator.html#ManifoldsBase.AbstractDecoratorManifold) using the [`IsGroupManifold`](@ref) to specify the operation. +For a concrete case the concrete wrapper [`GroupManifold`](@ref) can be used. """ abstract type AbstractGroupOperation end """ - abstract type AbstractGroupDecroatorType <: AbstractDecoratorType + IsGroupManifold{O<:AbstractGroupOperation} <: AbstractTrait -A common decorator type for all group decorators. -It is similar to [`DefaultEmbeddingType`](@ref) but for groups. -""" -abstract type AbstractGroupDecoratorType <: AbstractDecoratorType end +A trait to declare an [`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) as a manifold with group structure +with operation of type `O`. -""" - struct DefaultGroupDecoratorType <: AbstractDecoratorType +Using this trait you can turn a manifold that you implement _implictly_ into a Lie group. +If you wish to decorate an existing manifold with one (or different) [`AbstractGroupAction`](@ref)s, +see [`GroupManifold`](@ref). -The default group decorator type with no special properties. -""" -struct DefaultGroupDecoratorType <: AbstractGroupDecoratorType end -""" - struct TransparentGroupDecoratorType <: AbstractDecoratorType +# Constructor -A transparent group decorator type that acts transparently, similar to -the [`TransparentIsometricEmbedding`](@ref), i.e. it passes through all metric-related functions such as -logarithmic and exponential map as well as retraction and inverse retractions -to the manifold it decorates. + IsGroupManifold(op) """ -struct TransparentGroupDecoratorType <: AbstractGroupDecoratorType end - -@doc raw""" - AbstractGroupManifold{𝔽,O<:AbstractGroupOperation} <: AbstractDecoratorManifold{𝔽} +struct IsGroupManifold{O<:AbstractGroupOperation} <: AbstractTrait + op::O +end -Abstract type for a Lie group, a group that is also a smooth manifold with an -[`AbstractGroupOperation`](@ref), a smooth binary operation. `AbstractGroupManifold`s must -implement at least [`inv`](@ref), [`compose`](@ref), and -[`translate_diff`](@ref). """ -abstract type AbstractGroupManifold{𝔽,O<:AbstractGroupOperation,T<:AbstractDecoratorType} <: - AbstractDecoratorManifold{𝔽,T} end + AbstractInvarianceTrait <: AbstractTrait +A common supertype for anz [`AbstractTrait`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/decorator.html#ManifoldsBase.AbstractTrait) related to metric invariance """ - GroupManifold{𝔽,M<:AbstractManifold{𝔽},O<:AbstractGroupOperation} <: AbstractGroupManifold{𝔽,O} +abstract type AbstractInvarianceTrait <: AbstractTrait end -Decorator for a smooth manifold that equips the manifold with a group operation, thus making -it a Lie group. See [`AbstractGroupManifold`](@ref) for more details. - -Group manifolds by default forward metric-related operations to the wrapped manifold. - -# Constructor - - GroupManifold(manifold, op) """ -struct GroupManifold{𝔽,M<:AbstractManifold{𝔽},O<:AbstractGroupOperation} <: - AbstractGroupManifold{𝔽,O,TransparentGroupDecoratorType} - manifold::M - op::O -end + HasLeftInvariantMetric <: AbstractInvarianceTrait -Base.show(io::IO, G::GroupManifold) = print(io, "GroupManifold($(G.manifold), $(G.op))") +Specify that a certain the metric of a [`GroupManifold`](@ref) is a left-invariant metric +""" +struct HasLeftInvariantMetric <: AbstractInvarianceTrait end -const GROUP_MANIFOLD_BASIS_DISAMBIGUATION = - [AbstractDecoratorManifold, ValidationManifold, VectorBundle] +direction(::HasLeftInvariantMetric) = LeftAction() +direction(::Type{HasLeftInvariantMetric}) = LeftAction() """ - base_group(M::AbstractManifold) -> AbstractGroupManifold + HasRightInvariantMetric <: AbstractInvarianceTrait -Un-decorate `M` until an `AbstractGroupManifold` is encountered. -Return an error if the [`base_manifold`](@ref) is reached without encountering a group. +Specify that a certain the metric of a [`GroupManifold`](@ref) is a right-invariant metric """ -base_group(M::AbstractDecoratorManifold) = base_group(decorated_manifold(M)) -function base_group(::AbstractManifold) - return error("base_group: no base group found.") -end -base_group(G::AbstractGroupManifold) = G +struct HasRightInvariantMetric <: AbstractInvarianceTrait end -""" - base_manifold(M::AbstractGroupManifold, d::Val{N} = Val(-1)) +direction(::HasRightInvariantMetric) = RightAction() +direction(::Type{HasRightInvariantMetric}) = RightAction() -Return the base manifold of `M` that is enhanced with its group. -While functions like `inner` might be overwritten to use the (decorated) manifold -representing the group, the `base_manifold` is the manifold itself. -Hence for this abstract case, just `M` is returned. """ -base_manifold(M::AbstractGroupManifold, ::Val{N}=Val(-1)) where {N} = M + HasBiinvariantMetric <: AbstractInvarianceTrait +Specify that a certain the metric of a [`GroupManifold`](@ref) is a bi-invariant metric """ - base_manifold(M::GroupManifold, d::Val{N} = Val(-1)) +struct HasBiinvariantMetric <: AbstractInvarianceTrait end +function parent_trait(::HasBiinvariantMetric) + return ManifoldsBase.TraitList(HasLeftInvariantMetric(), HasRightInvariantMetric()) +end -Return the base manifold of `M` that is enhanced with its group. -Here, the internally stored enhanced manifold `M.manifold` is returned. """ -base_manifold(G::GroupManifold, ::Val{N}=Val(-1)) where {N} = G.manifold - -decorator_group_dispatch(::AbstractManifold) = Val(false) -function decorator_group_dispatch(M::AbstractDecoratorManifold) - return decorator_group_dispatch(decorated_manifold(M)) -end -decorator_group_dispatch(::AbstractGroupManifold) = Val(true) + is_group_manifold(G::GroupManifold) + is_group_manifoldd(G::AbstractManifold, o::AbstractGroupOperation) -function is_group_decorator(M::AbstractManifold) - return _extract_val(decorator_group_dispatch(M)) +returns whether an [`AbstractDecoratorManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/decorator.html#ManifoldsBase.AbstractDecoratorManifold) is a group manifold with +[`AbstractGroupOperation`](@ref) `o`. +For a [`GroupManifold`](@ref) `G` this checks whether the right operations is stored within `G`. +""" +is_group_manifold(::AbstractManifold, ::AbstractGroupOperation) = false + +@trait_function is_group_manifold(M::AbstractDecoratorManifold, op::AbstractGroupOperation) +function is_group_manifold( + ::TraitList{<:IsGroupManifold{<:O}}, + ::AbstractDecoratorManifold, + ::O, +) where {O<:AbstractGroupOperation} + return true end - -default_decorator_dispatch(::AbstractGroupManifold) = Val(false) - -(op::AbstractGroupOperation)(M::AbstractManifold) = GroupManifold(M, op) -function (::Type{T})(M::AbstractManifold) where {T<:AbstractGroupOperation} - return GroupManifold(M, T()) +@trait_function is_group_manifold(M::AbstractDecoratorManifold) +is_group_manifold(::AbstractManifold) = false +function is_group_manifold( + t::TraitList{<:IsGroupManifold{<:AbstractGroupOperation}}, + M::AbstractDecoratorManifold, +) + return is_group_manifold(M, t.head.op) end -manifold_dimension(G::GroupManifold) = manifold_dimension(G.manifold) - -################### -# Action directions -################### - """ ActionDirection @@ -159,24 +130,20 @@ switch_direction(::ActionDirection) switch_direction(::LeftAction) = RightAction() switch_direction(::RightAction) = LeftAction() -################################## -# General Identity element methods -################################## - @doc raw""" Identity{O<:AbstractGroupOperation} -Represent the group identity element ``e ∈ \mathcal{G}`` on an [`AbstractGroupManifold`](@ref) `G` +Represent the group identity element ``e ∈ \mathcal{G}`` on a Lie group ``\mathcal G`` with [`AbstractGroupOperation`](@ref) of type `O`. Similar to the philosophy that points are agnostic of their group at hand, the identity does not store the group `g` it belongs to. However it depends on the type of the [`AbstractGroupOperation`](@ref) used. -See also [`identity_element`](@ref) on how to obtain the corresponding [`AbstractManifoldPoint`](@ref) or array representation. +See also [`identity_element`](@ref) on how to obtain the corresponding [`AbstractManifoldPoint`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifoldPoint) or array representation. # Constructors - Identity(G::AbstractGroupManifold{𝔽,O}) + Identity(G::AbstractDecoratorManifold{𝔽}) Identity(o::O) Identity(::Type{O}) @@ -184,10 +151,13 @@ create the identity of the corresponding subtype `O<:`[`AbstractGroupOperation`] """ struct Identity{O<:AbstractGroupOperation} end -function Identity(::AbstractGroupManifold{𝔽,O}) where {𝔽,O<:AbstractGroupOperation} +@trait_function Identity(M::AbstractDecoratorManifold) +function Identity( + ::TraitList{<:IsGroupManifold{O}}, + ::AbstractDecoratorManifold, +) where {O<:AbstractGroupOperation} return Identity{O}() end -Identity(M::AbstractDecoratorManifold) = Identity(base_group(M)) Identity(::O) where {O<:AbstractGroupOperation} = Identity(O) Identity(::Type{O}) where {O<:AbstractGroupOperation} = Identity{O}() @@ -195,92 +165,132 @@ Identity(::Type{O}) where {O<:AbstractGroupOperation} = Identity{O}() number_eltype(::Identity) = Bool @doc raw""" - identity_element(G::AbstractGroupManifold) + identity_element(G) -Return a point representation of the [`Identity`](@ref) on the [`AbstractGroupManifold`](@ref) `G`. +Return a point representation of the [`Identity`](@ref) on the [`IsGroupManifold`](@ref) `G`. By default this representation is the default array or number representation. -It should return the corresponding [`AbstractManifoldPoint`](@ref) of points on `G` if +It should return the corresponding default representation of ``e`` as a point on `G` if points are not represented by arrays. """ -identity_element(G::AbstractGroupManifold) -@decorator_transparent_function function identity_element(G::AbstractGroupManifold) +identity_element(G::AbstractDecoratorManifold) +@trait_function identity_element(G::AbstractDecoratorManifold) +function identity_element(::TraitList{<:IsGroupManifold}, G::AbstractDecoratorManifold) q = allocate_result(G, identity_element) return identity_element!(G, q) end -@decorator_transparent_signature identity_element!(G::AbstractGroupManifold, p) +@trait_function identity_element!(G::AbstractDecoratorManifold, p) -function allocate_result(G::AbstractGroupManifold, ::typeof(identity_element)) +function allocate_result(G::AbstractDecoratorManifold, ::typeof(identity_element)) return zeros(representation_size(G)...) end @doc raw""" - identity_element(G::AbstractGroupManifold, p) + identity_element(G::AbstractDecoratorManifold, p) -Return a point representation of the [`Identity`](@ref) on the [`AbstractGroupManifold`](@ref) `G`, +Return a point representation of the [`Identity`](@ref) on the [`IsGroupManifold`](@ref) `G`, where `p` indicates the type to represent the identity. """ -identity_element(G::AbstractGroupManifold, p) -@decorator_transparent_function function identity_element(G::AbstractGroupManifold, p) +identity_element(G::AbstractDecoratorManifold, p) +@trait_function identity_element(G::AbstractDecoratorManifold, p) +function identity_element(::TraitList{<:IsGroupManifold}, G::AbstractDecoratorManifold, p) q = allocate_result(G, identity_element, p) return identity_element!(G, q) end +function check_size( + ::TraitList{<:IsGroupManifold{<:O}}, + M::AbstractDecoratorManifold, + ::Identity{<:O}, +) where {O<:AbstractGroupOperation} + return nothing +end +function check_size(::EmptyTrait, M::AbstractDecoratorManifold, e::Identity) + return DomainError(0, "$M seems to not be a group manifold with $e.") +end @doc raw""" - identity_element!(G::AbstractGroupManifold, p) - -Return a point representation of the [`Identity`](@ref) on the [`AbstractGroupManifold`](@ref) `G` -in place of `p`. -""" -identity_element!(G::AbstractGroupManifold, p) - -@doc raw""" - is_identity(G, q; kwargs) + is_identity(G::AbstractDecoratorManifold, q; kwargs) -Check whether `q` is the identity on the [`AbstractGroupManifold`](@ref) `G`, i.e. it is either +Check whether `q` is the identity on the [`IsGroupManifold`](@ref) `G`, i.e. it is either the [`Identity`](@ref)`{O}` with the corresponding [`AbstractGroupOperation`](@ref) `O`, or (approximately) the correct point representation. """ -is_identity(G::AbstractGroupManifold, q) - -@decorator_transparent_function function is_identity(G::AbstractGroupManifold, q; kwargs...) +is_identity(G::AbstractDecoratorManifold, q) +@trait_function is_identity(G::AbstractDecoratorManifold, q; kwargs...) +function is_identity( + ::TraitList{<:IsGroupManifold}, + G::AbstractDecoratorManifold, + q; + kwargs..., +) return isapprox(G, identity_element(G), q; kwargs...) end function is_identity( - ::AbstractGroupManifold{𝔽,O}, + ::TraitList{<:IsGroupManifold{O}}, + G::AbstractDecoratorManifold, ::Identity{O}; kwargs..., -) where {𝔽,O<:AbstractGroupOperation} +) where {O<:AbstractGroupOperation} return true end -is_identity(::AbstractGroupManifold, ::Identity; kwargs...) = false +function is_identity( + ::TraitList{<:IsGroupManifold}, + ::AbstractDecoratorManifold, + ::Identity; + kwargs..., +) + return false +end -function isapprox( - G::AbstractGroupManifold{𝔽,O}, +@inline function isapprox( + ::TraitList{<:IsGroupManifold{O}}, + G::AbstractDecoratorManifold, p::Identity{O}, q; kwargs..., -) where {𝔽,O<:AbstractGroupOperation} +) where {O<:AbstractGroupOperation} return is_identity(G, q; kwargs...) end -function isapprox( - G::AbstractGroupManifold{𝔽,O}, +@inline function isapprox( + ::TraitList{<:IsGroupManifold{O}}, + G::AbstractDecoratorManifold, p, q::Identity{O}; kwargs..., -) where {𝔽,O<:AbstractGroupOperation} +) where {O<:AbstractGroupOperation} return is_identity(G, p; kwargs...) end function isapprox( - G::AbstractGroupManifold{𝔽,O}, + ::TraitList{<:IsGroupManifold{O}}, + G::AbstractDecoratorManifold, p::Identity{O}, q::Identity{O}; kwargs..., -) where {𝔽,O<:AbstractGroupOperation} +) where {O<:AbstractGroupOperation} return true end function isapprox( - G::AbstractGroupManifold{𝔽,O}, + ::TraitList{<:IsGroupManifold{O}}, + G::AbstractDecoratorManifold, + p::Identity{O}, + q::Identity; + kwargs..., +) where {O<:AbstractGroupOperation} + return false +end +function isapprox( + ::TraitList{<:IsGroupManifold{O}}, + G::AbstractDecoratorManifold, + p::Identity, + q::Identity{O}; + kwargs..., +) where {O<:AbstractGroupOperation} + return false +end + +@inline function isapprox( + ::TraitList{IsGroupManifold{O}}, + G::AbstractDecoratorManifold, p::Identity{O}, X, Y; @@ -288,37 +298,62 @@ function isapprox( ) where {𝔽,O<:AbstractGroupOperation} return isapprox(G, identity_element(G), X, Y; kwargs...) end -Base.isapprox(::AbstractGroupManifold, ::Identity, ::Identity; kwargs...) = false +function Base.isapprox( + ::TraitList{<:IsGroupManifold}, + ::AbstractDecoratorManifold, + ::Identity, + ::Identity; + kwargs..., +) + return false +end function Base.show(io::IO, ::Identity{O}) where {O<:AbstractGroupOperation} return print(io, "Identity($O)") end function check_point( - G::AbstractGroupManifold{𝔽,O}, - e::Identity{O}; - kwargs..., -) where {𝔽,M,O<:AbstractGroupOperation} - return nothing -end - -function check_point( - G::AbstractGroupManifold{𝔽,O1}, + ::TraitList{<:IsGroupManifold{O1}}, + G::AbstractDecoratorManifold, e::Identity{O2}; kwargs..., -) where {𝔽,M,O1<:AbstractGroupOperation,O2<:AbstractGroupOperation} +) where {O1<:AbstractGroupOperation,O2<:AbstractGroupOperation} return DomainError( e, "The Identity $e does not lie on $G, since its the identity with respect to $O2 and not $O1.", ) end -########################## -# Group-specific functions -########################## +function is_point( + ::TraitList{<:IsGroupManifold}, + G::AbstractDecoratorManifold, + e::Identity, + te=false; + kwargs..., +) + ie = is_identity(G, e; kwargs...) + (te && !ie) && throw(DomainError(e, "The provided identity is not a point on $G.")) + return ie +end + +function is_vector( + t::TraitList{<:IsGroupManifold}, + G::AbstractDecoratorManifold, + e::Identity, + X, + te=false, + cbp=true; + kwargs..., +) + if cbp + ie = is_point(G, e; kwargs...) + (!te) && return ie + end + return is_vector(next_trait(t), G, identity_element(G), X, te, false; kwargs...) +end @doc raw""" - adjoint_action(G::AbstractGroupManifold, p, X) + adjoint_action(G::AbstractDecoratorManifold, p, X) Adjoint action of the element `p` of the Lie group `G` on the element `X` of the corresponding Lie algebra. @@ -335,154 +370,163 @@ where $e$ is the identity element of `G`. Note that the adjoint representation of a Lie group isn't generally faithful. Notably the adjoint representation of SO(2) is trivial. """ -adjoint_action(G::AbstractGroupManifold, p, X) -@decorator_transparent_function :intransparent function adjoint_action( - G::AbstractGroupManifold, - p, - Xβ‚‘, -) +adjoint_action(G::AbstractDecoratorManifold, p, X) +@trait_function adjoint_action(G::AbstractDecoratorManifold, p, Xβ‚‘) +function adjoint_action(::TraitList{<:IsGroupManifold}, G::AbstractDecoratorManifold, p, Xβ‚‘) Xβ‚š = translate_diff(G, p, Identity(G), Xβ‚‘, LeftAction()) Y = inverse_translate_diff(G, p, p, Xβ‚š, RightAction()) return Y end -function adjoint_action!(G::AbstractGroupManifold, Y, p, Xβ‚‘) +@trait_function adjoint_action!(G::AbstractDecoratorManifold, Y, p, Xβ‚‘) +function adjoint_action!( + ::TraitList{<:IsGroupManifold}, + G::AbstractDecoratorManifold, + Y, + p, + Xβ‚‘, +) Xβ‚š = translate_diff(G, p, Identity(G), Xβ‚‘, LeftAction()) inverse_translate_diff!(G, Y, p, p, Xβ‚š, RightAction()) return Y end @doc raw""" - inv(G::AbstractGroupManifold, p) + inv(G::AbstractDecoratorManifold, p) Inverse $p^{-1} ∈ \mathcal{G}$ of an element $p ∈ \mathcal{G}$, such that $p \circ p^{-1} = p^{-1} \circ p = e ∈ \mathcal{G}$, where $e$ is the [`Identity`](@ref) element of $\mathcal{G}$. """ -inv(::AbstractGroupManifold, ::Any...) -@decorator_transparent_function function Base.inv(G::AbstractGroupManifold, p) +inv(::AbstractDecoratorManifold, ::Any...) +@trait_function Base.inv(G::AbstractDecoratorManifold, p) +function Base.inv(::TraitList{<:IsGroupManifold}, G::AbstractDecoratorManifold, p) q = allocate_result(G, inv, p) return inv!(G, q, p) end function Base.inv( - ::AbstractGroupManifold{𝔽,O}, + ::TraitList{IsGroupManifold{O}}, + ::AbstractDecoratorManifold, e::Identity{O}, -) where {𝔽,O<:AbstractGroupOperation} +) where {O<:AbstractGroupOperation} return e end -@decorator_transparent_function function inv!(G::AbstractGroupManifold, q, p) +@trait_function inv!(G::AbstractDecoratorManifold, q, p) +function inv!(::TraitList{<:IsGroupManifold}, G::AbstractDecoratorManifold, q, p) return inv!(G.manifold, q, p) end function inv!( - G::AbstractGroupManifold{𝔽,O}, + ::TraitList{IsGroupManifold{O}}, + G::AbstractDecoratorManifold, q, ::Identity{O}, -) where {𝔽,O<:AbstractGroupOperation} +) where {O<:AbstractGroupOperation} return identity_element!(G, q) end function Base.copyto!( - ::AbstractGroupManifold{𝔽,O}, + ::TraitList{IsGroupManifold{O}}, + ::AbstractDecoratorManifold, e::Identity{O}, ::Identity{O}, -) where {𝔽,O<:AbstractGroupOperation} +) where {O<:AbstractGroupOperation} return e end function Base.copyto!( - G::AbstractGroupManifold{𝔽,O}, + ::TraitList{IsGroupManifold{O}}, + G::AbstractDecoratorManifold, p, ::Identity{O}, -) where {𝔽,O<:AbstractGroupOperation} +) where {O<:AbstractGroupOperation} return identity_element!(G, p) end @doc raw""" - compose(G::AbstractGroupManifold, p, q) + compose(G::AbstractDecoratorManifold, p, q) Compose elements ``p,q ∈ \mathcal{G}`` using the group operation ``p \circ q``. For implementing composition on a new group manifold, please overload `_compose` instead so that methods with [`Identity`](@ref) arguments are not ambiguous. """ -compose(::AbstractGroupManifold, ::Any...) +compose(::AbstractDecoratorManifold, ::Any...) -@decorator_transparent_function function compose( - G::AbstractGroupManifold{𝔽,Op}, - p, - q, -) where {𝔽,Op<:AbstractGroupOperation} +@trait_function compose(G::AbstractDecoratorManifold, p, q) +function compose(::TraitList{<:IsGroupManifold}, G::AbstractDecoratorManifold, p, q) return _compose(G, p, q) end function compose( - ::AbstractGroupManifold{𝔽,Op}, - ::Identity{Op}, + ::AbstractDecoratorManifold, + ::Identity{O}, p, -) where {𝔽,Op<:AbstractGroupOperation} +) where {O<:AbstractGroupOperation} return p end function compose( - ::AbstractGroupManifold{𝔽,Op}, + ::AbstractDecoratorManifold, p, - ::Identity{Op}, -) where {𝔽,Op<:AbstractGroupOperation} + ::Identity{O}, +) where {O<:AbstractGroupOperation} return p end function compose( - ::AbstractGroupManifold{𝔽,Op}, - e::Identity{Op}, - ::Identity{Op}, -) where {𝔽,Op<:AbstractGroupOperation} + ::AbstractDecoratorManifold, + e::Identity{O}, + ::Identity{O}, +) where {O<:AbstractGroupOperation} return e end -function _compose(G::AbstractGroupManifold, p, q) +function _compose(G::AbstractDecoratorManifold, p, q) x = allocate_result(G, compose, p, q) return _compose!(G, x, p, q) end -@decorator_transparent_signature compose!(M::AbstractDecoratorManifold, x, p, q) +@trait_function compose!(M::AbstractDecoratorManifold, x, p, q) -compose!(G::AbstractGroupManifold, x, q, p) = _compose!(G, x, q, p) +function compose!(::TraitList{<:IsGroupManifold}, G::AbstractDecoratorManifold, x, q, p) + return _compose!(G, x, q, p) +end function compose!( - G::AbstractGroupManifold{𝔽,Op}, + G::AbstractDecoratorManifold, q, p, - ::Identity{Op}, -) where {𝔽,Op<:AbstractGroupOperation} + ::Identity{O}, +) where {O<:AbstractGroupOperation} return copyto!(G, q, p) end function compose!( - G::AbstractGroupManifold{𝔽,Op}, + G::AbstractDecoratorManifold, q, - ::Identity{Op}, + ::Identity{O}, p, -) where {𝔽,Op<:AbstractGroupOperation} +) where {O<:AbstractGroupOperation} return copyto!(G, q, p) end function compose!( - G::AbstractGroupManifold{𝔽,Op}, + G::AbstractDecoratorManifold, q, - ::Identity{Op}, - e::Identity{Op}, -) where {𝔽,Op<:AbstractGroupOperation} + ::Identity{O}, + e::Identity{O}, +) where {O<:AbstractGroupOperation} return identity_element!(G, q) end function compose!( - ::AbstractGroupManifold{𝔽,Op}, - e::Identity{Op}, - ::Identity{Op}, - ::Identity{Op}, -) where {𝔽,Op<:AbstractGroupOperation} + ::AbstractDecoratorManifold, + e::Identity{O}, + ::Identity{O}, + ::Identity{O}, +) where {O<:AbstractGroupOperation} return e end transpose(e::Identity) = e @doc raw""" - hat(M::AbstractGroupManifold{𝔽,O}, ::Identity{O}, Xⁱ) where {𝔽,O<:AbstractGroupOperation} + hat(M::AbstractDecoratorManifold{𝔽,O}, ::Identity{O}, Xⁱ) where {𝔽,O<:AbstractGroupOperation} Given a basis $e_i$ on the tangent space at a the [`Identity`}(@ref) and tangent component vector ``X^i``, compute the equivalent vector representation @@ -497,19 +541,19 @@ vector to an array representation. The [`vee`](@ref) map is the `hat` map's inverse. """ function hat( - G::AbstractGroupManifold{𝔽,O}, + M::AbstractDecoratorManifold, ::Identity{O}, X, -) where {𝔽,O<:AbstractGroupOperation} - return get_vector_lie(G, X, VeeOrthogonalBasis()) +) where {O<:AbstractGroupOperation} + return get_vector_lie(M, X, VeeOrthogonalBasis()) end function hat!( - G::AbstractGroupManifold{𝔽,O}, + M::AbstractDecoratorManifold, Y, ::Identity{O}, X, -) where {𝔽,O<:AbstractGroupOperation} - return get_vector_lie!(G, Y, X, VeeOrthogonalBasis()) +) where {O<:AbstractGroupOperation} + return get_vector_lie!(M, Y, X, VeeOrthogonalBasis()) end function hat(M::AbstractManifold, e::Identity, ::Any) return throw(ErrorException("On $M there exsists no identity $e")) @@ -534,18 +578,18 @@ vector to a vector representation. The [`hat`](@ref) map is the `vee` map's inverse. """ function vee( - M::AbstractGroupManifold{𝔽,O}, + M::AbstractDecoratorManifold, ::Identity{O}, X, -) where {𝔽,O<:AbstractGroupOperation} +) where {O<:AbstractGroupOperation} return get_coordinates_lie(M, X, VeeOrthogonalBasis()) end function vee!( - M::AbstractGroupManifold{𝔽,O}, + M::AbstractDecoratorManifold, Y, ::Identity{O}, X, -) where {𝔽,O<:AbstractGroupOperation} +) where {O<:AbstractGroupOperation} return get_coordinates_lie!(M, Y, X, VeeOrthogonalBasis()) end function vee(M::AbstractManifold, e::Identity, ::Any) @@ -556,23 +600,25 @@ function vee!(M::AbstractManifold, ::Any, e::Identity, ::Any) end """ - lie_bracket(G::AbstractGroupManifold, X, Y) + lie_bracket(G::AbstractDecoratorManifold, X, Y) -Lie bracket between elements `X` and `Y` of the Lie algebra corresponding to the Lie group `G`. +Lie bracket between elements `X` and `Y` of the Lie algebra corresponding to +the Lie group `G`, cf. [`IsGroupManifold`](@ref). This can be used to compute the adjoint representation of a Lie algebra. Note that this representation isn't generally faithful. Notably the adjoint representation of 𝔰𝔬(2) is trivial. """ -lie_bracket(G::AbstractGroupManifold, X, Y) -@decorator_transparent_signature lie_bracket(M::AbstractDecoratorManifold, X, Y) +lie_bracket(G::AbstractDecoratorManifold, X, Y) +@trait_function lie_bracket(M::AbstractDecoratorManifold, X, Y) + +@trait_function lie_bracket!(M::AbstractDecoratorManifold, Z, X, Y) _action_order(p, q, ::LeftAction) = (p, q) _action_order(p, q, ::RightAction) = (q, p) @doc raw""" - translate(G::AbstractGroupManifold, p, q) - translate(G::AbstractGroupManifold, p, q, conv::ActionDirection=LeftAction()]) + translate(G::AbstractDecoratorManifold, p, q, conv::ActionDirection=LeftAction()]) Translate group element $q$ by $p$ with the translation $Ο„_p$ with the specified `conv`ention, either left ($L_p$) or right ($R_p$), defined as @@ -583,12 +629,16 @@ R_p &: q ↦ q \circ p. \end{aligned} ``` """ -translate(::AbstractGroupManifold, ::Any...) -@decorator_transparent_function function translate(G::AbstractGroupManifold, p, q) - return translate(G, p, q, LeftAction()) -end -@decorator_transparent_function function translate( - G::AbstractGroupManifold, +translate(::AbstractDecoratorManifold, ::Any...) +@trait_function translate( + G::AbstractDecoratorManifold, + p, + q, + conv::ActionDirection=LeftAction(), +) +function translate( + ::TraitList{<:IsGroupManifold}, + G::AbstractDecoratorManifold, p, q, conv::ActionDirection, @@ -596,11 +646,16 @@ end return compose(G, _action_order(p, q, conv)...) end -@decorator_transparent_function function translate!(G::AbstractGroupManifold, X, p, q) - return translate!(G, X, p, q, LeftAction()) -end -@decorator_transparent_function function translate!( - G::AbstractGroupManifold, +@trait_function translate!( + G::AbstractDecoratorManifold, + X, + p, + q, + conv::ActionDirection=LeftAction(), +) +function translate!( + ::TraitList{<:IsGroupManifold}, + G::AbstractDecoratorManifold, X, p, q, @@ -610,8 +665,7 @@ end end @doc raw""" - inverse_translate(G::AbstractGroupManifold, p, q) - inverse_translate(G::AbstractGroupManifold, p, q, conv::ActionDirection=LeftAction()) + inverse_translate(G::AbstractDecoratorManifold, p, q, conv::ActionDirection=LeftAction()) Inverse translate group element $q$ by $p$ with the inverse translation $Ο„_p^{-1}$ with the specified `conv`ention, either left ($L_p^{-1}$) or right ($R_p^{-1}$), defined as @@ -622,12 +676,16 @@ R_p^{-1} &: q ↦ q \circ p^{-1}. \end{aligned} ``` """ -inverse_translate(::AbstractGroupManifold, ::Any...) -@decorator_transparent_function function inverse_translate(G::AbstractGroupManifold, p, q) - return inverse_translate(G, p, q, LeftAction()) -end -@decorator_transparent_function function inverse_translate( - G::AbstractGroupManifold, +inverse_translate(::AbstractDecoratorManifold, ::Any...) +@trait_function inverse_translate( + G::AbstractDecoratorManifold, + p, + q, + conv::ActionDirection=LeftAction(), +) +function inverse_translate( + ::TraitList{<:IsGroupManifold}, + G::AbstractDecoratorManifold, p, q, conv::ActionDirection, @@ -635,16 +693,16 @@ end return translate(G, inv(G, p), q, conv) end -@decorator_transparent_function function inverse_translate!( - G::AbstractGroupManifold, +@trait_function inverse_translate!( + G::AbstractDecoratorManifold, X, p, q, + conv::ActionDirection=LeftAction(), ) - return inverse_translate!(G, X, p, q, LeftAction()) -end -@decorator_transparent_function function inverse_translate!( - G::AbstractGroupManifold, +function inverse_translate!( + ::TraitList{<:IsGroupManifold}, + G::AbstractDecoratorManifold, X, p, q, @@ -654,8 +712,7 @@ end end @doc raw""" - translate_diff(G::AbstractGroupManifold, p, q, X) - translate_diff(G::AbstractGroupManifold, p, q, X, conv::ActionDirection=LeftAction()) + translate_diff(G::AbstractDecoratorManifold, p, q, X, conv::ActionDirection=LeftAction()) For group elements $p, q ∈ \mathcal{G}$ and tangent vector $X ∈ T_q \mathcal{G}$, compute the action of the differential of the translation $Ο„_p$ by $p$ on $X$, with the specified @@ -664,43 +721,37 @@ left or right `conv`ention. The differential transports vectors: (\mathrm{d}Ο„_p)_q : T_q \mathcal{G} β†’ T_{Ο„_p q} \mathcal{G}\\ ``` """ -translate_diff(::AbstractGroupManifold, ::Any...) -@decorator_transparent_function function translate_diff(G::AbstractGroupManifold, p, q, X) - return translate_diff(G, p, q, X, LeftAction()) -end -@decorator_transparent_function function translate_diff( - G::AbstractGroupManifold, +translate_diff(::AbstractDecoratorManifold, ::Any...) +@trait_function translate_diff( + G::AbstractDecoratorManifold, p, q, X, - conv::ActionDirection, + conv::ActionDirection=LeftAction(), ) - Y = allocate_result(G, translate_diff, X, p, q) - translate_diff!(G, Y, p, q, X, conv) - return Y -end - -@decorator_transparent_function function translate_diff!( - G::AbstractGroupManifold, - Y, +function translate_diff( + ::TraitList{<:IsGroupManifold}, + G::AbstractDecoratorManifold, p, q, X, + conv::ActionDirection, ) - return translate_diff!(G, Y, p, q, X, LeftAction()) + Y = allocate_result(G, translate_diff, X, p, q) + translate_diff!(G, Y, p, q, X, conv) + return Y end -@decorator_transparent_signature translate_diff!( - M::AbstractDecoratorManifold, +@trait_function translate_diff!( + G::AbstractDecoratorManifold, Y, p, q, X, - conv::ActionDirection, + conv::ActionDirection=LeftAction(), ) @doc raw""" - inverse_translate_diff(G::AbstractGroupManifold, p, q, X) - inverse_translate_diff(G::AbstractGroupManifold, p, q, X, conv::ActionDirection=LeftAction()) + inverse_translate_diff(G::AbstractDecoratorManifold, p, q, X, conv::ActionDirection=LeftAction()) For group elements $p, q ∈ \mathcal{G}$ and tangent vector $X ∈ T_q \mathcal{G}$, compute the action on $X$ of the differential of the inverse translation $Ο„_p$ by $p$, with the @@ -709,17 +760,17 @@ specified left or right `conv`ention. The differential transports vectors: (\mathrm{d}Ο„_p^{-1})_q : T_q \mathcal{G} β†’ T_{Ο„_p^{-1} q} \mathcal{G}\\ ``` """ -inverse_translate_diff(::AbstractGroupManifold, ::Any...) -@decorator_transparent_function function inverse_translate_diff( - G::AbstractGroupManifold, +inverse_translate_diff(::AbstractDecoratorManifold, ::Any...) +@trait_function inverse_translate_diff( + G::AbstractDecoratorManifold, p, q, X, + conv::ActionDirection=LeftAction(), ) - return inverse_translate_diff(G, p, q, X, LeftAction()) -end -@decorator_transparent_function function inverse_translate_diff( - G::AbstractGroupManifold, +function inverse_translate_diff( + ::TraitList{<:IsGroupManifold}, + G::AbstractDecoratorManifold, p, q, X, @@ -728,17 +779,17 @@ end return translate_diff(G, inv(G, p), q, X, conv) end -@decorator_transparent_function function inverse_translate_diff!( - G::AbstractGroupManifold, +@trait_function inverse_translate_diff!( + G::AbstractDecoratorManifold, Y, p, q, X, + conv::ActionDirection=LeftAction(), ) - return inverse_translate_diff!(G, Y, p, q, X, LeftAction()) -end -@decorator_transparent_function function inverse_translate_diff!( - G::AbstractGroupManifold, +function inverse_translate_diff!( + ::TraitList{<:IsGroupManifold}, + G::AbstractDecoratorManifold, Y, p, q, @@ -749,7 +800,8 @@ end end @doc raw""" - exp_lie(G::AbstractGroupManifold, X) + exp_lie(G, X) + exp_lie!(G, q, X) Compute the group exponential of the Lie algebra element `X`. It is equivalent to the exponential map defined by the [`CartanSchoutenMinus`](@ref) connection. @@ -776,32 +828,28 @@ following properties: In general, the group exponential map is distinct from the Riemannian exponential map [`exp`](@ref). -``` -exp_lie(G::AbstractGroupManifold{𝔽,AdditionOperation}, X) where {𝔽} -``` - -Compute $q = X$. - - exp_lie(G::AbstractGroupManifold{𝔽,MultiplicationOperation}, X) where {𝔽} - -For `Number` and `AbstractMatrix` types of `X`, compute the usual numeric/matrix -exponential, +For example for the [`MultiplicationOperation`](@ref) and either `Number` or `AbstractMatrix` +the Lie exponential is the numeric/matrix exponential. ````math \exp X = \operatorname{Exp} X = \sum_{n=0}^∞ \frac{1}{n!} X^n. ```` + +Since this function also depends on the group operation, make sure to implement +the corresponding trait version `exp_lie(::TraitList{<:IsGroupManifold}, G, X)`. """ -exp_lie(::AbstractGroupManifold, ::Any...) -@decorator_transparent_function function exp_lie(G::AbstractGroupManifold, X) +exp_lie(G::AbstractManifold, X) +@trait_function exp_lie(M::AbstractDecoratorManifold, X) +function exp_lie(::TraitList{<:IsGroupManifold}, G::AbstractDecoratorManifold, X) q = allocate_result(G, exp_lie, X) return exp_lie!(G, q, X) end -@decorator_transparent_signature exp_lie!(M::AbstractDecoratorManifold, q, X) +@trait_function exp_lie!(M::AbstractDecoratorManifold, q, X) @doc raw""" - log_lie(G::AbstractGroupManifold, q) - log_lie!(G::AbstractGroupManifold, X, q) + log_lie(G, q) + log_lie!(G, X, q) Compute the Lie group logarithm of the Lie group element `q`. It is equivalent to the logarithmic map defined by the [`CartanSchoutenMinus`](@ref) connection. @@ -814,7 +862,7 @@ $q = \exp X$ In general, the group logarithm map is distinct from the Riemannian logarithm map [`log`](@ref). - For matrix Llie groups this is equal to the (matrix) logarithm: + For matrix Lie groups this is equal to the (matrix) logarithm: ````math \log q = \operatorname{Log} q = \sum_{n=1}^∞ \frac{(-1)^{n+1}}{n} (q - e)^n, @@ -823,37 +871,42 @@ $q = \exp X$ where $e$ here is the [`Identity`](@ref) element, that is, $1$ for numeric $q$ or the identity matrix $I_m$ for matrix $q ∈ ℝ^{m Γ— m}$. -Since this function handles [`Identity`](@ref) arguments, the preferred function to override -is `_log_lie!`. +Since this function also depends on the group operation, make sure to implement +either +* `_log_lie(G, q)` and `_log_lie!(G, X, q)` for the points not being the [`Identity`](@ref) +* the trait version `log_lie(::TraitList{<:IsGroupManifold}, G, e)`, `log_lie(::TraitList{<:IsGroupManifold}, G, X, e)` for own implementations of the identity case. """ -log_lie(::AbstractGroupManifold, ::Any...) -@decorator_transparent_function function log_lie(G::AbstractGroupManifold, q) - X = allocate_result(G, log_lie, q) - return log_lie!(G, X, q) +log_lie(::AbstractDecoratorManifold, q) +@trait_function log_lie(G::AbstractDecoratorManifold, q) +function log_lie(::TraitList{<:IsGroupManifold}, G::AbstractDecoratorManifold, q) + return _log_lie(G, q) end function log_lie( - G::AbstractGroupManifold{𝔽,Op}, - ::Identity{Op}, -) where {𝔽,Op<:AbstractGroupOperation} + ::TraitList{<:IsGroupManifold{O}}, + G::AbstractDecoratorManifold, + ::Identity{O}, +) where {O<:AbstractGroupOperation} return zero_vector(G, identity_element(G)) end +# though identity was taken care of – as usual restart decorator dispatch +function _log_lie(G::AbstractDecoratorManifold, q) + X = allocate_result(G, log_lie, q) + return log_lie!(G, X, q) +end -@decorator_transparent_function function log_lie!(G::AbstractGroupManifold, X, q) +@trait_function log_lie!(G::AbstractDecoratorManifold, X, q) +function log_lie!(::TraitList{<:IsGroupManifold}, G::AbstractDecoratorManifold, X, q) return _log_lie!(G, X, q) end - function log_lie!( - G::AbstractGroupManifold{𝔽,Op}, + ::TraitList{<:IsGroupManifold{O}}, + G::AbstractDecoratorManifold, X, - ::Identity{Op}, -) where {𝔽,Op<:AbstractGroupOperation} + ::Identity{O}, +) where {O<:AbstractGroupOperation} return zero_vector!(G, X, identity_element(G)) end -############################ -# Group-specific Retractions -############################ - """ GroupExponentialRetraction{D<:ActionDirection} <: AbstractRetractionMethod @@ -898,7 +951,7 @@ direction(::GroupLogarithmicInverseRetraction{D}) where {D} = D() @doc raw""" retract( - G::AbstractGroupManifold, + G::AbstractDecoratorManifold, p, X, method::GroupExponentialRetraction{<:ActionDirection}, @@ -917,7 +970,13 @@ where $\exp$ is the group exponential ([`exp_lie`](@ref)), and $(\mathrm{d}Ο„_p^ the action of the differential of inverse translation $Ο„_p^{-1}$ evaluated at $p$ (see [`inverse_translate_diff`](@ref)). """ -function retract(G::AbstractGroupManifold, p, X, method::GroupExponentialRetraction) +function retract( + ::TraitList{<:IsGroupManifold}, + G::AbstractDecoratorManifold, + p, + X, + method::GroupExponentialRetraction, +) conv = direction(method) Xβ‚‘ = inverse_translate_diff(G, p, p, X, conv) pinvq = exp_lie(G, Xβ‚‘) @@ -925,7 +984,14 @@ function retract(G::AbstractGroupManifold, p, X, method::GroupExponentialRetract return q end -function retract!(G::AbstractGroupManifold, q, p, X, method::GroupExponentialRetraction) +function retract!( + ::TraitList{<:IsGroupManifold}, + G::AbstractDecoratorManifold, + q, + p, + X, + method::GroupExponentialRetraction, +) conv = direction(method) Xβ‚‘ = inverse_translate_diff(G, p, p, X, conv) pinvq = exp_lie(G, Xβ‚‘) @@ -934,7 +1000,7 @@ end @doc raw""" inverse_retract( - G::AbstractGroupManifold, + G::AbstractDecoratorManifold, p, X, method::GroupLogarithmicInverseRetraction{<:ActionDirection}, @@ -953,7 +1019,13 @@ where $\log$ is the group logarithm ([`log_lie`](@ref)), and $(\mathrm{d}Ο„_p)_e action of the differential of translation $Ο„_p$ evaluated at the identity element $e$ (see [`translate_diff`](@ref)). """ -function inverse_retract(G::GroupManifold, p, q, method::GroupLogarithmicInverseRetraction) +function inverse_retract( + ::TraitList{<:IsGroupManifold}, + G::AbstractDecoratorManifold, + p, + q, + method::GroupLogarithmicInverseRetraction, +) conv = direction(method) pinvq = inverse_translate(G, p, q, conv) Xβ‚‘ = log_lie(G, pinvq) @@ -961,7 +1033,8 @@ function inverse_retract(G::GroupManifold, p, q, method::GroupLogarithmicInverse end function inverse_retract!( - G::AbstractGroupManifold, + ::TraitList{<:IsGroupManifold}, + G::AbstractDecoratorManifold, X, p, q, @@ -973,404 +1046,28 @@ function inverse_retract!( return translate_diff!(G, X, p, Identity(G), Xβ‚‘, conv) end -################################# -# Overloads for AdditionOperation -################################# - -""" - AdditionOperation <: AbstractGroupOperation - -Group operation that consists of simple addition. -""" -struct AdditionOperation <: AbstractGroupOperation end - -const AdditionGroup = AbstractGroupManifold{𝔽,AdditionOperation} where {𝔽} - -Base.:+(e::Identity{AdditionOperation}) = e -Base.:+(e::Identity{AdditionOperation}, ::Identity{AdditionOperation}) = e -Base.:+(::Identity{AdditionOperation}, p) = p -Base.:+(p, ::Identity{AdditionOperation}) = p - -Base.:-(e::Identity{AdditionOperation}) = e -Base.:-(e::Identity{AdditionOperation}, ::Identity{AdditionOperation}) = e -Base.:-(::Identity{AdditionOperation}, p) = -p -Base.:-(p, ::Identity{AdditionOperation}) = p - -Base.:*(e::Identity{AdditionOperation}, p) = e -Base.:*(p, e::Identity{AdditionOperation}) = e -Base.:*(e::Identity{AdditionOperation}, ::Identity{AdditionOperation}) = e - -adjoint_action(::AdditionGroup, p, X) = X - -adjoint_action!(G::AdditionGroup, Y, p, X) = copyto!(G, Y, p, X) - -identity_element(::AdditionGroup, p::Number) = zero(p) - -function identity_element!(::AbstractGroupManifold{𝔽,<:AdditionOperation}, p) where {𝔽} - return fill!(p, zero(eltype(p))) -end - -Base.inv(::AdditionGroup, p) = -p -Base.inv(::AdditionGroup, e::Identity) = e - -inv!(G::AdditionGroup, q, p) = copyto!(G, q, -p) -inv!(G::AdditionGroup, q, ::Identity{AdditionOperation}) = identity_element!(G, q) -inv!(::AdditionGroup, q::Identity{AdditionOperation}, e::Identity{AdditionOperation}) = q - -function is_identity(G::AdditionGroup, q; kwargs...) - return isapprox(G, q, zero(q); kwargs...) -end -function is_identity(G::AdditionGroup, e::Identity; kwargs...) - return invoke(is_identity, Tuple{AbstractGroupManifold,typeof(e)}, G, e; kwargs...) -end - -_compose(::AdditionGroup, p, q) = p + q - -function _compose!(::AdditionGroup, x, p, q) - x .= p .+ q - return x -end - -translate_diff(::AdditionGroup, p, q, X, ::ActionDirection) = X - -translate_diff!(G::AdditionGroup, Y, p, q, X, ::ActionDirection) = copyto!(G, Y, p, X) - -inverse_translate_diff(::AdditionGroup, p, q, X, ::ActionDirection) = X - -function inverse_translate_diff!(G::AdditionGroup, Y, p, q, X, ::ActionDirection) - return copyto!(G, Y, p, X) -end - -exp_lie(::AdditionGroup, X) = X - -exp_lie!(G::AdditionGroup, q, X) = copyto!(G, q, X) - -log_lie(::AdditionGroup, q) = q -function log_lie(G::AdditionGroup, ::Identity{AdditionOperation}) - return zero_vector(G, identity_element(G)) -end - -_log_lie!(G::AdditionGroup, X, q) = copyto!(G, X, q) - -lie_bracket(::AdditionGroup, X, Y) = zero(X) - -lie_bracket!(::AdditionGroup, Z, X, Y) = fill!(Z, 0) - -####################################### -# Overloads for MultiplicationOperation -####################################### - -""" - MultiplicationOperation <: AbstractGroupOperation - -Group operation that consists of multiplication. -""" -struct MultiplicationOperation <: AbstractGroupOperation end - -const MultiplicationGroup = AbstractGroupManifold{𝔽,MultiplicationOperation} where {𝔽} - -Base.:*(e::Identity{MultiplicationOperation}) = e -Base.:*(::Identity{MultiplicationOperation}, p) = p -Base.:*(p, ::Identity{MultiplicationOperation}) = p -Base.:*(e::Identity{MultiplicationOperation}, ::Identity{MultiplicationOperation}) = e -Base.:*(::Identity{MultiplicationOperation}, e::Identity{AdditionOperation}) = e -Base.:*(e::Identity{AdditionOperation}, ::Identity{MultiplicationOperation}) = e - -Base.:/(p, ::Identity{MultiplicationOperation}) = p -Base.:/(::Identity{MultiplicationOperation}, p) = inv(p) -Base.:/(e::Identity{MultiplicationOperation}, ::Identity{MultiplicationOperation}) = e - -Base.:\(p, ::Identity{MultiplicationOperation}) = inv(p) -Base.:\(::Identity{MultiplicationOperation}, p) = p -Base.:\(e::Identity{MultiplicationOperation}, ::Identity{MultiplicationOperation}) = e - -LinearAlgebra.det(::Identity{MultiplicationOperation}) = true -LinearAlgebra.adjoint(e::Identity{MultiplicationOperation}) = e - -function identity_element!(::MultiplicationGroup, p::AbstractMatrix) - return copyto!(p, I) -end - -function identity_element!(G::MultiplicationGroup, p::AbstractArray) - if length(p) == 1 - fill!(p, one(eltype(p))) - else - throw(DimensionMismatch("Array $p cannot be set to identity element of group $G")) - end - return p -end - -function is_identity(G::MultiplicationGroup, q::Number; kwargs...) - return isapprox(G, q, one(q); kwargs...) -end -function is_identity(G::MultiplicationGroup, q::AbstractVector; kwargs...) - return length(q) == 1 && isapprox(G, q[], one(q[]); kwargs...) -end -function is_identity(G::MultiplicationGroup, q::AbstractMatrix; kwargs...) - return isapprox(G, q, I; kwargs...) -end -function is_identity(G::MultiplicationGroup, e::Identity; kwargs...) - return invoke(is_identity, Tuple{AbstractGroupManifold,typeof(e)}, G, e; kwargs...) -end - -LinearAlgebra.mul!(q, ::Identity{MultiplicationOperation}, p) = copyto!(q, p) -LinearAlgebra.mul!(q, p, ::Identity{MultiplicationOperation}) = copyto!(q, p) -function LinearAlgebra.mul!( - q::AbstractMatrix, - ::Identity{MultiplicationOperation}, - ::Identity{MultiplicationOperation}, -) - return copyto!(q, I) -end -function LinearAlgebra.mul!( - q, - ::Identity{MultiplicationOperation}, - ::Identity{MultiplicationOperation}, -) - return copyto!(q, one(q)) -end -function LinearAlgebra.mul!( - q::Identity{MultiplicationOperation}, - ::Identity{MultiplicationOperation}, - ::Identity{MultiplicationOperation}, -) - return q -end -Base.one(e::Identity{MultiplicationOperation}) = e - -Base.inv(::MultiplicationGroup, p) = inv(p) -Base.inv(::MultiplicationGroup, e::Identity{MultiplicationOperation}) = e - -inv!(G::MultiplicationGroup, q, p) = copyto!(q, inv(G, p)) -function inv!(G::MultiplicationGroup, q, ::Identity{MultiplicationOperation}) - return identity_element!(G, q) -end -function inv!( - ::MultiplicationGroup, - q::Identity{MultiplicationOperation}, - e::Identity{MultiplicationOperation}, -) - return q -end - -_compose(::MultiplicationGroup, p, q) = p * q - -_compose!(::MultiplicationGroup, x, p, q) = mul!_safe(x, p, q) - -inverse_translate(::MultiplicationGroup, p, q, ::LeftAction) = p \ q -inverse_translate(::MultiplicationGroup, p, q, ::RightAction) = q / p - -function inverse_translate!(G::MultiplicationGroup, x, p, q, conv::ActionDirection) - return copyto!(x, inverse_translate(G, p, q, conv)) -end - -function exp_lie!(G::MultiplicationGroup, q, X) - X isa Union{Number,AbstractMatrix} && return copyto!(q, exp(X)) - return error( - "exp_lie! not implemented on $(typeof(G)) for vector $(typeof(X)) and element $(typeof(q)).", - ) -end - -log_lie!(::MultiplicationGroup, X::AbstractMatrix, q::AbstractMatrix) = log_safe!(X, q) - -lie_bracket(::MultiplicationGroup, X, Y) = mul!(X * Y, Y, X, -1, true) - -function lie_bracket!(::MultiplicationGroup, Z, X, Y) - mul!(Z, X, Y) - mul!(Z, Y, X, -1, true) - return Z -end - @doc raw""" - get_vector_lie(G::AbstractGroupManifold, a, B::AbstractBasis) + get_vector_lie(G::AbstractDecoratorManifold, a, B::AbstractBasis) Reconstruct a tangent vector from the Lie algebra of `G` from cooordinates `a` of a basis `B`. This is similar to calling [`get_vector`](@ref) at the `p=`[`Identity`](@ref)`(G)`. """ -function get_vector_lie(G::AbstractGroupManifold, X, B::AbstractBasis) - return get_vector(G, identity_element(G), X, B) +function get_vector_lie(G::AbstractManifold, X, B::AbstractBasis) + return get_vector(base_manifold(G), identity_element(G), X, B) end -function get_vector_lie!(G::AbstractGroupManifold, Y, X, B::AbstractBasis) - return get_vector!(G, Y, identity_element(G), X, B) +function get_vector_lie!(G::AbstractManifold, Y, X, B::AbstractBasis) + return get_vector!(base_manifold(G), Y, identity_element(G), X, B) end @doc raw""" - get_coordinates_lie(G::AbstractGroupManifold, X, B::AbstractBasis) + get_coordinates_lie(G::AbstractManifold, X, B::AbstractBasis) Get the coordinates of an element `X` from the Lie algebra og `G` with respect to a basis `B`. This is similar to calling [`get_coordinates`](@ref) at the `p=`[`Identity`](@ref)`(G)`. """ -function get_coordinates_lie(G::AbstractGroupManifold, X, B::AbstractBasis) - return get_coordinates(G, identity_element(G), X, B) -end -function get_coordinates_lie!(G::AbstractGroupManifold, a, X, B::AbstractBasis) - return get_coordinates!(G, a, identity_element(G), X, B) -end - -# (a) changes / parent. -for f in [ - embed, - get_basis, - vector_transport_direction, - vector_transport_direction!, - vector_transport_to, -] - eval( - quote - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractGroupManifold{𝔽,O,<:AbstractGroupDecoratorType}, - args..., - ) where {𝔽,O} - return Val(:parent) - end - end, - ) -end -for f in [ - get_coordinates, - get_coordinates!, - get_vector, - get_vector!, - inverse_retract!, - mid_point!, - project, - retract!, - vector_transport_along, -] - eval( - quote - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractGroupManifold{𝔽,O,DefaultGroupDecoratorType}, - args..., - ) where {𝔽,O} - return Val(:parent) - end - end, - ) -end -# (b) changes / transparencies -for f in [ - check_point, - check_vector, - copy, - copyto!, - distance, - exp, - exp!, - embed!, - get_coordinates, - get_coordinates!, - get_vector, - get_vector!, - inner, - inverse_retract, - inverse_retract!, - isapprox, - log, - log!, - mid_point, - mid_point!, - project!, - project, - retract, - retract!, - vector_transport_along, - vector_transport_direction, -] - eval( - quote - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractGroupManifold{𝔽,O,<:TransparentGroupDecoratorType}, - args..., - ) where {𝔽,O} - return Val(:transparent) - end - end, - ) -end - -# (c) changes / intransparencies. -for f in [ - compose, - compose!, - exp_lie, - exp_lie!, - log_lie, - log_lie!, - translate, - translate!, - translate_diff, - translate_diff!, -] - eval( - quote - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractGroupManifold, - args..., - ) - return Val(:intransparent) - end - end, - ) +function get_coordinates_lie(G::AbstractManifold, X, B::AbstractBasis) + return get_coordinates(base_manifold(G), identity_element(G), X, B) end -# (d) specials -for f in [vector_transport_along!, vector_transport_direction!, vector_transport_to!] - eval( - quote - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractGroupManifold{𝔽,O,<:TransparentGroupDecoratorType}, - Y, - p, - X, - q, - ::T, - ) where {𝔽,O,T} - return Val(:transparent) - end - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractGroupManifold{𝔽,O,<:AbstractGroupDecoratorType}, - Y, - p, - X, - q, - ::T, - ) where {𝔽,O,T} - return Val(:intransparent) - end - end, - ) - for m in [PoleLadderTransport, SchildsLadderTransport, ScaledVectorTransport] - eval( - quote - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractGroupManifold{𝔽,O,<:TransparentGroupDecoratorType}, - Y, - p, - X, - q, - ::$m, - ) where {𝔽,O} - return Val(:parent) - end - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractGroupManifold{𝔽,O,<:AbstractGroupDecoratorType}, - Y, - p, - X, - q, - ::$m, - ) where {𝔽,O} - return Val(:parent) - end - end, - ) - end +function get_coordinates_lie!(G::AbstractManifold, a, X, B::AbstractBasis) + return get_coordinates!(base_manifold(G), a, identity_element(G), X, B) end diff --git a/src/groups/group_action.jl b/src/groups/group_action.jl index 3a633a2578..5a243662ab 100644 --- a/src/groups/group_action.jl +++ b/src/groups/group_action.jl @@ -13,13 +13,17 @@ The group that acts in action `A`. base_group(A::AbstractGroupAction) = error("base_group not implemented for $(typeof(A)).") """ - g_manifold(A::AbstractGroupAction) + group_manifold(A::AbstractGroupAction) The manifold the action `A` acts upon. """ -g_manifold(A::AbstractGroupAction) = error("g_manifold not implemented for $(typeof(A)).") +function group_manifold(A::AbstractGroupAction) + return error("group_manifold not implemented for $(typeof(A)).") +end -allocate_result(A::AbstractGroupAction, f, p...) = allocate_result(g_manifold(A), f, p...) +function allocate_result(A::AbstractGroupAction, f, p...) + return allocate_result(group_manifold(A), f, p...) +end """ direction(::AbstractGroupAction{AD}) -> AD diff --git a/src/groups/group_operation_action.jl b/src/groups/group_operation_action.jl index 97de1beea9..1a46f5dac9 100644 --- a/src/groups/group_operation_action.jl +++ b/src/groups/group_operation_action.jl @@ -1,5 +1,5 @@ @doc raw""" - GroupOperationAction(group::AbstractGroupManifold, AD::ActionDirection = LeftAction()) + GroupOperationAction(group::AbstractDecoratorManifold, AD::ActionDirection = LeftAction()) Action of a group upon itself via left or right translation. """ @@ -8,10 +8,10 @@ struct GroupOperationAction{G,AD} <: AbstractGroupAction{AD} end function GroupOperationAction( - G::AbstractGroupManifold, + G::TM, ::TAD=LeftAction(), -) where {TAD<:ActionDirection} - return GroupOperationAction{typeof(G),TAD}(G) +) where {TM<:AbstractDecoratorManifold,TAD<:ActionDirection} + return GroupOperationAction{TM,TAD}(G) end function Base.show(io::IO, A::GroupOperationAction) @@ -20,7 +20,7 @@ end base_group(A::GroupOperationAction) = A.group -g_manifold(A::GroupOperationAction) = A.group +group_manifold(A::GroupOperationAction) = A.group function switch_direction(A::GroupOperationAction) return GroupOperationAction(A.group, switch_direction(direction(A))) diff --git a/src/groups/metric.jl b/src/groups/metric.jl index 9304f33f46..e0ab8f6783 100644 --- a/src/groups/metric.jl +++ b/src/groups/metric.jl @@ -1,74 +1,6 @@ -@doc raw""" - InvariantMetric{G<:AbstractMetric,D<:ActionDirection} <: AbstractMetric - -Extend a metric on the Lie algebra of an [`AbstractGroupManifold`](@ref) to the whole group -via translation in the specified direction. - -Given a group $\mathcal{G}$ and a left- or right group translation map $Ο„$ on the group, a -metric $g$ is $Ο„$-invariant if it has the inner product - -````math -g_p(X, Y) = g_{Ο„_q p}((\mathrm{d}Ο„_q)_p X, (\mathrm{d}Ο„_q)_p Y), -```` - -for all $p,q ∈ \mathcal{G}$ and $X,Y ∈ T_p \mathcal{G}$, where $(\mathrm{d}Ο„_q)_p$ is the -differential of translation by $q$ evaluated at $p$ (see [`translate_diff`](@ref)). - -`InvariantMetric` constructs an (assumed) $Ο„$-invariant metric by extending the inner -product of a metric $h_e$ on the Lie algebra to the whole group: - -````math -g_p(X, Y) = h_e((\mathrm{d}Ο„_p^{-1})_p X, (\mathrm{d}Ο„_p^{-1})_p Y). -```` - -!!! warning - The invariance condition is not checked and must be verified for the entire group. - To verify the condition for a set of points numerically, use - [`has_approx_invariant_metric`](@ref). - -The convenient aliases [`LeftInvariantMetric`](@ref) and [`RightInvariantMetric`](@ref) are -provided. - -# Constructor - - InvariantMetric(metric::AbstractMetric, conv::ActionDirection = LeftAction()) -""" -struct InvariantMetric{G<:AbstractMetric,D<:ActionDirection} <: AbstractMetric - metric::G - function InvariantMetric{G,D}(metric::G) where {G<:AbstractMetric,D<:ActionDirection} - return new(metric) - end -end - -function InvariantMetric(metric::MC, conv=LeftAction()) where {MC<:AbstractMetric} - return InvariantMetric{MC,typeof(conv)}(metric) -end - -const LeftInvariantMetric{G} = InvariantMetric{G,LeftAction} where {G<:AbstractMetric} - -""" - LeftInvariantMetric(metric::AbstractMetric) - -Alias for a left-[`InvariantMetric`](@ref). -""" -function LeftInvariantMetric(metric::T) where {T<:AbstractMetric} - return InvariantMetric{T,LeftAction}(metric) -end - -const RightInvariantMetric{G} = InvariantMetric{G,RightAction} where {G<:AbstractMetric} - -""" - RightInvariantMetric(metric::AbstractMetric) - -Alias for a right-[`InvariantMetric`](@ref). -""" -function RightInvariantMetric(metric::T) where {T<:AbstractMetric} - return InvariantMetric{T,RightAction}(metric) -end - @doc raw""" has_approx_invariant_metric( - G::AbstractGroupManifold, + G::AbstractDecoratorManifold, p, X, Y, @@ -90,22 +22,25 @@ This is necessary but not sufficient for invariance. Optionally, `kwargs` passed to `isapprox` may be provided. """ -has_approx_invariant_metric( - ::AbstractGroupManifold, - ::Any, - ::Any, - ::Any, - ::Any, - ::ActionDirection, -) -@decorator_transparent_function function has_approx_invariant_metric( - M::AbstractGroupManifold, +has_approx_invariant_metric(::AbstractDecoratorManifold, p, X, Y, qs, ::ActionDirection) +@trait_function has_approx_invariant_metric( + M::AbstractDecoratorManifold, p, X, Y, qs, conv::ActionDirection=LeftAction(); kwargs..., +) +function has_approx_invariant_metric( + ::TraitList{<:IsGroupManifold}, + M::AbstractDecoratorManifold, + p, + X, + Y, + qs, + conv::ActionDirection; + kwargs..., ) gpXY = inner(M, p, X, Y) for q in qs @@ -117,127 +52,190 @@ has_approx_invariant_metric( return true end -direction(::InvariantMetric{G,D}) where {G,D} = D() - -function exp!( - M::MetricManifold{𝔽,<:AbstractGroupManifold,<:InvariantMetric}, - q, - p, - X, -) where {𝔽} - if has_biinvariant_metric(M) - conv = direction(metric(M)) - return retract!(base_group(M), q, p, X, GroupExponentialRetraction(conv)) - end - return invoke(exp!, Tuple{MetricManifold,typeof(q),typeof(p),typeof(X)}, M, q, p, X) -end - """ - biinvariant_metric_dispatch(G::AbstractGroupManifold) -> Val + direction(::AbstractDecoratorManifold) -> AD -Return `Val(true)` if the metric on the manifold is bi-invariant, that is, if the metric -is both left- and right-invariant (see [`invariant_metric_dispatch`](@ref)). +Get the direction of the action a certain Lie group with its implicit metric has """ -function biinvariant_metric_dispatch(M::AbstractManifold) - return Val( - invariant_metric_dispatch(M, LeftAction()) === Val(true) && - invariant_metric_dispatch(M, RightAction()) === Val(true), - ) -end +direction(::AbstractDecoratorManifold) -has_biinvariant_metric(M::AbstractManifold) = _extract_val(biinvariant_metric_dispatch(M)) +@trait_function direction(M::AbstractDecoratorManifold) -@doc raw""" - invariant_metric_dispatch(G::AbstractGroupManifold, conv::ActionDirection) -> Val +direction(::TraitList{HasLeftInvariantMetric}, ::AbstractDecoratorManifold) = LeftAction() -Return `Val(true)` if the metric on the group $\mathcal{G}$ is invariant under translations -by the specified direction, that is, given a group $\mathcal{G}$, a left- or right group -translation map $Ο„$, and a metric $g_e$ on the Lie algebra, a $Ο„$-invariant metric at -any point $p ∈ \mathcal{G}$ is defined as a metric with the inner product +direction(::TraitList{HasRightInvariantMetric}, ::AbstractDecoratorManifold) = RightAction() -````math -g_p(X, Y) = g_{Ο„_q p}((\mathrm{d}Ο„_q)_p X, (\mathrm{d}Ο„_q)_p Y), -```` +function exp(::TraitList{HasLeftInvariantMetric}, M::MetricManifold, p, X) + return retract(M.manifold, p, X, GroupExponentialRetraction(LeftAction())) +end +function exp!(::TraitList{HasLeftInvariantMetric}, M::MetricManifold, q, p, X) + return retract!(M.manifold, q, p, X, GroupExponentialRetraction(LeftAction())) +end +function exp(::TraitList{HasRightInvariantMetric}, M::MetricManifold, p, X) + return retract(M.manifold, p, X, GroupExponentialRetraction(RightAction())) +end +function exp!(::TraitList{HasRightInvariantMetric}, M::MetricManifold, q, p, X) + return retract!(M.manifold, q, p, X, GroupExponentialRetraction(RightAction())) +end +function exp(::TraitList{HasBiinvariantMetric}, M::MetricManifold, p, X) + return exp(M.manifold, p, X) +end +function exp!(::TraitList{HasBiinvariantMetric}, M::MetricManifold, q, p, X) + return exp!(M.manifold, q, p, X) +end -for $X,Y ∈ T_q \mathcal{G}$ and all $q ∈ \mathcal{G}$, where $(\mathrm{d}Ο„_q)_p$ is the -differential of translation by $q$ evaluated at $p$ (see [`translate_diff`](@ref)). -""" -invariant_metric_dispatch(::MetricManifold, ::ActionDirection) +@trait_function has_invariant_metric(M::AbstractDecoratorManifold, op::ActionDirection) -@decorator_transparent_signature invariant_metric_dispatch( - M::AbstractDecoratorManifold, - conv::ActionDirection, +# Fallbacks / false +has_invariant_metric(::AbstractManifold, op::ActionDirection) = false +function has_invariant_metric( + ::TraitList{<:HasLeftInvariantMetric}, + ::AbstractDecoratorManifold, + ::LeftAction, ) -function invariant_metric_dispatch(M::MetricManifold, conv::ActionDirection) - is_default_metric(M) && return invariant_metric_dispatch(M.manifold, conv) - return Val(false) + return true end -function invariant_metric_dispatch( - M::MetricManifold{𝔽,<:AbstractManifold,<:InvariantMetric}, - conv::ActionDirection, -) where {𝔽} - direction(metric(M)) === conv && return Val(true) - return invoke(invariant_metric_dispatch, Tuple{MetricManifold,typeof(conv)}, M, conv) +function has_invariant_metric( + ::TraitList{<:HasRightInvariantMetric}, + ::AbstractDecoratorManifold, + ::RightAction, +) + return true end -invariant_metric_dispatch(::AbstractManifold, ::ActionDirection) = Val(false) -function has_invariant_metric(M::AbstractManifold, conv::ActionDirection) - return _extract_val(invariant_metric_dispatch(M, conv)) +@trait_function has_biinvariant_metric(M::AbstractDecoratorManifold) + +# fallbavk / default: false +has_biinvariant_metric(::AbstractManifold) = false +function has_biinvariant_metric( + ::TraitList{<:HasBiinvariantMetric}, + ::AbstractDecoratorManifold, +) + return true end -function inner(M::MetricManifold{𝔽,<:AbstractManifold,<:InvariantMetric}, p, X, Y) where {𝔽} - imetric = metric(M) - conv = direction(imetric) - N = MetricManifold(M.manifold, imetric.metric) +function inner( + t::TraitList{IT}, + M::AbstractDecoratorManifold, + p, + X, + Y, +) where {IT<:AbstractInvarianceTrait} + conv = direction(t, M) Xβ‚‘ = inverse_translate_diff(M, p, p, X, conv) Yβ‚‘ = inverse_translate_diff(M, p, p, Y, conv) - return inner(N, Identity(N), Xβ‚‘, Yβ‚‘) + return inner(next_trait(t), M, Identity(M), Xβ‚‘, Yβ‚‘) end - -function default_metric_dispatch( - M::MetricManifold{𝔽,<:AbstractManifold,<:InvariantMetric}, -) where {𝔽} - imetric = metric(M) - N = MetricManifold(M.manifold, imetric.metric) - default_metric_dispatch(N) === Val(true) || return Val(false) - return invariant_metric_dispatch(N, direction(imetric)) +function inner( + t::TraitList{<:IsGroupManifold}, + M::AbstractDecoratorManifold, + ::Identity, + X, + Y, +) + return inner(next_trait(t), M, identity_element(M), X, Y) end -function log!( - M::MetricManifold{𝔽,<:AbstractGroupManifold,<:InvariantMetric}, +function inverse_translate_diff( + ::TraitList{IsMetricManifold}, + M::MetricManifold, + p, + q, X, + conv::ActionDirection, +) + return inverse_translate_diff(M.manifold, p, q, X, conv) +end +function inverse_translate_diff!( + ::TraitList{IsMetricManifold}, + M::MetricManifold, + Y, p, q, -) where {𝔽} - if has_biinvariant_metric(M) - imetric = metric(M) - conv = direction(imetric) - return inverse_retract!( - base_group(M), - X, - p, - q, - GroupLogarithmicInverseRetraction(conv), - ) - end - return invoke(log!, Tuple{MetricManifold,typeof(X),typeof(p),typeof(q)}, M, X, p, q) + X, + conv::ActionDirection, +) + return inverse_translate_diff!(M.manifold, Y, p, q, X, conv) +end + +function log(::TraitList{HasLeftInvariantMetric}, M::MetricManifold, p, q) + return inverse_retract( + M.manifold, + p, + q, + GroupLogarithmicInverseRetraction(LeftAction()), + ) +end +function log!(::TraitList{HasLeftInvariantMetric}, M::MetricManifold, X, p, q) + return inverse_retract!( + M.manifold, + X, + p, + q, + GroupLogarithmicInverseRetraction(LeftAction()), + ) +end +function log(::TraitList{HasRightInvariantMetric}, M::MetricManifold, p, q) + return inverse_retract( + M.manifold, + p, + q, + GroupLogarithmicInverseRetraction(RightAction()), + ) +end +function log!(::TraitList{HasRightInvariantMetric}, M::MetricManifold, X, p, q) + return inverse_retract!( + M.manifold, + X, + p, + q, + GroupLogarithmicInverseRetraction(RightAction()), + ) +end +function log(::TraitList{HasBiinvariantMetric}, M::MetricManifold, p, q) + return log(M.manifold, p, q) +end +function log!(::TraitList{HasBiinvariantMetric}, M::MetricManifold, X, p, q) + return log!(M.manifold, X, p, q) end function LinearAlgebra.norm( - M::MetricManifold{𝔽,<:AbstractManifold,<:InvariantMetric}, + t::TraitList{IT}, + M::AbstractDecoratorManifold, p, X, -) where {𝔽} - imetric = metric(M) - conv = direction(imetric) - N = MetricManifold(M.manifold, imetric.metric) +) where {IT<:AbstractInvarianceTrait} + conv = direction(t, M) Xβ‚‘ = inverse_translate_diff(M, p, p, X, conv) - return norm(N, Identity(N), Xβ‚‘) + return norm(next_trait(t), M, Identity(M), Xβ‚‘) +end +function LinearAlgebra.norm( + t::TraitList{<:IsGroupManifold}, + M::AbstractDecoratorManifold, + ::Identity, + X, +) + return norm(next_trait(t), M, identity_element(M), X) end -function Base.show(io::IO, metric::LeftInvariantMetric) - return print(io, "LeftInvariantMetric($(metric.metric))") +function translate_diff!( + ::TraitList{IsMetricManifold}, + M::MetricManifold, + p, + q, + X, + conv::ActionDirection, +) + return translate_diff(M.manifold, p, q, X, conv) end -function Base.show(io::IO, metric::RightInvariantMetric) - return print(io, "RightInvariantMetric($(metric.metric))") +function translate_diff!( + ::TraitList{IsMetricManifold}, + M::MetricManifold, + Y, + p, + q, + X, + conv::ActionDirection, +) + return translate_diff!(M.manifold, Y, p, q, X, conv) end diff --git a/src/groups/multiplication_operation.jl b/src/groups/multiplication_operation.jl new file mode 100644 index 0000000000..f1a3c3e880 --- /dev/null +++ b/src/groups/multiplication_operation.jl @@ -0,0 +1,198 @@ + +""" + MultiplicationOperation <: AbstractGroupOperation + +Group operation that consists of multiplication. +""" +struct MultiplicationOperation <: AbstractGroupOperation end + +const MultiplicationGroupTrait = TraitList{<:IsGroupManifold{<:MultiplicationOperation}} + +Base.:*(e::Identity{MultiplicationOperation}) = e +Base.:*(::Identity{MultiplicationOperation}, p) = p +Base.:*(p, ::Identity{MultiplicationOperation}) = p +Base.:*(e::Identity{MultiplicationOperation}, ::Identity{MultiplicationOperation}) = e +Base.:*(::Identity{MultiplicationOperation}, e::Identity{AdditionOperation}) = e +Base.:*(e::Identity{AdditionOperation}, ::Identity{MultiplicationOperation}) = e + +Base.:/(p, ::Identity{MultiplicationOperation}) = p +Base.:/(::Identity{MultiplicationOperation}, p) = inv(p) +Base.:/(e::Identity{MultiplicationOperation}, ::Identity{MultiplicationOperation}) = e + +Base.:\(p, ::Identity{MultiplicationOperation}) = inv(p) +Base.:\(::Identity{MultiplicationOperation}, p) = p +Base.:\(e::Identity{MultiplicationOperation}, ::Identity{MultiplicationOperation}) = e + +LinearAlgebra.det(::Identity{MultiplicationOperation}) = true +LinearAlgebra.adjoint(e::Identity{MultiplicationOperation}) = e + +function identity_element!( + ::MultiplicationGroupTrait, + G::AbstractDecoratorManifold, + p::AbstractMatrix, +) + return copyto!(p, I) +end + +function identity_element!( + ::MultiplicationGroupTrait, + G::AbstractDecoratorManifold, + p::AbstractArray, +) + if length(p) == 1 + fill!(p, one(eltype(p))) + else + throw(DimensionMismatch("Array $p cannot be set to identity element of group $G")) + end + return p +end + +function is_identity( + ::MultiplicationGroupTrait, + G::AbstractDecoratorManifold, + q::Number; + kwargs..., +) + return isapprox(G, q, one(q); kwargs...) +end +function is_identity( + ::MultiplicationGroupTrait, + G::AbstractDecoratorManifold, + q::AbstractArray{<:Any,0}; + kwargs..., +) + return isapprox(G, q[], one(q[]); kwargs...) +end +function is_identity( + ::MultiplicationGroupTrait, + G::AbstractDecoratorManifold, + q::AbstractMatrix; + kwargs..., +) + return isapprox(G, q, I; kwargs...) +end + +LinearAlgebra.mul!(q, ::Identity{MultiplicationOperation}, p) = copyto!(q, p) +LinearAlgebra.mul!(q, p, ::Identity{MultiplicationOperation}) = copyto!(q, p) +function LinearAlgebra.mul!( + q::AbstractMatrix, + ::Identity{MultiplicationOperation}, + ::Identity{MultiplicationOperation}, +) + return copyto!(q, I) +end +function LinearAlgebra.mul!( + q, + ::Identity{MultiplicationOperation}, + ::Identity{MultiplicationOperation}, +) + return copyto!(q, one(q)) +end +function LinearAlgebra.mul!( + q::Identity{MultiplicationOperation}, + ::Identity{MultiplicationOperation}, + ::Identity{MultiplicationOperation}, +) + return q +end +Base.one(e::Identity{MultiplicationOperation}) = e + +Base.inv(::MultiplicationGroupTrait, G::AbstractDecoratorManifold, p) = inv(p) +function Base.inv( + ::MultiplicationGroupTrait, + G::AbstractDecoratorManifold, + e::Identity{MultiplicationOperation}, +) + return e +end + +inv!(::MultiplicationGroupTrait, G::AbstractDecoratorManifold, q, p) = copyto!(q, inv(G, p)) +function inv!( + ::MultiplicationGroupTrait, + G::AbstractDecoratorManifold, + q, + ::Identity{MultiplicationOperation}, +) + return identity_element!(G, q) +end +function inv!( + ::MultiplicationGroupTrait, + G::AbstractDecoratorManifold, + q::Identity{MultiplicationOperation}, + e::Identity{MultiplicationOperation}, +) + return q +end + +compose(::MultiplicationGroupTrait, G::AbstractDecoratorManifold, p, q) = p * q + +function compose!(::MultiplicationGroupTrait, G::AbstractDecoratorManifold, x, p, q) + return mul!_safe(x, p, q) +end + +function inverse_translate( + ::MultiplicationGroupTrait, + G::AbstractDecoratorManifold, + p, + q, + ::LeftAction, +) + return p \ q +end +function inverse_translate( + ::MultiplicationGroupTrait, + G::AbstractDecoratorManifold, + p, + q, + ::RightAction, +) + return q / p +end + +function inverse_translate!( + ::MultiplicationGroupTrait, + G::AbstractDecoratorManifold, + x, + p, + q, + conv::ActionDirection, +) + return copyto!(x, inverse_translate(G, p, q, conv)) +end + +function exp_lie!( + ::MultiplicationGroupTrait, + G::AbstractDecoratorManifold, + q, + X::Union{Number,AbstractMatrix}, +) + copyto!(q, exp(X)) + return q +end + +function log_lie!( + ::MultiplicationGroupTrait, + G::AbstractDecoratorManifold, + X::AbstractMatrix, + q::AbstractMatrix, +) + return log_safe!(X, q) +end +function log_lie!( + ::MultiplicationGroupTrait, + G::AbstractDecoratorManifold, + X, + ::Identity{MultiplicationOperation}, +) + return zero_vector!(G, X, identity_element(G)) +end + +function lie_bracket(::MultiplicationGroupTrait, G::AbstractDecoratorManifold, X, Y) + return mul!(X * Y, Y, X, -1, true) +end + +function lie_bracket!(::MultiplicationGroupTrait, G::AbstractDecoratorManifold, Z, X, Y) + mul!(Z, X, Y) + mul!(Z, Y, X, -1, true) + return Z +end diff --git a/src/groups/product_group.jl b/src/groups/product_group.jl index e41bff6800..4c32482381 100644 --- a/src/groups/product_group.jl +++ b/src/groups/product_group.jl @@ -12,7 +12,7 @@ const ProductGroup{𝔽,T} = GroupManifold{𝔽,ProductManifold{𝔽,T},ProductO Decorate a product manifold with a [`ProductOperation`](@ref). -Each submanifold must also be an [`AbstractGroupManifold`](@ref) or a decorated instance of +Each submanifold must also have a [`IsGroupManifold`](@ref) or a decorated instance of one. This type is mostly useful for equipping the direct product of group manifolds with an [`Identity`](@ref) element. @@ -20,20 +20,13 @@ one. This type is mostly useful for equipping the direct product of group manifo ProductGroup(manifold::ProductManifold) """ function ProductGroup(manifold::ProductManifold{𝔽}) where {𝔽} - if !all(is_group_decorator, manifold.manifolds) + if !all(is_group_manifold, manifold.manifolds) error("All submanifolds of product manifold must be or decorate groups.") end op = ProductOperation() return GroupManifold(manifold, op) end -function decorator_transparent_dispatch(::typeof(exp_lie!), M::ProductGroup, q, X) - return Val(:transparent) -end -function decorator_transparent_dispatch(::typeof(log_lie!), M::ProductGroup, X, q) - return Val(:transparent) -end - function identity_element(G::ProductGroup) M = G.manifold return ProductRepr(map(identity_element, M.manifolds)) @@ -45,14 +38,11 @@ function identity_element!(G::ProductGroup, p) return p end -function is_identity(G::ProductGroup, p; kwargs...) +function is_identity(G::ProductGroup, p::Identity{<:ProductOperation}; kwargs...) pes = submanifold_components(G, p) M = G.manifold # Inner prodct manifold (of groups) return all(map((M, pe) -> is_identity(M, pe; kwargs...), M.manifolds, pes)) end -function is_identity(G::ProductGroup, e::Identity; kwargs...) - return invoke(is_identity, Tuple{AbstractGroupManifold,typeof(e)}, G, e; kwargs...) -end function Base.show(io::IO, ::MIME"text/plain", G::ProductGroup) print( diff --git a/src/groups/rotation_action.jl b/src/groups/rotation_action.jl index f1c67e8003..f07157a4e7 100644 --- a/src/groups/rotation_action.jl +++ b/src/groups/rotation_action.jl @@ -34,18 +34,18 @@ const RotationActionOnVector{N,F,TAD} = RotationAction{ base_group(A::RotationAction) = A.SOn -g_manifold(A::RotationAction) = A.manifold +group_manifold(A::RotationAction) = A.manifold function switch_direction(A::RotationAction{TM,TSO,TAD}) where {TM,TSO,TAD} return RotationAction(A.manifold, A.SOn, switch_direction(TAD())) end -apply(A::RotationActionOnVector{N,F,LeftAction}, a, p) where {N,F} = a * p +apply(::RotationActionOnVector{N,F,LeftAction}, a, p) where {N,F} = a * p function apply(A::RotationActionOnVector{N,F,RightAction}, a, p) where {N,F} return inv(base_group(A), a) * p end -apply!(A::RotationActionOnVector{N,F,LeftAction}, q, a, p) where {N,F} = mul!(q, a, p) +apply!(::RotationActionOnVector{N,F,LeftAction}, q, a, p) where {N,F} = mul!(q, a, p) function inverse_apply(A::RotationActionOnVector{N,F,LeftAction}, a, p) where {N,F} return inv(base_group(A), a) * p @@ -65,7 +65,7 @@ function apply_diff(A::RotationActionOnVector{N,F,RightAction}, a, p, X) where { return inv(base_group(A), a) * X end -function apply_diff!(A::RotationActionOnVector{N,F,LeftAction}, Y, a, p, X) where {N,F} +function apply_diff!(::RotationActionOnVector{N,F,LeftAction}, Y, a, p, X) where {N,F} return mul!(Y, a, X) end function apply_diff!(A::RotationActionOnVector{N,F,RightAction}, Y, a, p, X) where {N,F} @@ -103,7 +103,7 @@ end base_group(::RotationAroundAxisAction) = RealCircleGroup() -g_manifold(::RotationAroundAxisAction) = Euclidean(3) +group_manifold(::RotationAroundAxisAction) = Euclidean(3) @doc raw""" apply(A::RotationAroundAxisAction, ΞΈ, p) diff --git a/src/groups/semidirect_product_group.jl b/src/groups/semidirect_product_group.jl index 057d87154c..d4562979b2 100644 --- a/src/groups/semidirect_product_group.jl +++ b/src/groups/semidirect_product_group.jl @@ -7,12 +7,6 @@ Group operation of a semidirect product group. The operation consists of the ope """ struct SemidirectProductOperation{A<:AbstractGroupAction} <: AbstractGroupOperation action::A - function SemidirectProductOperation{A}(action::A) where {A<:AbstractGroupAction} - return new(action) - end -end -function SemidirectProductOperation(action::A) where {A<:AbstractGroupAction} - return SemidirectProductOperation{A}(action) end function Base.show(io::IO, op::SemidirectProductOperation) @@ -45,7 +39,7 @@ function SemidirectProductGroup( H::GroupManifold{𝔽}, A::AbstractGroupAction, ) where {𝔽} - N === g_manifold(A) || error("Subgroup $(N) must be the G-manifold of action $(A)") + N === group_manifold(A) || error("Subgroup $(N) must be the G-manifold of action $(A)") H === base_group(A) || error("Subgroup $(H) must be the base group of action $(A)") op = SemidirectProductOperation(A) M = ProductManifold(N, H) @@ -78,15 +72,16 @@ function identity_element!(G::SemidirectProductGroup, q) return q end -function is_identity(G::SemidirectProductGroup, p; kwargs...) +function is_identity( + G::SemidirectProductGroup, + p::Identity{<:SemidirectProductOperation}; + kwargs..., +) M = base_manifold(G) N, H = M.manifolds nq, hq = submanifold_components(G, p) return is_identity(N, nq; kwargs...) && is_identity(H, hq; kwargs...) end -function is_identity(G::SemidirectProductGroup, e::Identity; kwargs...) - return invoke(is_identity, Tuple{AbstractGroupManifold,typeof(e)}, G, e; kwargs...) -end function Base.show(io::IO, G::SemidirectProductGroup) M = base_manifold(G) diff --git a/src/groups/special_euclidean.jl b/src/groups/special_euclidean.jl index e7b3fe33ce..ced2a8362c 100644 --- a/src/groups/special_euclidean.jl +++ b/src/groups/special_euclidean.jl @@ -42,11 +42,10 @@ function SpecialEuclidean(n) return SemidirectProductGroup(Tn, SOn, A) end -const SpecialEuclideanIdentity{N} = Identity{ - SemidirectProductOperation{ - RotationAction{TranslationGroup{Tuple{N},ℝ},SpecialOrthogonal{N},LeftAction}, - }, +const SpecialEuclideanOperation{N} = SemidirectProductOperation{ + RotationAction{TranslationGroup{Tuple{N},ℝ},SpecialOrthogonal{N},LeftAction}, } +const SpecialEuclideanIdentity{N} = Identity{SpecialEuclideanOperation{N}} Base.show(io::IO, ::SpecialEuclidean{n}) where {n} = print(io, "SpecialEuclidean($(n))") @@ -153,11 +152,8 @@ function affine_matrix(::SpecialEuclidean{n}, ::SpecialEuclideanIdentity{n}) whe return Diagonal{Float64}(I, n) end -function check_point(G::SpecialEuclidean{n}, p::AbstractMatrix; kwargs...) where {n} +function check_point(G::SpecialEuclideanManifold{n}, p::AbstractMatrix; kwargs...) where {n} errs = DomainError[] - # Valid matrix - err1 = check_point(Euclidean(n + 1, n + 1), p) - !isnothing(err1) && push!(errs, err1) # homogeneous if !isapprox(p[end, :], [zeros(size(p, 2) - 1)..., 1]; kwargs...) push!( @@ -179,16 +175,26 @@ function check_point(G::SpecialEuclidean{n}, p::AbstractMatrix; kwargs...) where end return length(errs) == 0 ? nothing : first(errs) end + +function check_size(G::SpecialEuclideanManifold{n}, p::AbstractMatrix; kwargs...) where {n} + return check_size(Euclidean(n + 1, n + 1), p) +end +function check_size( + G::SpecialEuclideanManifold{n}, + p::AbstractMatrix, + X::AbstractMatrix; + kwargs..., +) where {n} + return check_size(Euclidean(n + 1, n + 1), X) +end + function check_vector( - G::SpecialEuclidean{n}, + G::SpecialEuclideanManifold{n}, p::AbstractMatrix, X::AbstractMatrix; kwargs..., ) where {n} errs = DomainError[] - # Valid matrix - err1 = check_point(Euclidean(n + 1, n + 1), X) - !isnothing(err1) && push!(errs, err1) # homogeneous if !isapprox(X[end, :], zeros(size(X, 2)); kwargs...) push!( @@ -246,9 +252,9 @@ function allocate_result(::SpecialEuclidean{n}, ::typeof(screw_matrix), X...) wh return allocate(X[1], Size(n + 1, n + 1)) end -_compose(::SpecialEuclidean, p::AbstractMatrix, q::AbstractMatrix) = p * q +compose(::SpecialEuclidean, p::AbstractMatrix, q::AbstractMatrix) = p * q -function _compose!( +function compose!( ::SpecialEuclidean, x::AbstractMatrix, p::AbstractMatrix, @@ -486,7 +492,6 @@ function _log_lie!(G::SpecialEuclidean{3}, X, q) @inbounds _padvector!(G, X) return X end - """ lie_bracket(G::SpecialEuclidean, X::ProductRepr, Y::ProductRepr) lie_bracket(G::SpecialEuclidean, X::AbstractMatrix, Y::AbstractMatrix) diff --git a/src/groups/special_linear.jl b/src/groups/special_linear.jl index 2e8b17646c..02f2853188 100644 --- a/src/groups/special_linear.jl +++ b/src/groups/special_linear.jl @@ -1,6 +1,5 @@ @doc raw""" - SpecialLinear{n,𝔽} <: - AbstractGroupManifold{𝔽,MultiplicationOperation,DefaultEmbeddingType} + SpecialLinear{n,𝔽} <: AbstractDecoratorManifold The special linear group ``\mathrm{SL}(n,𝔽)`` that is, the group of all invertible matrices with unit determinant in ``𝔽^{nΓ—n}``. @@ -14,20 +13,26 @@ The default metric is the same left-``\mathrm{GL}(n)``-right-``\mathrm{O}(n)``-i metric used for [`GeneralLinear(n, 𝔽)`](@ref). The resulting geodesic on ``\mathrm{GL}(n,𝔽)`` emanating from an element of ``\mathrm{SL}(n,𝔽)`` in the direction of an element of ``𝔰𝔩(n, 𝔽)`` is a closed subgroup of ``\mathrm{SL}(n,𝔽)``. As a result, most -metric functions forward to `GeneralLinear`. +metric functions forward to [`GeneralLinear`](@ref). """ -struct SpecialLinear{n,𝔽} <: - AbstractGroupManifold{𝔽,MultiplicationOperation,TransparentGroupDecoratorType} end +struct SpecialLinear{n,𝔽} <: AbstractDecoratorManifold{𝔽} end SpecialLinear(n, 𝔽::AbstractNumbers=ℝ) = SpecialLinear{n,𝔽}() +@inline function active_traits(f, ::SpecialLinear, args...) + return merge_traits( + IsGroupManifold(MultiplicationOperation()), + IsEmbeddedSubmanifold(), + HasLeftInvariantMetric(), + IsDefaultMetric(EuclideanMetric()), + ) +end + function allocation_promotion_function(::SpecialLinear{n,β„‚}, f, args::Tuple) where {n} return complex end function check_point(G::SpecialLinear{n,𝔽}, p; kwargs...) where {n,𝔽} - mpv = check_point(Euclidean(n, n; field=𝔽), p; kwargs...) - mpv === nothing || return mpv detp = det(p) if !isapprox(detp, 1; kwargs...) return DomainError( @@ -38,18 +43,8 @@ function check_point(G::SpecialLinear{n,𝔽}, p; kwargs...) where {n,𝔽} end return nothing end -check_point(G::SpecialLinear, ::Identity{MultiplicationOperation}; kwargs...) = nothing -function check_point( - G::SpecialLinear, - e::Identity{O}; - kwargs..., -) where {O<:AbstractGroupOperation} - return invoke(check_point, Tuple{AbstractGroupManifold,typeof(e)}, G, e; kwargs...) -end function check_vector(G::SpecialLinear, p, X; kwargs...) - mpv = check_vector(decorated_manifold(G), p, X; kwargs...) - mpv === nothing || return mpv trX = tr(inverse_translate_diff(G, p, p, X, LeftAction())) if !isapprox(trX, 0; kwargs...) return DomainError( @@ -61,10 +56,10 @@ function check_vector(G::SpecialLinear, p, X; kwargs...) return nothing end -decorated_manifold(::SpecialLinear{n,𝔽}) where {n,𝔽} = GeneralLinear(n, 𝔽) +embed(::SpecialLinear, p) = p +embed(::SpecialLinear, p, X) = X -default_metric_dispatch(::SpecialLinear, ::EuclideanMetric) = Val(true) -default_metric_dispatch(::SpecialLinear, ::LeftInvariantMetric{EuclideanMetric}) = Val(true) +get_embedding(::SpecialLinear{n,𝔽}) where {n,𝔽} = GeneralLinear(n, 𝔽) inverse_translate_diff(::SpecialLinear, p, q, X, ::LeftAction) = X inverse_translate_diff(::SpecialLinear, p, q, X, ::RightAction) = p * X / p @@ -74,7 +69,7 @@ function inverse_translate_diff!(G::SpecialLinear, Y, p, q, X, conv::ActionDirec end function manifold_dimension(G::SpecialLinear) - return manifold_dimension(decorated_manifold(G)) - real_dimension(number_system(G)) + return manifold_dimension(get_embedding(G)) - real_dimension(number_system(G)) end @doc raw""" @@ -131,10 +126,6 @@ function project!(G::SpecialLinear{n}, Y, p, X) where {n} return Y end -function decorator_transparent_dispatch(::typeof(project), ::SpecialLinear, args...) - return Val(:parent) -end - Base.show(io::IO, ::SpecialLinear{n,𝔽}) where {n,𝔽} = print(io, "SpecialLinear($n, $𝔽)") translate_diff(::SpecialLinear, p, q, X, ::LeftAction) = X diff --git a/src/groups/special_orthogonal.jl b/src/groups/special_orthogonal.jl index 2987f461c8..ac7422487a 100644 --- a/src/groups/special_orthogonal.jl +++ b/src/groups/special_orthogonal.jl @@ -8,18 +8,41 @@ Special orthogonal group $\mathrm{SO}(n)$ represented by rotation matrices. """ const SpecialOrthogonal{n} = GroupManifold{ℝ,Rotations{n},MultiplicationOperation} -invariant_metric_dispatch(::SpecialOrthogonal, ::ActionDirection) = Val(true) - -function default_metric_dispatch( - ::MetricManifold{𝔽,<:SpecialOrthogonal,EuclideanMetric}, -) where {𝔽} - return Val(true) +@inline function active_traits(f, ::SpecialOrthogonal, args...) + if is_metric_function(f) + #pass to Rotations by default - but keep Group Decorator for the retraction + return merge_traits( + IsGroupManifold(MultiplicationOperation()), + IsExplicitDecorator(), + ) + else + return merge_traits( + IsGroupManifold(MultiplicationOperation()), + HasBiinvariantMetric(), + IsDefaultMetric(EuclideanMetric()), + IsExplicitDecorator(), #pass to Rotations by default/last fallback + ) + end end -default_metric_dispatch(::SpecialOrthogonal, ::EuclideanMetric) = Val(true) SpecialOrthogonal(n) = SpecialOrthogonal{n}(Rotations(n), MultiplicationOperation()) -Base.show(io::IO, ::SpecialOrthogonal{n}) where {n} = print(io, "SpecialOrthogonal($(n))") +function allocate_result( + ::SpecialOrthogonal, + ::typeof(exp), + ::Identity{MultiplicationOperation}, + X, +) + return allocate(X) +end +function allocate_result( + ::SpecialOrthogonal, + ::typeof(log), + ::Identity{MultiplicationOperation}, + q, +) + return allocate(q) +end Base.inv(::SpecialOrthogonal, p) = transpose(p) Base.inv(::SpecialOrthogonal, e::Identity{MultiplicationOperation}) = e @@ -27,13 +50,6 @@ Base.inv(::SpecialOrthogonal, e::Identity{MultiplicationOperation}) = e inverse_translate(G::SpecialOrthogonal, p, q, ::LeftAction) = inv(G, p) * q inverse_translate(G::SpecialOrthogonal, p, q, ::RightAction) = q * inv(G, p) -translate_diff(::SpecialOrthogonal, p, q, X, ::LeftAction) = X -translate_diff(G::SpecialOrthogonal, p, q, X, ::RightAction) = inv(G, p) * X * p - -function translate_diff!(G::SpecialOrthogonal, Y, p, q, X, conv::ActionDirection) - return copyto!(Y, translate_diff(G, p, q, X, conv)) -end - function inverse_translate_diff(G::SpecialOrthogonal, p, q, X, conv::ActionDirection) return translate_diff(G, inv(G, p), q, X, conv) end @@ -42,19 +58,11 @@ function inverse_translate_diff!(G::SpecialOrthogonal, Y, p, q, X, conv::ActionD return copyto!(Y, inverse_translate_diff(G, p, q, X, conv)) end -function allocate_result( - ::GT, - ::typeof(exp), - ::Identity, - X, -) where {n,GT<:SpecialOrthogonal{n}} - return allocate(X) -end -function allocate_result( - ::GT, - ::typeof(log), - ::Identity, - q, -) where {n,GT<:SpecialOrthogonal{n}} - return allocate(q) +translate_diff(::SpecialOrthogonal, p, q, X, ::LeftAction) = X +translate_diff(G::SpecialOrthogonal, p, q, X, ::RightAction) = inv(G, p) * X * p + +function translate_diff!(G::SpecialOrthogonal, Y, p, q, X, conv::ActionDirection) + return copyto!(Y, translate_diff(G, p, q, X, conv)) end + +Base.show(io::IO, ::SpecialOrthogonal{n}) where {n} = print(io, "SpecialOrthogonal($(n))") diff --git a/src/groups/translation_action.jl b/src/groups/translation_action.jl index 970e81461e..ed68926a34 100644 --- a/src/groups/translation_action.jl +++ b/src/groups/translation_action.jl @@ -30,7 +30,7 @@ end base_group(A::TranslationAction) = A.Rn -g_manifold(A::TranslationAction) = A.manifold +group_manifold(A::TranslationAction) = A.manifold function switch_direction(A::TranslationAction{TM,TRN,TAD}) where {TM,TRN,TAD} return TranslationAction(A.manifold, A.Rn, switch_direction(TAD())) diff --git a/src/groups/translation_group.jl b/src/groups/translation_group.jl index bc60a2612d..c4838ce672 100644 --- a/src/groups/translation_group.jl +++ b/src/groups/translation_group.jl @@ -18,16 +18,18 @@ function TranslationGroup(n::Int...; field::AbstractNumbers=ℝ) ) end -identity_element!(::TranslationGroup, p) = fill!(p, 0) - -invariant_metric_dispatch(::TranslationGroup, ::ActionDirection) = Val(true) - -function default_metric_dispatch( - ::MetricManifold{𝔽,<:TranslationGroup,EuclideanMetric}, -) where {𝔽} - return Val(true) +@inline function active_traits(f, M::TranslationGroup, args...) + return merge_traits( + IsGroupManifold(M.op), + IsDefaultMetric(EuclideanMetric()), + HasBiinvariantMetric(), + active_traits(f, M.manifold, args...), + IsExplicitDecorator(), + ) end +identity_element!(::TranslationGroup, p) = fill!(p, 0) + function Base.show(io::IO, ::TranslationGroup{N,𝔽}) where {N,𝔽} return print(io, "TranslationGroup($(join(N.parameters, ", ")); field = $(𝔽))") end diff --git a/src/groups/validation_group.jl b/src/groups/validation_group.jl index 025c77a25c..98ec45915d 100644 --- a/src/groups/validation_group.jl +++ b/src/groups/validation_group.jl @@ -6,9 +6,7 @@ array_value(e::Identity) = e array_point(p) = ValidationMPoint(p) array_point(p::ValidationMPoint) = p -const ValidationGroup{𝔽} = ValidationManifold{𝔽,G} where {G<:AbstractGroupManifold} - -function adjoint_action(M::ValidationGroup, p, X; kwargs...) +function adjoint_action(M::ValidationManifold, p, X; kwargs...) is_point(M, p, true; kwargs...) eM = Identity(M.manifold) is_vector(M, eM, X, true; kwargs...) @@ -17,7 +15,7 @@ function adjoint_action(M::ValidationGroup, p, X; kwargs...) return Y end -function adjoint_action!(M::ValidationGroup, Y, p, X; kwargs...) +function adjoint_action!(M::ValidationManifold, Y, p, X; kwargs...) is_point(M, p, true; kwargs...) eM = Identity(M.manifold) is_vector(M, eM, X, true; kwargs...) @@ -26,24 +24,24 @@ function adjoint_action!(M::ValidationGroup, Y, p, X; kwargs...) return Y end -Identity(M::ValidationGroup) = array_point(Identity(M.manifold)) -identity_element!(M::ValidationGroup, p) = identity_element!(M.manifold, array_value(p)) +Identity(M::ValidationManifold) = array_point(Identity(M.manifold)) +identity_element!(M::ValidationManifold, p) = identity_element!(M.manifold, array_value(p)) -function Base.inv(M::ValidationGroup, p; kwargs...) +function Base.inv(M::ValidationManifold, p; kwargs...) is_point(M, p, true; kwargs...) q = array_point(inv(M.manifold, array_value(p))) is_point(M, q, true; kwargs...) return q end -function inv!(M::ValidationGroup, q, p; kwargs...) +function inv!(M::ValidationManifold, q, p; kwargs...) is_point(M, p, true; kwargs...) inv!(M.manifold, array_value(q), array_value(p)) is_point(M, q, true; kwargs...) return q end -function lie_bracket(M::ValidationGroup, X, Y) +function lie_bracket(M::ValidationManifold, X, Y) eM = Identity(M.manifold) is_vector(M, eM, X, true) is_vector(M, eM, Y, true) @@ -52,7 +50,7 @@ function lie_bracket(M::ValidationGroup, X, Y) return Z end -function lie_bracket!(M::ValidationGroup, Z, X, Y) +function lie_bracket!(M::ValidationManifold, Z, X, Y) eM = Identity(M.manifold) is_vector(M, eM, X, true) is_vector(M, eM, Y, true) @@ -61,15 +59,44 @@ function lie_bracket!(M::ValidationGroup, Z, X, Y) return Z end -function compose(M::ValidationGroup, p, q; kwargs...) +function compose(M::ValidationManifold, p, q; kwargs...) is_point(M, p, true; kwargs...) is_point(M, q, true; kwargs...) x = array_point(compose(M.manifold, array_value(p), array_value(q))) is_point(M, x, true; kwargs...) return x end +function compose(M::ValidationManifold, p::Identity, q; kwargs...) + is_point(M, p, true; kwargs...) + is_point(M, q, true; kwargs...) + x = array_point(compose(M.manifold, p, array_value(q))) + is_point(M, x, true; kwargs...) + return x +end +function compose(M::ValidationManifold, p, q::Identity; kwargs...) + is_point(M, p, true; kwargs...) + is_point(M, q, true; kwargs...) + x = array_point(compose(M.manifold, array_value(p), q)) + is_point(M, x, true; kwargs...) + return x +end + +function compose!(M::ValidationManifold, x, p, q; kwargs...) + is_point(M, p, true; kwargs...) + is_point(M, q, true; kwargs...) + compose!(M.manifold, array_value(x), array_value(p), array_value(q)) + is_point(M, x, true; kwargs...) + return x +end -function compose!(M::ValidationGroup, x, p, q; kwargs...) +function compose!(M::ValidationManifold, x, p::Identity, q; kwargs...) + is_point(M, p, true; kwargs...) + is_point(M, q, true; kwargs...) + compose!(M.manifold, array_value(x), array_value(p), array_value(q)) + is_point(M, x, true; kwargs...) + return x +end +function compose!(M::ValidationManifold, x, p, q::Identity; kwargs...) is_point(M, p, true; kwargs...) is_point(M, q, true; kwargs...) compose!(M.manifold, array_value(x), array_value(p), array_value(q)) @@ -204,28 +231,14 @@ function inverse_translate_diff!( end function exp_lie(M::ValidationManifold, X; kwargs...) - is_vector( - M, - Identity(M.manifold), - array_value(X), - true; - check_base_point=false, - kwargs..., - ) + is_vector(M, Identity(M.manifold), array_value(X), true; kwargs...) q = array_point(exp_lie(M.manifold, array_value(X))) is_point(M, q, true; kwargs...) return q end function exp_lie!(M::ValidationManifold, q, X; kwargs...) - is_vector( - M, - Identity(M.manifold), - array_value(X), - true; - check_base_point=false, - kwargs..., - ) + is_vector(M, Identity(M.manifold), array_value(X), true; kwargs...) exp_lie!(M.manifold, array_value(q), array_value(X)) is_point(M, q, true; kwargs...) return q @@ -234,27 +247,13 @@ end function log_lie(M::ValidationManifold, q; kwargs...) is_point(M, q, true; kwargs...) X = ValidationTVector(log_lie(M.manifold, array_value(q))) - is_vector( - M, - Identity(M.manifold), - array_value(X), - true; - check_base_point=false, - kwargs..., - ) + is_vector(M, Identity(M.manifold), array_value(X), true; kwargs...) return X end function log_lie!(M::ValidationManifold, X, q; kwargs...) is_point(M, q, true; kwargs...) log_lie!(M.manifold, array_value(X), array_value(q)) - is_vector( - M, - Identity(M.manifold), - array_value(X), - true; - check_base_point=false, - kwargs..., - ) + is_vector(M, Identity(M.manifold), array_value(X), true; kwargs...) return X end diff --git a/src/manifolds/CenteredMatrices.jl b/src/manifolds/CenteredMatrices.jl index 33d922b727..83bfc342b9 100644 --- a/src/manifolds/CenteredMatrices.jl +++ b/src/manifolds/CenteredMatrices.jl @@ -1,5 +1,5 @@ @doc raw""" - CenteredMatrices{m,n,𝔽} <: AbstractEmbeddedManifold{𝔽,TransparentIsometricEmbedding} + CenteredMatrices{m,n,𝔽} <: AbstractDecoratorManifold{𝔽} The manifold of $m Γ— n$ real-valued or complex-valued matrices whose columns sum to zero, i.e. ````math @@ -12,12 +12,14 @@ where $𝔽 ∈ \{ℝ,β„‚\}$. Generate the manifold of `m`-by-`n` (`field`-valued) matrices whose columns sum to zero. """ -struct CenteredMatrices{M,N,𝔽} <: AbstractEmbeddedManifold{𝔽,TransparentIsometricEmbedding} end +struct CenteredMatrices{M,N,𝔽} <: AbstractDecoratorManifold{𝔽} end function CenteredMatrices(m::Int, n::Int, field::AbstractNumbers=ℝ) return CenteredMatrices{m,n,field}() end +active_traits(f, ::CenteredMatrices, args...) = merge_traits(IsEmbeddedSubmanifold()) + @doc raw""" check_point(M::CenteredMatrices{m,n,𝔽}, p; kwargs...) @@ -28,8 +30,6 @@ zero. The tolerance for the column sums of `p` can be set using `kwargs...`. """ function check_point(M::CenteredMatrices{m,n,𝔽}, p; kwargs...) where {m,n,𝔽} - mpv = invoke(check_point, Tuple{supertype(typeof(M)),typeof(p)}, M, p; kwargs...) - mpv === nothing || return mpv if !isapprox(sum(p, dims=1), zeros(1, n); kwargs...) return DomainError( p, @@ -46,19 +46,10 @@ end Check whether `X` is a tangent vector to manifold point `p` on the [`CenteredMatrices`](@ref) `M`, i.e. that `X` is a matrix of size `(m, n)` whose columns -sum to zero and its values are from the correct [`AbstractNumbers`](@ref). +sum to zero and its values are from the correct [`AbstractNumbers`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#number-system). The tolerance for the column sums of `p` and `X` can be set using `kwargs...`. """ function check_vector(M::CenteredMatrices{m,n,𝔽}, p, X; kwargs...) where {m,n,𝔽} - mpv = invoke( - check_vector, - Tuple{supertype(typeof(M)),typeof(p),typeof(X)}, - M, - p, - X; - kwargs..., - ) - mpv === nothing || return mpv if !isapprox(sum(X, dims=1), zeros(1, n); kwargs...) return DomainError( X, @@ -68,7 +59,10 @@ function check_vector(M::CenteredMatrices{m,n,𝔽}, p, X; kwargs...) where {m,n return nothing end -decorated_manifold(M::CenteredMatrices{m,n,𝔽}) where {m,n,𝔽} = Euclidean(m, n; field=𝔽) +embed(::CenteredMatrices, p) = p +embed(::CenteredMatrices, p, X) = X + +get_embedding(::CenteredMatrices{m,n,𝔽}) where {m,n,𝔽} = Euclidean(m, n; field=𝔽) @doc raw""" manifold_dimension(M::CenteredMatrices{m,n,𝔽}) @@ -79,7 +73,7 @@ Return the manifold dimension of the [`CenteredMatrices`](@ref) `m`-by-`n` matri ````math \dim(\mathcal M) = (m*n - n) \dim_ℝ 𝔽, ```` -where $\dim_ℝ 𝔽$ is the [`real_dimension`](@ref) of `𝔽`. +where $\dim_ℝ 𝔽$ is the [`real_dimension`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.real_dimension-Tuple{ManifoldsBase.AbstractNumbers}) of `𝔽`. """ function manifold_dimension(::CenteredMatrices{m,n,𝔽}) where {m,n,𝔽} return (m * n - n) * real_dimension(𝔽) diff --git a/src/manifolds/CholeskySpace.jl b/src/manifolds/CholeskySpace.jl index de4f8ad50f..373b4a14cf 100644 --- a/src/manifolds/CholeskySpace.jl +++ b/src/manifolds/CholeskySpace.jl @@ -28,12 +28,8 @@ entries on the diagonal. The tolerance for the tests can be set using the `kwargs...`. """ function check_point(M::CholeskySpace, p; kwargs...) - if size(p) != representation_size(M) - return DomainError( - size(p), - "The point $(p) does not lie on $(M), since its size is not $(representation_size(M)).", - ) - end + cks = check_size(M, p) + cks === nothing || return cks if !isapprox(norm(strictlyUpperTriangular(p)), 0.0; kwargs...) return DomainError( norm(UpperTriangular(p) - Diagonal(p)), @@ -58,12 +54,6 @@ and a symmetric matrix. The tolerance for the tests can be set using the `kwargs...`. """ function check_vector(M::CholeskySpace, p, X; kwargs...) - if size(X) != representation_size(M) - return DomainError( - size(X), - "The vector $(X) is not a tangent to a point on $(M) since its size does not match $(representation_size(M)).", - ) - end if !isapprox(norm(strictlyUpperTriangular(X)), 0.0; kwargs...) return DomainError( norm(UpperTriangular(X) - Diagonal(X)), @@ -188,7 +178,7 @@ strictlyLowerTriangular(p) = LowerTriangular(p) - Diagonal(diag(p)) strictlyUpperTriangular(p) = UpperTriangular(p) - Diagonal(diag(p)) @doc raw""" - vector_transport_to(M::CholeskySpace, p, X, q, ::ParallelTransport) + parallel_transport_to(M::CholeskySpace, p, X, q) Parallely transport the tangent vector `X` at `p` along the geodesic to `q` on the [`CholeskySpace`](@ref) manifold `M`. The formula reads @@ -201,9 +191,9 @@ on the [`CholeskySpace`](@ref) manifold `M`. The formula reads where $⌊\cdotβŒ‹$ denotes the strictly lower triangular matrix, and $\operatorname{diag}$ extracts the diagonal matrix. """ -vector_transport_to(::CholeskySpace, ::Any, ::Any, ::Any, ::ParallelTransport) +parallel_transport_to(::CholeskySpace, ::Any, ::Any, ::Any) -function vector_transport_to!(::CholeskySpace, Y, p, X, q, ::ParallelTransport) +function parallel_transport_to!(::CholeskySpace, Y, p, X, q) return copyto!(Y, strictlyLowerTriangular(p) + Diagonal(diag(q) .* diag(X) ./ diag(p))) end diff --git a/src/manifolds/Circle.jl b/src/manifolds/Circle.jl index 1adb8d1a45..aa821a7da0 100644 --- a/src/manifolds/Circle.jl +++ b/src/manifolds/Circle.jl @@ -9,7 +9,7 @@ $\lvert z\rvert = 1$. Circle(𝔽=ℝ) Generate the `ℝ`-valued Circle represented by angles, which -alternatively can be set to use the [`AbstractNumbers`](@ref) `𝔽=β„‚` to obtain the circle +alternatively can be set to use the [`AbstractNumbers`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#number-system) `𝔽=β„‚` to obtain the circle represented by `β„‚`-valued circle of unit numbers. """ struct Circle{𝔽} <: AbstractManifold{𝔽} end @@ -43,6 +43,22 @@ function check_point(M::Circle{β„‚}, p; kwargs...) end return nothing end +check_size(::Circle, ::Number) = nothing +function check_size(M::Circle, p) + (size(p) == ()) && return nothing + return DomainError( + size(p), + "The point $p can not belong to the $M, since it is not a number nor a vector of size (1,).", + ) +end +check_size(::Circle, ::Number, ::Number) = nothing +function check_size(M::Circle, p, X) + (size(X) == ()) && return nothing + return DomainError( + size(X), + "The vector $X is not a tangent vector to $p on $M, since it is not a number nor a vector of size (1,).", + ) +end """ check_vector(M::Circle, p, X; kwargs...) @@ -115,118 +131,65 @@ function exp!(M::Circle{β„‚}, q, p, X) return q end -function get_basis(::Circle{ℝ}, p, B::DiagonalizingOrthonormalBasis) +function get_basis_diagonalizing(::Circle{ℝ}, p, B::DiagonalizingOrthonormalBasis) sbv = sign(B.frame_direction[]) vs = @SVector [@SVector [sbv == 0 ? one(sbv) : sbv]] return CachedBasis(B, (@SVector [0]), vs) end -get_coordinates(::Circle{ℝ}, p, X, ::AbstractBasis{<:Any,TangentSpaceType}) = X -get_coordinates(::Circle{ℝ}, p, X, ::DefaultOrthonormalBasis{<:Any,TangentSpaceType}) = X -function get_coordinates(M::Circle{ℝ}, p, X, B::DiagonalizingOrthonormalBasis) +get_coordinates_orthonormal(::Circle{ℝ}, p, X, ::RealNumbers) = @SVector [X[]] +get_coordinates_orthonormal!(::Circle{ℝ}, c, p, X, ::RealNumbers) = (c .= X) +function get_coordinates_diagonalizing(::Circle{ℝ}, p, X, B::DiagonalizingOrthonormalBasis) sbv = sign(B.frame_direction[]) return X .* (sbv == 0 ? one(sbv) : sbv) end +function get_coordinates_diagonalizing!( + M::Circle{ℝ}, + Y, + p, + X, + B::DiagonalizingOrthonormalBasis, +) + Y[] = get_coordinates_diagonalizing(M, p, X, B)[] + return Y +end + """ get_coordinates(M::Circle{β„‚}, p, X, B::DefaultOrthonormalBasis) Return tangent vector coordinates in the Lie algebra of the [`Circle`](@ref). """ -function get_coordinates( - ::Circle{β„‚}, - p, - X, - ::DefaultOrthonormalBasis{<:Any,TangentSpaceType}, -) +get_coordinates(::Circle{β„‚}, p, X, ::DefaultOrthonormalBasis{<:Any,TangentSpaceType}) +function get_coordinates_orthonormal!(M::Circle{β„‚}, Y, p, X, n::RealNumbers) + Y[] = get_coordinates_orthonormal(M, p, X, n)[] + return Y +end +function get_coordinates_orthonormal(::Circle{β„‚}, p, X, ::RealNumbers) X, p = X[1], p[1] Xⁱ = imag(X) * real(p) - real(X) * imag(p) return @SVector [Xⁱ] end -function get_coordinates!( - M::Circle, - Y::AbstractArray, - p, - X, - B::DefaultOrthonormalBasis{<:Any,TangentSpaceType}, -) - Y[] = get_coordinates(M, p, X, B)[] - return Y -end -function get_coordinates!( - M::Circle, - Y::AbstractArray, - p, - X, - B::DiagonalizingOrthonormalBasis, -) - Y[] = get_coordinates(M, p, X, B)[] - return Y -end - -eval( - quote - @invoke_maker 1 AbstractManifold get_coordinates!( - M::Circle, - Y::AbstractArray, - p, - X, - B::VeeOrthogonalBasis, - ) - end, -) - -get_vector(::Circle{ℝ}, p, X, ::AbstractBasis{ℝ,TangentSpaceType}) = X -get_vector(::Circle{ℝ}, p, X, ::DefaultOrthonormalBasis{ℝ,TangentSpaceType}) = X -function get_vector(::Circle{ℝ}, p, X, B::DiagonalizingOrthonormalBasis) +get_vector_orthonormal(::Circle{ℝ}, p, c, ::RealNumbers) = Scalar(c[]) +# the method below is required for FD and AD differentiation in ManifoldDiff.jl +# if changed, make sure no tests in that repository get broken +get_vector_orthonormal(::Circle{ℝ}, p::AbstractVector, c, ::RealNumbers) = c +get_vector_orthonormal!(::Circle{ℝ}, X, p, c, ::RealNumbers) = (X .= c[]) +function get_vector_diagonalizing(::Circle{ℝ}, p, c, B::DiagonalizingOrthonormalBasis) sbv = sign(B.frame_direction[]) - return X .* (sbv == 0 ? one(sbv) : sbv) + return c .* (sbv == 0 ? one(sbv) : sbv) end """ get_vector(M::Circle{β„‚}, p, X, B::DefaultOrthonormalBasis) Return tangent vector from the coordinates in the Lie algebra of the [`Circle`](@ref). """ -function get_vector(::Circle{β„‚}, p, X, ::AbstractBasis{<:Any,TangentSpaceType}) - @SVector [1im * X[1] * p[1]] +function get_vector_orthonormal(::Circle{β„‚}, p, c, ::RealNumbers) + @SArray fill(1im * c[1] * p[1]) end -eval( - quote - @invoke_maker 4 AbstractBasis get_vector( - M::Circle{β„‚}, - p, - X, - B::DefaultOrthonormalBasis{<:Any,TangentSpaceType}, - ) - end, -) - -for BT in [AbstractBasis{<:Any,TangentSpaceType}] - eval(quote - function get_vector!(::Circle{ℝ}, Y::AbstractArray, p, X, ::$BT) - Y[] = X[] - return Y - end - end) - eval(quote - function get_vector!(::Circle{β„‚}, Y::AbstractArray, p, X, ::$BT) - Y[] = 1im * X[1] * p[1] - return Y - end - end) -end -for BT in ManifoldsBase.DISAMBIGUATION_BASIS_TYPES, CT in [Circle, Circle{ℝ}, Circle{β„‚}] - eval( - quote - @invoke_maker 5 $(supertype(BT)) get_vector!( - M::$CT, - Y::AbstractArray, - p, - X, - B::$BT, - ) - end, - ) +function get_vector_orthonormal!(::Circle{β„‚}, X, p, c, ::RealNumbers) + X .= 1im * c[1] * p[1] + return X end @doc raw""" @@ -235,17 +198,6 @@ end Return the injectivity radius on the [`Circle`](@ref) `M`, i.e. $Ο€$. """ injectivity_radius(::Circle) = Ο€ -injectivity_radius(::Circle, ::ExponentialRetraction) = Ο€ -injectivity_radius(::Circle, ::Any) = Ο€ -injectivity_radius(::Circle, ::Any, ::ExponentialRetraction) = Ο€ -eval( - quote - @invoke_maker 1 AbstractManifold injectivity_radius( - M::Circle, - rm::AbstractRetractionMethod, - ) - end, -) @doc raw""" inner(M::Circle, p, X, Y) @@ -271,13 +223,6 @@ inner(::Circle, ::Any...) @inline inner(::Circle{ℝ}, p::Real, X::Real, Y::Real) = X * Y @inline inner(::Circle{β„‚}, p, X, Y) = complex_dot(X, Y) -function inverse_retract(M::Circle, p::Number, q::Number) - return inverse_retract(M, p, q, LogarithmicInverseRetraction()) -end -function inverse_retract(M::Circle, p::Number, q::Number, ::LogarithmicInverseRetraction) - return log(M, p, q) -end - @doc raw""" log(M::Circle, p, q) @@ -388,7 +333,7 @@ end mid_point(M::Circle{ℝ}, p1, p2) = exp(M, p1, 0.5 * log(M, p1, p2)) mid_point(::Circle{β„‚}, p1::Complex, p2::Complex) = exp(im * (angle(p1) + angle(p2)) / 2) -mid_point(M::Circle{β„‚}, p1::StaticArray, p2::StaticArray) = SA[mid_point(M, p1[], p2[])] +mid_point(M::Circle{β„‚}, p1::StaticArray, p2::StaticArray) = Scalar(mid_point(M, p1[], p2[])) @inline LinearAlgebra.norm(::Circle, p, X) = sum(abs, X) @@ -439,14 +384,15 @@ function Random.rand(::Circle{ℝ}; vector_at=nothing, Οƒ::Real=1.0) if vector_at === nothing return sym_rem(rand() * 2 * Ο€) else - return Οƒ * randn() + # written like that to properly handle `vector_at` being a number or a one-element array + return map(_ -> Οƒ * randn(), vector_at) end end function Random.rand(rng::AbstractRNG, ::Circle{ℝ}; vector_at=nothing, Οƒ::Real=1.0) if vector_at === nothing return sym_rem(rand(rng) * 2 * Ο€) else - return Οƒ * randn(rng) + return map(_ -> Οƒ * randn(rng), vector_at) end end @@ -484,7 +430,7 @@ end sym_rem(x, T=Ο€) where {N} = map(sym_rem, x, Ref(T)) @doc raw""" - vector_transport_to(M::Circle, p, X, q, ::ParallelTransport) + parallel_transport_to(M::Circle, p, X, q) Compute the parallel transport of `X` from the tangent space at `p` to the tangent space at `q` on the [`Circle`](@ref) `M`. @@ -497,15 +443,10 @@ complex plane. ```` where [`log`](@ref) denotes the logarithmic map on `M`. """ -vector_transport_to(::Circle, ::Any, ::Any, ::Any, ::ParallelTransport) -vector_transport_to(::Circle{ℝ}, p::Real, X::Real, q::Real, ::ParallelTransport) = X -function vector_transport_to( - M::Circle{β„‚}, - p::Number, - X::Number, - q::Number, - ::ParallelTransport, -) +parallel_transport_to(::Circle, ::Any, ::Any, ::Any) + +parallel_transport_to(::Circle{ℝ}, p::Real, X::Real, q::Real) = X +function parallel_transport_to(M::Circle{β„‚}, p::Number, X::Number, q::Number) X_pq = log(M, p, q) Xnorm = norm(M, p, X_pq) Y = X @@ -516,8 +457,8 @@ function vector_transport_to( return Y end -vector_transport_to!(::Circle{ℝ}, Y, p, X, q, ::ParallelTransport) = (Y .= X) -function vector_transport_to!(M::Circle{β„‚}, Y, p, X, q, ::ParallelTransport) +parallel_transport_to!(::Circle{ℝ}, Y, p, X, q) = (Y .= X) +function parallel_transport_to!(M::Circle{β„‚}, Y, p, X, q) X_pq = log(M, p, q) Xnorm = norm(M, p, X_pq) Y .= X @@ -528,16 +469,5 @@ function vector_transport_to!(M::Circle{β„‚}, Y, p, X, q, ::ParallelTransport) return Y end -function vector_transport_direction( - M::Circle, - p::Number, - X::Number, - Y::Number, - m::AbstractVectorTransportMethod, -) - q = exp(M, p, Y) - return vector_transport_to(M, p, X, q, m) -end - -zero_vector(::Circle, p::Number) = zero(p) +zero_vector(::Circle, p::T) where {T<:Number} = zero(p) zero_vector!(::Circle, X, p) = fill!(X, 0) diff --git a/src/manifolds/ConnectionManifold.jl b/src/manifolds/ConnectionManifold.jl index b24302e7c1..0db43a236b 100644 --- a/src/manifolds/ConnectionManifold.jl +++ b/src/manifolds/ConnectionManifold.jl @@ -12,48 +12,98 @@ The [Levi-Civita connection](https://en.wikipedia.org/wiki/Levi-Civita_connectio """ struct LeviCivitaConnection <: AbstractAffineConnection end -struct MetricDecoratorType <: AbstractDecoratorType end - """ - AbstractConnectionManifold{𝔽,M<:AbstractManifold{𝔽},G<:AbstractAffineConnection} <: AbstractDecoratorManifold{𝔽} - -Equip an [`AbstractManifold`](@ref) explicitly with an [`AbstractAffineConnection`](@ref) `G`. - -`AbstractConnectionManifold` is defined by values of [`christoffel_symbols_second`](@ref), -which is used for a default implementation of [`exp`](@ref), [`log`](@ref) and -[`vector_transport_to`](@ref). Closed-form formulae for particular connection manifolds may -be explicitly implemented when available. - -An overview of basic properties of affine connection manifolds can be found in [^Pennec2020]. + IsConnectionManifold <: AbstractTrait -[^Pennec2020]: - > X. Pennec and M. Lorenzi, β€œ5 - Beyond Riemannian geometry: The affine connection - > setting for transformation groups,” in Riemannian Geometric Statistics in Medical Image - > Analysis, X. Pennec, S. Sommer, and T. Fletcher, Eds. Academic Press, 2020, pp. 169–229. - > doi: 10.1016/B978-0-12-814725-2.00012-1. +Specify that a certain decorated Manifold is a connection manifold in the sence that it provides +explicit connection properties, extending/changing the default connection properties of a manifold. """ -abstract type AbstractConnectionManifold{𝔽} <: - AbstractDecoratorManifold{𝔽,MetricDecoratorType} end +struct IsConnectionManifold <: AbstractTrait end """ - connection(M::AbstractManifold) + IsDefaultConnection{G<:AbstractAffineConnection} -Get the connection (an object of a subtype of [`AbstractAffineConnection`](@ref)) -of [`AbstractManifold`](@ref) `M`. +Specify that a certain [`AbstractAffineConnection`](@ref) is the default connection for a manifold. +This way the corresponding [`ConnectionManifold`](@ref) falls back to the default methods +of the manifold it decorates. """ -connection(::AbstractManifold) +struct IsDefaultConnection{C<:AbstractAffineConnection} <: AbstractTrait + connection::C +end +parent_trait(::IsDefaultConnection) = IsConnectionManifold() """ + ConnectionManifold{𝔽,,M<:AbstractManifold{𝔽},G<:AbstractAffineConnection} <: AbstractDecoratorManifold{𝔽} + +# Constructor + ConnectionManifold(M, C) -Decorate the [`AbstractManifold`](@ref) `M` with [`AbstractAffineConnection`](@ref) `C`. +Decorate the [`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) `M` with [`AbstractAffineConnection`](@ref) `C`. """ struct ConnectionManifold{𝔽,M<:AbstractManifold{𝔽},C<:AbstractAffineConnection} <: - AbstractConnectionManifold{𝔽} + AbstractDecoratorManifold{𝔽} manifold::M connection::C end +function Base.filter(f, t::TraitList) + if f(t.head) + return merge_traits(t.head, filter(f, t.tail)) + else + return filter(f, t.tail) + end +end +Base.filter(f, t::EmptyTrait) = t + +function active_traits(f, M::ConnectionManifold, args...) + return merge_traits( + is_default_connection(M.manifold, M.connection) ? + IsDefaultConnection(M.connection) : EmptyTrait(), + IsConnectionManifold(), + filter(x -> x isa IsGroupManifold, active_traits(f, M.manifold, args...)), + ) +end + +@doc raw""" + christoffel_symbols_first( + M::AbstractManifold, + p, + B::AbstractBasis; + backend::AbstractDiffBackend = default_differential_backend(), + ) + +Compute the Christoffel symbols of the first kind in local coordinates of basis `B`. +The Christoffel symbols are (in Einstein summation convention) + +````math +Ξ“_{ijk} = \frac{1}{2} \Bigl[g_{kj,i} + g_{ik,j} - g_{ij,k}\Bigr], +```` + +where ``g_{ij,k}=\frac{βˆ‚}{βˆ‚ p^k} g_{ij}`` is the coordinate +derivative of the local representation of the metric tensor. The dimensions of +the resulting multi-dimensional array are ordered ``(i,j,k)``. +""" +christoffel_symbols_first(::AbstractManifold, ::Any, B::AbstractBasis) +function christoffel_symbols_first( + M::AbstractManifold, + p, + B::AbstractBasis; + backend::AbstractDiffBackend=default_differential_backend(), +) + βˆ‚g = local_metric_jacobian(M, p, B; backend=backend) + n = size(βˆ‚g, 1) + Ξ“ = allocate(βˆ‚g, Size(n, n, n)) + @einsum Ξ“[i, j, k] = 1 / 2 * (βˆ‚g[k, j, i] + βˆ‚g[i, k, j] - βˆ‚g[i, j, k]) + return Ξ“ +end +@trait_function christoffel_symbols_first( + M::AbstractDecoratorManifold, + p, + B::AbstractBasis; + kwargs..., +) + @doc raw""" christoffel_symbols_second( M::AbstractManifold, @@ -75,9 +125,20 @@ where ``Ξ“_{ijk}`` are the Christoffel symbols of the first kind representation of the metric tensor. The dimensions of the resulting multi-dimensional array are ordered ``(l,i,j)``. """ -christoffel_symbols_second(::AbstractManifold, ::Any, ::AbstractBasis) +function christoffel_symbols_second( + M::AbstractManifold, + p, + B::AbstractBasis; + backend::AbstractDiffBackend=default_differential_backend(), +) + Ginv = inverse_local_metric(M, p, B) + Γ₁ = christoffel_symbols_first(M, p, B; backend=backend) + Ξ“β‚‚ = allocate(Γ₁) + @einsum Ξ“β‚‚[l, i, j] = Ginv[k, l] * Γ₁[i, j, k] + return Ξ“β‚‚ +end -@decorator_transparent_signature christoffel_symbols_second( +@trait_function christoffel_symbols_second( M::AbstractDecoratorManifold, p, B::AbstractBasis; @@ -118,13 +179,21 @@ function christoffel_symbols_second_jacobian( ) return βˆ‚Ξ“ end -@decorator_transparent_signature christoffel_symbols_second_jacobian( +@trait_function christoffel_symbols_second_jacobian( M::AbstractDecoratorManifold, p, B::AbstractBasis; kwargs..., ) +""" + connection(M::AbstractManifold) + +Get the connection (an object of a subtype of [`AbstractAffineConnection`](@ref)) +of [`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) `M`. +""" +connection(::AbstractManifold) + """ connection(M::ConnectionManifold) @@ -132,26 +201,27 @@ Return the connection associated with [`ConnectionManifold`](@ref) `M`. """ connection(M::ConnectionManifold) = M.connection -Base.copyto!(M::AbstractConnectionManifold, q, p) = copyto!(M.manifold, q, p) -Base.copyto!(M::AbstractConnectionManifold, Y, p, X) = copyto!(M.manifold, Y, p, X) +decorated_manifold(M::ConnectionManifold) = M.manifold + +default_retraction_method(M::ConnectionManifold) = default_retraction_method(M.manifold) @doc raw""" - exp(M::AbstractConnectionManifold, p, X) + exp(::TraitList{IsConnectionManifold}, M::AbstractDecoratorManifold, p, X) -Compute the exponential map on the [`AbstractConnectionManifold`](@ref) `M` equipped with +Compute the exponential map on a manifold that [`IsConnectionManifold`](@ref) `M` equipped with corresponding affine connection. -If `M` is a [`MetricManifold`](@ref) with a metric that was declared the default metric -using [`is_default_metric`](@ref), this method falls back to `exp(M, p, X)`. +If `M` is a [`MetricManifold`](@ref) with a [`IsDefaultMetric`](@ref) trait, +this method falls back to `exp(M, p, X)`. Otherwise it numerically integrates the underlying ODE, see [`solve_exp_ode`](@ref). Currently, the numerical integration is only accurate when using a single coordinate chart that covers the entire manifold. This excludes coordinates in an embedded space. """ -exp(::AbstractConnectionManifold, ::Any...) +exp(::TraitList{IsConnectionManifold}, M::AbstractDecoratorManifold, p, X) -@decorator_transparent_fallback function exp!(M::AbstractConnectionManifold, q, p, X) +function exp!(::TraitList{IsConnectionManifold}, M::AbstractDecoratorManifold, q, p, X) return retract!( M, q, @@ -171,32 +241,48 @@ gaussian_curvature(::AbstractManifold, ::Any, ::AbstractBasis) function gaussian_curvature(M::AbstractManifold, p, B::AbstractBasis; kwargs...) return ricci_curvature(M, p, B; kwargs...) / 2 end -@decorator_transparent_signature gaussian_curvature( +@trait_function gaussian_curvature( M::AbstractDecoratorManifold, p, B::AbstractBasis; kwargs..., ) -function injectivity_radius(M::AbstractConnectionManifold, p) - return injectivity_radius(base_manifold(M), p) -end -function injectivity_radius(M::AbstractConnectionManifold, m::AbstractRetractionMethod) - return injectivity_radius(base_manifold(M), m) -end -function injectivity_radius(M::AbstractConnectionManifold, m::ExponentialRetraction) - return injectivity_radius(base_manifold(M), m) -end -function injectivity_radius(M::AbstractConnectionManifold, p, m::AbstractRetractionMethod) - return injectivity_radius(base_manifold(M), p, m) +""" + is_default_connection(M::AbstractManifold, G::AbstractAffineConnection) + +returns whether an [`AbstractAffineConnection`](@ref) is the default metric on the manifold `M` or not. +This can be set by defining this function, or setting the [`IsDefaultConnection`](@ref) trait for an +[`AbstractDecoratorManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/decorator.html#ManifoldsBase.AbstractDecoratorManifold). +""" +is_default_connection(M::AbstractManifold, G::AbstractAffineConnection) +@trait_function is_default_connection( + M::AbstractDecoratorManifold, + G::AbstractAffineConnection, +) +function is_default_connection( + ::TraitList{IsDefaultConnection{C}}, + ::AbstractDecoratorManifold, + ::C, +) where {C<:AbstractAffineConnection} + return true end -function injectivity_radius(M::AbstractConnectionManifold, p, m::ExponentialRetraction) - return injectivity_radius(base_manifold(M), p, m) +function is_default_connection(M::ConnectionManifold) + return is_default_connection(M.manifold, M.connection) end +is_default_connection(::AbstractManifold, ::AbstractAffineConnection) = false -function retract!(M::AbstractDecoratorManifold, q, p, X, r::ODEExponentialRetraction) - sol = solve_exp_ode(M, p, X; basis=r.basis, dense=false) - return copyto!(q, sol) +function retract_exp_ode!( + M::AbstractManifold, + q, + p, + X, + ::AbstractRetractionMethod, + b::AbstractBasis, +) + sol = solve_exp_ode(M, p, X; basis=b, dense=false) + copyto!(q, sol) + return q end """ @@ -214,12 +300,7 @@ function ricci_tensor(M::AbstractManifold, p, B::AbstractBasis; kwargs...) @einsum Ric[i, j] = R[l, i, l, j] return Ric end -@decorator_transparent_signature ricci_tensor( - M::AbstractDecoratorManifold, - p, - B::AbstractBasis; - kwargs..., -) +@trait_function ricci_tensor(M::AbstractDecoratorManifold, p, B::AbstractBasis; kwargs...) @doc raw""" riemann_tensor(M::AbstractManifold, p, B::AbstractBasis; backend::AbstractDiffBackend=default_differential_backend()) @@ -251,12 +332,7 @@ function riemann_tensor( βˆ‚Ξ“[l, i, k, j] - βˆ‚Ξ“[l, i, j, k] + Ξ“[s, i, k] * Ξ“[l, s, j] - Ξ“[s, i, j] * Ξ“[l, s, k] return R end -@decorator_transparent_signature riemann_tensor( - M::AbstractDecoratorManifold, - p, - B::AbstractBasis; - kwargs..., -) +@trait_function riemann_tensor(M::AbstractDecoratorManifold, p, B::AbstractBasis; kwargs...) @doc raw""" solve_exp_ode( @@ -294,105 +370,12 @@ in an embedded space. ``` """ function solve_exp_ode(M, p, X; kwargs...) - return error( - """ - solve_exp_ode not implemented on $(typeof(M)) for point $(typeof(p)), vector $(typeof(X)). - For a suitable default, enter `using OrdinaryDiffEq`. - """, - ) -end - -# -# Introduce transparency -# (a) new functions & other parents -for f in [ - christoffel_symbols_second_jacobian, - exp, - gaussian_curvature, - get_coordinates, - get_vector, - log, - mean, - median, - project, - ricci_tensor, - riemann_tensor, - vector_transport_along, - vector_transport_direction, - vector_transport_direction!, #since it has a default using _to! - vector_transport_to, -] - eval( - quote - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractConnectionManifold, - args..., - ) - return Val(:parent) - end - end, + throw( + ErrorException( + """ + solve_exp_ode not implemented on $(typeof(M)) for point $(typeof(p)), vector $(typeof(X)). + For a suitable default, enter `using OrdinaryDiffEq`. + """, + ), ) end - -# (b) changes / intransparencies. -for f in [ - christoffel_symbols_second, # this is basic for connection manifolds but not for metric manifolds - exp!, - get_coordinates!, - get_vector!, - get_basis, - inner, - inverse_retract!, - log!, - mean!, - median!, - norm, - project!, - retract!, - vector_transport_along!, - vector_transport_to!, -] - eval( - quote - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractConnectionManifold, - args..., - ) - return Val(:intransparent) - end - end, - ) -end -# (c) special cases -function decorator_transparent_dispatch( - ::typeof(exp!), - M::AbstractConnectionManifold, - q, - p, - X, - t, -) - return Val(:parent) -end -function decorator_transparent_dispatch( - ::typeof(inverse_retract!), - M::AbstractConnectionManifold, - X, - p, - q, - m::LogarithmicInverseRetraction, -) - return Val(:parent) -end -function decorator_transparent_dispatch( - ::typeof(retract!), - M::AbstractConnectionManifold, - q, - p, - X, - m::ExponentialRetraction, -) - return Val(:parent) -end diff --git a/src/manifolds/Elliptope.jl b/src/manifolds/Elliptope.jl index 981827d379..21514e9cff 100644 --- a/src/manifolds/Elliptope.jl +++ b/src/manifolds/Elliptope.jl @@ -1,5 +1,5 @@ @doc raw""" - Elliptope{N,K} <: AbstractEmbeddedManifold{ℝ,DefaultIsometricEmbeddingType} + Elliptope{N,K} <: AbstractDecoratorManifold{ℝ} The Elliptope manifold, also known as the set of correlation matrices, consists of all symmetric positive semidefinite matrices of rank $k$ with unit diagonal, i.e., @@ -46,7 +46,9 @@ generates the manifold $\mathcal E(n,k) \subset ℝ^{n Γ— n}$. > doi: [10.1137/080731359](https://doi.org/10.1137/080731359), > arXiv: [0807.4423](http://arxiv.org/abs/0807.4423). """ -struct Elliptope{N,K} <: AbstractEmbeddedManifold{ℝ,DefaultIsometricEmbeddingType} end +struct Elliptope{N,K} <: AbstractDecoratorManifold{ℝ} end + +active_traits(f, ::Elliptope, args...) = merge_traits(IsEmbeddedManifold()) Elliptope(n::Int, k::Int) = Elliptope{n,k}() @@ -61,8 +63,6 @@ Since $p$ is by construction positive semidefinite, this is not checked. The tolerances for positive semidefiniteness and unit trace can be set using the `kwargs...`. """ function check_point(M::Elliptope{N,K}, q; kwargs...) where {N,K} - mpv = invoke(check_point, Tuple{supertype(typeof(M)),typeof(q)}, M, q; kwargs...) - mpv === nothing || return mpv row_norms_sq = sum(abs2, q; dims=2) if !all(isapprox.(row_norms_sq, 1.0; kwargs...)) return DomainError( @@ -85,15 +85,6 @@ The tolerance for the base point check and zero diagonal can be set using the `k Note that symmetric of $X$ holds by construction an is not explicitly checked. """ function check_vector(M::Elliptope{N,K}, q, Y; kwargs...) where {N,K} - mpv = invoke( - check_vector, - Tuple{supertype(typeof(M)),typeof(q),typeof(Y)}, - M, - q, - Y; - kwargs..., - ) - mpv === nothing || return mpv X = q * Y' + Y * q' n = diag(X) if !all(isapprox.(n, 0.0; kwargs...)) @@ -105,9 +96,7 @@ function check_vector(M::Elliptope{N,K}, q, Y; kwargs...) where {N,K} return nothing end -function decorated_manifold(M::Elliptope) - return Euclidean(representation_size(M)...; field=ℝ) -end +get_embedding(M::Elliptope) = Euclidean(representation_size(M)...; field=ℝ) @doc raw""" manifold_dimension(M::Elliptope) @@ -130,6 +119,7 @@ project `q` onto the manifold [`Elliptope`](@ref) `M`, by normalizing the rows o project(::Elliptope, ::Any) project!(::Elliptope, r, q) = copyto!(r, q ./ (sqrt.(sum(abs2, q, dims=2)))) +project(::Elliptope, q) = q ./ (sqrt.(sum(abs2, q, dims=2))) """ project(M::Elliptope, q, Y) @@ -151,7 +141,7 @@ compute a projection based retraction by projecting $q+Y$ back onto the manifold """ retract(::Elliptope, ::Any, ::Any, ::ProjectionRetraction) -retract!(M::Elliptope, r, q, Y, ::ProjectionRetraction) = project!(M, r, q + Y) +retract_project!(M::Elliptope, r, q, Y) = project!(M, r, q + Y) @doc raw""" representation_size(M::Elliptope) @@ -173,11 +163,7 @@ transport the tangent vector `X` at `p` to `q` by projecting it onto the tangent at `q`. """ vector_transport_to(::Elliptope, ::Any, ::Any, ::Any, ::ProjectionTransport) - -function vector_transport_to!(M::Elliptope, Y, p, X, q, ::ProjectionTransport) - project!(M, Y, q, X) - return Y -end +vector_transport_to_project!(M::Elliptope, Y, p, X, q) = project!(M, Y, q, X) @doc raw""" zero_vector(M::Elliptope,p) diff --git a/src/manifolds/EssentialManifold.jl b/src/manifolds/EssentialManifold.jl index e2d3ea491e..cb040784e7 100644 --- a/src/manifolds/EssentialManifold.jl +++ b/src/manifolds/EssentialManifold.jl @@ -69,12 +69,6 @@ Check whether the matrix is a valid point on the [`EssentialManifold`](@ref) `M` i.e. a 2-element array containing SO(3) matrices. """ function check_point(M::EssentialManifold, p; kwargs...) - if length(p) != 2 - return DomainError( - length(p), - "The point $(p) does not lie on $M, since it does not contain exactly two elements.", - ) - end return check_point( PowerManifold(M.manifold, NestedPowerRepresentation(), 2), p; @@ -89,12 +83,6 @@ Check whether `X` is a tangent vector to manifold point `p` on the [`EssentialMa i.e. `X` has to be a 2-element array of `3`-by-`3` skew-symmetric matrices. """ function check_vector(M::EssentialManifold, p, X; kwargs...) - if length(X) != 2 - return DomainError( - length(X), - "$(X) is not a tangent vector to the manifold $M, since it does not contain exactly two elements.", - ) - end return check_vector( PowerManifold(M.manifold, NestedPowerRepresentation(), 2), p, @@ -138,11 +126,6 @@ where $\tilde X$ is the horizontal lift of $X$[^TronDaniilidis2017]. """ exp(::EssentialManifold, ::Any...) -function exp!(M::EssentialManifold, q, p, X) - exp!.(Ref(M.manifold), q, p, X) - return q -end - get_iterator(::EssentialManifold) = Base.OneTo(2) function isapprox(M::EssentialManifold, p, q; kwargs...) @@ -416,6 +399,10 @@ function manifold_dimension(::EssentialManifold) return 5 end +function power_dimensions(::EssentialManifold) + return (2,) +end + @doc raw""" project(M::EssentialManifold, p, X) @@ -453,36 +440,27 @@ function Base.show(io::IO, M::EssentialManifold) return print(io, "EssentialManifold($(M.is_signed))") end -function vector_transport_direction(M::EssentialManifold, p, X, d) - return vector_transport_direction(M, p, X, d, ParallelTransport()) -end - -function vector_transport_direction!(M::EssentialManifold, Y, p, X, d) - return vector_transport_direction!(M, Y, p, X, d, ParallelTransport()) +function parallel_transport_direction(M::EssentialManifold, p, X, d) + return parallel_transport_to(M, p, X, exp(M, p, d)) end - -function vector_transport_direction!(M::EssentialManifold, Y, p, X, d, m::ParallelTransport) - y = exp(M, p, d) - return vector_transport_to!(M, Y, p, X, y, m) +function parallel_transport_direction!(M::EssentialManifold, Y, p, X, d) + parallel_transport_to!(M, Y, p, X, exp(M, p, d)) + return Y end @doc raw""" - vector_transport_to(M::EssentialManifold, p, X, q, method::ParallelTransport) + parallel_transport_to(M::EssentialManifold, p, X, q) Compute the vector transport of the tangent vector `X` at `p` to `q` on the [`EssentialManifold`](@ref) `M` using left translation of the ambient group. """ -vector_transport_to(::EssentialManifold, ::Any, ::Any, ::Any, ::ParallelTransport) - -function vector_transport_to(M::EssentialManifold, p, X, q) - return vector_transport_to(M, p, X, q, ParallelTransport()) -end - -function vector_transport_to!(M::EssentialManifold, Y, p, X, q) - return vector_transport_to!(M, Y, p, X, q, ParallelTransport()) +function parallel_transport_to(::EssentialManifold, p, X, q) + # group operation in the ambient group + pq = [qe' * pe for (pe, qe) in zip(p, q)] + # left translation + return [pqe * Xe * pqe' for (pqe, Xe) in zip(pq, X)] end - -function vector_transport_to!(::EssentialManifold, Y, p, X, q, ::ParallelTransport) +function parallel_transport_to!(::EssentialManifold, Y, p, X, q) # group operation in the ambient group pq = [qe' * pe for (pe, qe) in zip(p, q)] # left translation @@ -499,7 +477,6 @@ Project `X` onto the vertical space $T_{\text{vp}}\text{SO}(3)^2$ with ```` where $e_z$ is the third unit vector, $X_i ∈ T_{p}\text{SO}(3)$ for $i=1,2,$ and it holds $R_i = R_0 R'_i, i=1,2,$ where $R'_i$ is part of the pose of camera $i$ $g_i = (R_i,T'_i) ∈ \text{SE}(3)$ and $R_0 ∈ \text{SO}(3)$ such that $R_0(T'_2-T'_1) = e_z$ [^TronDaniilidis2017]. - """ function vert_proj(M::EssentialManifold, p, X) return sum(vert_proj.(Ref(M.manifold), p, X)) diff --git a/src/manifolds/Euclidean.jl b/src/manifolds/Euclidean.jl index a6237e34b0..ebaa72fef6 100644 --- a/src/manifolds/Euclidean.jl +++ b/src/manifolds/Euclidean.jl @@ -18,19 +18,26 @@ elements are interpreted as ``n_1 Γ— n_2 Γ— … Γ— n_i`` arrays. For ``i=2`` we obtain a matrix space. The default `field=ℝ` can also be set to `field=β„‚`. The dimension of this space is ``k \dim_ℝ 𝔽``, where ``\dim_ℝ 𝔽`` is the -[`real_dimension`](@ref) of the field ``𝔽``. +[`real_dimension`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.real_dimension-Tuple{ManifoldsBase.AbstractNumbers}) of the field ``𝔽``. Euclidean(; field=ℝ) Generate the 1D Euclidean manifold for an `ℝ`-, `β„‚`-valued real- or complex-valued immutable values (in contrast to 1-element arrays from the constructor above). """ -struct Euclidean{N,𝔽} <: AbstractManifold{𝔽} where {N<:Tuple} end +struct Euclidean{N,𝔽} <: AbstractDecoratorManifold{𝔽} where {N<:Tuple} end function Euclidean(n::Vararg{Int,I}; field::AbstractNumbers=ℝ) where {I} return Euclidean{Tuple{n...},field}() end +function active_traits(f, ::Euclidean, args...) + return merge_traits( + IsDefaultMetric(EuclideanMetric()), + IsDefaultConnection(LeviCivitaConnection()), + ) +end + Base.:^(𝔽::AbstractNumbers, n) = Euclidean(n...; field=𝔽) """ @@ -74,12 +81,6 @@ function check_point(M::Euclidean{N,𝔽}, p) where {N,𝔽} "The matrix $(p) is neither a real- nor complex-valued matrix, so it does not lie on $(M).", ) end - if size(p) != representation_size(M) - return DomainError( - size(p), - "The matrix $(p) does not lie on $(M), since its dimensions ($(size(p))) are wrong (expected: $(representation_size(M))).", - ) - end return nothing end @@ -96,12 +97,6 @@ function check_vector(M::Euclidean{N,𝔽}, p, X; kwargs...) where {N,𝔽} "The matrix $(X) is neither a real- nor complex-valued matrix, so it can not be a tangent vector to $(p) on $(M).", ) end - if size(X) != representation_size(M) - return DomainError( - size(X), - "The matrix $(X) does not lie in the tangent space of $(p) on $(M), since its dimensions $(size(X)) are wrong (expected: $(representation_size(M))).", - ) - end return nothing end @@ -138,9 +133,6 @@ Embed the tangent vector `X` at point `p` in `M`. Equivalent to an identity map. """ embed(::Euclidean, p, X) -embed!(::Euclidean, q, p) = copyto!(q, p) -embed!(::Euclidean, Y, p, X) = copyto!(Y, X) - function embed!( ::EmbeddedManifold{𝔽,Euclidean{nL,𝔽},Euclidean{mL,𝔽2}}, q, @@ -181,80 +173,111 @@ Base.exp(::Euclidean, p::Number, q::Number) = p + q exp!(::Euclidean, q, p, X) = (q .= p .+ X) -function get_basis(::Euclidean, p, B::DefaultOrthonormalBasis{ℝ,TangentSpaceType}) - vecs = [_euclidean_basis_vector(p, i) for i in eachindex(p)] - return CachedBasis(B, vecs) -end -function get_basis( - ::Euclidean{<:Tuple,β„‚}, +function get_basis_diagonalizing( + M::Euclidean, p, - B::DefaultOrthonormalBasis{β„‚,TangentSpaceType}, -) - vecs = [_euclidean_basis_vector(p, i) for i in eachindex(p)] - return CachedBasis(B, [vecs; im * vecs]) -end -function get_basis(M::Euclidean, p, B::DiagonalizingOrthonormalBasis) - vecs = get_vectors(M, p, get_basis(M, p, DefaultOrthonormalBasis())) + B::DiagonalizingOrthonormalBasis{𝔽}, +) where {𝔽} + vecs = get_vectors(M, p, get_basis(M, p, DefaultOrthonormalBasis(𝔽))) eigenvalues = zeros(real(eltype(p)), manifold_dimension(M)) return CachedBasis(B, DiagonalizingBasisData(B.frame_direction, eigenvalues, vecs)) end -function get_coordinates!( +function get_coordinates_orthonormal!(M::Euclidean, c, p, X, ::RealNumbers) + S = representation_size(M) + PS = prod(S) + copyto!(c, reshape(X, PS)) + return c +end + +function get_coordinates_induced_basis!( M::Euclidean, - Y, + c, p, X, - ::Union{ - DefaultOrDiagonalizingBasis{ℝ}, - InducedBasis{ℝ,TangentSpaceType,<:RetractionAtlas}, - }, + ::InducedBasis{ℝ,TangentSpaceType,<:RetractionAtlas}, ) S = representation_size(M) PS = prod(S) - copyto!(Y, reshape(X, PS)) - return Y + copyto!(c, reshape(X, PS)) + return c end -function get_coordinates!( + +function get_coordinates_orthonormal!( M::Euclidean{<:Tuple,β„‚}, - Y, + c, ::Any, X, - ::DefaultOrDiagonalizingBasis{β„‚}, + ::ComplexNumbers, ) S = representation_size(M) PS = prod(S) - Y .= [reshape(real.(X), PS)..., reshape(imag(X), PS)...] - return Y + c .= [reshape(real.(X), PS)..., reshape(imag(X), PS)...] + return c end -function get_vector!( - M::Euclidean, - Y, +function get_coordinates_diagonalizing!( + M::Euclidean{<:Tuple,β„‚}, + c, ::Any, X, - ::Union{ - DefaultOrDiagonalizingBasis{ℝ}, - InducedBasis{ℝ,TangentSpaceType,<:RetractionAtlas}, - }, + ::DiagonalizingOrthonormalBasis{β„‚}, ) S = representation_size(M) - copyto!(Y, reshape(X, S)) + PS = prod(S) + c .= [reshape(real.(X), PS)..., reshape(imag(X), PS)...] + return c +end +function get_coordinates_diagonalizing!( + M::Euclidean, + c, + p, + X, + ::DiagonalizingOrthonormalBasis{ℝ}, +) where {𝔽} + S = representation_size(M) + PS = prod(S) + copyto!(c, reshape(X, PS)) + return c +end + +function get_vector_orthonormal!(M::Euclidean, Y, ::Any, c, ::RealNumbers) + S = representation_size(M) + copyto!(Y, reshape(c, S)) return Y end -function get_vector!( - ::Euclidean, - Y::AbstractVector, +function get_vector_diagonalizing!( + M::Euclidean, + Y, ::Any, - X, - ::DefaultOrDiagonalizingBasis{ℝ}, + c, + B::DiagonalizingOrthonormalBasis, ) - copyto!(Y, X) + S = representation_size(M) + copyto!(Y, reshape(c, S)) return Y end -function get_vector!(M::Euclidean{<:Tuple,β„‚}, Y, ::Any, X, ::DefaultOrDiagonalizingBasis{β„‚}) +function get_vector_induced_basis!(M::Euclidean, Y, ::Any, c, B::InducedBasis) S = representation_size(M) - N = div(length(X), 2) - copyto!(Y, reshape(X[1:N] + im * X[(N + 1):end], S)) + copyto!(Y, reshape(c, S)) + return Y +end +function get_vector_orthonormal!(M::Euclidean{<:Tuple,β„‚}, Y, ::Any, c, ::ComplexNumbers) + S = representation_size(M) + N = div(length(c), 2) + copyto!(Y, reshape(c[1:N] + im * c[(N + 1):end], S)) + return Y +end +function get_vector_diagonalizing!( + M::Euclidean{<:Tuple,β„‚}, + Y, + ::Any, + c, + ::DiagonalizingOrthonormalBasis{β„‚}, +) + S = representation_size(M) + N = div(length(c), 2) + copyto!(Y, reshape(c[1:N] + im * c[(N + 1):end], S)) return Y end @doc raw""" @@ -313,8 +336,6 @@ function inverse_local_metric( return local_metric(M, p, B) end -default_metric_dispatch(::Euclidean, ::EuclideanMetric) = Val(true) - function local_metric( ::MetricManifold{𝔽,<:AbstractManifold,EuclideanMetric}, p, @@ -330,18 +351,6 @@ function local_metric( return Diagonal(ones(SVector{size(p, 1),eltype(p)})) end -function inverse_retract(M::Euclidean{Tuple{}}, x::T, y::T) where {T<:Number} - return inverse_retract(M, x, y, LogarithmicInverseRetraction()) -end -function inverse_retract( - M::Euclidean{Tuple{}}, - x::Number, - y::Number, - ::LogarithmicInverseRetraction, -) - return log(M, x, y) -end - @doc raw""" log(M::Euclidean, p, q) @@ -353,6 +362,7 @@ which in this case is just """ Base.log(::Euclidean, ::Any...) Base.log(::Euclidean{Tuple{}}, p::Number, q::Number) = q - p +Base.log(::Euclidean, p, q) = q .- p log!(::Euclidean, X, p, q) = (X .= q .- p) @@ -370,7 +380,7 @@ end manifold_dimension(M::Euclidean) Return the manifold dimension of the [`Euclidean`](@ref) `M`, i.e. -the product of all array dimensions and the [`real_dimension`](@ref) of the +the product of all array dimensions and the [`real_dimension`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.real_dimension-Tuple{ManifoldsBase.AbstractNumbers}) of the underlying number system. """ function manifold_dimension(M::Euclidean{N,𝔽}) where {N,𝔽} @@ -389,10 +399,6 @@ function Statistics.mean( end Statistics.mean(::Euclidean, x::AbstractVector; kwargs...) = mean(x) -function Statistics.mean!(M::Euclidean, p, x::AbstractVector, w::AbstractVector; kwargs...) - return mean!(M, p, x, w, GeodesicInterpolation(); kwargs...) -end - function StatsBase.mean_and_var( ::Euclidean{Tuple{}}, x::AbstractVector{<:Number}; @@ -411,14 +417,6 @@ function StatsBase.mean_and_var( m, v = mean_and_var(x, w; corrected=corrected, kwargs...) return m, sum(v) end -function StatsBase.mean_and_var( - M::Euclidean, - x::AbstractVector, - w::AbstractWeights; - kwargs..., -) - return mean_and_var(M, x, w, GeodesicInterpolation(); kwargs...) -end Statistics.median(::Euclidean{Tuple{}}, x::AbstractVector{<:Number}; kwargs...) = median(x) function Statistics.median( @@ -472,6 +470,30 @@ function project!( return q end +""" + parallel_transport_along(M::Euclidean, p, X, c) + +the parallel transport on [`Euclidean`](@ref) is the identiy, i.e. returns `X`. +""" +parallel_transport_along(::Euclidean, ::Any, X, c::AbstractVector) = X +parallel_transport_along!(::Euclidean, Y, ::Any, X, c::AbstractVector) = copyto!(Y, X) + +""" + parallel_transport_direction(M::Euclidean, p, X, d) + +the parallel transport on [`Euclidean`](@ref) is the identiy, i.e. returns `X`. +""" +parallel_transport_direction(::Euclidean, ::Any, X, ::Any) = X +parallel_transport_direction!(::Euclidean, Y, ::Any, X, ::Any) = copyto!(Y, X) + +""" + parallel_transport_to(M::Euclidean, p, X, q) + +the parallel transport on [`Euclidean`](@ref) is the identiy, i.e. returns `X`. +""" +parallel_transport_to(::Euclidean, ::Any, X, ::Any) = X +parallel_transport_to!(::Euclidean, Y, ::Any, X, ::Any) = copyto!(Y, X) + @doc raw""" project(M::Euclidean, p) @@ -529,18 +551,54 @@ end function Base.show(io::IO, ::Euclidean{N,𝔽}) where {N,𝔽} return print(io, "Euclidean($(join(N.parameters, ", ")); field = $(𝔽))") end - +# +# Vector Transport +# +# The following functions are defined on layer 1 already, since +# a) its independent of the transport or retraction method +# b) no amibuities occur +# c) Euclidean is so basic, that these are plain defaults +# +function vector_transport_along( + ::Euclidean, + ::Any, + X, + ::AbstractVector, + method::AbstractVectorTransportMethod, +) + return X +end +function vector_transport_along!( + M::Euclidean, + Y, + ::Any, + X, + ::AbstractVector, + ::AbstractVectorTransportMethod=default_vector_transport_method(M), +) + return copyto!(Y, X) +end function vector_transport_direction( - M::Euclidean{Tuple{}}, - p::Number, - X::Number, - Y::Number, - m::AbstractVectorTransportMethod, + M::Euclidean, + ::Any, + X, + ::Any, + ::AbstractVectorTransportMethod=default_vector_transport_method(M), + ::AbstractRetractionMethod=default_retraction_method(M), ) - q = exp(M, p, Y) - return vector_transport_to(M, p, X, q, m) + return X +end +function vector_transport_direction!( + M::Euclidean, + Y, + ::Any, + X, + ::Any, + ::AbstractVectorTransportMethod=default_vector_transport_method(M), + ::AbstractRetractionMethod=default_retraction_method(M), +) + return copyto!(Y, X) end - """ vector_transport_to(M::Euclidean, p, X, q, ::AbstractVectorTransportMethod) @@ -549,41 +607,28 @@ on the [`Euclidean`](@ref) `M`, which simplifies to the identity. """ vector_transport_to(::Euclidean, ::Any, ::Any, ::Any, ::AbstractVectorTransportMethod) function vector_transport_to( - ::Euclidean{Tuple{}}, - ::Number, - X::Number, - ::Number, - ::AbstractVectorTransportMethod, + M::Euclidean, + ::Any, + X, + ::Any, + ::AbstractVectorTransportMethod=default_vector_transport_method(M), + ::AbstractRetractionMethod=default_retraction_method(M), ) return X end function vector_transport_to!( - ::Euclidean, + M::Euclidean, Y, ::Any, X, ::Any, - ::AbstractVectorTransportMethod, + ::AbstractVectorTransportMethod=default_vector_transport_method(M), + ::AbstractRetractionMethod=default_retraction_method(M), ) return copyto!(Y, X) end -for VT in ManifoldsBase.VECTOR_TRANSPORT_DISAMBIGUATION - eval( - quote - @invoke_maker 6 AbstractVectorTransportMethod vector_transport_to!( - M::Euclidean, - Y, - p, - X, - q, - B::$VT, - ) - end, - ) -end - Statistics.var(::Euclidean, x::AbstractVector; kwargs...) = sum(var(x; kwargs...)) function Statistics.var(::Euclidean, x::AbstractVector{<:Number}, m::Number; kwargs...) return sum(var(x; mean=m, kwargs...)) diff --git a/src/manifolds/FixedRankMatrices.jl b/src/manifolds/FixedRankMatrices.jl index 276b41f51c..afcc7664f2 100644 --- a/src/manifolds/FixedRankMatrices.jl +++ b/src/manifolds/FixedRankMatrices.jl @@ -1,5 +1,5 @@ @doc raw""" - FixedRankMatrices{m,n,k,𝔽} <: AbstractEmbeddedManifold{𝔽,DefaultEmbeddingType} + FixedRankMatrices{m,n,k,𝔽} <: AbstractDecoratorManifold{𝔽} The manifold of ``m Γ— n`` real-valued or complex-valued matrices of fixed rank ``k``, i.e. ````math @@ -42,11 +42,13 @@ Generate the manifold of `m`-by-`n` (`field`-valued) matrices of rank `k`. > doi: [10.1137/110845768](https://doi.org/10.1137/110845768), > arXiv: [1209.3834](https://arxiv.org/abs/1209.3834). """ -struct FixedRankMatrices{M,N,K,𝔽} <: AbstractEmbeddedManifold{𝔽,DefaultEmbeddingType} end +struct FixedRankMatrices{M,N,K,𝔽} <: AbstractDecoratorManifold{𝔽} end function FixedRankMatrices(m::Int, n::Int, k::Int, field::AbstractNumbers=ℝ) return FixedRankMatrices{m,n,k,field}() end +active_traits(f, ::FixedRankMatrices, args...) = merge_traits(IsEmbeddedManifold()) + @doc raw""" SVDMPoint <: AbstractManifoldPoint @@ -119,10 +121,6 @@ function allocate(X::UMVTVector, ::Type{T}) where {T} return UMVTVector(allocate(X.U, T), allocate(X.M, T), allocate(X.Vt, T)) end -function allocate_result(::FixedRankMatrices{m,n,k}, ::typeof(embed), vals...) where {m,n,k} - #note that vals is (p,) or (X,p) but both first entries have a U of correct type - return similar(typeof(vals[1].U), m, n) -end function allocate_result( ::FixedRankMatrices{m,n,k}, ::typeof(project), @@ -201,40 +199,54 @@ shape, i.e. `p.U` and `p.Vt` have to be unitary. The keyword arguments are passe function check_point(M::FixedRankMatrices{m,n,k}, p; kwargs...) where {m,n,k} r = rank(p; kwargs...) s = "The point $(p) does not lie on $(M), " - if size(p) != (m, n) - return DomainError(size(p), string(s, "since its size is wrong.")) - end if r > k return DomainError(r, string(s, "since its rank is too large ($(r)).")) end return nothing end -function check_point(M::FixedRankMatrices{m,n,k}, x::SVDMPoint; kwargs...) where {m,n,k} - s = "The point $(x) does not lie on $(M), " - if (size(x.U) != (m, k)) || (length(x.S) != k) || (size(x.Vt) != (k, n)) - return DomainError( - [size(x.U)..., length(x.S), size(x.Vt)...], - string( - s, - "since the dimensions do not fit (expected $(n)x$(m) rank $(k) got $(size(x.U,1))x$(size(x.Vt,2)) rank $(size(x.S)).", - ), - ) - end - if !isapprox(x.U' * x.U, one(zeros(k, k)); kwargs...) +function check_point(M::FixedRankMatrices{m,n,k}, p::SVDMPoint; kwargs...) where {m,n,k} + s = "The point $(p) does not lie on $(M), " + if !isapprox(p.U' * p.U, one(zeros(k, k)); kwargs...) return DomainError( - norm(x.U' * x.U - one(zeros(k, k))), + norm(p.U' * p.U - one(zeros(k, k))), string(s, " since U is not orthonormal/unitary."), ) end - if !isapprox(x.Vt * x.Vt', one(zeros(k, k)); kwargs...) + if !isapprox(p.Vt * p.Vt', one(zeros(k, k)); kwargs...) return DomainError( - norm(x.Vt * x.Vt' - one(zeros(k, k))), + norm(p.Vt * p.Vt' - one(zeros(k, k))), string(s, " since V is not orthonormal/unitary."), ) end return nothing end +function check_size(M::FixedRankMatrices{m,n,k}, p::SVDMPoint) where {m,n,k} + if (size(p.U) != (m, k)) || (length(p.S) != k) || (size(p.Vt) != (k, n)) + return DomainError( + [size(p.U)..., length(p.S), size(p.Vt)...], + "The point $(p) does not lie on $(M) since the dimensions do not fit (expected $(n)x$(m) rank $(k) got $(size(p.U,1))x$(size(p.Vt,2)) rank $(size(p.S,1)).", + ) + end +end +function check_size(M::FixedRankMatrices{m,n,k}, p) where {m,n,k} + pS = svd(p) + if (size(pS.U) != (m, k)) || (length(pS.S) != k) || (size(pS.Vt) != (k, n)) + return DomainError( + [size(pS.U)..., length(pS.S), size(pS.Vt)...], + "The point $(p) does not lie on $(M) since the dimensions do not fit (expected $(n)x$(m) rank $(k) got $(size(pS.U,1))x$(size(pS.Vt,2)) rank $(size(pS.S,1)).", + ) + end +end +function check_size(M::FixedRankMatrices{m,n,k}, p, X::UMVTVector) where {m,n,k} + if (size(X.U) != (m, k)) || (size(X.Vt) != (k, n)) || (size(X.M) != (k, k)) + return DomainError( + cat(size(X.U), size(X.M), size(X.Vt), dims=1), + "The tangent vector $(X) is not a tangent vector to $(p) on $(M), since matrix dimensions do not agree (expected $(m)x$(k), $(k)x$(k), $(k)x$(n)).", + ) + end +end + @doc raw""" check_vector(M:FixedRankMatrices{m,n,k}, p, X; kwargs...) @@ -248,12 +260,6 @@ function check_vector( X::UMVTVector; kwargs..., ) where {m,n,k} - if (size(X.U) != (m, k)) || (size(X.Vt) != (k, n)) || (size(X.M) != (k, k)) - return DomainError( - cat(size(X.U), size(X.M), size(X.Vt), dims=1), - "The tangent vector $(X) is not a tangent vector to $(p) on $(M), since matrix dimensions do not agree (expected $(m)x$(k), $(k)x$(k), $(k)x$(n)).", - ) - end if !isapprox(X.U' * p.U, zeros(k, k); kwargs...) return DomainError( norm(X.U' * p.U - zeros(k, k)), @@ -288,9 +294,11 @@ end Embed the point `p` from its `SVDMPoint` representation into the set of ``mΓ—n`` matrices by computing ``USV^{\mathrm{H}}``. """ -embed(::FixedRankMatrices, p) +function embed(::FixedRankMatrices, p::SVDMPoint) + return p.U * Diagonal(p.S) * p.Vt +end -function embed!(::FixedRankMatrices, q, p) +function embed!(::FixedRankMatrices, q, p::SVDMPoint) return mul!(q, p.U * Diagonal(p.S), p.Vt) end @@ -305,9 +313,11 @@ The formula reads U_pMV_p^{\mathrm{H}} + U_XV_p^{\mathrm{H}} + U_pV_X^{\mathrm{H}} ``` """ -embed(::FixedRankMatrices, p, X) +function embed(::FixedRankMatrices, p::SVDMPoint, X::UMVTVector) + return (p.U * X.M .+ X.U) * p.Vt + p.U * X.Vt +end -function embed!(::FixedRankMatrices, Y, p, X) +function embed!(::FixedRankMatrices, Y, p::SVDMPoint, X::UMVTVector) tmp = p.U * X.M tmp .+= X.U mul!(Y, tmp, p.Vt) @@ -360,7 +370,7 @@ of dimension `m`x`n` of rank `k`, namely \dim(\mathcal M) = k(m + n - k) \dim_ℝ 𝔽, ```` -where ``\dim_ℝ 𝔽`` is the [`real_dimension`](@ref) of `𝔽`. +where ``\dim_ℝ 𝔽`` is the [`real_dimension`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.real_dimension-Tuple{ManifoldsBase.AbstractNumbers}) of `𝔽`. """ function manifold_dimension(::FixedRankMatrices{m,n,k,𝔽}) where {m,n,k,𝔽} return (m + n - k) * k * real_dimension(𝔽) @@ -520,12 +530,11 @@ singular values and ``U`` and ``V`` are shortened accordingly. """ retract(::FixedRankMatrices, ::Any, ::Any, ::PolarRetraction) -function retract!( +function retract_polar!( ::FixedRankMatrices{m,n,k}, q::SVDMPoint, p::SVDMPoint, X::UMVTVector, - ::PolarRetraction, ) where {m,n,k} QU, RU = qr([p.U X.U]) QV, RV = qr([p.Vt' X.Vt']) @@ -595,7 +604,7 @@ of `X` to `q`. """ vector_transport_to!(::FixedRankMatrices, ::Any, ::Any, ::Any, ::ProjectionTransport) -function vector_transport_to!(M::FixedRankMatrices, Y, p, X, q, ::ProjectionTransport) +function vector_transport_to_project!(M::FixedRankMatrices, Y, p, X, q) return project!(M, Y, q, embed(M, p, X)) end diff --git a/src/manifolds/GeneralizedGrassmann.jl b/src/manifolds/GeneralizedGrassmann.jl index 41112e5dda..8962105745 100644 --- a/src/manifolds/GeneralizedGrassmann.jl +++ b/src/manifolds/GeneralizedGrassmann.jl @@ -1,5 +1,5 @@ @doc raw""" - GeneralizedGrassmann{n,k,𝔽} <: AbstractEmbeddedManifold{𝔽,DefaultEmbeddingType} + GeneralizedGrassmann{n,k,𝔽} <: AbstractDecoratorManifold{𝔽} The generalized Grassmann manifold $\operatorname{Gr}(n,k,B)$ consists of all subspaces spanned by $k$ linear independent vectors $𝔽^n$, where $𝔽 ∈ \{ℝ, β„‚\}$ is either the real- (or complex-) valued vectors. @@ -43,8 +43,7 @@ The manifold is named after Generate the (real-valued) Generalized Grassmann manifold of $n\times k$ dimensional orthonormal matrices with scalar product `B`. """ -struct GeneralizedGrassmann{n,k,𝔽,TB<:AbstractMatrix} <: - AbstractEmbeddedManifold{𝔽,DefaultEmbeddingType} +struct GeneralizedGrassmann{n,k,𝔽,TB<:AbstractMatrix} <: AbstractDecoratorManifold{𝔽} B::TB end @@ -57,6 +56,8 @@ function GeneralizedGrassmann( return GeneralizedGrassmann{n,k,field,typeof(B)}(B) end +active_traits(f, ::GeneralizedGrassmann, args...) = merge_traits(IsEmbeddedManifold()) + @doc raw""" change_representer(M::GeneralizedGrassmann, ::EuclideanMetric, p, X) @@ -98,22 +99,7 @@ a `n`-by-`k` matrix of unitary column vectors with respect to the B inner prudct of correct `eltype` with respect to `𝔽`. """ function check_point(M::GeneralizedGrassmann{n,k,𝔽}, p; kwargs...) where {n,k,𝔽} - mpv = invoke( - check_point, - Tuple{typeof(get_embedding(M)),typeof(p)}, - get_embedding(M), - p; - kwargs..., - ) - mpv === nothing || return mpv - c = p' * M.B * p - if !isapprox(c, one(c); kwargs...) - return DomainError( - norm(c - one(c)), - "The point $(p) does not lie on $(M), because x'Bx is not the unit matrix.", - ) - end - return nothing + return nothing # everything already checked in the embedding (generalized Stiefel) end @doc raw""" @@ -130,26 +116,7 @@ where $\cdot^{\mathrm{H}}$ denotes the complex conjugate transpose or Hermitian, $\overline{\cdot}$ the (elementwise) complex conjugate, and $0_k$ denotes the $k Γ— k$ zero natrix. """ function check_vector(M::GeneralizedGrassmann{n,k,𝔽}, p, X; kwargs...) where {n,k,𝔽} - mpv = invoke( - check_vector, - Tuple{typeof(get_embedding(M)),typeof(p),typeof(X)}, - get_embedding(M), - p, - X; - kwargs..., - ) - mpv === nothing || return mpv - if !isapprox(p' * M.B * X, -conj(X' * M.B * p); kwargs...) - return DomainError( - norm(p' * M.B * X + conj(X' * M.B * p)), - "The matrix $(X) does not lie in the tangent space of $(p) on $(M), since x'Bv + v'Bx is not the zero matrix.", - ) - end - return nothing -end - -function decorated_manifold(M::GeneralizedGrassmann{N,K,𝔽}) where {N,K,𝔽} - return Euclidean(N, K; field=𝔽) + return nothing # everything already checked in the embedding (generalized Stiefel) end @doc raw""" @@ -179,6 +146,9 @@ function distance(M::GeneralizedGrassmann, p, q) return sqrt(sum(x -> abs2(acos(clamp(x, -1, 1))), a)) end +embed(::GeneralizedGrassmann, p) = p +embed(::GeneralizedGrassmann, p, X) = X + @doc raw""" exp(M::GeneralizedGrassmann, p, X) @@ -213,17 +183,13 @@ Return the injectivity radius on the [`GeneralizedGrassmann`](@ref) `M`, which is $\frac{Ο€}{2}$. """ injectivity_radius(::GeneralizedGrassmann) = Ο€ / 2 -injectivity_radius(::GeneralizedGrassmann, ::ExponentialRetraction) = Ο€ / 2 -injectivity_radius(::GeneralizedGrassmann, ::Any) = Ο€ / 2 -injectivity_radius(::GeneralizedGrassmann, ::Any, ::ExponentialRetraction) = Ο€ / 2 -eval( - quote - @invoke_maker 1 AbstractManifold injectivity_radius( - M::GeneralizedGrassmann, - rm::AbstractRetractionMethod, - ) - end, -) +injectivity_radius(::GeneralizedGrassmann, p) = Ο€ / 2 +injectivity_radius(::GeneralizedGrassmann, ::AbstractRetractionMethod) = Ο€ / 2 +injectivity_radius(::GeneralizedGrassmann, p, ::AbstractRetractionMethod) = Ο€ / 2 + +function get_embedding(M::GeneralizedGrassmann{N,K,𝔽}) where {N,K,𝔽} + return GeneralizedStiefel(N, K, M.B, 𝔽) +end @doc raw""" inner(M::GeneralizedGrassmann, p, X, Y) @@ -250,7 +216,7 @@ end log(M::GeneralizedGrassmann, p, q) Compute the logarithmic map on the [`GeneralizedGrassmann`](@ref) `M`$ = \mathcal M=\mathrm{Gr}(n,k,B)$, -i.e. the tangent vector `X` whose corresponding [`geodesic`](@ref) starting from `p` +i.e. the tangent vector `X` whose corresponding [`geodesic`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions.html#ManifoldsBase.geodesic-Tuple{AbstractManifold,%20Any,%20Any}) starting from `p` reaches `q` after time 1 on `M`. The formula reads ````math @@ -286,7 +252,7 @@ Return the dimension of the [`GeneralizedGrassmann(n,k,𝔽)`](@ref) manifold `M \dim \operatorname{Gr}(n,k,B) = k(n-k) \dim_ℝ 𝔽, ```` -where $\dim_ℝ 𝔽$ is the [`real_dimension`](@ref) of `𝔽`. +where $\dim_ℝ 𝔽$ is the [`real_dimension`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.real_dimension-Tuple{ManifoldsBase.AbstractNumbers}) of `𝔽`. """ function manifold_dimension(::GeneralizedGrassmann{n,k,𝔽}) where {n,k,𝔽} return k * (n - k) * real_dimension(𝔽) @@ -306,14 +272,8 @@ Compute the Riemannian [`mean`](@ref mean(M::AbstractManifold, args...)) of `x` """ mean(::GeneralizedGrassmann{n,k} where {n,k}, ::Any...) -function Statistics.mean!( - M::GeneralizedGrassmann{n,k}, - p, - x::AbstractVector, - w::AbstractVector; - kwargs..., -) where {n,k} - return mean!(M, p, x, w, GeodesicInterpolationWithinRadius(Ο€ / 4); kwargs...) +function default_estimation_method(::GeneralizedGrassmann, ::typeof(mean)) + return GeodesicInterpolationWithinRadius(Ο€ / 4) end @doc raw""" @@ -365,17 +325,17 @@ Return the represenation size or matrix dimension of a point on the [`Generalize @doc raw""" retract(M::GeneralizedGrassmann, p, X, ::PolarRetraction) -Compute the SVD-based retraction [`PolarRetraction`](@ref) on the +Compute the SVD-based retraction [`PolarRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.PolarRetraction) on the [`GeneralizedGrassmann`](@ref) `M`, by [`project`](@ref project(M::GeneralizedGrassmann, p))ing $p + X$ onto `M`. """ retract(::GeneralizedGrassmann, ::Any, ::Any, ::PolarRetraction) -function retract!(M::GeneralizedGrassmann, q, p, X, ::PolarRetraction) +function retract_polar!(M::GeneralizedGrassmann, q, p, X) project!(M, q, p + X) return q end -function retract!(M::GeneralizedGrassmann, q, p, X, ::ProjectionRetraction) +function retract_project!(M::GeneralizedGrassmann, q, p, X) project!(M, q, p + X) return q end @@ -384,20 +344,6 @@ function Base.show(io::IO, M::GeneralizedGrassmann{n,k,𝔽}) where {n,k,𝔽} return print(io, "GeneralizedGrassmann($(n), $(k), $(M.B), $(𝔽))") end -@doc raw""" - vector_transport_to(M::GeneralizedGrassmann, p, X, q, ::ProjectionTransport) - -Compute the vector transport of the tangent vector `X` at `p` to `q`, -using the [`project`](@ref project(::GeneralizedGrassmann, ::Any...)) -of `X` to `q`. -""" -vector_transport_to(::GeneralizedGrassmann, ::Any, ::Any, ::Any, ::ProjectionTransport) - -function vector_transport_to!(M::GeneralizedGrassmann, Y, p, X, q, ::ProjectionTransport) - project!(M, Y, q, X) - return Y -end - @doc raw""" zero_vector(M::GeneralizedGrassmann, p) diff --git a/src/manifolds/GeneralizedStiefel.jl b/src/manifolds/GeneralizedStiefel.jl index 7aad63d12f..a7ccaf450b 100644 --- a/src/manifolds/GeneralizedStiefel.jl +++ b/src/manifolds/GeneralizedStiefel.jl @@ -1,5 +1,5 @@ @doc raw""" - GeneralizedStiefel{n,k,𝔽,B} <: AbstractEmbeddedManifold{𝔽,DefaultEmbeddingType} + GeneralizedStiefel{n,k,𝔽,B} <: AbstractDecoratorManifold{𝔽} The Generalized Stiefel manifold consists of all $n\times k$, $n\geq k$ orthonormal matrices w.r.t. an arbitrary scalar product with symmetric positive definite matrix @@ -35,8 +35,7 @@ The manifold is named after Generate the (real-valued) Generalized Stiefel manifold of $n\times k$ dimensional orthonormal matrices with scalar product `B`. """ -struct GeneralizedStiefel{n,k,𝔽,TB<:AbstractMatrix} <: - AbstractEmbeddedManifold{𝔽,DefaultEmbeddingType} +struct GeneralizedStiefel{n,k,𝔽,TB<:AbstractMatrix} <: AbstractDecoratorManifold{𝔽} B::TB end @@ -49,17 +48,17 @@ function GeneralizedStiefel( return GeneralizedStiefel{n,k,𝔽,typeof(B)}(B) end +active_traits(f, ::GeneralizedStiefel, args...) = merge_traits(IsEmbeddedManifold()) + @doc raw""" check_point(M::GeneralizedStiefel, p; kwargs...) Check whether `p` is a valid point on the [`GeneralizedStiefel`](@ref) `M`=$\operatorname{St}(n,k,B)$, -i.e. that it has the right [`AbstractNumbers`](@ref) type and $x^{\mathrm{H}}Bx$ +i.e. that it has the right [`AbstractNumbers`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#number-system) type and $x^{\mathrm{H}}Bx$ is (approximately) the identity, where $\cdot^{\mathrm{H}}$ is the complex conjugate transpose. The settings for approximately can be set with `kwargs...`. """ function check_point(M::GeneralizedStiefel{n,k,𝔽}, p; kwargs...) where {n,k,𝔽} - mpv = invoke(check_point, Tuple{supertype(typeof(M)),typeof(p)}, M, p; kwargs...) - mpv === nothing || return mpv c = p' * M.B * p if !isapprox(c, one(c); kwargs...) return DomainError( @@ -70,25 +69,24 @@ function check_point(M::GeneralizedStiefel{n,k,𝔽}, p; kwargs...) where {n,k, return nothing end +# overwrite passing to embedding +function check_size(M::GeneralizedStiefel{n,k,𝔽}, p) where {n,k,𝔽} + return check_size(get_embedding(M), p) #avoid embed, since it uses copyto! +end +function check_size(M::GeneralizedStiefel{n,k,𝔽}, p, X) where {n,k,𝔽} + return check_size(get_embedding(M), p, X) #avoid embed, since it uses copyto! +end + @doc raw""" check_vector(M::GeneralizedStiefel, p, X; kwargs...) Check whether `X` is a valid tangent vector at `p` on the [`GeneralizedStiefel`](@ref) -`M`=$\operatorname{St}(n,k,B)$, i.e. the [`AbstractNumbers`](@ref) fits, +`M`=$\operatorname{St}(n,k,B)$, i.e. the [`AbstractNumbers`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#number-system) fits, `p` is a valid point on `M` and it (approximately) holds that $p^{\mathrm{H}}BX + \overline{X^{\mathrm{H}}Bp} = 0$, where `kwargs...` is passed to the `isapprox`. """ function check_vector(M::GeneralizedStiefel{n,k,𝔽}, p, X; kwargs...) where {n,k,B,𝔽} - mpv = invoke( - check_vector, - Tuple{supertype(typeof(M)),typeof(p),typeof(X)}, - M, - p, - X; - kwargs..., - ) - mpv === nothing || return mpv if !isapprox(p' * M.B * X, -conj(X' * M.B * p); kwargs...) return DomainError( norm(p' * M.B * X + conj(X' * M.B * p)), @@ -98,7 +96,7 @@ function check_vector(M::GeneralizedStiefel{n,k,𝔽}, p, X; kwargs...) where {n return nothing end -decorated_manifold(M::GeneralizedStiefel{N,K,𝔽}) where {N,K,𝔽} = Euclidean(N, K; field=𝔽) +get_embedding(::GeneralizedStiefel{N,K,𝔽}) where {N,K,𝔽} = Euclidean(N, K; field=𝔽) @doc raw""" inner(M::GeneralizedStiefel, p, X, Y) @@ -176,22 +174,22 @@ end retract(M::GeneralizedStiefel, p, X, ::PolarRetraction) retract(M::GeneralizedStiefel, p, X, ::ProjectionRetraction) -Compute the SVD-based retraction [`PolarRetraction`](@ref) on the +Compute the SVD-based retraction [`PolarRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.PolarRetraction) on the [`GeneralizedStiefel`](@ref) manifold `M`, which in this case is the same as the projection based retraction employing the exponential map in the embedding and projecting the result back to the manifold. -The default retraction for this manifold is the [`ProjectionRetraction`](@ref). +The default retraction for this manifold is the [`ProjectionRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.ProjectionRetraction). """ retract(::GeneralizedStiefel, ::Any...) -retract(M::GeneralizedStiefel, p, X) = retract(M, p, X, ProjectionRetraction()) -retract!(M::GeneralizedStiefel, q, p, X) = retract!(M, q, p, X, ProjectionRetraction()) -function retract!(M::GeneralizedStiefel, q, p, X, ::PolarRetraction) +default_retraction_method(::GeneralizedStiefel) = ProjectionRetraction() + +function retract_polar!(M::GeneralizedStiefel, q, p, X) project!(M, q, p + X) return q end -function retract!(M::GeneralizedStiefel, q, p, X, ::ProjectionRetraction) +function retract_project!(M::GeneralizedStiefel, q, p, X) project!(M, q, p + X) return q end @@ -199,17 +197,3 @@ end function Base.show(io::IO, M::GeneralizedStiefel{n,k,𝔽}) where {n,k,𝔽} return print(io, "GeneralizedStiefel($(n), $(k), $(M.B), $(𝔽))") end - -@doc raw""" - vector_transport_to(M::GeneralizedStiefel, p, X, q, ::ProjectionTransport) - -Compute the vector transport of the tangent vector `X` at `p` to `q`, -using the [`project`](@ref project(::GeneralizedStiefel, ::Any...)) -of `X` to `q`. -""" -vector_transport_to(::GeneralizedStiefel, ::Any, ::Any, ::Any, ::ProjectionTransport) - -function vector_transport_to!(M::GeneralizedStiefel, Y, p, X, q, ::ProjectionTransport) - project!(M, Y, q, X) - return Y -end diff --git a/src/manifolds/GraphManifold.jl b/src/manifolds/GraphManifold.jl index b36fe7c65e..b9154d118a 100644 --- a/src/manifolds/GraphManifold.jl +++ b/src/manifolds/GraphManifold.jl @@ -23,12 +23,12 @@ struct VertexManifold <: GraphManifoldType end @doc raw""" GraphManifold{G,𝔽,M,T} <: AbstractPowerManifold{𝔽,M,NestedPowerRepresentation} -Build a manifold, that is a [`PowerManifold`](@ref) of the [`AbstractManifold`](@ref) `M` either on +Build a manifold, that is a [`PowerManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.PowerManifold) of the [`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) `M` either on the edges or vertices of a graph `G` depending on the [`GraphManifoldType`](@ref) `T`. # Fields * `G` is an `AbstractSimpleGraph` -* `M` is a [`AbstractManifold`](@ref) +* `M` is a `AbstractManifold` """ struct GraphManifold{G<:AbstractGraph,𝔽,TM,T<:GraphManifoldType} <: AbstractPowerManifold{𝔽,TM,NestedPowerRepresentation} @@ -65,22 +65,10 @@ passes the [`check_point`](@ref) test for the base manifold `M.manifold`. """ check_point(::GraphManifold, ::Any...) function check_point(M::VertexGraphManifold, p; kwargs...) - if size(p) != (nv(M.graph),) - return DomainError( - length(p), - "The number of points in `x` ($(length(p))) does not match the number of nodes in the graph ($(nv(M.graph))).", - ) - end PM = PowerManifold(M.manifold, NestedPowerRepresentation(), nv(M.graph)) return check_point(PM, p; kwargs...) end function check_point(M::EdgeGraphManifold, p; kwargs...) - if size(p) != (ne(M.graph),) - return DomainError( - length(p), - "The number of points in `x` ($(size(p))) does not match the number of edges in the graph ($(ne(M.graph))).", - ) - end PM = PowerManifold(M.manifold, NestedPowerRepresentation(), ne(M.graph)) return check_point(PM, p; kwargs...) end @@ -97,22 +85,10 @@ together with its corresponding entry of `p` passes the """ check_vector(::GraphManifold, ::Any...) function check_vector(M::VertexGraphManifold, p, X; kwargs...) - if size(X) != (nv(M.graph),) - return DomainError( - length(X), - "The number of points in `v` ($(size(X)) does not match the number of nodes in the graph ($(nv(M.graph))).", - ) - end PM = PowerManifold(M.manifold, NestedPowerRepresentation(), nv(M.graph)) return check_vector(PM, p, X; kwargs...) end function check_vector(M::EdgeGraphManifold, p, X; kwargs...) - if size(X) != (ne(M.graph),) - return DomainError( - length(X), - "The number of elements in `v` ($(size(X)) does not match the number of edges in the graph ($(ne(M.graph))).", - ) - end PM = PowerManifold(M.manifold, NestedPowerRepresentation(), ne(M.graph)) return check_vector(PM, p, X; kwargs...) end @@ -206,6 +182,9 @@ function manifold_dimension(M::EdgeGraphManifold) return manifold_dimension(M.manifold) * ne(M.graph) end +power_dimensions(M::EdgeGraphManifold) = (ne(M.graph),) +power_dimensions(M::VertexGraphManifold) = (nv(M.graph),) + function _show_graph_manifold(io::IO, M; man_desc="", pre="") println(io, "GraphManifold\nGraph:") sg = sprint(show, "text/plain", M.graph, context=io, sizehint=0) diff --git a/src/manifolds/Grassmann.jl b/src/manifolds/Grassmann.jl index d8a6df5caa..8f40323b24 100644 --- a/src/manifolds/Grassmann.jl +++ b/src/manifolds/Grassmann.jl @@ -1,5 +1,5 @@ @doc raw""" - Grassmann{n,k,𝔽} <: AbstractEmbeddedManifold{𝔽,DefaultIsometricEmbeddingType} + Grassmann{n,k,𝔽} <: AbstractDecoratorManifold{𝔽} The Grassmann manifold $\operatorname{Gr}(n,k)$ consists of all subspaces spanned by $k$ linear independent vectors $𝔽^n$, where $𝔽 ∈ \{ℝ, β„‚\}$ is either the real- (or complex-) valued vectors. @@ -51,67 +51,16 @@ The manifold is named after Generate the Grassmann manifold $\operatorname{Gr}(n,k)$, where the real-valued case `field = ℝ` is the default. """ -struct Grassmann{n,k,𝔽} <: AbstractEmbeddedManifold{𝔽,DefaultIsometricEmbeddingType} end +struct Grassmann{n,k,𝔽} <: AbstractDecoratorManifold{𝔽} end Grassmann(n::Int, k::Int, field::AbstractNumbers=ℝ) = Grassmann{n,k,field}() +active_traits(f, ::Grassmann, args...) = merge_traits(IsIsometricEmbeddedManifold()) + function allocation_promotion_function(M::Grassmann{n,k,β„‚}, f, args::Tuple) where {n,k} return complex end -@doc raw""" - check_point(M::Grassmann{n,k,𝔽}, p) - -Check whether `p` is representing a point on the [`Grassmann`](@ref) `M`, i.e. its -a `n`-by-`k` matrix of unitary column vectors and of correct `eltype` with respect to `𝔽`. -""" -function check_point(M::Grassmann{n,k,𝔽}, p; kwargs...) where {n,k,𝔽} - mpv = invoke(check_point, Tuple{supertype(typeof(M)),typeof(p)}, M, p; kwargs...) - mpv === nothing || return mpv - c = p' * p - if !isapprox(c, one(c); kwargs...) - return DomainError( - norm(c - one(c)), - "The point $(p) does not lie on $(M), because x'x is not the unit matrix.", - ) - end - return nothing -end - -@doc raw""" - check_vector(M::Grassmann{n,k,𝔽}, p, X; kwargs...) - -Check whether `X` is a tangent vector in the tangent space of `p` on -the [`Grassmann`](@ref) `M`, i.e. that `X` is of size and type as well as that - -````math - p^{\mathrm{H}}X + X^{\mathrm{H}}p = 0_k, -```` - -where $\cdot^{\mathrm{H}}$ denotes the complex conjugate transpose or Hermitian -and $0_k$ the $k Γ— k$ zero matrix. -""" -function check_vector(M::Grassmann{n,k,𝔽}, p, X; kwargs...) where {n,k,𝔽} - mpv = invoke( - check_vector, - Tuple{supertype(typeof(M)),typeof(p),typeof(X)}, - M, - p, - X; - kwargs..., - ) - mpv === nothing || return mpv - if !isapprox(p' * X, -conj(X' * p); kwargs...) - return DomainError( - norm(p' * X + conj(X' * p)), - "The matrix $(X) does not lie in the tangent space of $(p) on $(M), since p'X + X'p is not the zero matrix.", - ) - end - return nothing -end - -decorated_manifold(::Grassmann{N,K,𝔽}) where {N,K,𝔽} = Euclidean(N, K; field=𝔽) - @doc raw""" distance(M::Grassmann, p, q) @@ -138,6 +87,9 @@ function distance(::Grassmann, p, q) return sqrt(sum(x -> abs2(acos(clamp(x, -1, 1))), a)) end +embed(::Grassmann, p) = p +embed(::Grassmann, p, X) = X + @doc raw""" exp(M::Grassmann, p, X) @@ -166,6 +118,10 @@ function exp!(M::Grassmann, q, p, X) return copyto!(q, Array(qr(z).Q)) end +function get_embedding(::Grassmann{N,K,𝔽}) where {N,K,𝔽} + return Stiefel(N, K, 𝔽) +end + @doc raw""" injectivity_radius(M::Grassmann) injectivity_radius(M::Grassmann, p) @@ -173,17 +129,9 @@ end Return the injectivity radius on the [`Grassmann`](@ref) `M`, which is $\frac{Ο€}{2}$. """ injectivity_radius(::Grassmann) = Ο€ / 2 -injectivity_radius(::Grassmann, ::ExponentialRetraction) = Ο€ / 2 -injectivity_radius(::Grassmann, ::Any) = Ο€ / 2 -injectivity_radius(::Grassmann, ::Any, ::ExponentialRetraction) = Ο€ / 2 -eval( - quote - @invoke_maker 1 AbstractManifold injectivity_radius( - M::Grassmann, - rm::AbstractRetractionMethod, - ) - end, -) +injectivity_radius(::Grassmann, p) = Ο€ / 2 +injectivity_radius(::Grassmann, ::AbstractRetractionMethod) = Ο€ / 2 +injectivity_radius(::Grassmann, p, ::AbstractRetractionMethod) = Ο€ / 2 @doc raw""" inner(M::Grassmann, p, X, Y) @@ -202,7 +150,7 @@ inner(::Grassmann, p, X, Y) = dot(X, Y) @doc raw""" inverse_retract(M::Grassmann, p, q, ::PolarInverseRetraction) -Compute the inverse retraction for the [`PolarRetraction`](@ref), on the +Compute the inverse retraction for the [`PolarRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.PolarRetraction), on the [`Grassmann`](@ref) manifold `M`, i.e., ````math @@ -213,14 +161,14 @@ where $\cdot^{\mathrm{H}}$ denotes the complex conjugate transposed or Hermitian """ inverse_retract(::Grassmann, ::Any, ::Any, ::PolarInverseRetraction) -function inverse_retract!(::Grassmann, X, p, q, ::PolarInverseRetraction) +function inverse_retract_polar!(::Grassmann, X, p, q) return copyto!(X, q / (p' * q) - p) end @doc raw""" inverse_retract(M, p, q, ::QRInverseRetraction) -Compute the inverse retraction for the [`QRRetraction`](@ref), on the +Compute the inverse retraction for the [`QRRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.QRRetraction), on the [`Grassmann`](@ref) manifold `M`, i.e., ````math @@ -230,7 +178,7 @@ where $\cdot^{\mathrm{H}}$ denotes the complex conjugate transposed or Hermitian """ inverse_retract(::Grassmann, ::Any, ::Any, ::QRInverseRetraction) -inverse_retract!(::Grassmann, X, p, q, ::QRInverseRetraction) = copyto!(X, q / (p' * q) - p) +inverse_retract_qr!(::Grassmann, X, p, q) = copyto!(X, q / (p' * q) - p) function Base.isapprox(M::Grassmann, p, X, Y; kwargs...) return isapprox(sqrt(inner(M, p, zero_vector(M, p), X - Y)), 0; kwargs...) @@ -241,7 +189,7 @@ Base.isapprox(M::Grassmann, p, q; kwargs...) = isapprox(distance(M, p, q), 0.0; log(M::Grassmann, p, q) Compute the logarithmic map on the [`Grassmann`](@ref) `M`$ = \mathcal M=\mathrm{Gr}(n,k)$, -i.e. the tangent vector `X` whose corresponding [`geodesic`](@ref) starting from `p` +i.e. the tangent vector `X` whose corresponding [`geodesic`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions.html#ManifoldsBase.geodesic-Tuple{AbstractManifold,%20Any,%20Any}) starting from `p` reaches `q` after time 1 on `M`. The formula reads ````math @@ -277,7 +225,7 @@ Return the dimension of the [`Grassmann(n,k,𝔽)`](@ref) manifold `M`, i.e. \dim \operatorname{Gr}(n,k) = k(n-k) \dim_ℝ 𝔽, ```` -where $\dim_ℝ 𝔽$ is the [`real_dimension`](@ref) of `𝔽`. +where $\dim_ℝ 𝔽$ is the [`real_dimension`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.real_dimension-Tuple{ManifoldsBase.AbstractNumbers}) of `𝔽`. """ manifold_dimension(::Grassmann{n,k,𝔽}) where {n,k,𝔽} = k * (n - k) * real_dimension(𝔽) @@ -295,14 +243,8 @@ Compute the Riemannian [`mean`](@ref mean(M::AbstractManifold, args...)) of `x` """ mean(::Grassmann{n,k} where {n,k}, ::Any...) -function Statistics.mean!( - M::Grassmann{n,k}, - p, - x::AbstractVector, - w::AbstractVector; - kwargs..., -) where {n,k} - return mean!(M, p, x, w, GeodesicInterpolationWithinRadius(Ο€ / 4); kwargs...) +function default_estimation_method(::Grassmann, ::typeof(mean)) + return GeodesicInterpolationWithinRadius(Ο€ / 4) end @doc raw""" @@ -379,7 +321,7 @@ Return the represenation size or matrix dimension of a point on the [`Grassmann` @doc raw""" retract(M::Grassmann, p, X, ::PolarRetraction) -Compute the SVD-based retraction [`PolarRetraction`](@ref) on the +Compute the SVD-based retraction [`PolarRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.PolarRetraction) on the [`Grassmann`](@ref) `M`. With $USV = p + X$ the retraction reads ````math \operatorname{retr}_p X = UV^\mathrm{H}, @@ -389,10 +331,15 @@ where $\cdot^{\mathrm{H}}$ denotes the complex conjugate transposed or Hermitian """ retract(::Grassmann, ::Any, ::Any, ::PolarRetraction) +function retract_polar!(::Grassmann, q, p, X) + s = svd(p + X) + return mul!(q, s.U, s.Vt) +end + @doc raw""" retract(M::Grassmann, p, X, ::QRRetraction ) -Compute the QR-based retraction [`QRRetraction`](@ref) on the +Compute the QR-based retraction [`QRRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.QRRetraction) on the [`Grassmann`](@ref) `M`. With $QR = p + X$ the retraction reads ````math \operatorname{retr}_p X = QD, @@ -404,11 +351,7 @@ D = \operatorname{diag}\left( \operatorname{sgn}\left(R_{ii}+\frac{1}{2}\right)_ """ retract(::Grassmann, ::Any, ::Any, ::QRRetraction) -function retract!(::Grassmann, q, p, X, ::PolarRetraction) - s = svd(p + X) - return mul!(q, s.U, s.Vt) -end -function retract!(::Grassmann{N,K}, q, p, X, ::QRRetraction) where {N,K} +function retract_qr!(::Grassmann{N,K}, q, p, X) where {N,K} qrfac = qr(p + X) d = diag(qrfac.R) D = Diagonal(sign.(d .+ 1 // 2)) @@ -453,11 +396,6 @@ projecting it onto the tangent space at q. """ vector_transport_to(::Grassmann, ::Any, ::Any, ::Any, ::ProjectionTransport) -function vector_transport_to!(M::Grassmann, Y, p, X, q, ::ProjectionTransport) - project!(M, Y, q, X) - return Y -end - @doc raw""" zero_vector(M::Grassmann, p) diff --git a/src/manifolds/Hyperbolic.jl b/src/manifolds/Hyperbolic.jl index a9e5af25f3..4390f5a9ea 100644 --- a/src/manifolds/Hyperbolic.jl +++ b/src/manifolds/Hyperbolic.jl @@ -1,5 +1,5 @@ @doc raw""" - Hyperbolic{N} <: AbstractEmbeddedManifold{ℝ,DefaultIsometricEmbeddingType} + Hyperbolic{N} <: AbstractDecoratorManifold{ℝ} The hyperbolic space $\mathcal H^n$ represented by $n+1$-Tuples, i.e. embedded in the [`Lorentz`](@ref)ian manifold equipped with the [`MinkowskiMetric`](@ref) @@ -33,10 +33,14 @@ and the PoincarΓ© half space model, see [`PoincareHalfSpacePoint`](@ref) and [`P Generate the Hyperbolic manifold of dimension `n`. """ -struct Hyperbolic{N} <: AbstractEmbeddedManifold{ℝ,DefaultIsometricEmbeddingType} end +struct Hyperbolic{N} <: AbstractDecoratorManifold{ℝ} end Hyperbolic(n::Int) = Hyperbolic{n}() +function active_traits(f, ::Hyperbolic, args...) + return merge_traits(IsIsometricEmbeddedManifold(), IsDefaultMetric(MinkowskiMetric())) +end + @doc raw""" HyperboloidPoint <: AbstractManifoldPoint @@ -102,90 +106,27 @@ struct PoincareHalfSpaceTVector{TValue<:AbstractVector} <: TVector value::TValue end +ManifoldsBase.@manifold_element_forwards HyperboloidPoint value +ManifoldsBase.@manifold_vector_forwards HyperboloidTVector value +ManifoldsBase.@default_manifold_fallbacks Hyperbolic HyperboloidPoint HyperboloidTVector value value + +ManifoldsBase.@manifold_element_forwards PoincareBallPoint value +ManifoldsBase.@manifold_vector_forwards PoincareBallTVector value + +ManifoldsBase.@manifold_element_forwards PoincareHalfSpacePoint value +ManifoldsBase.@manifold_vector_forwards PoincareHalfSpaceTVector value + include("HyperbolicHyperboloid.jl") include("HyperbolicPoincareBall.jl") include("HyperbolicPoincareHalfspace.jl") -_HyperbolicPointTypes = [HyperboloidPoint, PoincareBallPoint, PoincareHalfSpacePoint] -_HyperbolicTangentTypes = - [HyperboloidTVector, PoincareBallTVector, PoincareHalfSpaceTVector] -_HyperbolicTypes = [_HyperbolicPointTypes..., _HyperbolicTangentTypes...] - -for T in _HyperbolicTangentTypes - @eval begin - Base.:*(v::$T, s::Number) = $T(v.value * s) - Base.:*(s::Number, v::$T) = $T(s * v.value) - Base.:/(v::$T, s::Number) = $T(v.value / s) - Base.:\(s::Number, v::$T) = $T(s \ v.value) - Base.:+(v::$T, w::$T) = $T(v.value + w.value) - Base.:-(v::$T, w::$T) = $T(v.value - w.value) - Base.:-(v::$T) = $T(-v.value) - Base.:+(v::$T) = $T(v.value) - end -end +_ExtraHyperbolicPointTypes = [PoincareBallPoint, PoincareHalfSpacePoint] +_ExtraHyperbolicTangentTypes = [PoincareBallTVector, PoincareHalfSpaceTVector] +_ExtraHyperbolicTypes = [_ExtraHyperbolicPointTypes..., _ExtraHyperbolicTangentTypes...] -for T in _HyperbolicTypes - @eval begin - Base.:(==)(v::$T, w::$T) = (v.value == w.value) - - allocate(p::$T) = $T(allocate(p.value)) - allocate(p::$T, ::Type{P}) where {P} = $T(allocate(p.value, P)) - allocate(p::$T, ::Type{P}, dims::Tuple) where {P} = $T(allocate(p.value, P, dims)) - - @inline Base.copy(p::$T) = $T(copy(p.value)) - function Base.copyto!(q::$T, p::$T) - copyto!(q.value, p.value) - return q - end - - Base.similar(p::$T) = $T(similar(p.value)) - - function Broadcast.BroadcastStyle(::Type{<:$T}) - return Broadcast.Style{$T}() - end - function Broadcast.BroadcastStyle( - ::Broadcast.AbstractArrayStyle{0}, - b::Broadcast.Style{$T}, - ) - return b - end - - Broadcast.instantiate(bc::Broadcast.Broadcasted{Broadcast.Style{$T},Nothing}) = bc - function Broadcast.instantiate(bc::Broadcast.Broadcasted{Broadcast.Style{$T}}) - Broadcast.check_broadcast_axes(bc.axes, bc.args...) - return bc - end - - Broadcast.broadcastable(v::$T) = v - - @inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.Style{$T}}) - return $T(Broadcast._broadcast_getindex(bc, 1)) - end - - Base.@propagate_inbounds Broadcast._broadcast_getindex(v::$T, I) = v.value - - Base.axes(v::$T) = axes(v.value) - - @inline function Base.copyto!( - dest::$T, - bc::Broadcast.Broadcasted{Broadcast.Style{$T}}, - ) - axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) - # Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match - if bc.f === identity && bc.args isa Tuple{$T} # only a single input argument to broadcast! - A = bc.args[1] - if axes(dest) == axes(A) - return copyto!(dest, A) - end - end - bcβ€² = Broadcast.preprocess(dest, bc) - # Performance may vary depending on whether `@inbounds` is placed outside the - # for loop or not. (cf. https://github.com/JuliaLang/julia/issues/38086) - copyto!(dest.value, bcβ€²[1]) - return dest - end - end -end +_HyperbolicPointTypes = [HyperboloidPoint, _ExtraHyperbolicPointTypes...] +_HyperbolicTangentTypes = [HyperboloidTVector, _ExtraHyperbolicTangentTypes...] +_HyperbolicTypes = [_HyperbolicPointTypes..., _HyperbolicTangentTypes...] for (P, T) in zip(_HyperbolicPointTypes, _HyperbolicTangentTypes) @eval allocate(p::$P, ::Type{$T}) = $T(allocate(p.value)) @@ -245,9 +186,10 @@ for (P, T) in zip(_HyperbolicPointTypes, _HyperbolicTangentTypes) end end -decorated_manifold(::Hyperbolic{N}) where {N} = Lorentz(N + 1, MinkowskiMetric()) +get_embedding(::Hyperbolic{N}) where {N} = Lorentz(N + 1, MinkowskiMetric()) -default_metric_dispatch(::Hyperbolic, ::MinkowskiMetric) = Val(true) +embed(::Hyperbolic, p::AbstractArray) = p +embed(::Hyperbolic, p::AbstractArray, X::AbstractArray) = X @doc raw""" exp(M::Hyperbolic, p, X) @@ -265,7 +207,7 @@ the [`Lorentz`](@ref)ian manifold. """ exp(::Hyperbolic, ::Any...) -for (P, T) in zip(_HyperbolicPointTypes, _HyperbolicTangentTypes) +for (P, T) in zip(_ExtraHyperbolicPointTypes, _ExtraHyperbolicTangentTypes) @eval function exp!(M::Hyperbolic, q::$P, p::$P, X::$T) q.value .= convert( @@ -283,24 +225,13 @@ end Return the injectivity radius on the [`Hyperbolic`](@ref), which is $∞$. """ injectivity_radius(::Hyperbolic) = Inf -injectivity_radius(::Hyperbolic, ::ExponentialRetraction) = Inf -injectivity_radius(::Hyperbolic, ::Any) = Inf -injectivity_radius(::Hyperbolic, ::Any, ::ExponentialRetraction) = Inf -eval( - quote - @invoke_maker 1 AbstractManifold injectivity_radius( - M::Hyperbolic, - rm::AbstractRetractionMethod, - ) - end, -) - -for T in _HyperbolicPointTypes + +for T in _ExtraHyperbolicPointTypes @eval function isapprox(::Hyperbolic, p::$T, q::$T; kwargs...) return isapprox(p.value, q.value; kwargs...) end end -for (P, T) in zip(_HyperbolicPointTypes, _HyperbolicTangentTypes) +for (P, T) in zip(_ExtraHyperbolicPointTypes, _ExtraHyperbolicTangentTypes) @eval function isapprox(::Hyperbolic, ::$P, X::$T, Y::$T; kwargs...) return isapprox(X.value, Y.value; kwargs...) end @@ -310,7 +241,7 @@ end log(M::Hyperbolic, p, q) Compute the logarithmic map on the [`Hyperbolic`](@ref) space $\mathcal H^n$, the tangent -vector representing the [`geodesic`](@ref) starting from `p` +vector representing the [`geodesic`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions.html#ManifoldsBase.geodesic-Tuple{AbstractManifold,%20Any,%20Any}) starting from `p` reaches `q` after time 1. The formula reads for $p β‰  q$ ```math @@ -323,7 +254,7 @@ the [`Lorentz`](@ref)ian manifold. For $p=q$ the logarihmic map is equal to the """ log(::Hyperbolic, ::Any...) -for (P, T) in zip(_HyperbolicPointTypes, _HyperbolicTangentTypes) +for (P, T) in zip(_ExtraHyperbolicPointTypes, _ExtraHyperbolicTangentTypes) @eval function log!(M::Hyperbolic, X::$T, p::$P, q::$P) X.value .= convert( @@ -356,13 +287,7 @@ Compute the Riemannian [`mean`](@ref mean(M::AbstractManifold, args...)) of `x` """ mean(::Hyperbolic, ::Any...) -function Statistics.mean!(M::Hyperbolic, p, x::AbstractVector, w::AbstractVector; kwargs...) - return mean!(M, p, x, w, CyclicProximalPointEstimation(); kwargs...) -end - -for T in _HyperbolicTypes - @eval number_eltype(p::$T) = typeof(one(eltype(p.value))) -end +default_estimation_method(::Hyperbolic, ::typeof(mean)) = CyclicProximalPointEstimation() @doc raw""" project(M::Hyperbolic, p, X) @@ -390,10 +315,10 @@ for T in _HyperbolicTypes end @doc raw""" - vector_transport_to(M::Hyperbolic, p, X, q, ::ParallelTransport) + parallel_transport_to(M::Hyperbolic, p, X, q) Compute the paralllel transport of the `X` from the tangent space at `p` on the -[`Hyperbolic`](@ref) space $\mathcal H^n$ to the tangent at `q` along the [`geodesic`](@ref) +[`Hyperbolic`](@ref) space $\mathcal H^n$ to the tangent at `q` along the [`geodesic`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions.html#ManifoldsBase.geodesic-Tuple{AbstractManifold,%20Any,%20Any}) connecting `p` and `q`. The formula reads ````math @@ -402,27 +327,19 @@ connecting `p` and `q`. The formula reads ```` where $⟨\cdot,\cdot⟩_p$ denotes the inner product in the tangent space at `p`. """ -vector_transport_to(::Hyperbolic, ::Any, ::Any, ::Any, ::ParallelTransport) +parallel_transport_to(::Hyperbolic, ::Any, ::Any, ::Any) -for (P, T) in zip(_HyperbolicPointTypes, _HyperbolicTangentTypes) - @eval function vector_transport_to!( - M::Hyperbolic, - Y::$T, - p::$P, - X::$T, - q::$P, - m::ParallelTransport, - ) +for (P, T) in zip(_ExtraHyperbolicPointTypes, _ExtraHyperbolicTangentTypes) + @eval function parallel_transport_to!(M::Hyperbolic, Y::$T, p::$P, X::$T, q::$P) Y.value .= convert( $T, convert(AbstractVector, q), - vector_transport_to( + parallel_transport_to( M, convert(AbstractVector, p), convert(AbstractVector, p, X), convert(AbstractVector, q), - m, ), ).value return Y diff --git a/src/manifolds/HyperbolicHyperboloid.jl b/src/manifolds/HyperbolicHyperboloid.jl index f852b987b1..90df9309d7 100644 --- a/src/manifolds/HyperbolicHyperboloid.jl +++ b/src/manifolds/HyperbolicHyperboloid.jl @@ -26,8 +26,6 @@ function change_metric!(::Hyperbolic, ::Any, ::EuclideanMetric, ::Any, ::Any) end function check_point(M::Hyperbolic, p; kwargs...) - mpv = invoke(check_point, Tuple{supertype(typeof(M)),typeof(p)}, M, p; kwargs...) - mpv === nothing || return mpv if !isapprox(minkowski_metric(p, p), -1.0; kwargs...) return DomainError( minkowski_metric(p, p), @@ -36,20 +34,8 @@ function check_point(M::Hyperbolic, p; kwargs...) end return nothing end -function check_point(M::Hyperbolic, p::HyperboloidPoint; kwargs...) - return check_point(M, p.value; kwargs...) -end function check_vector(M::Hyperbolic, p, X; kwargs...) - mpv = invoke( - check_vector, - Tuple{supertype(typeof(M)),typeof(p),typeof(X)}, - M, - p, - X; - kwargs..., - ) - mpv === nothing || return mpv if !isapprox(minkowski_metric(p, X), 0.0; kwargs...) return DomainError( abs(minkowski_metric(p, X)), @@ -58,9 +44,6 @@ function check_vector(M::Hyperbolic, p, X; kwargs...) end return nothing end -function check_vector(M::Hyperbolic, p::HyperboloidPoint, X::HyperboloidTVector; kwargs...) - return check_vector(M, p.value, X.value; kwargs...) -end function convert(::Type{HyperboloidTVector}, X::T) where {T<:AbstractVector} return HyperboloidTVector(X) @@ -244,8 +227,14 @@ where $⟨\cdot,\cdot⟩_{\mathrm{M}}$ denotes the [`MinkowskiMetric`](@ref) on the [`Lorentz`](@ref)ian manifold. """ distance(::Hyperbolic, p, q) = acosh(max(-minkowski_metric(p, q), 1.0)) -function distance(M::Hyperbolic, p::HyperboloidPoint, q::HyperboloidPoint) - return distance(M, p.value, q.value) + +embed(M::Hyperbolic, p::HyperboloidPoint) = embed(M, p.value) +embed!(M::Hyperbolic, q, p::HyperboloidPoint) = embed!(M, q, p.value) +function embed(M::Hyperbolic, p::HyperboloidPoint, X::HyperboloidTVector) + return embed(M, p.value, X.value) +end +function embed!(M::Hyperbolic, Y, p::HyperboloidPoint, X::HyperboloidTVector) + return embed!(M, Y, p.value, X.value) end function exp!(M::Hyperbolic, q, p, X) @@ -254,16 +243,26 @@ function exp!(M::Hyperbolic, q, p, X) return copyto!(q, cosh(vn) * p + sinh(vn) / vn * X) end -function get_basis(M::Hyperbolic, p, B::DefaultOrthonormalBasis{ℝ,TangentSpaceType}) - n = manifold_dimension(M) +# overwrite the default construction on level 2 (dispatching on basis) +# since this function should not call get_vector (that relies on get_basis itself on H2) +function _get_basis( + M::Hyperbolic, + p, + B::DefaultOrthonormalBasis{ℝ,TangentSpaceType}; + kwargs..., +) + return get_basis_orthonormal(M, p, ℝ) +end + +function get_basis_orthonormal(M::Hyperbolic{n}, p, r::RealNumbers) where {n} V = [ _hyperbolize(M, p, [i == k ? one(eltype(p)) : zero(eltype(p)) for k in 1:n]) for i in 1:n ] - return CachedBasis(B, gram_schmidt(M, p, V)) + return CachedBasis(DefaultOrthonormalBasis(r), gram_schmidt(M, p, V)) end -function get_basis(M::Hyperbolic, p, B::DiagonalizingOrthonormalBasis) +function get_basis_diagonalizing(M::Hyperbolic, p, B::DiagonalizingOrthonormalBasis) n = manifold_dimension(M) X = B.frame_direction V = [ @@ -297,20 +296,23 @@ Compute the coordinates of the vector `X` with respect to the orthogonalized ver the unit vectors from $ℝ^n$, where $n$ is the manifold dimension of the [`Hyperbolic`](@ref) `M`, utting them intop the tangent space at `p` and orthonormalizing them. """ -get_coordinates(M::Hyperbolic, p, X, B::DefaultOrthonormalBasis) +get_coordinates(M::Hyperbolic, p, X, ::DefaultOrthonormalBasis) -function get_coordinates!( +function get_coordinates_orthonormal(M::Hyperbolic, p, X, r::RealNumbers) + return get_coordinates(M, p, X, get_basis_orthonormal(M, p, r)) +end +function get_coordinates_orthonormal!(M::Hyperbolic, c, p, X, r::RealNumbers) + c = get_coordinates!(M, c, p, X, get_basis_orthonormal(M, p, r)) + return c +end +function get_coordinates_diagonalizing!( M::Hyperbolic, c, p, X, - B::DefaultOrthonormalBasis{ℝ,TangentSpaceType}, + B::DiagonalizingOrthonormalBasis, ) - c = get_coordinates!(M, c, p, X, get_basis(M, p, B)) - return c -end -function get_coordinates!(M::Hyperbolic, c, p, X, B::DiagonalizingOrthonormalBasis) - c = get_coordinates!(M, c, p, X, get_basis(M, p, B)) + c = get_coordinates!(M, c, p, X, get_basis_diagonalizing(M, p, B)) return c end @@ -323,8 +325,8 @@ the unit vectors from $ℝ^n$, where $n$ is the manifold dimension of the [`Hype """ get_vector(M::Hyperbolic, p, c, ::DefaultOrthonormalBasis) -function get_vector!(M::Hyperbolic, X, p, c, B::DefaultOrthonormalBasis{ℝ,TangentSpaceType}) - X = get_vector!(M, X, p, c, get_basis(M, p, B)) +function get_vector_orthonormal!(M::Hyperbolic, X, p, c, r::RealNumbers) + X = get_vector!(M, X, p, c, get_basis(M, p, DefaultOrthonormalBasis(r))) return X end function get_vector!(M::Hyperbolic, X, p, c, B::DiagonalizingOrthonormalBasis) @@ -365,14 +367,7 @@ g_p(X,Y) = ⟨X,Y⟩_{\mathrm{M}} = -X_{n}Y_{n} + \displaystyle\sum_{k=1}^{n-1} ```` This employs the metric of the embedding, see [`Lorentz`](@ref) space. """ -function inner( - M::Hyperbolic, - p::HyperboloidPoint, - X::HyperboloidTVector, - Y::HyperboloidTVector, -) - return inner(M, p.value, X.value, Y.value) -end +inner(M::Hyperbolic, p, X, Y) function log!(M::Hyperbolic, X, p, q) scp = minkowski_metric(p, q) @@ -435,7 +430,7 @@ function Random.rand!( return pX end -function vector_transport_to!(M::Hyperbolic, Y, p, X, q, ::ParallelTransport) +function parallel_transport_to!(M::Hyperbolic, Y, p, X, q) w = log(M, p, q) wn = norm(M, p, w) wn < eps(eltype(p + q)) && return copyto!(Y, X) diff --git a/src/manifolds/HyperbolicPoincareBall.jl b/src/manifolds/HyperbolicPoincareBall.jl index 0499b02f2a..a5252c963e 100644 --- a/src/manifolds/HyperbolicPoincareBall.jl +++ b/src/manifolds/HyperbolicPoincareBall.jl @@ -44,8 +44,6 @@ function change_metric!( end function check_point(M::Hyperbolic{N}, p::PoincareBallPoint; kwargs...) where {N} - mpv = check_point(Euclidean(N), p.value; kwargs...) - mpv === nothing || return mpv if !(norm(p.value) < 1) return DomainError( norm(p.value), @@ -54,6 +52,30 @@ function check_point(M::Hyperbolic{N}, p::PoincareBallPoint; kwargs...) where {N end end +function check_size(M::Hyperbolic{N}, p::PoincareBallPoint) where {N} + if size(p.value, 1) != N + !(norm(p.value) < 1) + return DomainError( + size(p.value, 1), + "The point $p does not lie on $M since its length is not $N.", + ) + end +end + +function check_size( + M::Hyperbolic{N}, + p::PoincareBallPoint, + X::PoincareBallTVector; + kwargs..., +) where {N} + if size(X.value, 1) != N + return DomainError( + size(X.value, 1), + "The tangent vector $X can not be a tangent vector for $M since its length is not $N.", + ) + end +end + @doc raw""" convert(::Type{PoincareBallPoint}, p::HyperboloidPoint) convert(::Type{PoincareBallPoint}, p::T) where {T<:AbstractVector} @@ -247,6 +269,12 @@ function distance(::Hyperbolic, p::PoincareBallPoint, q::PoincareBallPoint) ) end +embed(::Hyperbolic, p::PoincareBallPoint) = p.value +embed!(::Hyperbolic, q, p::PoincareBallPoint) = copyto!(q, p.value) +embed(::Hyperbolic, p::PoincareBallPoint, X::PoincareBallTVector) = X.value +embed!(::Hyperbolic, Y, p::PoincareBallPoint, X::PoincareBallTVector) = copyto!(Y, X.value) +get_embedding(::Hyperbolic{n}, p::PoincareBallPoint) where {n} = Euclidean(n) + @doc raw""" inner(::Hyperbolic, p::PoincareBallPoint, X::PoincareBallTVector, Y::PoincareBallTVector) @@ -264,6 +292,10 @@ function inner( return 4 / (1 - norm(p.value)^2)^2 * dot(X.value, Y.value) end +function norm(M::Hyperbolic, p::PoincareBallPoint, X::PoincareBallTVector) + return sqrt(inner(M, p, X, X)) +end + @doc raw""" project(::Hyperbolic, ::PoincareBallPoint, ::PoincareBallTVector) diff --git a/src/manifolds/HyperbolicPoincareHalfspace.jl b/src/manifolds/HyperbolicPoincareHalfspace.jl index 06f3cbca2d..973a39b8e7 100644 --- a/src/manifolds/HyperbolicPoincareHalfspace.jl +++ b/src/manifolds/HyperbolicPoincareHalfspace.jl @@ -1,6 +1,4 @@ function check_point(M::Hyperbolic{N}, p::PoincareHalfSpacePoint; kwargs...) where {N} - mpv = check_point(Euclidean(N), p.value; kwargs...) - mpv === nothing || return mpv if !(last(p.value) > 0) return DomainError( norm(p.value), @@ -9,6 +7,30 @@ function check_point(M::Hyperbolic{N}, p::PoincareHalfSpacePoint; kwargs...) whe end end +function check_size(M::Hyperbolic{N}, p::PoincareHalfSpacePoint) where {N} + if size(p.value, 1) != N + !(norm(p.value) < 1) + return DomainError( + size(p.value, 1), + "The point $p does not lie on $M since its length is not $N.", + ) + end +end + +function check_size( + M::Hyperbolic{N}, + p::PoincareHalfSpacePoint, + X::PoincareHalfSpaceTVector; + kwargs..., +) where {N} + if size(X.value, 1) != N + return DomainError( + size(X.value, 1), + "The tangent vector $X can not be a tangent vector for $M since its length is not $N.", + ) + end +end + @doc raw""" convert(::Type{PoincareHalfSpacePoint}, p::PoincareBallPoint) @@ -182,6 +204,14 @@ function distance(::Hyperbolic, p::PoincareHalfSpacePoint, q::PoincareHalfSpaceP return acosh(1 + norm(p.value .- q.value)^2 / (2 * p.value[end] * q.value[end])) end +embed(::Hyperbolic, p::PoincareHalfSpacePoint) = p.value +embed!(::Hyperbolic, q, p::PoincareHalfSpacePoint) = copyto!(q, p.value) +embed(::Hyperbolic, p::PoincareHalfSpacePoint, X::PoincareHalfSpaceTVector) = X.value +function embed!(::Hyperbolic, Y, p::PoincareHalfSpacePoint, X::PoincareHalfSpaceTVector) + return copyto!(Y, X.value) +end +get_embedding(::Hyperbolic{n}, p::PoincareHalfSpacePoint) where {n} = Euclidean(n) + @doc raw""" inner( ::Hyperbolic{n}, @@ -204,6 +234,10 @@ function inner( return dot(X.value, Y.value) / last(p.value)^2 end +function norm(M::Hyperbolic, p::PoincareHalfSpacePoint, X::PoincareHalfSpaceTVector) + return sqrt(inner(M, p, X, X)) +end + @doc raw""" project(::Hyperbolic, ::PoincareHalfSpacePoint ::PoincareHalfSpaceTVector) diff --git a/src/manifolds/MetricManifold.jl b/src/manifolds/MetricManifold.jl index e7f6db3d2d..d03c9a1724 100644 --- a/src/manifolds/MetricManifold.jl +++ b/src/manifolds/MetricManifold.jl @@ -16,6 +16,26 @@ If `M` is already a metric manifold, the inner manifold with the new `metric` is """ abstract type AbstractMetric end +""" + IsMetricManifold <: AbstractTrait + +Specify that a certain decorated Manifold is a metric manifold in the sence that it provides +explicit metric properties, extending/changing the default metric properties of a manifold. +""" +struct IsMetricManifold <: AbstractTrait end + +""" + IsDefaultMetric{G<:AbstractMetric} + +Specify that a certain [`AbstractMetric`](@ref) is the default metric for a manifold. +This way the corresponding [`MetricManifold`](@ref) falls back to the default methods +of the manifold it decorates. +""" +struct IsDefaultMetric{G<:AbstractMetric} <: AbstractTrait + metric::G +end +parent_trait(::IsDefaultMetric) = IsMetricManifold() + # piping syntax for decoration (metric::AbstractMetric)(M::AbstractManifold) = MetricManifold(M, metric) (::Type{T})(M::AbstractManifold) where {T<:AbstractMetric} = MetricManifold(M, T()) @@ -23,7 +43,7 @@ abstract type AbstractMetric end """ MetricManifold{𝔽,M<:AbstractManifold{𝔽},G<:AbstractMetric} <: AbstractDecoratorManifold{𝔽} -Equip a [`AbstractManifold`](@ref) explicitly with a [`AbstractMetric`](@ref) `G`. +Equip a [`AbstractManifold`](https://juliamanifolds.github.io/Manifolds.jl/latest/interface.html#ManifoldsBase.AbstractManifold) explicitly with a [`AbstractMetric`](@ref) `G`. For a Metric AbstractManifold, by default, assumes, that you implement the linear form from [`local_metric`](@ref) in order to evaluate the exponential map. @@ -36,13 +56,22 @@ you can of course still implement that directly. MetricManifold(M, G) -Generate the [`AbstractManifold`](@ref) `M` as a manifold with the [`AbstractMetric`](@ref) `G`. +Generate the [`AbstractManifold`](https://juliamanifolds.github.io/Manifolds.jl/latest/interface.html#ManifoldsBase.AbstractManifold) `M` as a manifold with the [`AbstractMetric`](@ref) `G`. """ struct MetricManifold{𝔽,M<:AbstractManifold{𝔽},G<:AbstractMetric} <: - AbstractConnectionManifold{𝔽} + AbstractDecoratorManifold{𝔽} manifold::M metric::G end + +function active_traits(f, M::MetricManifold, args...) + return merge_traits( + is_default_metric(M.manifold, M.metric) ? IsDefaultMetric(M.metric) : EmptyTrait(), + IsMetricManifold(), + active_traits(f, M.manifold, args...), + is_metric_function(f) ? EmptyTrait() : IsExplicitDecorator(), + ) +end # remetricise instead of double-decorating (metric::AbstractMetric)(M::MetricManifold) = MetricManifold(M.manifold, metric) (::Type{T})(M::MetricManifold) where {T<:AbstractMetric} = MetricManifold(M.manifold, T()) @@ -56,10 +85,14 @@ inner product ``g(X, X) > 0`` whenever ``X`` is not the zero vector. """ abstract type RiemannianMetric <: AbstractMetric end +decorated_manifold(M::MetricManifold) = M.manifold + +get_embedding(M::MetricManifold) = get_embedding(M.manifold) + @doc raw""" change_metric(M::AbstractcManifold, G2::AbstractMetric, p, X) -On the [`AbstractManifold`](@ref) `M` with implicitly given metric ``g_1`` +On the [`AbstractManifold`](https://juliamanifolds.github.io/Manifolds.jl/latest/interface.html#ManifoldsBase.AbstractManifold) `M` with implicitly given metric ``g_1`` and a second [`AbstractMetric`](@ref) ``g_2`` this function performs a change of metric in the sense that it returns the tangent vector ``Z=BX`` such that the linear map ``B`` fulfills @@ -90,8 +123,17 @@ function change_metric(M::AbstractManifold, G::AbstractMetric, p, X) Y = allocate_result(M, change_metric, X, p) # this way we allocate a tangent return change_metric!(M, Y, G, p, X) end -function change_metric!(M::AbstractManifold, Y, G::AbstractMetric, p, X) - is_default_metric(M, G) && return copyto!(M, Y, p, X) +function change_metric!( + ::T, + M::AbstractDecoratorManifold, + Y, + ::G, + p, + X, +) where {G<:AbstractMetric,T<:TraitList{<:IsDefaultMetric{<:G}}} + return copyto!(M, Y, p, X) +end +function change_metric!(M::MetricManifold, Y, G::AbstractMetric, p, X) M.metric === G && return copyto!(M, Y, p, X) # no metric change # TODO: For local metric, inverse_local metric, det_local_metric: Introduce a default basis? B = DefaultOrthogonalBasis() @@ -104,25 +146,14 @@ function change_metric!(M::AbstractManifold, Y, G::AbstractMetric, p, X) return get_vector!(M, Y, p, z, B) end -@decorator_transparent_signature change_metric( - M::AbstractDecoratorManifold, - G::AbstractMetric, - X, - p, -) -@decorator_transparent_signature change_metric!( - M::AbstractDecoratorManifold, - Y, - G::AbstractMetric, - X, - p, -) +@trait_function change_metric(M::AbstractDecoratorManifold, G::AbstractMetric, X, p) +@trait_function change_metric!(M::AbstractDecoratorManifold, Y, G::AbstractMetric, X, p) @doc raw""" change_representer(M::AbstractManifold, G2::AbstractMetric, p, X) Convert the representer `X` of a linear function (in other words a cotangent vector at `p`) -in the tangent space at `p` on the [`AbstractManifold`](@ref) `M` given with respect to the +in the tangent space at `p` on the [`AbstractManifold`](https://juliamanifolds.github.io/Manifolds.jl/latest/interface.html#ManifoldsBase.AbstractManifold) `M` given with respect to the [`AbstractMetric`](@ref) `G2` into the representer with respect to the (implicit) metric of `M`. In order to convert `X` into the representer with respect to the (implicitly given) metric ``g_1`` of `M`, @@ -168,23 +199,28 @@ function change_representer(M::AbstractManifold, G::AbstractMetric, p, X) return change_representer!(M, Y, G, p, X) end -@decorator_transparent_signature change_representer( +@trait_function change_representer(M::AbstractDecoratorManifold, G::AbstractMetric, X, p) +@trait_function change_representer!( M::AbstractDecoratorManifold, + Y, G::AbstractMetric, X, p, ) -@decorator_transparent_signature change_representer!( + +# Default fallback II: Default metric (not yet hit, check subtyping?) +function change_representer!( + ::T, M::AbstractDecoratorManifold, Y, - G::AbstractMetric, - X, + ::G, p, -) - -# Default fallback I: compute in local metric representations + X, +) where {G<:AbstractMetric,T<:TraitList{<:IsDefaultMetric{<:G}}} + return copyto!(M, Y, p, X) +end +# Default fallback II: compute in local metric representations function change_representer!(M::AbstractManifold, Y, G::AbstractMetric, p, X) - is_default_metric(M, G) && return copyto!(M, Y, p, X) M.metric === G && return copyto!(M, Y, p, X) # no metric change # TODO: For local metric, inverse_local metric, det_local_metric: Introduce a default basis? B = DefaultOrthogonalBasis() @@ -195,58 +231,6 @@ function change_representer!(M::AbstractManifold, Y, G::AbstractMetric, p, X) return get_vector!(M, Y, p, z, B) end -@doc raw""" - christoffel_symbols_first( - M::MetricManifold, - p, - B::AbstractBasis; - backend::AbstractDiffBackend = default_differential_backend(), - ) - -Compute the Christoffel symbols of the first kind in local coordinates of basis `B`. -The Christoffel symbols are (in Einstein summation convention) - -````math -Ξ“_{ijk} = \frac{1}{2} \Bigl[g_{kj,i} + g_{ik,j} - g_{ij,k}\Bigr], -```` - -where ``g_{ij,k}=\frac{βˆ‚}{βˆ‚ p^k} g_{ij}`` is the coordinate -derivative of the local representation of the metric tensor. The dimensions of -the resulting multi-dimensional array are ordered ``(i,j,k)``. -""" -christoffel_symbols_first(::AbstractManifold, ::Any, B::AbstractBasis) -function christoffel_symbols_first( - M::AbstractManifold, - p, - B::AbstractBasis; - backend::AbstractDiffBackend=default_differential_backend(), -) - βˆ‚g = local_metric_jacobian(M, p, B; backend=backend) - n = size(βˆ‚g, 1) - Ξ“ = allocate(βˆ‚g, Size(n, n, n)) - @einsum Ξ“[i, j, k] = 1 / 2 * (βˆ‚g[k, j, i] + βˆ‚g[i, k, j] - βˆ‚g[i, j, k]) - return Ξ“ -end -@decorator_transparent_signature christoffel_symbols_first( - M::AbstractDecoratorManifold, - p, - B::AbstractBasis; - kwargs..., -) - -function christoffel_symbols_second( - M::AbstractManifold, - p, - B::AbstractBasis; - backend::AbstractDiffBackend=default_differential_backend(), -) - Ginv = inverse_local_metric(M, p, B) - Γ₁ = christoffel_symbols_first(M, p, B; backend=backend) - Ξ“β‚‚ = allocate(Γ₁) - @einsum Ξ“β‚‚[l, i, j] = Ginv[k, l] * Γ₁[i, j, k] - return Ξ“β‚‚ -end - """ connection(::MetricManifold) @@ -254,6 +238,8 @@ Return the [`LeviCivitaConnection`](@ref) for a metric manifold. """ connection(::MetricManifold) = LeviCivitaConnection() +default_retraction_method(M::MetricManifold) = default_retraction_method(M.manifold) + @doc raw""" det_local_metric(M::AbstractManifold, p, B::AbstractBasis) @@ -262,21 +248,26 @@ matrix ``G(p)`` representing the metric in the tangent space at ``p`` with as a See also [`local_metric`](@ref) """ -det_local_metric(::AbstractManifold, p, ::AbstractBasis) function det_local_metric(M::AbstractManifold, p, B::AbstractBasis) return det(local_metric(M, p, B)) end -@decorator_transparent_signature det_local_metric( - M::AbstractDecoratorManifold, - p, - B::AbstractBasis, -) +@trait_function det_local_metric(M::AbstractDecoratorManifold, p, B::AbstractBasis) + +function exp!(::TraitList{IsMetricManifold}, M::AbstractDecoratorManifold, q, p, X) + return retract!( + M, + q, + p, + X, + ODEExponentialRetraction(ManifoldsBase.default_retraction_method(M)), + ) +end + """ einstein_tensor(M::AbstractManifold, p, B::AbstractBasis; backend::AbstractDiffBackend = diff_badefault_differential_backendckend()) Compute the Einstein tensor of the manifold `M` at the point `p`, see [https://en.wikipedia.org/wiki/Einstein_tensor](https://en.wikipedia.org/wiki/Einstein_tensor) """ -einstein_tensor(::AbstractManifold, ::Any, ::AbstractBasis) function einstein_tensor( M::AbstractManifold, p, @@ -290,7 +281,7 @@ function einstein_tensor( G = Ric - g .* S / 2 return G end -@decorator_transparent_signature einstein_tensor( +@trait_function einstein_tensor( M::AbstractDecoratorManifold, p, B::AbstractBasis; @@ -298,10 +289,10 @@ end ) @doc raw""" - flat(N::MetricManifold{M,G}, p, X::FVector{TangentSpaceType}) + flat(N::MetricManifold{M,G}, p, X::TFVector) Compute the musical isomorphism to transform the tangent vector `X` from the -[`AbstractManifold`](@ref) `M` equipped with [`AbstractMetric`](@ref) `G` to a cotangent by +[`AbstractManifold`](https://juliamanifolds.github.io/Manifolds.jl/latest/interface.html#ManifoldsBase.AbstractManifold) `M` equipped with [`AbstractMetric`](@ref) `G` to a cotangent by computing ````math @@ -309,10 +300,11 @@ X^β™­= G_p X, ```` where ``G_p`` is the local matrix representation of `G`, see [`local_metric`](@ref) """ -flat(::MetricManifold, ::Any...) +flat(::MetricManifold, ::Any, ::TFVector) -@decorator_transparent_fallback function flat!( - M::MetricManifold, +function flat!( + ::TraitList{IsMetricManifold}, + M::AbstractDecoratorManifold, ΞΎ::CoTFVector, p, X::TFVector, @@ -321,13 +313,72 @@ flat(::MetricManifold, ::Any...) copyto!(ΞΎ.data, g * X.data) return ΞΎ end +function flat!( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + ΞΎ::CoTFVector, + p, + X::TFVector, +) where {𝔽,TM<:AbstractManifold,G<:AbstractMetric} + flat!(M.manifold, ΞΎ, p, X) + return ΞΎ +end + +function get_basis( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + p, + B::AbstractBasis, +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return get_basis(M.manifold, p, B) +end + +function get_coordinates( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + p, + X, + B::AbstractBasis, +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return get_coordinates(M.manifold, p, X, B) +end +function get_coordinates!( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + Y, + p, + X, + B::AbstractBasis, +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return get_coordinates!(M.manifold, Y, p, X, B) +end + +function get_vector( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + p, + c, + B::AbstractBasis, +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return get_vector(M.manifold, p, c, B) +end +function get_vector!( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + Y, + p, + c, + B::AbstractBasis, +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return get_vector!(M.manifold, Y, p, c, B) +end @doc raw""" inverse_local_metric(M::AbstractcManifold{𝔽}, p, B::AbstractBasis) Return the local matrix representation of the inverse metric (cometric) tensor -of the tangent space at `p` on the [`AbstractManifold`](@ref) `M` with respect -to the [`AbstractBasis`](@ref) basis `B`. +of the tangent space at `p` on the [`AbstractManifold`](https://juliamanifolds.github.io/Manifolds.jl/latest/interface.html#ManifoldsBase.AbstractManifold) `M` with respect +to the [`AbstractBasis`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/bases.html#ManifoldsBase.AbstractBasis) basis `B`. The metric tensor (see [`local_metric`](@ref)) is usually denoted by ``G = (g_{ij}) ∈ 𝔽^{dΓ—d}``, where ``d`` is the dimension of the manifold. @@ -338,50 +389,10 @@ inverse_local_metric(::AbstractManifold, ::Any, ::AbstractBasis) function inverse_local_metric(M::AbstractManifold, p, B::AbstractBasis) return inv(local_metric(M, p, B)) end -@decorator_transparent_signature inverse_local_metric( - M::AbstractDecoratorManifold, - p, - B::AbstractBasis, -) - -default_decorator_dispatch(M::MetricManifold) = default_metric_dispatch(M) - -""" - is_default_metric(M, G) - -Indicate whether the [`AbstractMetric`](@ref) `G` is the default metric for -the [`AbstractManifold`](@ref) `M`. This means that any occurence of -[`MetricManifold`](@ref)(M,G) where `typeof(is_default_metric(M,G)) = true` -falls back to just be called with `M` such that the [`AbstractManifold`](@ref) `M` -implicitly has this metric, for example if this was the first one implemented -or is the one most commonly assumed to be used. -""" -function is_default_metric(M::AbstractManifold, G::AbstractMetric) - return _extract_val(default_metric_dispatch(M, G)) -end - -default_metric_dispatch(::AbstractManifold, ::AbstractMetric) = Val(false) -function default_metric_dispatch(M::MetricManifold) - return default_metric_dispatch(base_manifold(M), metric(M)) -end - -""" - is_default_metric(MM::MetricManifold) - -Indicate whether the [`AbstractMetric`](@ref) `MM.G` is the default metric for -the [`AbstractManifold`](@ref) `MM.manifold,` within the [`MetricManifold`](@ref) `MM`. -This means that any occurence of -[`MetricManifold`](@ref)`(MM.manifold, MM.G)` where `is_default_metric(MM.manifold, MM.G)) = true` -falls back to just be called with `MM.manifold,` such that the [`AbstractManifold`](@ref) `MM.manifold` -implicitly has the metric `MM.G`, for example if this was the first one -implemented or is the one most commonly assumed to be used. -""" -function is_default_metric(M::MetricManifold) - return _extract_val(default_metric_dispatch(M)) -end +@trait_function inverse_local_metric(M::AbstractDecoratorManifold, p, B::AbstractBasis) function Base.convert(::Type{MetricManifold{𝔽,MT,GT}}, M::MT) where {𝔽,MT,GT} - return _convert_with_default(M, GT, default_metric_dispatch(M, GT())) + return _convert_with_default(M, GT, Val(is_default_metric(M, GT()))) end function _convert_with_default( @@ -401,13 +412,36 @@ function _convert_with_default( ) end +function exp( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + p, + X, +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return exp(M.manifold, p, X) +end +function exp!( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + q, + p, + X, +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return exp!(M.manifold, q, p, X) +end + +injectivity_radius(M::MetricManifold) = injectivity_radius(M.manifold) +function injectivity_radius(M::MetricManifold, m::AbstractRetractionMethod) + return injectivity_radius(M.manifold, m) +end + @doc raw""" inner(N::MetricManifold{M,G}, p, X, Y) Compute the inner product of `X` and `Y` from the tangent space at `p` on the -[`AbstractManifold`](@ref) `M` using the [`AbstractMetric`](@ref) `G`. If `G` is the default -metric (see [`is_default_metric`](@ref)) this is done using `inner(M, p, X, Y)`, -otherwise the [`local_metric`](@ref)`(M, p)` is employed as +[`AbstractManifold`](https://juliamanifolds.github.io/Manifolds.jl/latest/interface.html#ManifoldsBase.AbstractManifold) `M` using the [`AbstractMetric`](@ref) `G`. +If `M` has `G` as its [`IsDefaultMetric`](@ref) trait, +this is done using `inner(M, p, X, Y)`, otherwise the [`local_metric`](@ref)`(M, p)` is employed as ````math g_p(X, Y) = ⟨X, G_p Y⟩, @@ -416,8 +450,9 @@ where ``G_p`` is the loal matrix representation of the [`AbstractMetric`](@ref) """ inner(::MetricManifold, ::Any, ::Any, ::Any) -@decorator_transparent_fallback :intransparent function inner( - M::MetricManifold, +function inner( + ::TraitList{IsMetricManifold}, + M::AbstractDecoratorManifold, p, X::TFVector, Y::TFVector, @@ -426,12 +461,63 @@ inner(::MetricManifold, ::Any, ::Any, ::Any) error("calculating inner product of vectors from different bases is not supported") return dot(X.data, local_metric(M, p, X.basis) * Y.data) end +function inner( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + p, + X, + Y, +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return inner(M.manifold, p, X, Y) +end + +""" + is_default_metric(M::AbstractManifold, G::AbstractMetric) + +returns whether an [`AbstractMetric`](@ref) is the default metric on the manifold `M` or not. +This can be set by defining this function, or setting the [`IsDefaultMetric`](@ref) trait for an +[`AbstractDecoratorManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/decorator.html#ManifoldsBase.AbstractDecoratorManifold). +""" +is_default_metric(M::AbstractManifold, G::AbstractMetric) + +@trait_function is_default_metric(M::AbstractDecoratorManifold, G::AbstractMetric) +function is_default_metric( + ::TraitList{IsDefaultMetric{G}}, + ::AbstractDecoratorManifold, + ::G, +) where {G<:AbstractMetric} + return true +end +is_default_metric(M::MetricManifold) = is_default_metric(M.manifold, M.metric) +is_default_metric(::AbstractManifold, ::AbstractMetric) = false + +function is_point( + ::TraitList{IsMetricManifold}, + M::MetricManifold{𝔽,TM,G}, + p, + te=false; + kwargs..., +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return is_point(M.manifold, p, te; kwargs...) +end + +function is_vector( + ::TraitList{IsMetricManifold}, + M::MetricManifold{𝔽,TM,G}, + p, + X, + te=false, + cbp=true; + kwargs..., +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return is_vector(M.manifold, p, X, te, cbp; kwargs...) +end @doc raw""" local_metric(M::AbstractManifold{𝔽}, p, B::AbstractBasis) Return the local matrix representation at the point `p` of the metric tensor ``g`` with -respect to the [`AbstractBasis`](@ref) `B` on the [`AbstractManifold`](@ref) `M`. +respect to the [`AbstractBasis`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/bases.html#ManifoldsBase.AbstractBasis) `B` on the [`AbstractManifold`](https://juliamanifolds.github.io/Manifolds.jl/latest/interface.html#ManifoldsBase.AbstractManifold) `M`. Let ``d``denote the dimension of the manifold and $b_1,\ldots,b_d$ the basis vectors. Then the local matrix representation is a matrix ``G\in 𝔽^{n\times n}`` whose entries are given by ``g_{ij} = g_p(b_i,b_j), i,j\in\{1,…,d\}``. @@ -440,12 +526,16 @@ This yields the property for two tangent vectors (using Einstein summation conve ``X = X^ib_i, Y=Y^ib_i \in T_p\mathcal M`` we get ``g_p(X, Y) = g_{ij} X^i Y^j``. """ local_metric(::AbstractManifold, ::Any, ::AbstractBasis) -@decorator_transparent_signature local_metric( - M::AbstractDecoratorManifold, +@trait_function local_metric(M::AbstractDecoratorManifold, p, B::AbstractBasis; kwargs...) + +function local_metric( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, p, - B::AbstractBasis; - kwargs..., -) + B::AbstractBasis, +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return local_metric(M.manifold, p, B) +end @doc raw""" local_metric_jacobian( @@ -470,7 +560,7 @@ function local_metric_jacobian( βˆ‚g = reshape(_jacobian(q -> local_metric(M, q, B), p, backend), n, n, n) return βˆ‚g end -@decorator_transparent_signature local_metric_jacobian( +@trait_function local_metric_jacobian( M::AbstractDecoratorManifold, p, B::AbstractBasis; @@ -480,14 +570,32 @@ end @doc raw""" log(N::MetricManifold{M,G}, p, q) -Copute the logarithmic map on the [`AbstractManifold`](@ref) `M` equipped with the [`AbstractMetric`](@ref) `G`. +Copute the logarithmic map on the [`AbstractManifold`](https://juliamanifolds.github.io/Manifolds.jl/latest/interface.html#ManifoldsBase.AbstractManifold) `M` equipped with the [`AbstractMetric`](@ref) `G`. -If the metric was declared the default metric using [`is_default_metric`](@ref), this method +If the metric was declared the default metric using the [`IsDefaultMetric`](@ref) trait or [`is_default_metric`](@ref), this method falls back to `log(M,p,q)`. Otherwise, you have to provide an implementation for the non-default [`AbstractMetric`](@ref) `G` metric within its [`MetricManifold`](@ref)`{M,G}`. """ log(::MetricManifold, ::Any...) +function log( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + p, + q, +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return log(M.manifold, p, q) +end +function log!( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + X, + p, + q, +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return log!(M.manifold, X, p, q) +end + @doc raw""" log_local_metric_density(M::AbstractManifold, p, B::AbstractBasis) @@ -498,11 +606,9 @@ log_local_metric_density(::AbstractManifold, ::Any, ::AbstractBasis) function log_local_metric_density(M::AbstractManifold, p, B::AbstractBasis) return log(abs(det_local_metric(M, p, B))) / 2 end -@decorator_transparent_signature log_local_metric_density( - M::AbstractDecoratorManifold, - p, - B::AbstractBasis, -) +@trait_function log_local_metric_density(M::AbstractDecoratorManifold, p, B::AbstractBasis) + +manifold_dimension(M::MetricManifold) = manifold_dimension(M.manifold) @doc raw""" metric(M::MetricManifold) @@ -515,6 +621,65 @@ function metric(M::MetricManifold) return M.metric end +function norm(::TraitList{IsMetricManifold}, M::AbstractDecoratorManifold, p, X::TFVector) + return sqrt(dot(X.data, local_metric(M, p, X.basis) * X.data)) +end + +function parallel_transport_to( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + p, + X, + q, +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return parallel_transport_to(M.manifold, p, X, q) +end +function parallel_transport_to!( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + Y, + p, + X, + q, +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return parallel_transport_to!(M.manifold, Y, p, X, q) +end + +function project( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + p, +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return project(M.manifold, p) +end +function project!( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + q, + p, +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return project!(M.manifold, q, p) +end +function project( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + p, + X, +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return project(M.manifold, p, X) +end +function project!( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + Y, + p, + X, +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return project!(M.manifold, Y, p, X) +end + +representation_size(M::MetricManifold) = representation_size(M.manifold) + @doc raw""" ricci_curvature(M::AbstractManifold, p, B::AbstractBasis; backend::AbstractDiffBackend = default_differential_backend()) @@ -538,7 +703,7 @@ function ricci_curvature( S = sum(Ginv .* Ric) return S end -@decorator_transparent_signature ricci_curvature( +@trait_function ricci_curvature( M::AbstractDecoratorManifold, p, B::AbstractBasis; @@ -546,10 +711,10 @@ end ) @doc raw""" - sharp(N::MetricManifold{M,G}, p, ΞΎ::FVector{CotangentSpaceType}) + sharp(N::MetricManifold{M,G}, p, ΞΎ::CoTFVector) Compute the musical isomorphism to transform the cotangent vector `ΞΎ` from the -[`AbstractManifold`](@ref) `M` equipped with [`AbstractMetric`](@ref) `G` to a tangent by +[`AbstractManifold`](https://juliamanifolds.github.io/Manifolds.jl/latest/interface.html#ManifoldsBase.AbstractManifold) `M` equipped with [`AbstractMetric`](@ref) `G` to a tangent by computing ````math @@ -560,59 +725,153 @@ where ``G_p`` is the local matrix representation of `G`, i.e. one employs """ sharp(::MetricManifold, ::Any, ::CoTFVector) -function sharp!(M::N, X::TFVector, p, ΞΎ::CoTFVector) where {N<:MetricManifold} +function sharp!( + ::TraitList{IsMetricManifold}, + M::AbstractDecoratorManifold, + X::TFVector, + p, + ΞΎ::CoTFVector, +) Ginv = inverse_local_metric(M, p, X.basis) copyto!(X.data, Ginv * ΞΎ.data) return X end +function sharp!( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + X::TFVector, + p, + ΞΎ::CoTFVector, +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + sharp!(M.manifold, X, p, ΞΎ) + return X +end function Base.show(io::IO, M::MetricManifold) return print(io, "MetricManifold($(M.manifold), $(M.metric))") end +function Base.show(io::IO, i::IsDefaultMetric) + return print(io, "IsDefaultMetric($(i.metric))") +end + +function vector_transport_along( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + p, + X, + c::AbstractVector, + m::AbstractVectorTransportMethod=default_vector_transport_method(M), +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return vector_transport_along(M.manifold, p, X, c, m) +end +function vector_transport_along!( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + Y, + p, + X, + c::AbstractVector, + m::AbstractVectorTransportMethod=default_vector_transport_method(M), +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return vector_transport_along!(M.manifold, Y, p, X, c, m) +end -# -# Introduce transparency -# (a) new functions & other parents -for f in [ +function vector_transport_direction( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + p, + X, + d, + m::AbstractVectorTransportMethod=default_vector_transport_method(M), +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return vector_transport_direction(M.manifold, p, X, d, m) +end +function vector_transport_direction!( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + Y, + p, + X, + d, + m::AbstractVectorTransportMethod=default_vector_transport_method(M), +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return vector_transport_direction!(M.manifold, Y, p, X, d, m) +end + +function vector_transport_to( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + p, + X, + q, + m::AbstractVectorTransportMethod=default_vector_transport_method(M), +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return vector_transport_to(M.manifold, p, X, q, m) +end +function vector_transport_to!( + ::TraitList{IsDefaultMetric{G}}, + M::MetricManifold{𝔽,TM,G}, + Y, + p, + X, + q, + m::AbstractVectorTransportMethod=default_vector_transport_method(M), +) where {𝔽,G<:AbstractMetric,TM<:AbstractManifold} + return vector_transport_to!(M.manifold, Y, p, X, q, m) +end + +zero_vector(M::MetricManifold, p) = zero_vector(M.manifold, p) +zero_vector!(M::MetricManifold, X, p) = zero_vector!(M.manifold, X, p) + +is_metric_function(::Any) = false +for mf in [ + change_metric, + change_metric!, + change_representer, + change_representer!, christoffel_symbols_first, + christoffel_symbols_second, + christoffel_symbols_second_jacobian, det_local_metric, einstein_tensor, + exp, + exp!, + flat!, + gaussian_curvature, + get_basis, + get_coordinates, + get_coordinates!, + get_vector, + get_vector!, + get_vectors, + inner, inverse_local_metric, + inverse_retract, + inverse_retract!, local_metric, local_metric_jacobian, + log, + log!, log_local_metric_density, + norm, + parallel_transport_along, + parallel_transport_along!, + parallel_transport_direction, + parallel_transport_direction!, + parallel_transport_to, + parallel_transport_to!, + retract, + retract!, ricci_curvature, + ricci_tensor, + riemann_tensor, + sharp!, + vector_transport_along, + vector_transport_along!, + vector_transport_direction, + vector_transport_direction!, + vector_transport_to, + vector_transport_to!, ] - eval( - quote - function decorator_transparent_dispatch( - ::typeof($f), - M::AbstractConnectionManifold, - args..., - ) - return Val(:parent) - end - end, - ) -end - -for f in [change_metric, change_representer, change_metric!, change_representer!] - eval( - quote - function decorator_transparent_dispatch( - ::typeof($f), - ::AbstractManifold, - args..., - ) - return Val(:parent) - end - end, - ) -end -function decorator_transparent_dispatch( - ::typeof(christoffel_symbols_second), - ::MetricManifold, - args..., -) - return Val(:parent) + @eval is_metric_function(::typeof($mf)) = true end diff --git a/src/manifolds/Multinomial.jl b/src/manifolds/Multinomial.jl index 8db4e6e63d..53e5aad0da 100644 --- a/src/manifolds/Multinomial.jl +++ b/src/manifolds/Multinomial.jl @@ -10,11 +10,11 @@ The multinomial manifold consists of `m` column vectors, where each column is of where $\mathbb{1}_k$ is the vector of length $k$ containing ones. This yields exactly the same metric as -considering the product metric of the probablity vectors, i.e. [`PowerManifold`](@ref) of the +considering the product metric of the probablity vectors, i.e. [`PowerManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.PowerManifold) of the $(n-1)$-dimensional [`ProbabilitySimplex`](@ref). The [`ProbabilitySimplex`](@ref) is stored internally within `M.manifold`, such that all functions of -[`AbstractPowerManifold`](@ref) can be used directly. +[`AbstractPowerManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.AbstractPowerManifold) can be used directly. # Constructor @@ -45,12 +45,6 @@ of `m` discrete probability distributions as columns from $\mathbb R^{n}$, i.e. """ check_point(::MultinomialMatrices, ::Any) function check_point(M::MultinomialMatrices{n,m}, p; kwargs...) where {n,m} - if size(p) != (n, m) - return DomainError( - length(p), - "The matrix in `p` ($(size(p))) does not match the dimensions of $(M).", - ) - end return check_point(PowerManifold(M.manifold, m), p; kwargs...) end @@ -62,18 +56,13 @@ This means, that `p` is valid, that `X` is of correct dimension and columnswise a tangent vector to the columns of `p` on the [`ProbabilitySimplex`](@ref). """ function check_vector(M::MultinomialMatrices{n,m}, p, X; kwargs...) where {n,m} - if size(X) != (n, m) - return DomainError( - length(X), - "The matrix `X` ($(size(X))) does not match the required dimension ($(representation_size(M))) for $(M).", - ) - end return check_vector(PowerManifold(M.manifold, m), p, X; kwargs...) end get_iterator(::MultinomialMatrices{n,m}) where {n,m} = Base.OneTo(m) @generated manifold_dimension(::MultinomialMatrices{n,m}) where {n,m} = (n - 1) * m +@generated power_dimensions(::MultinomialMatrices{n,m}) where {n,m} = (m,) @generated representation_size(::MultinomialMatrices{n,m}) where {n,m} = (n, m) diff --git a/src/manifolds/MultinomialDoublyStochastic.jl b/src/manifolds/MultinomialDoublyStochastic.jl index 8f549b2425..97444c9675 100644 --- a/src/manifolds/MultinomialDoublyStochastic.jl +++ b/src/manifolds/MultinomialDoublyStochastic.jl @@ -1,13 +1,15 @@ @doc raw""" - AbstractMultinomialDoublyStochastic{N} <: AbstractEmbeddedManifold{ℝ, DefaultIsometricEmbeddingType} + AbstractMultinomialDoublyStochastic{N} <: AbstractDecoratorManifold{ℝ} A common type for manifolds that are doubly stochastic, for example by direct constraint [`MultinomialDoubleStochastic`](@ref) or by symmetry [`MultinomialSymmetric`](@ref), -as long as they are also modeled as [`DefaultIsometricEmbeddingType`](@ref) -[`AbstractEmbeddedManifold`](@ref)s. +as long as they are also modeled as [`IsIsometricEmbeddedManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/decorator.html#ManifoldsBase.IsIsometricEmbeddedManifold). """ -abstract type AbstractMultinomialDoublyStochastic{N} <: - AbstractEmbeddedManifold{ℝ,DefaultIsometricEmbeddingType} end +abstract type AbstractMultinomialDoublyStochastic{N} <: AbstractDecoratorManifold{ℝ} end + +function active_traits(f, ::AbstractMultinomialDoublyStochastic, args...) + return merge_traits(IsIsometricEmbeddedManifold()) +end @doc raw""" MultinomialDoublyStochastic{n} <: AbstractMultinomialDoublyStochastic{N} @@ -63,9 +65,6 @@ Checks whether `p` is a valid point on the [`MultinomialDoubleStochastic`](@ref) i.e. is a matrix with positive entries whose rows and columns sum to one. """ function check_point(M::MultinomialDoubleStochastic{n}, p; kwargs...) where {n} - mpv = invoke(check_point, Tuple{supertype(typeof(M)),typeof(p)}, M, p; kwargs...) - mpv === nothing || return mpv - # positivity and columns are checked in the embedding, we further check r = sum(p, dims=2) if !isapprox(norm(r - ones(n, 1)), 0.0; kwargs...) return DomainError( @@ -83,16 +82,6 @@ This means, that `p` is valid, that `X` is of correct dimension and sums to zero column or row. """ function check_vector(M::MultinomialDoubleStochastic{n}, p, X; kwargs...) where {n} - mpv = invoke( - check_vector, - Tuple{supertype(typeof(M)),typeof(p),typeof(X)}, - M, - p, - X; - kwargs..., - ) - mpv === nothing || return mpv - # columns are checked in the embedding, we further check r = sum(X, dims=2) # check for stochastic rows if !isapprox(norm(r), 0.0; kwargs...) return DomainError( @@ -103,7 +92,7 @@ function check_vector(M::MultinomialDoubleStochastic{n}, p, X; kwargs...) where return nothing end -function decorated_manifold(::MultinomialDoubleStochastic{N}) where {N} +function get_embedding(::MultinomialDoubleStochastic{N}) where {N} return MultinomialMatrices(N, N) end @@ -204,36 +193,10 @@ refers to the elementwise exponentiation. """ retract(::MultinomialDoubleStochastic, ::Any, ::Any, ::ProjectionRetraction) -function retract!(M::MultinomialDoubleStochastic, q, p, X, ::ProjectionRetraction) +function retract_project!(M::MultinomialDoubleStochastic, q, p, X) return project!(M, q, p .* exp.(X ./ p)) end -""" - vector_transport_to(M::MultinomialDoubleStochastic, p, X, q) - -transport the tangent vector `X` at `p` to `q` by projecting it onto the tangent space -at `q`. -""" -vector_transport_to( - ::MultinomialDoubleStochastic, - ::Any, - ::Any, - ::Any, - ::ProjectionTransport, -) - -function vector_transport_to!( - M::MultinomialDoubleStochastic, - Y, - p, - X, - q, - ::ProjectionTransport, -) - project!(M, Y, q, X) - return Y -end - function Base.show(io::IO, ::MultinomialDoubleStochastic{n}) where {n} return print(io, "MultinomialDoubleStochastic($(n))") end diff --git a/src/manifolds/MultinomialSymmetric.jl b/src/manifolds/MultinomialSymmetric.jl index d6fd9844be..4e17609539 100644 --- a/src/manifolds/MultinomialSymmetric.jl +++ b/src/manifolds/MultinomialSymmetric.jl @@ -15,7 +15,7 @@ positive entries such that each column sums to one, i.e. where $\mathbf{1}_n$ is the vector of length $n$ containing ones. -It is modeled as an [`DefaultIsometricEmbeddingType`](@ref), [`AbstractEmbeddedManifold`](@ref) +It is modeled as [`IsIsometricEmbeddedManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/decorator.html#ManifoldsBase.IsIsometricEmbeddedManifold). via the [`AbstractMultinomialDoublyStochastic`](@ref) type, since it shares a few functions also with [`AbstractMultinomialDoublyStochastic`](@ref), most and foremost projection of a point from the embedding onto the manifold. @@ -59,11 +59,6 @@ Checks whether `p` is a valid point on the [`MultinomialSymmetric`](@ref)`(m,n)` i.e. is a symmetric matrix with positive entries whose rows sum to one. """ function check_point(M::MultinomialSymmetric{n}, p; kwargs...) where {n} - mpv = invoke(check_point, Tuple{supertype(typeof(M)),typeof(p)}, M, p; kwargs...) - mpv === nothing || return mpv - # the embedding checks for positivity and unit sum columns, by symmetry we would get - # the same for the rows, so checking symmetry is the only thing left, we can just use - # the corresponding manifold for that return check_point(SymmetricMatrices(n, ℝ), p) end @doc raw""" @@ -74,20 +69,10 @@ This means, that `p` is valid, that `X` is of correct dimension, symmetric, and along any row. """ function check_vector(M::MultinomialSymmetric{n}, p, X; kwargs...) where {n} - mpv = invoke( - check_vector, - Tuple{supertype(typeof(M)),typeof(p),typeof(X)}, - M, - p, - X; - kwargs..., - ) - mpv === nothing || return mpv - # from the embedding we know that columns sum to zero, only symmety is left, i.e. return check_vector(SymmetricMatrices(n, ℝ), p, X; kwargs...) end -function decorated_manifold(::MultinomialSymmetric{N}) where {N} +function get_embedding(::MultinomialSymmetric{N}) where {N} return MultinomialMatrices(N, N) end @@ -143,23 +128,10 @@ refers to the elementwise exponentiation. """ retract(::MultinomialSymmetric, ::Any, ::Any, ::ProjectionRetraction) -function retract!(M::MultinomialSymmetric, q, p, X, ::ProjectionRetraction) +function retract_project!(M::MultinomialSymmetric, q, p, X) return project!(M, q, p .* exp.(X ./ p)) end -""" - vector_transport_to(M::MultinomialSymmetric, p, X, q) - -transport the tangent vector `X` at `p` to `q` by projecting it onto the tangent space -at `q`. -""" -vector_transport_to(::MultinomialSymmetric, ::Any, ::Any, ::Any, ::ProjectionTransport) - -function vector_transport_to!(M::MultinomialSymmetric, Y, p, X, q, ::ProjectionTransport) - project!(M, Y, q, X) - return Y -end - function Base.show(io::IO, ::MultinomialSymmetric{n}) where {n} return print(io, "MultinomialSymmetric($(n))") end diff --git a/src/manifolds/Oblique.jl b/src/manifolds/Oblique.jl index bb7710067b..f03a213fe1 100644 --- a/src/manifolds/Oblique.jl +++ b/src/manifolds/Oblique.jl @@ -3,11 +3,11 @@ The oblique manifold $\mathcal{OB}(n,m)$ is the set of 𝔽-valued matrices with unit norm column endowed with the metric from the embedding. This yields exactly the same metric as -considering the product metric of the unit norm vectors, i.e. [`PowerManifold`](@ref) of the +considering the product metric of the unit norm vectors, i.e. [`PowerManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.PowerManifold) of the $(n-1)$-dimensional [`Sphere`](@ref). The [`Sphere`](@ref) is stored internally within `M.manifold`, such that all functions of -[`AbstractPowerManifold`](@ref) can be used directly. +[`AbstractPowerManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.AbstractPowerManifold) can be used directly. # Constructor @@ -36,12 +36,6 @@ of `m` unit columns from $\mathbb R^{n}$, i.e. each column is a point from """ check_point(::Oblique, ::Any) function check_point(M::Oblique{n,m}, p; kwargs...) where {n,m} - if size(p) != (n, m) - return DomainError( - length(p), - "The matrix in `p` ($(size(p))) does not match the dimension of $(M).", - ) - end return check_point(PowerManifold(M.manifold, m), p; kwargs...) end @doc raw""" @@ -52,12 +46,6 @@ This means, that `p` is valid, that `X` is of correct dimension and columnswise a tangent vector to the columns of `p` on the [`Sphere`](@ref). """ function check_vector(M::Oblique{n,m}, p, X; kwargs...) where {n,m} - if size(X) != (n, m) - return DomainError( - length(X), - "The matrix `X` ($(size(X))) does not match the required dimension ($(representation_size(M))) for $(M).", - ) - end return check_vector(PowerManifold(M.manifold, m), p, X; kwargs...) end @@ -66,23 +54,18 @@ get_iterator(::Oblique{n,m}) where {n,m} = Base.OneTo(m) @generated function manifold_dimension(::Oblique{n,m,𝔽}) where {n,m,𝔽} return (n * real_dimension(𝔽) - 1) * m end +power_dimensions(::Oblique{n,m}) where {n,m} = (m,) @generated representation_size(::Oblique{n,m}) where {n,m} = (n, m) @doc raw""" - vector_transport_to(M::Oblique, p, X, q, ::ParallelTransport) + parallel_transport_to(M::Oblique, p, X, q) Compute the parallel transport on the [`Oblique`](@ref) manifold by doing a column wise parallel transport on the [`Sphere`](@ref) -This is a shortcut to using [`PowerVectorTransport{ParallelTransport}`](@ref) -from the [`AbstractPowerManifold`](@ref). """ -vector_transport_to(::Oblique, ::Any, ::Any, ::Any, ::ParallelTransport) - -function vector_transport_to!(M::Oblique, Y, p, X, q, m::ParallelTransport) - return vector_transport_to!(M, Y, p, X, q, PowerVectorTransport(m)) -end +parallel_transport_to(::Oblique, p, X, q) function Base.show(io::IO, ::Oblique{n,m,𝔽}) where {n,m,𝔽} return print(io, "Oblique($(n),$(m); field = $(𝔽))") diff --git a/src/manifolds/PositiveNumbers.jl b/src/manifolds/PositiveNumbers.jl index c3b22362fa..fdab4b93e7 100644 --- a/src/manifolds/PositiveNumbers.jl +++ b/src/manifolds/PositiveNumbers.jl @@ -18,7 +18,7 @@ struct PositiveNumbers <: AbstractManifold{ℝ} end PositiveVectors(n) Generate the manifold of vectors with positive entries. -This manifold is modeled as a [`PowerManifold`](@ref) of [`PositiveNumbers`](@ref). +This manifold is modeled as a [`PowerManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.PowerManifold) of [`PositiveNumbers`](@ref). """ PositiveVectors(n::Integer) = PositiveNumbers()^n @@ -26,7 +26,7 @@ PositiveVectors(n::Integer) = PositiveNumbers()^n PositiveMatrices(m,n) Generate the manifold of matrices with positive entries. -This manifold is modeled as a [`PowerManifold`](@ref) of [`PositiveNumbers`](@ref). +This manifold is modeled as a [`PowerManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.PowerManifold) of [`PositiveNumbers`](@ref). """ PositiveMatrices(n::Integer, m::Integer) = PositiveNumbers()^(n, m) @@ -34,7 +34,7 @@ PositiveMatrices(n::Integer, m::Integer) = PositiveNumbers()^(n, m) PositiveArrays(n₁,nβ‚‚,...,nα΅’) Generate the manifold of `i`-dimensional arrays with positive entries. -This manifold is modeled as a [`PowerManifold`](@ref) of [`PositiveNumbers`](@ref). +This manifold is modeled as a [`PowerManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.PowerManifold) of [`PositiveNumbers`](@ref). """ PositiveArrays(n::Vararg{Int,I}) where {I} = PositiveNumbers()^(n) @@ -148,17 +148,6 @@ exp!(::PositiveNumbers, q, p, X) = (q .= p .* exp.(X ./ p)) Return the injectivity radius on the [`PositiveNumbers`](@ref) `M`, i.e. $\infty$. """ injectivity_radius(::PositiveNumbers) = Inf -injectivity_radius(::PositiveNumbers, ::ExponentialRetraction) = Inf -injectivity_radius(::PositiveNumbers, ::Any) = Inf -injectivity_radius(::PositiveNumbers, ::Any, ::ExponentialRetraction) = Inf -eval( - quote - @invoke_maker 1 AbstractManifold injectivity_radius( - M::PositiveNumbers, - rm::AbstractRetractionMethod, - ) - end, -) @doc raw""" inner(M::PositiveNumbers, p, X, Y) @@ -237,7 +226,7 @@ function Base.show( end @doc raw""" - vector_transport_to(M::PositiveNumbers, p, X, q, ::ParallelTransport) + parallel_transport_to(M::PositiveNumbers, p, X, q) Compute the parallel transport of `X` from the tangent space at `p` to the tangent space at `q` on the [`PositiveNumbers`](@ref) `M`. @@ -246,18 +235,12 @@ Compute the parallel transport of `X` from the tangent space at `p` to the tange \mathcal P_{q\gets p}(X) = X\cdot\frac{q}{p}. ```` """ -vector_transport_to(::PositiveNumbers, ::Any, ::Any, ::Any, ::ParallelTransport) -function vector_transport_to( - ::PositiveNumbers, - p::Real, - X::Real, - q::Real, - ::ParallelTransport, -) +parallel_transport_to(::PositiveNumbers, ::Any, ::Any, ::Any) +function parallel_transport_to(::PositiveNumbers, p::Real, X::Real, q::Real) return X * q / p end -function vector_transport_to!(::PositiveNumbers, Y, p, X, q, ::ParallelTransport) +function parallel_transport_to!(::PositiveNumbers, Y, p, X, q) return (Y .= X .* q ./ p) end diff --git a/src/manifolds/PowerManifold.jl b/src/manifolds/PowerManifold.jl index 1febe6e594..ac88913d84 100644 --- a/src/manifolds/PowerManifold.jl +++ b/src/manifolds/PowerManifold.jl @@ -14,7 +14,7 @@ struct ArrayPowerRepresentation <: AbstractPowerRepresentation end @doc raw""" PowerMetric <: AbstractMetric -Represent the [`AbstractMetric`](@ref) on an [`AbstractPowerManifold`](@ref), i.e. the inner +Represent the [`AbstractMetric`](@ref) on an `AbstractPowerManifold`, i.e. the inner product on the tangent space is the sum of the inner product of each elements tangent space of the power manifold. """ @@ -76,8 +76,6 @@ for PowerRepr in [PowerManifoldNested, PowerManifoldNestedReplacing] end end -default_metric_dispatch(::AbstractPowerManifold, ::PowerMetric) = Val(true) - """ change_representer(M::AbstractPowerManifold, ::AbstractMetric, p, X) @@ -124,7 +122,7 @@ end flat(M::AbstractPowerManifold, p, X) use the musical isomorphism to transform the tangent vector `X` from the tangent space at -`p` on an [`AbstractPowerManifold`](@ref) `M` to a cotangent vector. +`p` on an [`AbstractPowerManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.AbstractPowerManifold) `M` to a cotangent vector. This can be done elementwise for each entry of `X` (and `p`). """ flat(::AbstractPowerManifold, ::Any...) @@ -305,7 +303,7 @@ end sharp(M::AbstractPowerManifold, p, ΞΎ::RieszRepresenterCotangentVector) Use the musical isomorphism to transform the cotangent vector `ΞΎ` from the tangent space at -`p` on an [`AbstractPowerManifold`](@ref) `M` to a tangent vector. +`p` on an [`AbstractPowerManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.AbstractPowerManifold) `M` to a tangent vector. This can be done elementwise for every entry of `ΞΎ` (and `p`). """ sharp(::AbstractPowerManifold, ::Any...) @@ -335,7 +333,7 @@ Distributions.support(tvd::PowerFVectorDistribution) = FVectorSupport(tvd.type, Distributions.support(d::PowerPointDistribution) = MPointSupport(d.manifold) function vector_bundle_transport(fiber::VectorSpaceType, M::PowerManifold) - return PowerVectorTransport(ParallelTransport()) + return ParallelTransport() end @inline function _write( diff --git a/src/manifolds/ProbabilitySimplex.jl b/src/manifolds/ProbabilitySimplex.jl index c66aebff99..dd6226510b 100644 --- a/src/manifolds/ProbabilitySimplex.jl +++ b/src/manifolds/ProbabilitySimplex.jl @@ -1,5 +1,5 @@ @doc raw""" - ProbabilitySimplex{n} <: AbstractEmbeddedManifold{ℝ,DefaultEmbeddingType} + ProbabilitySimplex{n} <: AbstractDecoratorManifold{𝔽} The (relative interior of) the probability simplex is the set ````math @@ -29,24 +29,10 @@ This implementation follows the notation in [^Γ…strΓΆmPetraSchmitzerSchnΓΆrr2017 > doi: [10.1007/s10851-016-0702-4](https://doi.org/10.1007/s10851-016-0702-4) > arxiv: [1603.05285](https://arxiv.org/abs/1603.05285). """ -struct ProbabilitySimplex{n} <: AbstractEmbeddedManifold{ℝ,DefaultEmbeddingType} end +struct ProbabilitySimplex{n} <: AbstractDecoratorManifold{ℝ} end ProbabilitySimplex(n::Int) = ProbabilitySimplex{n}() -""" - SoftmaxRetraction <: AbstractRetractionMethod - -Describes a retraction that is based on the softmax function. -""" -struct SoftmaxRetraction <: AbstractRetractionMethod end - -""" - SoftmaxInverseRetraction <: AbstractInverseRetractionMethod - -Describes an inverse retraction that is based on the softmax function. -""" -struct SoftmaxInverseRetraction <: AbstractInverseRetractionMethod end - """ FisherRaoMetric <: AbstractMetric @@ -58,6 +44,8 @@ See for example the [`ProbabilitySimplex`](@ref). """ struct FisherRaoMetric <: AbstractMetric end +active_traits(f, ::ProbabilitySimplex, args...) = merge_traits(IsEmbeddedManifold()) + @doc raw""" change_representer(M::ProbabilitySimplex, ::EuclideanMetric, p, X) @@ -104,14 +92,6 @@ the embedding with positive entries that sum to one The tolerance for the last test can be set using the `kwargs...`. """ function check_point(M::ProbabilitySimplex, p; kwargs...) - mpv = invoke( - check_point, - Tuple{(typeof(get_embedding(M))),typeof(p)}, - get_embedding(M), - p; - kwargs..., - ) - mpv === nothing || return mpv if minimum(p) <= 0 return DomainError( minimum(p), @@ -136,15 +116,6 @@ after [`check_point`](@ref check_point(::ProbabilitySimplex, ::Any))`(M,p)`, The tolerance for the last test can be set using the `kwargs...`. """ function check_vector(M::ProbabilitySimplex, p, X; kwargs...) - mpv = invoke( - check_vector, - Tuple{typeof(get_embedding(M)),typeof(p),typeof(X)}, - get_embedding(M), - p, - X; - kwargs..., - ) - mpv === nothing || return mpv if !isapprox(sum(X), 0.0; kwargs...) return DomainError( sum(X), @@ -154,9 +125,7 @@ function check_vector(M::ProbabilitySimplex, p, X; kwargs...) return nothing end -decorated_manifold(M::ProbabilitySimplex) = Euclidean(representation_size(M)...; field=ℝ) - -default_metric_dispatch(::ProbabilitySimplex, ::FisherRaoMetric) = Val(true) +get_embedding(M::ProbabilitySimplex) = Euclidean(representation_size(M)...; field=ℝ) @doc raw""" distance(M,p,q) @@ -175,6 +144,9 @@ function distance(::ProbabilitySimplex, p, q) return 2 * acos(sumsqrt) end +embed(::ProbabilitySimplex, p) = p +embed(::ProbabilitySimplex, p, X) = X + @doc raw""" exp(M::ProbabilitySimplex,p,X) @@ -211,21 +183,11 @@ function injectivity_radius(::ProbabilitySimplex{n}, p) where {n} s = sum(p) - p[i] return 2 * acos(sqrt(s)) end -function injectivity_radius(M::ProbabilitySimplex, p, ::ExponentialRetraction) +function injectivity_radius(M::ProbabilitySimplex, p, ::AbstractRetractionMethod) return injectivity_radius(M, p) end -injectivity_radius(M::ProbabilitySimplex, p, ::SoftmaxRetraction) = injectivity_radius(M, p) injectivity_radius(M::ProbabilitySimplex) = 0 -injectivity_radius(M::ProbabilitySimplex, ::SoftmaxRetraction) = 0 -injectivity_radius(M::ProbabilitySimplex, ::ExponentialRetraction) = 0 -eval( - quote - @invoke_maker 1 AbstractManifold injectivity_radius( - M::ProbabilitySimplex, - rm::AbstractRetractionMethod, - ) - end, -) +injectivity_radius(M::ProbabilitySimplex, ::AbstractRetractionMethod) = 0 @doc raw""" inner(M::ProbabilitySimplex,p,X,Y) @@ -255,13 +217,7 @@ where $\mathbb{1}^{m,n}$ is the size `(m,n)` matrix containing ones, and $\log$ """ inverse_retract(::ProbabilitySimplex, ::Any, ::Any, ::SoftmaxInverseRetraction) -function inverse_retract!( - ::ProbabilitySimplex{n}, - X, - p, - q, - ::SoftmaxInverseRetraction, -) where {n} +function inverse_retract_softmax!(::ProbabilitySimplex{n}, X, p, q) where {n} X .= log.(q) .- log.(p) meanlogdiff = mean(X) X .-= meanlogdiff @@ -316,15 +272,7 @@ Compute the Riemannian [`mean`](@ref mean(M::AbstractManifold, args...)) of `x` """ mean(::ProbabilitySimplex, ::Any...) -function Statistics.mean!( - M::ProbabilitySimplex, - p, - x::AbstractVector, - w::AbstractVector; - kwargs..., -) - return mean!(M, p, x, w, GeodesicInterpolation(); kwargs...) -end +default_estimation_method(::ProbabilitySimplex, ::typeof(mean)) = GeodesicInterpolation() @doc raw""" project(M::ProbabilitySimplex, p, Y) @@ -392,7 +340,7 @@ where multiplication, exponentiation and division are meant elementwise. """ retract(::ProbabilitySimplex, ::Any, ::Any, ::SoftmaxRetraction) -function retract!(::ProbabilitySimplex, q, p, X, ::SoftmaxRetraction) +function retract_softmax!(::ProbabilitySimplex, q, p, X) s = zero(eltype(q)) @inbounds for i in eachindex(q, p, X) q[i] = p[i] * exp(X[i]) diff --git a/src/manifolds/ProductManifold.jl b/src/manifolds/ProductManifold.jl index 46fc7d5b43..ef90452df2 100644 --- a/src/manifolds/ProductManifold.jl +++ b/src/manifolds/ProductManifold.jl @@ -11,7 +11,7 @@ generates the product manifold $M_1 Γ— M_2 Γ— … Γ— M_n$. Alternatively, the same manifold can be contructed using the `Γ—` operator: `M_1 Γ— M_2 Γ— M_3`. """ -struct ProductManifold{𝔽,TM<:Tuple} <: AbstractManifold{𝔽} +struct ProductManifold{𝔽,TM<:Tuple} <: AbstractDecoratorManifold{𝔽} manifolds::TM end @@ -139,6 +139,10 @@ function ProductVectorTransport(methods::AbstractVectorTransportMethod...) return ProductVectorTransport{typeof(methods)}(methods) end +function active_traits(f, ::ProductManifold, args...) + return merge_traits(IsDefaultMetric(ProductMetric())) +end + function allocate_coordinates(M::AbstractManifold, p::ArrayPartition, T, n::Int) return allocate_coordinates(M, p.x[1], T, n) end @@ -183,7 +187,7 @@ end check_point(M::ProductManifold, p; kwargs...) Check whether `p` is a valid point on the [`ProductManifold`](@ref) `M`. -If `p` is not a point on `M` a [`CompositeManifoldError`](@ref) consisting of all error messages of the +If `p` is not a point on `M` a [`CompositeManifoldError`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions.html#ManifoldsBase.CompositeManifoldError).consisting of all error messages of the components, for which the tests fail is returned. The tolerance for the last test can be set using the `kwargs...`. @@ -204,12 +208,60 @@ function check_point(M::ProductManifold, p; kwargs...) ) end +""" + check_size(M::ProductManifold, p; kwargs...) + +Check whether `p` is of valid size on the [`ProductManifold`](@ref) `M`. +If `p` has components of wrong size a [`CompositeManifoldError`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions.html#ManifoldsBase.CompositeManifoldError).consisting of all error messages of the +components, for which the tests fail is returned. + +The tolerance for the last test can be set using the `kwargs...`. +""" +function check_size(M::ProductManifold, p::Union{ProductRepr,ArrayPartition}) + ts = ziptuples(Tuple(1:length(M.manifolds)), M.manifolds, submanifold_components(M, p)) + e = [(t[1], check_size(t[2:end]...)) for t in ts] + errors = filter((x) -> !(x[2] === nothing), e) + cerr = [ComponentManifoldError(er...) for er in errors] + (length(errors) > 1) && return CompositeManifoldError(cerr) + (length(errors) == 1) && return cerr[1] + return nothing +end +function check_size(M::ProductManifold, p; kwargs...) + return DomainError( + typeof(p), + "The point $p is not a point on $M, since currently only ProductRepr and ArrayPartition are supported types for points on arbitrary product manifolds", + ) +end +function check_size( + M::ProductManifold, + p::Union{ProductRepr,ArrayPartition}, + X::Union{ProductRepr,ArrayPartition}, +) + ts = ziptuples( + Tuple(1:length(M.manifolds)), + M.manifolds, + submanifold_components(M, p), + submanifold_components(M, X), + ) + e = [(t[1], check_size(t[2:end]...)) for t in ts] + errors = filter(x -> !(x[2] === nothing), e) + cerr = [ComponentManifoldError(er...) for er in errors] + (length(errors) > 1) && return CompositeManifoldError(cerr) + (length(errors) == 1) && return cerr[1] + return nothing +end +function check_size(M::ProductManifold, p, X; kwargs...) + return DomainError( + typeof(X), + "The vector $X is not a tangent vector to any tangent space on $M, since currently only ProductRepr and ArrayPartition are supported types for tangent vectors on arbitrary product manifolds", + ) +end """ check_vector(M::ProductManifold, p, X; kwargs... ) Check whether `X` is a tangent vector to `p` on the [`ProductManifold`](@ref) `M`, i.e. all projections to base manifolds must be respective tangent vectors. -If `X` is not a tangent vector to `p` on `M` a [`CompositeManifoldError`](@ref) consisting +If `X` is not a tangent vector to `p` on `M` a [`CompositeManifoldError`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions.html#ManifoldsBase.CompositeManifoldError).consisting of all error messages of the components, for which the tests fail is returned. The tolerance for the last test can be set using the `kwargs...`. @@ -270,7 +322,7 @@ end cross(M, N) cross(M1, M2, M3,...) -Return the [`ProductManifold`](@ref) For two [`AbstractManifold`](@ref)s `M` and `N`, +Return the [`ProductManifold`](@ref) For two `AbstractManifold`s `M` and `N`, where for the case that one of them is a [`ProductManifold`](@ref) itself, the other is either prepended (if `N` is a product) or appenden (if `M`) is. If both are product manifold, they are combined into one product manifold, @@ -378,11 +430,6 @@ function get_basis(M::ProductManifold, p, B::DiagonalizingOrthonormalBasis) end return CachedBasis(B, ProductBasisData(vs)) end -for BT in PRODUCT_BASIS_LIST - eval(quote - @invoke_maker 3 AbstractBasis get_basis(M::ProductManifold, p, B::$BT) - end) -end """ get_component(M::ProductManifold, p, i) @@ -393,6 +440,13 @@ function get_component(M::ProductManifold, p, i) return submanifold_component(M, p, i) end +function get_coordinates(M::ProductManifold, p, X, B::AbstractBasis) + reps = map( + t -> get_coordinates(t..., B), + ziptuples(M.manifolds, submanifold_components(M, p), submanifold_components(M, X)), + ) + return vcat(reps...) +end function get_coordinates( M::ProductManifold, p, @@ -402,48 +456,13 @@ function get_coordinates( reps = map( get_coordinates, M.manifolds, - submanifold_components(p), - submanifold_components(X), + submanifold_components(M, p), + submanifold_components(M, X), B.data.parts, ) return vcat(reps...) end -for BT in PRODUCT_BASIS_LIST_CACHED - eval( - quote - @invoke_maker 4 ( - CachedBasis{𝔽,<:AbstractBasis{𝔽},<:ProductBasisData} where {𝔽} - ) get_coordinates(M::ProductManifold, p, X, B::$BT) - end, - ) -end -eval( - quote - @invoke_maker 1 AbstractManifold get_coordinates( - M::ProductManifold, - e::Identity, - X, - B::VeeOrthogonalBasis, - ) - end, -) - -function get_coordinates(M::ProductManifold, p, X, B::AbstractBasis) - reps = map( - t -> get_coordinates(t..., B), - ziptuples(M.manifolds, submanifold_components(M, p), submanifold_components(M, X)), - ) - return vcat(reps...) -end -for BT in PRODUCT_BASIS_LIST - eval( - quote - @invoke_maker 4 AbstractBasis get_coordinates(M::ProductManifold, p, X, B::$BT) - end, - ) -end - function get_coordinates!(M::ProductManifold, Xⁱ, p, X, B::AbstractBasis) dim = manifold_dimension(M) @assert length(Xⁱ) == dim @@ -484,116 +503,40 @@ function get_coordinates!( return Xⁱ end -for BT in PRODUCT_BASIS_LIST_CACHED - eval( - quote - @invoke_maker 5 ( - CachedBasis{𝔽,<:AbstractBasis{𝔽},<:ProductBasisData} where {𝔽} - ) get_coordinates!(M::ProductManifold, Xⁱ, p, X, B::$BT) - end, - ) -end -for BT in PRODUCT_BASIS_LIST - eval( - quote - @invoke_maker 5 AbstractBasis get_coordinates!( - M::ProductManifold, - Xⁱ, - p, - X, - B::$BT, - ) - end, - ) -end -eval( - quote - @invoke_maker 1 AbstractManifold get_coordinates!( - M::ProductManifold, - Y, - e::Identity, - X, - B::VeeOrthogonalBasis, - ) - end, -) - function _get_dim_ranges(dims::NTuple{N,Any}) where {N} dims_acc = accumulate(+, vcat(1, SVector(dims))) return ntuple(i -> (dims_acc[i]:(dims_acc[i] + dims[i] - 1)), Val(N)) end -eval( - quote - @invoke_maker 1 AbstractManifold get_vector( - M::ProductManifold, - e::Identity, - X, - B::VeeOrthogonalBasis, - ) - end, -) for TP in [ProductRepr, ArrayPartition] eval( quote - @invoke_maker 4 ( - CachedBasis{𝔽,<:AbstractBasis{𝔽},<:ProductBasisData} where {𝔽} - ) get_vector( - M::ProductManifold, - p::$TP, - X, - B::CachedBasis{ℝ,<:AbstractBasis{ℝ},<:ProductBasisData}, - ) function get_vector( M::ProductManifold, p::$TP, - X, - B::CachedBasis{𝔽,<:AbstractBasis{𝔽},<:ProductBasisData}, + Xⁱ, + B::AbstractBasis{𝔽,TangentSpaceType}, ) where {𝔽} - N = number_of_components(M) dims = map(manifold_dimension, M.manifolds) + @assert length(Xⁱ) == sum(dims) dim_ranges = _get_dim_ranges(dims) - parts = ntuple(N) do i - return get_vector( - M.manifolds[i], - submanifold_component(p, i), - view(X, dim_ranges[i]), - B.data.parts[i], - ) - end - return $TP(parts) + tXⁱ = map(dr -> (@inbounds view(Xⁱ, dr)), dim_ranges) + ts = ziptuples(M.manifolds, submanifold_components(M, p), tXⁱ) + return $TP(map((@inline t -> get_vector(t..., B)), ts)) end function get_vector( M::ProductManifold, p::$TP, - X, - B::AbstractBasis{𝔽,TangentSpaceType}, + Xⁱ, + B::CachedBasis{𝔽,<:AbstractBasis{𝔽},<:ProductBasisData}, ) where {𝔽} - N = number_of_components(M) dims = map(manifold_dimension, M.manifolds) + @assert length(Xⁱ) == sum(dims) dim_ranges = _get_dim_ranges(dims) - parts = ntuple(N) do i - return get_vector( - M.manifolds[i], - submanifold_component(p, i), - view(X, dim_ranges[i]), - B, - ) - end - return $TP(parts) - end - function get_vector(M::ProductManifold, p::$TP, Xⁱ, B::VeeOrthogonalBasis) - dim = manifold_dimension(M) - @assert length(Xⁱ) == dim - i = one(dim) - ts = ziptuples(M.manifolds, submanifold_components(M, p)) - mapped = map(ts) do t - dim = manifold_dimension(first(t)) - tXⁱ = @inbounds view(Xⁱ, i:(i + dim - 1)) - i += dim - return get_vector(t..., tXⁱ, B) - end - return $TP(mapped...) + tXⁱ = map(dr -> (@inbounds view(Xⁱ, dr)), dim_ranges) + ts = + ziptuples(M.manifolds, submanifold_components(M, p), tXⁱ, B.data.parts) + return $TP(map((@inline t -> get_vector(t...)), ts)) end end, ) @@ -638,30 +581,6 @@ function get_vector!( end return X end -eval( - quote - @invoke_maker 1 AbstractManifold get_vector!( - M::ProductManifold, - Xⁱ, - e::Identity, - X, - B::VeeOrthogonalBasis, - ) - end, -) - -for BT in PRODUCT_BASIS_LIST - eval( - quote - @invoke_maker 5 AbstractBasis get_vector!(M::ProductManifold, X, p, Xⁱ, B::$BT) - end, - ) -end -function get_vector!(M::ProductManifold, Y, p, X, B::CachedBasis) - return error( - "get_vector! called on $M with an incorrect CachedBasis. Expected a CachedBasis with ProductBasisData, given $B", - ) -end function get_vectors( M::ProductManifold, @@ -750,15 +669,6 @@ function injectivity_radius(M::ProductManifold, p, m::ProductRetraction) )..., ) end -eval( - quote - @invoke_maker 3 AbstractRetractionMethod injectivity_radius( - M::ProductManifold, - p, - B::ExponentialRetraction, - ) - end, -) injectivity_radius(M::ProductManifold) = min(map(injectivity_radius, M.manifolds)...) function injectivity_radius(M::ProductManifold, m::AbstractRetractionMethod) return min(map(manif -> injectivity_radius(manif, m), M.manifolds)...) @@ -766,14 +676,6 @@ end function injectivity_radius(M::ProductManifold, m::ProductRetraction) return min(map((lM, lm) -> injectivity_radius(lM, lm), M.manifolds, m.retractions)...) end -eval( - quote - @invoke_maker 2 AbstractRetractionMethod injectivity_radius( - M::ProductManifold, - B::ExponentialRetraction, - ) - end, -) @doc raw""" inner(M::ProductManifold, p, X, Y) @@ -803,20 +705,29 @@ so the encapsulated inverse retraction methods have to be available per factor. """ inverse_retract(::ProductManifold, ::Any, ::Any, ::Any, ::InverseProductRetraction) -function inverse_retract!(M::ProductManifold, X, p, q, method::InverseProductRetraction) - map( - inverse_retract!, - M.manifolds, - submanifold_components(M, X), - submanifold_components(M, p), - submanifold_components(M, q), - method.inverse_retractions, +for TP in [ProductRepr, ArrayPartition] + eval( + quote + function inverse_retract( + M::ProductManifold, + p::$TP, + q::$TP, + method::InverseProductRetraction, + ) + return $TP( + map( + inverse_retract, + M.manifolds, + submanifold_components(M, p), + submanifold_components(M, q), + method.inverse_retractions, + ), + ) + end + end, ) - return X end -default_metric_dispatch(::ProductManifold, ::ProductMetric) = Val(true) - function Base.isapprox(M::ProductManifold, p, q; kwargs...) return all( t -> isapprox(t...; kwargs...), @@ -958,6 +869,69 @@ function ProductPointDistribution(distributions::MPointDistribution...) return ProductPointDistribution(M, distributions...) end +for TP in [ProductRepr, ArrayPartition] + eval( + quote + function parallel_transport_direction( + M::ProductManifold, + p::$TP, + X::$TP, + d::$TP, + ) + return $TP( + map( + parallel_transport_direction, + M.manifolds, + submanifold_components(M, p), + submanifold_components(M, X), + submanifold_components(M, d), + ), + ) + end + end, + ) +end +function parallel_transport_direction!(M::ProductManifold, Y, p, X, d) + map( + parallel_transport_direction!, + M.manifolds, + submanifold_components(M, Y), + submanifold_components(M, p), + submanifold_components(M, X), + submanifold_components(M, d), + ) + return Y +end + +for TP in [ProductRepr, ArrayPartition] + eval( + quote + function parallel_transport_to(M::ProductManifold, p::$TP, X::$TP, q::$TP) + return $TP( + map( + parallel_transport_to, + M.manifolds, + submanifold_components(M, p), + submanifold_components(M, X), + submanifold_components(M, q), + ), + ) + end + end, + ) +end +function parallel_transport_to!(M::ProductManifold, Y, p, X, q) + map( + parallel_transport_to!, + M.manifolds, + submanifold_components(M, Y), + submanifold_components(M, p), + submanifold_components(M, X), + submanifold_components(M, q), + ) + return Y +end + function project(M::ProductManifold, p::ProductRepr) return ProductRepr(map(project, M.manifolds, submanifold_components(M, p))...) end @@ -1170,7 +1144,30 @@ method has to be one that is available on the manifolds. """ retract(::ProductManifold, ::Any...) -function retract!(M::ProductManifold, q, p, X, method::ProductRetraction) +for TP in [ProductRepr, ArrayPartition] + eval( + quote + function _retract( + M::ProductManifold, + p::$TP, + X::$TP, + method::ProductRetraction, + ) + return $TP( + map( + retract, + M.manifolds, + submanifold_components(M, p), + submanifold_components(M, X), + method.retractions, + ), + ) + end + end, + ) +end + +function _retract!(M::ProductManifold, q, p, X, method::ProductRetraction) map( retract!, M.manifolds, @@ -1189,7 +1186,7 @@ end """ set_component!(M::ProductManifold, q, p, i) -Set the `i`th component of a point `q` on a [`ProductManifold`](@ref) `M` to `p`, where `p` is a point on the [`AbstractManifold`](@ref) this factor of the product manifold consists of. +Set the `i`th component of a point `q` on a [`ProductManifold`](@ref) `M` to `p`, where `p` is a point on the [`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) this factor of the product manifold consists of. """ function set_component!(M::ProductManifold, q, p, i) return copyto!(submanifold_component(M, q, i), p) @@ -1331,17 +1328,26 @@ end function vector_bundle_transport(::VectorSpaceType, M::ProductManifold) return ProductVectorTransport(map(_ -> ParallelTransport(), M.manifolds)) end -for T in [ManifoldsBase.VECTOR_TRANSPORT_DISAMBIGUATION..., AbstractVectorTransportMethod] + +for TP in [ProductRepr, ArrayPartition] eval( quote - function vector_transport_direction!(M::ProductManifold, Y, p, X, d, m::$T) - return vector_transport_direction!( - M, - Y, - p, - X, - d, - ProductVectorTransport(map(_ -> m, M.manifolds)), + function vector_transport_direction( + M::ProductManifold, + p::$TP, + X::$TP, + d::$TP, + m::ProductVectorTransport, + ) + return $TP( + map( + vector_transport_direction, + M.manifolds, + submanifold_components(M, p), + submanifold_components(M, X), + submanifold_components(M, d), + m.methods, + ), ) end end, @@ -1377,22 +1383,48 @@ base manifold. """ vector_transport_to(::ProductManifold, ::Any, ::Any, ::Any, ::ProductVectorTransport) -for T in [ManifoldsBase.VECTOR_TRANSPORT_DISAMBIGUATION..., AbstractVectorTransportMethod] +for TP in [ProductRepr, ArrayPartition] eval( quote - function vector_transport_to!(M::ProductManifold, Y, p, X, q, m::$T) - return vector_transport_to!( - M, - Y, - p, - X, - q, - ProductVectorTransport(map(_ -> m, M.manifolds)), + function vector_transport_to( + M::ProductManifold, + p::$TP, + X::$TP, + q::$TP, + m::ProductVectorTransport, + ) + return $TP( + map( + vector_transport_to, + M.manifolds, + submanifold_components(M, p), + submanifold_components(M, X), + submanifold_components(M, q), + m.methods, + ), + ) + end + function vector_transport_to( + M::ProductManifold, + p::$TP, + X::$TP, + q::$TP, + m::ParallelTransport, + ) + return $TP( + map( + (iM, ip, iX, id) -> vector_transport_to(iM, ip, iX, id, m), + M.manifolds, + submanifold_components(M, p), + submanifold_components(M, X), + submanifold_components(M, q), + ), ) end end, ) end + function vector_transport_to!(M::ProductManifold, Y, p, X, q, m::ProductVectorTransport) map( vector_transport_to!, @@ -1405,6 +1437,17 @@ function vector_transport_to!(M::ProductManifold, Y, p, X, q, m::ProductVectorTr ) return Y end +function vector_transport_to!(M::ProductManifold, Y, p, X, q, m::ParallelTransport) + map( + (iM, iY, ip, iX, id) -> vector_transport_to!(iM, iY, ip, iX, id, m), + M.manifolds, + submanifold_components(M, Y), + submanifold_components(M, p), + submanifold_components(M, X), + submanifold_components(M, q), + ), + return Y +end function zero_vector!(M::ProductManifold, X, p) map( diff --git a/src/manifolds/ProjectiveSpace.jl b/src/manifolds/ProjectiveSpace.jl index c5b98880c2..c3ba95a838 100644 --- a/src/manifolds/ProjectiveSpace.jl +++ b/src/manifolds/ProjectiveSpace.jl @@ -1,11 +1,10 @@ """ - AbstractProjectiveSpace{𝔽} <: AbstractEmbeddedManifold{𝔽,DefaultIsometricEmbeddingType} + AbstractProjectiveSpace{𝔽} <: AbstractDecoratorManifold{𝔽} An abstract type to represent a projective space over `𝔽` that is represented isometrically in the embedding. """ -abstract type AbstractProjectiveSpace{𝔽} <: - AbstractEmbeddedManifold{𝔽,DefaultIsometricEmbeddingType} end +abstract type AbstractProjectiveSpace{𝔽} <: AbstractDecoratorManifold{𝔽} end @doc raw""" ProjectiveSpace{n,𝔽} <: AbstractProjectiveSpace{𝔽} @@ -41,6 +40,10 @@ projective spaces. struct ProjectiveSpace{N,𝔽} <: AbstractProjectiveSpace{𝔽} end ProjectiveSpace(n::Int, field::AbstractNumbers=ℝ) = ProjectiveSpace{n,field}() +function active_traits(f, ::AbstractProjectiveSpace, args...) + return merge_traits(IsIsometricEmbeddedManifold()) +end + @doc raw""" ArrayProjectiveSpace{T<:Tuple,𝔽} <: AbstractProjectiveSpace{𝔽} @@ -94,14 +97,6 @@ that it has the same size as elements of the embedding and has unit Frobenius no The tolerance for the norm check can be set using the `kwargs...`. """ function check_point(M::AbstractProjectiveSpace, p; kwargs...) - mpv = invoke( - check_point, - Tuple{(typeof(get_embedding(M))),typeof(p)}, - get_embedding(M), - p; - kwargs..., - ) - mpv === nothing || return mpv if !isapprox(norm(p), 1; kwargs...) return DomainError( norm(p), @@ -120,15 +115,6 @@ tangent space of the embedding and that the Frobenius inner product $⟨p, X⟩_{\mathrm{F}} = 0$. """ function check_vector(M::AbstractProjectiveSpace, p, X; kwargs...) - mpv = invoke( - check_vector, - Tuple{typeof(get_embedding(M)),typeof(p),typeof(X)}, - get_embedding(M), - p, - X; - kwargs..., - ) - mpv === nothing || return mpv if !isapprox(dot(p, X), 0; kwargs...) return DomainError( dot(p, X), @@ -145,6 +131,9 @@ end get_embedding(M::AbstractProjectiveSpace) = decorated_manifold(M) +embed(::AbstractProjectiveSpace, p) = p +embed(::AbstractProjectiveSpace, p, X) = p + @doc raw""" distance(M::AbstractProjectiveSpace, p, q) @@ -187,12 +176,12 @@ formula for $Y$ is """ get_coordinates(::AbstractProjectiveSpace{ℝ}, p, X, ::DefaultOrthonormalBasis) -function get_coordinates!( +function get_coordinates_orthonormal!( M::AbstractProjectiveSpace{𝔽}, Y, p, X, - ::DefaultOrthonormalBasis{ℝ,TangentSpaceType}, + ::RealNumbers, ) where {𝔽} n = div(manifold_dimension(M), real_dimension(𝔽)) z = p[1] @@ -222,12 +211,12 @@ Y = \left(X - q\frac{2 \left\langle q, \begin{pmatrix}0 \\ X\end{pmatrix}\right\ """ get_vector(::AbstractProjectiveSpace, p, X, ::DefaultOrthonormalBasis{ℝ}) -function get_vector!( +function get_vector_orthonormal!( M::AbstractProjectiveSpace{𝔽}, Y, p, X, - ::DefaultOrthonormalBasis{ℝ,TangentSpaceType}, + ::RealNumbers, ) where {𝔽} n = div(manifold_dimension(M), real_dimension(𝔽)) z = p[1] @@ -241,25 +230,17 @@ function get_vector!( end injectivity_radius(::AbstractProjectiveSpace) = Ο€ / 2 -injectivity_radius(::AbstractProjectiveSpace, ::ExponentialRetraction) = Ο€ / 2 -injectivity_radius(::AbstractProjectiveSpace, ::Any) = Ο€ / 2 -injectivity_radius(::AbstractProjectiveSpace, ::Any, ::ExponentialRetraction) = Ο€ / 2 -eval( - quote - @invoke_maker 1 AbstractManifold injectivity_radius( - M::AbstractProjectiveSpace, - rm::AbstractRetractionMethod, - ) - end, -) +injectivity_radius(::AbstractProjectiveSpace, p) = Ο€ / 2 +injectivity_radius(::AbstractProjectiveSpace, ::AbstractRetractionMethod) = Ο€ / 2 +injectivity_radius(::AbstractProjectiveSpace, p, ::AbstractRetractionMethod) = Ο€ / 2 @doc raw""" inverse_retract(M::AbstractProjectiveSpace, p, q, method::ProjectionInverseRetraction) inverse_retract(M::AbstractProjectiveSpace, p, q, method::PolarInverseRetraction) inverse_retract(M::AbstractProjectiveSpace, p, q, method::QRInverseRetraction) -Compute the equivalent inverse retraction [`ProjectionInverseRetraction`](@ref), -[`PolarInverseRetraction`](@ref), and [`QRInverseRetraction`](@ref) on the +Compute the equivalent inverse retraction [`ProjectionInverseRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.ProjectionInverseRetraction), +[`PolarInverseRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.PolarInverseRetraction), and [`QRInverseRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.QRInverseRetraction) on the [`AbstractProjectiveSpace`](@ref) manifold `M`$=𝔽ℙ^n$, i.e. ````math \operatorname{retr}_p^{-1} q = q \frac{1}{⟨p, q⟩_{\mathrm{F}}} - p, @@ -279,13 +260,15 @@ inverse_retract( ::Union{ProjectionInverseRetraction,PolarInverseRetraction,QRInverseRetraction}, ) -function inverse_retract!( - ::AbstractProjectiveSpace, - X, - p, - q, - ::Union{ProjectionInverseRetraction,PolarInverseRetraction,QRInverseRetraction}, -) +function inverse_retract_qr!(::AbstractProjectiveSpace, X, p, q) + X .= q ./ dot(p, q) .- p + return X +end +function inverse_retract_polar!(::AbstractProjectiveSpace, X, p, q) + X .= q ./ dot(p, q) .- p + return X +end +function inverse_retract_project!(::AbstractProjectiveSpace, X, p, q) X .= q ./ dot(p, q) .- p return X end @@ -307,7 +290,7 @@ end log(M::AbstractProjectiveSpace, p, q) Compute the logarithmic map on [`AbstractProjectiveSpace`](@ref) `M`$ = 𝔽ℙ^n$, -i.e. the tangent vector whose corresponding [`geodesic`](@ref) starting from `p` +i.e. the tangent vector whose corresponding [`geodesic`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions.html#ManifoldsBase.geodesic-Tuple{AbstractManifold,%20Any,%20Any}) starting from `p` reaches `q` after time 1 on `M`. The formula reads ````math @@ -359,14 +342,8 @@ using [`GeodesicInterpolationWithinRadius`](@ref). """ mean(::AbstractProjectiveSpace, ::Any...) -function Statistics.mean!( - M::AbstractProjectiveSpace, - p, - x::AbstractVector, - w::AbstractVector; - kwargs..., -) - return mean!(M, p, x, w, GeodesicInterpolationWithinRadius(Ο€ / 4); kwargs...) +function default_estimation_method(::AbstractProjectiveSpace, ::typeof(mean)) + return GeodesicInterpolationWithinRadius(Ο€ / 4) end function mid_point!(M::ProjectiveSpace, q, p1, p2) @@ -423,8 +400,8 @@ i.e., the representation size of the embedding. retract(M::AbstractProjectiveSpace, p, X, method::PolarRetraction) retract(M::AbstractProjectiveSpace, p, X, method::QRRetraction) -Compute the equivalent retraction [`ProjectionRetraction`](@ref), [`PolarRetraction`](@ref), -and [`QRRetraction`](@ref) on the [`AbstractProjectiveSpace`](@ref) manifold `M`$=𝔽ℙ^n$, +Compute the equivalent retraction [`ProjectionRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.ProjectionRetraction), [`PolarRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.PolarRetraction), +and [`QRRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.QRRetraction) on the [`AbstractProjectiveSpace`](@ref) manifold `M`$=𝔽ℙ^n$, i.e. ````math \operatorname{retr}_p X = \operatorname{proj}_p(p + X). @@ -441,13 +418,15 @@ retract( ::Union{ProjectionRetraction,PolarRetraction,QRRetraction}, ) -function retract!( - M::AbstractProjectiveSpace, - q, - p, - X, - ::Union{ProjectionRetraction,PolarRetraction,QRRetraction}, -) +function retract_polar!(M::AbstractProjectiveSpace, q, p, X) + q .= p .+ X + return project!(M, q, q) +end +function retract_project!(M::AbstractProjectiveSpace, q, p, X) + q .= p .+ X + return project!(M, q, q) +end +function retract_qr!(M::AbstractProjectiveSpace, q, p, X) q .= p .+ X return project!(M, q, q) end @@ -471,13 +450,13 @@ function uniform_distribution(M::ProjectiveSpace{n,ℝ}, p) where {n} end @doc raw""" - vector_transport_to(M::AbstractProjectiveSpace, p, X, q, method::ParallelTransport) + parallel_transport_to(M::AbstractProjectiveSpace, p, X, q) Parallel transport a vector `X` from the tangent space at a point `p` on the [`AbstractProjectiveSpace`](@ref) `M`$=𝔽ℙ^n$ to the tangent space at another point `q`. This implementation proceeds by transporting $X$ to $T_{q Ξ»} M$ using the same approach as -[`vector_transport_direction`](@ref vector_transport_direction(::AbstractProjectiveSpace, p, X, d, ::ParallelTransport)), +[`parallel_transport_direction`](@ref parallel_transport_direction(::AbstractProjectiveSpace, p, X, d)), where $Ξ» = \frac{⟨q, p⟩_{\mathrm{F}}}{|⟨q, p⟩_{\mathrm{F}}|} ∈ 𝔽$ is the unit scalar that takes $q$ to the member $q Ξ»$ of its equivalence class $[q]$ closest to $p$ in the embedding. @@ -490,9 +469,9 @@ where $d = \log_p q$ is the direction of the transport, $ΞΈ = \lVert d \rVert_p$ [`distance`](@ref distance(::AbstractProjectiveSpace, p, q)) between $p$ and $q$, and $\overline{β‹…}$ denotes complex or quaternionic conjugation. """ -vector_transport_to(::AbstractProjectiveSpace, ::Any, ::Any, ::Any, ::ParallelTransport) +parallel_transport_to(::AbstractProjectiveSpace, ::Any, ::Any, ::Any) -function vector_transport_to!(::AbstractProjectiveSpace, Y, p, X, q, ::ParallelTransport) +function parallel_transport_to!(::AbstractProjectiveSpace, Y, p, X, q) z = dot(q, p) Ξ» = nzsign(z) m = p .+ q .* Ξ» # un-normalized midpoint @@ -503,16 +482,16 @@ function vector_transport_to!(::AbstractProjectiveSpace, Y, p, X, q, ::ParallelT Y .= (X .- m .* factor) .* Ξ»' return Y end -function vector_transport_to!(M::AbstractProjectiveSpace, Y, p, X, q, ::ProjectionTransport) +function vector_transport_to_project!(M::AbstractProjectiveSpace, Y, p, X, q) project!(M, Y, q, X) return Y end @doc raw""" - vector_transport_direction(M::AbstractProjectiveSpace, p, X, d, method::ParallelTransport) + parallel_transport_direction(M::AbstractProjectiveSpace, p, X, d) Parallel transport a vector `X` from the tangent space at a point `p` on the -[`AbstractProjectiveSpace`](@ref) `M` along the [`geodesic`](@ref) in the direction +[`AbstractProjectiveSpace`](@ref) `M` along the [`geodesic`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions.html#ManifoldsBase.geodesic-Tuple{AbstractManifold,%20Any,%20Any}) in the direction indicated by the tangent vector `d`, i.e. ````math \mathcal{P}_{\exp_p (d) ← p}(X) = X - \left(p \frac{\sin ΞΈ}{ΞΈ} + d \frac{1 - \cos ΞΈ}{ΞΈ^2}\right) ⟨d, X⟩_p, @@ -521,22 +500,9 @@ where $ΞΈ = \lVert d \rVert$, and $βŸ¨β‹…, β‹…βŸ©_p$ is the [`inner`](@ref) prod For the real projective space, this is equivalent to the same vector transport on the real [`AbstractSphere`](@ref). """ -vector_transport_direction( - ::AbstractProjectiveSpace, - ::Any, - ::Any, - ::Any, - ::ParallelTransport, -) +parallel_transport_direction(::AbstractProjectiveSpace, ::Any, ::Any, ::Any) -function vector_transport_direction!( - M::AbstractProjectiveSpace, - Y, - p, - X, - d, - ::ParallelTransport, -) +function parallel_transport_direction!(M::AbstractProjectiveSpace, Y, p, X, d) ΞΈ = norm(M, p, d) cosΞΈ = cos(ΞΈ) dX = inner(M, p, d, X) diff --git a/src/manifolds/Rotations.jl b/src/manifolds/Rotations.jl index 186791731c..6c6a86a9d1 100644 --- a/src/manifolds/Rotations.jl +++ b/src/manifolds/Rotations.jl @@ -38,6 +38,8 @@ function NormalRotationDistribution( return NormalRotationDistribution{TResult,typeof(M),typeof(d)}(M, d) end +active_traits(f, ::Rotations, args...) = merge_traits(IsEmbeddedManifold()) + @doc raw""" angles_4d_skew_sym_matrix(A) @@ -69,12 +71,6 @@ valid rotation. The tolerance for the last test can be set using the `kwargs...`. """ function check_point(M::Rotations{N}, p; kwargs...) where {N} - if size(p) != (N, N) - return DomainError( - size(p), - "The point $(p) does not lie on $M, since its size is not $((N, N)).", - ) - end if !isapprox(det(p), 1; kwargs...) return DomainError(det(p), "The determinant of $p has to be +1 but it is $(det(p))") end @@ -258,23 +254,16 @@ $X^{j (j - 3)/2 + k + 1} = X_{jk}$, for $j ∈ [4,n], k ∈ [1,j)$. get_coordinates(::Rotations, ::Any...) get_coordinates(::Rotations{2}, p, X, ::DefaultOrthogonalBasis{ℝ,TangentSpaceType}) = [X[2]] -function get_coordinates!( - ::Rotations{2}, - Xⁱ, - p, - X, - ::DefaultOrthogonalBasis{ℝ,TangentSpaceType}, -) +function get_coordinates_orthogonal(M::Rotations, p, X, N) + Y = allocate_result(M, get_coordinates, p, X, DefaultOrthogonalBasis(N)) + return get_coordinates_orthogonal!(M, Y, p, X, N) +end + +function get_coordinates_orthogonal!(::Rotations{2}, Xⁱ, p, X, ::RealNumbers) Xⁱ[1] = X[2] return Xⁱ end -function get_coordinates!( - M::Rotations{N}, - Xⁱ, - p, - X, - ::DefaultOrthogonalBasis{ℝ,TangentSpaceType}, -) where {N} +function get_coordinates_orthogonal!(::Rotations{N}, Xⁱ, p, X, ::RealNumbers) where {N} @inbounds begin Xⁱ[1] = X[3, 2] Xⁱ[2] = X[1, 3] @@ -288,19 +277,15 @@ function get_coordinates!( end return Xⁱ end -function get_coordinates!( - M::Rotations{N}, - Xⁱ, - p, - X, - ::DefaultOrthonormalBasis{ℝ,TangentSpaceType}, -) where {N} +function get_coordinates_orthonormal!(M::Rotations{N}, Xⁱ, p, X, num::RealNumbers) where {N} T = Base.promote_eltype(p, X) - get_coordinates!(M, Xⁱ, p, X, DefaultOrthogonalBasis()) + get_coordinates_orthogonal!(M, Xⁱ, p, X, num) Xⁱ .*= sqrt(T(2)) return Xⁱ end +get_embedding(::Rotations{N}) where {N} = Euclidean(N, N) + @doc raw""" get_vector(M::Rotations, p, Xⁱ, B::DefaultOrthogonalBasis) @@ -310,22 +295,15 @@ group $\mathrm{SO}(n)$ to the matrix representation $X$ of the tangent vector. S """ get_vector(::Rotations, ::Any...) -function get_vector!( - M::Rotations{2}, - X, - p, - Xⁱ, - B::DefaultOrthogonalBasis{ℝ,TangentSpaceType}, -) - return get_vector!(M, X, p, Xⁱ[1], B) -end -function get_vector!( - M::Rotations{2}, - X, - p, - Xⁱ::Real, - ::DefaultOrthogonalBasis{ℝ,TangentSpaceType}, -) +function get_vector_orthogonal(M::Rotations, p, c, N::RealNumbers) + Y = allocate_result(M, get_vector, p, c) + return get_vector_orthogonal!(M, Y, p, c, N) +end + +function get_vector_orthogonal!(M::Rotations{2}, X, p, Xⁱ, N::RealNumbers) + return get_vector_orthogonal!(M, X, p, Xⁱ[1], N) +end +function get_vector_orthogonal!(M::Rotations{2}, X, p, Xⁱ::Real, ::RealNumbers) @assert length(X) == 4 @inbounds begin X[1] = 0 @@ -335,13 +313,7 @@ function get_vector!( end return X end -function get_vector!( - M::Rotations{N}, - X, - p, - Xⁱ, - ::DefaultOrthogonalBasis{ℝ,TangentSpaceType}, -) where {N} +function get_vector_orthogonal!(M::Rotations{N}, X, p, Xⁱ, ::RealNumbers) where {N} @assert size(X) == (N, N) @assert length(Xⁱ) == manifold_dimension(M) @inbounds begin @@ -366,9 +338,9 @@ function get_vector!( end return X end -function get_vector!(M::Rotations, X, p, Xⁱ, ::DefaultOrthonormalBasis{ℝ,TangentSpaceType}) +function get_vector_orthonormal!(M::Rotations, X, p, Xⁱ, N::RealNumbers) T = Base.promote_eltype(p, X) - get_vector!(M, X, p, Xⁱ, DefaultOrthogonalBasis()) + get_vector_orthogonal!(M, X, p, Xⁱ, N) X ./= sqrt(T(2)) return X end @@ -385,23 +357,11 @@ Return the injectivity radius on the [`Rotations`](@ref) `M`, which is globally injectivity_radius(M::Rotations, p, ::PolarRetraction) -Return the radius of injectivity for the [`PolarRetraction`](@ref) on the +Return the radius of injectivity for the [`PolarRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.PolarRetraction) on the [`Rotations`](@ref) `M` which is $\frac{Ο€}{\sqrt{2}}$. """ injectivity_radius(::Rotations) = Ο€ * sqrt(2.0) -injectivity_radius(::Rotations, ::ExponentialRetraction) = Ο€ * sqrt(2.0) -eval( - quote - @invoke_maker 1 AbstractManifold injectivity_radius( - M::Rotations, - rm::AbstractRetractionMethod, - ) - end, -) -injectivity_radius(::Rotations, ::Any) = Ο€ * sqrt(2.0) -injectivity_radius(::Rotations, ::Any, ::ExponentialRetraction) = Ο€ * sqrt(2.0) -injectivity_radius(::Rotations, ::PolarRetraction) = Ο€ / sqrt(2.0) -injectivity_radius(::Rotations, p, ::PolarRetraction) = Ο€ / sqrt(2.0) +_injectivity_radius(::Rotations, ::PolarRetraction) = Ο€ / sqrt(2.0) @doc raw""" inner(M::Rotations, p, X, Y) @@ -416,7 +376,7 @@ g_p(X, Y) = \operatorname{tr}(X^\mathrm{T} Y), Tangent vectors are represented by matrices. """ -inner(M::Rotations, p, X, Y) = dot(X, Y) +inner(::Rotations, p, X, Y) = dot(X, Y) @doc raw""" inverse_retract(M, p, q, ::PolarInverseRetraction) @@ -424,7 +384,7 @@ inner(M::Rotations, p, X, Y) = dot(X, Y) Compute a vector from the tangent space $T_p\mathrm{SO}(n)$ of the point `p` on the [`Rotations`](@ref) manifold `M` with which the point `q` can be reached by the -[`PolarRetraction`](@ref) from the point `p` after time 1. +[`PolarRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.PolarRetraction) from the point `p` after time 1. The formula reads ````math @@ -443,11 +403,11 @@ inverse_retract(::Rotations, ::Any, ::Any, ::PolarInverseRetraction) Compute a vector from the tangent space $T_p\mathrm{SO}(n)$ of the point `p` on the [`Rotations`](@ref) manifold `M` with which the point `q` can be reached by the -[`QRRetraction`](@ref) from the point `q` after time 1. +[`QRRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.QRRetraction) from the point `q` after time 1. """ inverse_retract(::Rotations, ::Any, ::Any, ::QRInverseRetraction) -function inverse_retract!(M::Rotations, X, p, q, method::PolarInverseRetraction) +function inverse_retract_polar!(M::Rotations, X, p, q) A = transpose(p) * q Amat = A isa StaticMatrix ? A : convert(Matrix, A) H = copyto!(allocate(Amat), -2I) @@ -463,7 +423,7 @@ function inverse_retract!(M::Rotations, X, p, q, method::PolarInverseRetraction) end return project!(M, X, p, X) end -function inverse_retract!(M::Rotations{N}, X, p, q, ::QRInverseRetraction) where {N} +function inverse_retract_qr!(M::Rotations{N}, X, p, q) where {N} A = transpose(p) * q R = zero(X) for i in 1:N @@ -563,8 +523,8 @@ Compute the Riemannian [`mean`](@ref mean(M::AbstractManifold, args...)) of `x` """ mean(::Rotations, ::Any) -function Statistics.mean!(M::Rotations, q, x::AbstractVector, w::AbstractVector; kwargs...) - return mean!(M, q, x, w, GeodesicInterpolationWithinRadius(Ο€ / 2 / √2); kwargs...) +function default_estimation_method(::Rotations, ::typeof(mean)) + return GeodesicInterpolationWithinRadius(Ο€ / 2 / √2) end @doc raw""" @@ -653,7 +613,7 @@ where tangent vectors are represented by elements from the Lie group """ project(::Rotations, ::Any, ::Any) -project!(M::Rotations{N}, Y, p, X) where {N} = project!(SkewSymmetricMatrices(N), Y, X) +project!(::Rotations{N}, Y, p, X) where {N} = project!(SkewSymmetricMatrices(N), Y, X) @doc raw""" representation_size(M::Rotations) @@ -768,14 +728,14 @@ This is also the default retraction on the [`Rotations`](@ref) """ retract(::Rotations, ::Any, ::Any, ::QRRetraction) -function retract!(M::Rotations, q::AbstractArray{T}, p, X, method::QRRetraction) where {T} +function retract_qr!(::Rotations, q::AbstractArray{T}, p, X) where {T} A = p + p * X qr_decomp = qr(A) d = diag(qr_decomp.R) D = Diagonal(sign.(d .+ convert(T, 0.5))) return copyto!(q, qr_decomp.Q * D) end -function retract!(M::Rotations, q, p, X, method::PolarRetraction) +function retract_polar!(M::Rotations, q, p, X) A = p + p * X return project!(M, q, A; check_det=false) end @@ -785,7 +745,7 @@ Base.show(io::IO, ::Rotations{N}) where {N} = print(io, "Rotations($(N))") Distributions.support(d::NormalRotationDistribution) = MPointSupport(d.manifold) @doc raw""" - vector_transport_direction(M::Rotations, p, X, d) + parallel_transport_direction(M::Rotations, p, X, d) Compute parallel transport of vector `X` tangent at `p` on the [`Rotations`](@ref) manifold in the direction `d`. The formula, provided in [^Rentmeesters], reads: @@ -802,25 +762,37 @@ The formula simplifies to identity for 2-D rotations. > Riemannian manifolds,” in 2011 50th IEEE Conference on Decision and Control and > European Control Conference, Dec. 2011, pp. 7141–7146. doi: 10.1109/CDC.2011.6161280. """ -vector_transport_direction(M::Rotations, p, X, d) +parallel_transport_direction(M::Rotations, p, X, d) -function vector_transport_direction!(M::Rotations, Y, p, X, d, ::ParallelTransport) +function parallel_transport_direction!(M::Rotations, Y, p, X, d) expdhalf = exp(d / 2) q = exp(M, p, d) return copyto!(Y, transpose(q) * p * expdhalf * X * expdhalf) end -function vector_transport_direction!(M::Rotations{2}, Y, p, X, d, ::ParallelTransport) +function parallel_transport_direction!(::Rotations{2}, Y, p, X, d) return copyto!(Y, X) end +function parallel_transport_direction(M::Rotations, p, X, d) + expdhalf = exp(d / 2) + q = exp(M, p, d) + return transpose(q) * p * expdhalf * X * expdhalf +end +parallel_transport_direction(::Rotations{2}, p, X, d) = X -function vector_transport_to!(M::Rotations, Y, p, X, q, ::ParallelTransport) +function parallel_transport_to!(M::Rotations, Y, p, X, q) d = log(M, p, q) expdhalf = exp(d / 2) return copyto!(Y, transpose(q) * p * expdhalf * X * expdhalf) end -function vector_transport_to!(M::Rotations{2}, Y, p, X, q, ::ParallelTransport) +function parallel_transport_to!(::Rotations{2}, Y, p, X, q) return copyto!(Y, X) end +function parallel_transport_to(M::Rotations, p, X, q) + d = log(M, p, q) + expdhalf = exp(d / 2) + return transpose(q) * p * expdhalf * X * expdhalf +end +parallel_transport_to(::Rotations{2}, p, X, q) = X @doc raw""" zero_vector(M::Rotations, p) @@ -828,6 +800,6 @@ end Return the zero tangent vector from the tangent space art `p` on the [`Rotations`](@ref) as an element of the Lie group, i.e. the zero matrix. """ -zero_vector(M::Rotations, p) = zero(p) +zero_vector(::Rotations, p) = zero(p) -zero_vector!(M::Rotations, X, p) = fill!(X, 0) +zero_vector!(::Rotations, X, p) = fill!(X, 0) diff --git a/src/manifolds/SkewHermitian.jl b/src/manifolds/SkewHermitian.jl index d2fbed5055..7e2d85376d 100644 --- a/src/manifolds/SkewHermitian.jl +++ b/src/manifolds/SkewHermitian.jl @@ -1,7 +1,7 @@ @doc raw""" - SkewHermitianMatrices{n,𝔽} <: AbstractEmbeddedManifold{𝔽,TransparentIsometricEmbedding} + SkewHermitianMatrices{n,𝔽} <: AbstractDecoratorManifold{𝔽} -The [`AbstractManifold`](@ref) $ \operatorname{SkewHerm}(n)$ consisting of the real- or +The [`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) $ \operatorname{SkewHerm}(n)$ consisting of the real- or complex-valued skew-hermitian matrices of size ``n Γ— n``, i.e. the set ````math @@ -22,8 +22,7 @@ which is also reflected in the Generate the manifold of ``n Γ— n`` skew-hermitian matrices. """ -struct SkewHermitianMatrices{n,𝔽} <: - AbstractEmbeddedManifold{𝔽,TransparentIsometricEmbedding} end +struct SkewHermitianMatrices{n,𝔽} <: AbstractDecoratorManifold{𝔽} end function SkewHermitianMatrices(n::Int, field::AbstractNumbers=ℝ) return SkewHermitianMatrices{n,field}() @@ -44,6 +43,10 @@ const SkewSymmetricMatrices{n} = SkewHermitianMatrices{n,ℝ} SkewSymmetricMatrices(n::Int) = SkewSymmetricMatrices{n}() @deprecate SkewSymmetricMatrices(n::Int, 𝔽) SkewHermitianMatrices(n, 𝔽) +function active_traits(f, ::SkewHermitianMatrices, args...) + return merge_traits(IsEmbeddedSubmanifold()) +end + function allocation_promotion_function( ::SkewHermitianMatrices{<:Any,β„‚}, ::typeof(get_vector), @@ -57,13 +60,11 @@ end Check whether `p` is a valid manifold point on the [`SkewHermitianMatrices`](@ref) `M`, i.e. whether `p` is a skew-hermitian matrix of size `(n,n)` with values from the corresponding -[`AbstractNumbers`](@ref) `𝔽`. +[`AbstractNumbers`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#number-system) `𝔽`. The tolerance for the skew-symmetry of `p` can be set using `kwargs...`. """ function check_point(M::SkewHermitianMatrices{n,𝔽}, p; kwargs...) where {n,𝔽} - mpv = check_point(decorated_manifold(M), p; kwargs...) - mpv === nothing || return mpv if !isapprox(p, -p'; kwargs...) return DomainError( norm(p + p'), @@ -78,27 +79,25 @@ end Check whether `X` is a tangent vector to manifold point `p` on the [`SkewHermitianMatrices`](@ref) `M`, i.e. `X` must be a skew-hermitian matrix of size `(n,n)` -and its values have to be from the correct [`AbstractNumbers`](@ref). +and its values have to be from the correct [`AbstractNumbers`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#number-system). The tolerance for the skew-symmetry of `p` and `X` can be set using `kwargs...`. """ function check_vector(M::SkewHermitianMatrices, p, X; kwargs...) return check_point(M, X; kwargs...) # manifold is its own tangent space end -decorated_manifold(M::SkewHermitianMatrices{N,𝔽}) where {N,𝔽} = Euclidean(N, N; field=𝔽) - function get_basis(M::SkewHermitianMatrices, p, B::DiagonalizingOrthonormalBasis) Ξ = get_basis(M, p, DefaultOrthonormalBasis()).data ΞΊ = zeros(real(eltype(p)), manifold_dimension(M)) return CachedBasis(B, ΞΊ, Ξ) end -function get_coordinates!( +function get_coordinates_orthonormal!( M::SkewSymmetricMatrices{N}, Y, p, X, - ::DefaultOrthonormalBasis{ℝ,TangentSpaceType}, + ::RealNumbers, ) where {N} dim = manifold_dimension(M) @assert size(Y) == (dim,) @@ -111,12 +110,12 @@ function get_coordinates!( end return Y end -function get_coordinates!( +function get_coordinates_orthonormal!( M::SkewHermitianMatrices{N,β„‚}, Y, p, X, - ::DefaultOrthonormalBasis{β„‚,TangentSpaceType}, + ::ComplexNumbers, ) where {N} dim = manifold_dimension(M) @assert size(Y) == (dim,) @@ -136,12 +135,14 @@ function get_coordinates!( return Y end -function get_vector!( +get_embedding(::SkewHermitianMatrices{N,𝔽}) where {N,𝔽} = Euclidean(N, N; field=𝔽) + +function get_vector_orthonormal!( M::SkewSymmetricMatrices{N}, Y, p, X, - ::DefaultOrthonormalBasis{ℝ,TangentSpaceType}, + ::RealNumbers, ) where {N} dim = manifold_dimension(M) @assert size(X) == (dim,) @@ -157,12 +158,12 @@ function get_vector!( end return Y end -function get_vector!( +function get_vector_orthonormal!( M::SkewHermitianMatrices{N,β„‚}, Y, p, X, - ::DefaultOrthonormalBasis{β„‚,TangentSpaceType}, + ::ComplexNumbers, ) where {N} dim = manifold_dimension(M) @assert size(X) == (dim,) @@ -191,7 +192,7 @@ system `𝔽`, i.e. \dim \mathrm{SkewHerm}(n,ℝ) = \frac{n(n+1)}{2} \dim_ℝ 𝔽 - n, ```` -where ``\dim_ℝ 𝔽`` is the [`real_dimension`](@ref) of ``𝔽``. The first term corresponds to +where ``\dim_ℝ 𝔽`` is the [`real_dimension`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.real_dimension-Tuple{ManifoldsBase.AbstractNumbers}) of ``𝔽``. The first term corresponds to only the upper triangular elements of the matrix being unique, and the second term corresponds to the constraint that the real part of the diagonal be zero. """ @@ -216,7 +217,7 @@ where $\cdot^{\mathrm{H}}$ denotes the Hermitian, i.e. complex conjugate transpo """ project(::SkewHermitianMatrices, ::Any) -function project!(M::SkewHermitianMatrices, q, p) +function project!(::SkewHermitianMatrices, q, p) q .= (p .- p') ./ 2 return q end @@ -236,6 +237,8 @@ project(::SkewHermitianMatrices, ::Any, ::Any) project!(M::SkewHermitianMatrices, Y, p, X) = project!(M, Y, X) +representation_size(::SkewHermitianMatrices{N}) where {N} = (N, N) + function Base.show(io::IO, ::SkewHermitianMatrices{n,F}) where {n,F} return print(io, "SkewHermitianMatrices($(n), $(F))") end diff --git a/src/manifolds/Spectrahedron.jl b/src/manifolds/Spectrahedron.jl index efd14f7f12..d4c89b3d64 100644 --- a/src/manifolds/Spectrahedron.jl +++ b/src/manifolds/Spectrahedron.jl @@ -1,5 +1,5 @@ @doc raw""" - Spectrahedron{N,K} <: AbstractEmbeddedManifold{ℝ,DefaultIsometricEmbeddingType} + Spectrahedron{N,K} <: AbstractDecoratorManifold{ℝ} The Spectrahedron manifold, also known as the set of correlation matrices (symmetric positive semidefinite matrices) of rank $k$ with unit trace. @@ -27,7 +27,7 @@ $Y\in ℝ^{n Γ— k}$ and reads as ````math T_p\mathcal S(n,k) = \bigl\{ X ∈ ℝ^{n Γ— n}\,|\,X = qY^{\mathrm{T}} + Yq^{\mathrm{T}} -\text{Β with } \operatorname{tr}(X) = \sum_{i=1}^{n}X_{ii} = 0 +\text{ with } \operatorname{tr}(X) = \sum_{i=1}^{n}X_{ii} = 0 \bigr\} ```` endowed with the [`Euclidean`](@ref) metric from the embedding, i.e. from the $ℝ^{n Γ— k}$ @@ -49,10 +49,12 @@ generates the manifold $\mathcal S(n,k) \subset ℝ^{n Γ— n}$. > doi: [10.1137/080731359](https://doi.org/10.1137/080731359), > arXiv: [0807.4423](http://arxiv.org/abs/0807.4423). """ -struct Spectrahedron{N,K} <: AbstractEmbeddedManifold{ℝ,DefaultIsometricEmbeddingType} end +struct Spectrahedron{N,K} <: AbstractDecoratorManifold{ℝ} end Spectrahedron(n::Int, k::Int) = Spectrahedron{n,k}() +active_traits(f, ::Spectrahedron, args...) = merge_traits(IsIsometricEmbeddedManifold()) + @doc raw""" check_point(M::Spectrahedron, q; kwargs...) @@ -65,8 +67,6 @@ Since $p$ is by construction positive semidefinite, this is not checked. The tolerances for positive semidefiniteness and unit trace can be set using the `kwargs...`. """ function check_point(M::Spectrahedron{N,K}, q; kwargs...) where {N,K} - mpv = invoke(check_point, Tuple{supertype(typeof(M)),typeof(q)}, M, q; kwargs...) - mpv === nothing || return mpv fro_n = norm(q) if !isapprox(fro_n, 1.0; kwargs...) return DomainError( @@ -88,15 +88,6 @@ The tolerance for the base point check and zero diagonal can be set using the `k Note that symmetry of $X$ holds by construction and is not explicitly checked. """ function check_vector(M::Spectrahedron{N,K}, q, Y; kwargs...) where {N,K} - mpv = invoke( - check_vector, - Tuple{supertype(typeof(M)),typeof(q),typeof(Y)}, - M, - q, - Y; - kwargs..., - ) - mpv === nothing || return mpv X = q * Y' + Y * q' n = tr(X) if !isapprox(n, 0.0; kwargs...) @@ -108,7 +99,7 @@ function check_vector(M::Spectrahedron{N,K}, q, Y; kwargs...) where {N,K} return nothing end -function decorated_manifold(M::Spectrahedron) +function get_embedding(M::Spectrahedron) return Euclidean(representation_size(M)...; field=ℝ) end @@ -155,7 +146,7 @@ compute a projection based retraction by projecting $q+Y$ back onto the manifold """ retract(::Spectrahedron, ::Any, ::Any, ::ProjectionRetraction) -retract!(M::Spectrahedron, r, q, Y, ::ProjectionRetraction) = project!(M, r, q + Y) +retract_project!(M::Spectrahedron, r, q, Y) = project!(M, r, q + Y) @doc raw""" representation_size(M::Spectrahedron) @@ -178,7 +169,7 @@ at `q`. """ vector_transport_to(::Spectrahedron, ::Any, ::Any, ::Any, ::ProjectionTransport) -function vector_transport_to!(M::Spectrahedron, Y, p, X, q, ::ProjectionTransport) +function vector_transport_to_project!(M::Spectrahedron, Y, p, X, q) project!(M, Y, q, X) return Y end diff --git a/src/manifolds/Sphere.jl b/src/manifolds/Sphere.jl index 0a7b4786ec..6412b5abd4 100644 --- a/src/manifolds/Sphere.jl +++ b/src/manifolds/Sphere.jl @@ -1,9 +1,13 @@ """ - AbstractSphere{𝔽} <: AbstractEmbeddedManifold{𝔽,DefaultIsometricEmbeddingType} + AbstractSphere{𝔽} <: AbstractDecoratorManifold{𝔽} An abstract type to represent a unit sphere that is represented isometrically in the embedding. """ -abstract type AbstractSphere{𝔽} <: AbstractEmbeddedManifold{𝔽,DefaultIsometricEmbeddingType} end +abstract type AbstractSphere{𝔽} <: AbstractDecoratorManifold{𝔽} end + +function active_traits(f, ::AbstractSphere, args...) + return merge_traits(IsIsometricEmbeddedManifold(), IsDefaultMetric(EuclideanMetric())) +end @doc raw""" Sphere{n,𝔽} <: AbstractSphere{𝔽} @@ -97,14 +101,6 @@ the embedding of unit length. The tolerance for the last test can be set using the `kwargs...`. """ function check_point(M::AbstractSphere, p; kwargs...) - mpv = invoke( - check_point, - Tuple{(typeof(get_embedding(M))),typeof(p)}, - get_embedding(M), - p; - kwargs..., - ) - mpv === nothing || return mpv if !isapprox(norm(p), 1.0; kwargs...) return DomainError( norm(p), @@ -123,15 +119,6 @@ and orthogonal to `p`. The tolerance for the last test can be set using the `kwargs...`. """ function check_vector(M::AbstractSphere, p, X; kwargs...) - mpv = invoke( - check_vector, - Tuple{typeof(get_embedding(M)),typeof(p),typeof(X)}, - get_embedding(M), - p, - X; - kwargs..., - ) - mpv === nothing || return mpv if !isapprox(abs(real(dot(p, X))), 0.0; kwargs...) return DomainError( abs(dot(p, X)), @@ -141,13 +128,6 @@ function check_vector(M::AbstractSphere, p, X; kwargs...) return nothing end -function decorated_manifold(M::AbstractSphere{𝔽}) where {𝔽} - return Euclidean(representation_size(M)...; field=𝔽) -end - -# Since on every tangent space the Euclidean matric (restricted to this space) is used, this should be fine -default_metric_dispatch(::AbstractSphere, ::EuclideanMetric) = Val(true) - @doc raw""" distance(M::AbstractSphere, p, q) @@ -161,6 +141,9 @@ d_{π•Š}(p,q) = \arccos(\Re(⟨p,q⟩)). """ distance(::AbstractSphere, p, q) = acos(clamp(real(dot(p, q)), -1, 1)) +embed(::AbstractSphere, p) = copy(p) +embed(::AbstractSphere, p, X) = copy(X) + @doc raw""" exp(M::AbstractSphere, p, X) @@ -181,7 +164,11 @@ function exp!(M::AbstractSphere, q, p, X) return q end -function get_basis(M::Sphere{n,ℝ}, p, B::DiagonalizingOrthonormalBasis{ℝ}) where {n} +function get_basis_diagonalizing( + M::Sphere{n,ℝ}, + p, + B::DiagonalizingOrthonormalBasis{ℝ}, +) where {n} A = zeros(n + 1, n + 1) A[1, :] = transpose(p) A[2, :] = transpose(B.frame_direction) @@ -212,13 +199,7 @@ denotes the Frobenius inner product, the formula for $Y$ is """ get_coordinates(::AbstractSphere{ℝ}, p, X, ::DefaultOrthonormalBasis) -function get_coordinates!( - M::AbstractSphere{ℝ}, - Y, - p, - X, - ::DefaultOrthonormalBasis{ℝ,TangentSpaceType}, -) +function get_coordinates_orthonormal!(M::AbstractSphere{ℝ}, Y, p, X, ::RealNumbers) n = manifold_dimension(M) p1 = p[1] cosΞΈ = abs(p1) @@ -229,7 +210,9 @@ function get_coordinates!( return Y end -get_embedding(M::AbstractSphere{𝔽}) where {𝔽} = decorated_manifold(M) +function get_embedding(M::AbstractSphere{𝔽}) where {𝔽} + return Euclidean(representation_size(M)...; field=𝔽) +end @doc raw""" get_vector(M::AbstractSphere{ℝ}, p, X, B::DefaultOrthonormalBasis) @@ -247,13 +230,7 @@ Y = X - q\frac{2 \left\langle q, \begin{pmatrix}0 \\ X\end{pmatrix}\right\rangle """ get_vector(::AbstractSphere{ℝ}, p, X, ::DefaultOrthonormalBasis) -function get_vector!( - M::AbstractSphere{ℝ}, - Y, - p, - X, - ::DefaultOrthonormalBasis{ℝ,TangentSpaceType}, -) +function get_vector_orthonormal!(M::AbstractSphere{ℝ}, Y, p, X, ::RealNumbers) n = manifold_dimension(M) p1 = p[1] cosΞΈ = abs(p1) @@ -273,23 +250,20 @@ Return the injectivity radius for the [`AbstractSphere`](@ref) `M`, which is glo injectivity_radius(M::Sphere, x, ::ProjectionRetraction) -Return the injectivity radius for the [`ProjectionRetraction`](@ref) on the +Return the injectivity radius for the [`ProjectionRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.ProjectionRetraction) on the [`AbstractSphere`](@ref), which is globally $\frac{Ο€}{2}$. """ injectivity_radius(::AbstractSphere) = Ο€ -injectivity_radius(::AbstractSphere, ::ExponentialRetraction) = Ο€ -injectivity_radius(::AbstractSphere, ::ProjectionRetraction) = Ο€ / 2 -injectivity_radius(::AbstractSphere, ::Any) = Ο€ -injectivity_radius(::AbstractSphere, ::Any, ::ExponentialRetraction) = Ο€ -injectivity_radius(::AbstractSphere, ::Any, ::ProjectionRetraction) = Ο€ / 2 -eval( - quote - @invoke_maker 1 AbstractManifold injectivity_radius( - M::AbstractSphere, - rm::AbstractRetractionMethod, - ) - end, -) +injectivity_radius(::AbstractSphere, p) = Ο€ +#avoid falling back but use the ones below +function injectivity_radius(M::AbstractSphere, m::AbstractRetractionMethod) + return _injectivity_radius(M, m) +end +function injectivity_radius(M::AbstractSphere, p, m::AbstractRetractionMethod) + return _injectivity_radius(M, p, m) +end +_injectivity_radius(::AbstractSphere, ::ExponentialRetraction) = Ο€ +_injectivity_radius(::AbstractSphere, ::ProjectionRetraction) = Ο€ / 2 @doc raw""" inverse_retract(M::AbstractSphere, p, q, ::ProjectionInverseRetraction) @@ -304,14 +278,14 @@ since $\Re(⟨p,X⟩) = 0$ and when $d_{π•Š^2}(p,q) ≀ \frac{Ο€}{2}$ that """ inverse_retract(::AbstractSphere, ::Any, ::Any, ::ProjectionInverseRetraction) -function inverse_retract!(::AbstractSphere, X, p, q, ::ProjectionInverseRetraction) +function inverse_retract_project!(::AbstractSphere, X, p, q) return (X .= q ./ real(dot(p, q)) .- p) end @doc raw""" local_metric(M::Sphere{n}, p, ::DefaultOrthonormalBasis) -return the local representation of the metric in a [`DefaultOrthonormalBasis`](@ref), namely +return the local representation of the metric in a [`DefaultOrthonormalBasis`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/bases.html#ManifoldsBase.DefaultOrthonormalBasis), namely the diagonal matrix of size ``nΓ—n`` with ones on the diagonal, since the metric is obtained from the embedding by restriction to the tangent space ``T_p\mathcal M`` at ``p``. """ @@ -375,14 +349,8 @@ Compute the Riemannian [`mean`](@ref mean(M::AbstractManifold, args...)) of `x` """ mean(::AbstractSphere, ::Any...) -function Statistics.mean!( - S::AbstractSphere, - p, - x::AbstractVector, - w::AbstractVector; - kwargs..., -) - return mean!(S, p, x, w, GeodesicInterpolationWithinRadius(Ο€ / 2); kwargs...) +function default_estimation_method(::AbstractSphere, ::typeof(mean)) + return GeodesicInterpolationWithinRadius(Ο€ / 2) end function mid_point!(S::Sphere, q, p1, p2) @@ -444,8 +412,8 @@ end Return the size points on the [`AbstractSphere`](@ref) `M` are represented as, i.e., the representation size of the embedding. """ -@generated representation_size(M::ArraySphere{N}) where {N} = size_to_tuple(N) -@generated representation_size(M::Sphere{N}) where {N} = (N + 1,) +@generated representation_size(::ArraySphere{N}) where {N} = size_to_tuple(N) +@generated representation_size(::Sphere{N}) where {N} = (N + 1,) @doc raw""" retract(M::AbstractSphere, p, X, ::ProjectionRetraction) @@ -458,7 +426,7 @@ Compute the retraction that is based on projection, i.e. """ retract(::AbstractSphere, ::Any, ::Any, ::ProjectionRetraction) -function retract!(M::AbstractSphere, q, p, X, ::ProjectionRetraction) +function retract_project!(M::AbstractSphere, q, p, X) q .= p .+ X return project!(M, q, q) end @@ -480,19 +448,19 @@ function uniform_distribution(M::Sphere{n,ℝ}, p) where {n} end @doc raw""" - vector_transport_to(M::AbstractSphere, p, X, q, ::ParallelTransport) + parallel_transport_to(M::AbstractSphere, p, X, q) Compute the parallel transport on the [`Sphere`](@ref) of the tangent vector `X` at `p` -to `q`, provided, the [`geodesic`](@ref) between `p` and `q` is unique. The formula reads +to `q`, provided, the [`geodesic`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions.html#ManifoldsBase.geodesic-Tuple{AbstractManifold,%20Any,%20Any}) between `p` and `q` is unique. The formula reads ````math P_{p←q}(X) = X - \frac{\Re(⟨\log_p q,X⟩_p)}{d^2_π•Š(p,q)} \bigl(\log_p q + \log_q p \bigr). ```` """ -vector_transport_to(::AbstractSphere, ::Any, ::Any, ::Any, ::Any, ::ParallelTransport) +parallel_transport_to(::AbstractSphere, ::Any, ::Any, ::Any, ::Any) -function vector_transport_to!(::AbstractSphere, Y, p, X, q, ::ParallelTransport) +function parallel_transport_to!(::AbstractSphere, Y, p, X, q) m = p .+ q mnorm2 = real(dot(m, m)) factor = 2 * real(dot(X, q)) / mnorm2 @@ -536,7 +504,7 @@ function get_point!(::Sphere{n,ℝ}, p, ::StereographicAtlas, i::Symbol, x) wher return p end -function get_coordinates!( +function get_coordinates_induced_basis!( ::Sphere{n,ℝ}, Y, p, @@ -555,7 +523,7 @@ function get_coordinates!( return Y end -function get_vector!( +function get_vector_induced_basis!( M::Sphere{n,ℝ}, Y, p, diff --git a/src/manifolds/SphereSymmetricMatrices.jl b/src/manifolds/SphereSymmetricMatrices.jl index 96f4d40a85..3777a0901a 100644 --- a/src/manifolds/SphereSymmetricMatrices.jl +++ b/src/manifolds/SphereSymmetricMatrices.jl @@ -1,7 +1,7 @@ @doc raw""" SphereSymmetricMatrices{n,𝔽} <: AbstractEmbeddedManifold{ℝ,TransparentIsometricEmbedding} -The [`AbstractManifold`](@ref) consisting of the $n Γ— n$ symmetric matrices +The [`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) consisting of the $n Γ— n$ symmetric matrices of unit Frobenius norm, i.e. ````math \mathcal{S}_{\text{sym}} :=\bigl\{p ∈ 𝔽^{n Γ— n}\ \big|\ p^{\mathrm{H}} = p, \lVert p \rVert = 1 \bigr\}, @@ -14,13 +14,16 @@ and the field $𝔽 ∈ \{ ℝ, β„‚\}$. Generate the manifold of `n`-by-`n` symmetric matrices of unit Frobenius norm. """ -struct SphereSymmetricMatrices{N,𝔽} <: - AbstractEmbeddedManifold{𝔽,TransparentIsometricEmbedding} end +struct SphereSymmetricMatrices{N,𝔽} <: AbstractDecoratorManifold{𝔽} end function SphereSymmetricMatrices(n::Int, field::AbstractNumbers=ℝ) return SphereSymmetricMatrices{n,field}() end +function active_traits(f, ::SphereSymmetricMatrices, arge...) + return merge_traits(IsEmbeddedSubmanifold()) +end + @doc raw""" check_point(M::SphereSymmetricMatrices{n,𝔽}, p; kwargs...) @@ -30,8 +33,6 @@ i.e. is an `n`-by-`n` symmetric matrix of unit Frobenius norm. The tolerance for the symmetry of `p` can be set using `kwargs...`. """ function check_point(M::SphereSymmetricMatrices{n,𝔽}, p; kwargs...) where {n,𝔽} - mpv = invoke(check_point, Tuple{supertype(typeof(M)),typeof(p)}, M, p; kwargs...) - mpv === nothing || return mpv if !isapprox(norm(p - p'), 0.0; kwargs...) return DomainError( norm(p - p'), @@ -51,15 +52,6 @@ of unit Frobenius norm. The tolerance for the symmetry of `p` and `X` can be set using `kwargs...`. """ function check_vector(M::SphereSymmetricMatrices{n,𝔽}, p, X; kwargs...) where {n,𝔽} - mpv = invoke( - check_vector, - Tuple{supertype(typeof(M)),typeof(p),typeof(X)}, - M, - p, - X; - kwargs..., - ) - mpv === nothing || return mpv if !isapprox(norm(X - X'), 0.0; kwargs...) return DomainError( norm(X - X'), @@ -69,7 +61,10 @@ function check_vector(M::SphereSymmetricMatrices{n,𝔽}, p, X; kwargs...) where return nothing end -function decorated_manifold(M::SphereSymmetricMatrices{n,𝔽}) where {n,𝔽} +embed(::SphereSymmetricMatrices, p) = p +embed(::SphereSymmetricMatrices, p, X) = X + +function get_embedding(::SphereSymmetricMatrices{n,𝔽}) where {n,𝔽} return ArraySphere(n, n; field=𝔽) end diff --git a/src/manifolds/Stiefel.jl b/src/manifolds/Stiefel.jl index 9ef3ff84fa..bad5405786 100644 --- a/src/manifolds/Stiefel.jl +++ b/src/manifolds/Stiefel.jl @@ -1,5 +1,5 @@ @doc raw""" - Stiefel{n,k,𝔽} <: AbstractEmbeddedManifold{𝔽,DefaultIsometricEmbeddingType} + Stiefel{n,k,𝔽} <: AbstractDecoratorManifold{𝔽} The Stiefel manifold consists of all $n Γ— k$, $n β‰₯ k$ unitary matrices, i.e. @@ -31,30 +31,13 @@ The manifold is named after Generate the (real-valued) Stiefel manifold of $n Γ— k$ dimensional orthonormal matrices. """ -struct Stiefel{n,k,𝔽} <: AbstractEmbeddedManifold{𝔽,DefaultIsometricEmbeddingType} end +struct Stiefel{n,k,𝔽} <: AbstractDecoratorManifold{𝔽} end Stiefel(n::Int, k::Int, field::AbstractNumbers=ℝ) = Stiefel{n,k,field}() -@doc raw""" - PadeRetraction{m} <: AbstractRetractionMethod - -A retraction based on the PadΓ© approximation of order $m$ -""" -struct PadeRetraction{m} <: AbstractRetractionMethod end - -function PadeRetraction(m::Int) - (m < 1) && error( - "The PadΓ© based retraction is only available for positive orders, not for order $m.", - ) - return PadeRetraction{m}() +function active_traits(f, ::Stiefel, args...) + return merge_traits(IsIsometricEmbeddedManifold(), IsDefaultMetric(EuclideanMetric())) end -@doc raw""" - CayleyRetraction <: AbstractRetractionMethod - -A retraction based on the Cayley transform, which is realized by using the -[`PadeRetraction`](@ref)`{1}`. -""" -const CayleyRetraction = PadeRetraction{1} function allocation_promotion_function(::Stiefel{n,k,β„‚}, ::Any, ::Tuple) where {n,k} return complex @@ -64,12 +47,12 @@ end check_point(M::Stiefel, p; kwargs...) Check whether `p` is a valid point on the [`Stiefel`](@ref) `M`=$\operatorname{St}(n,k)$, i.e. that it has the right -[`AbstractNumbers`](@ref) type and $p^{\mathrm{H}}p$ is (approximately) the identity, where $\cdot^{\mathrm{H}}$ is the +[`AbstractNumbers`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#number-system) type and $p^{\mathrm{H}}p$ is (approximately) the identity, where $\cdot^{\mathrm{H}}$ is the complex conjugate transpose. The settings for approximately can be set with `kwargs...`. """ function check_point(M::Stiefel{n,k,𝔽}, p; kwargs...) where {n,k,𝔽} - mpv = invoke(check_point, Tuple{supertype(typeof(M)),typeof(p)}, M, p; kwargs...) - mpv === nothing || return mpv + cks = check_size(M, p) + (cks === nothing) || return cks c = p' * p if !isapprox(c, one(c); kwargs...) return DomainError( @@ -84,21 +67,14 @@ end check_vector(M::Stiefel, p, X; kwargs...) Checks whether `X` is a valid tangent vector at `p` on the [`Stiefel`](@ref) -`M`=$\operatorname{St}(n,k)$, i.e. the [`AbstractNumbers`](@ref) fits and +`M`=$\operatorname{St}(n,k)$, i.e. the [`AbstractNumbers`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#number-system) fits and it (approximately) holds that $p^{\mathrm{H}}X + \overline{X^{\mathrm{H}}p} = 0$, where $\cdot^{\mathrm{H}}$ denotes the Hermitian and $\overline{\cdot}$ the (elementwise) complex conjugate. The settings for approximately can be set with `kwargs...`. """ function check_vector(M::Stiefel{n,k,𝔽}, p, X; kwargs...) where {n,k,𝔽} - mpv = invoke( - check_vector, - Tuple{supertype(typeof(M)),typeof(p),typeof(X)}, - M, - p, - X; - kwargs..., - ) - mpv === nothing || return mpv + cks = check_size(M, p, X) + cks === nothing || return cks if !isapprox(p' * X, -conj(X' * p); kwargs...) return DomainError( norm(p' * X + conj(X' * p)), @@ -108,7 +84,12 @@ function check_vector(M::Stiefel{n,k,𝔽}, p, X; kwargs...) where {n,k,𝔽} return nothing end -decorated_manifold(::Stiefel{N,K,𝔽}) where {N,K,𝔽} = Euclidean(N, K; field=𝔽) +embed(::Stiefel, p) = p +embed(::Stiefel, p, X) = X + +function get_embedding(::Stiefel{N,K,𝔽}) where {N,K,𝔽} + return Euclidean(N, K; field=𝔽) +end @doc raw""" inverse_retract(M::Stiefel, p, q, ::PolarInverseRetraction) @@ -229,7 +210,7 @@ function _stiefel_inv_retr_qr_mul_by_r!( return _stiefel_inv_retr_qr_mul_by_r_generic!(M, X, q, R, A) end -function inverse_retract!(::Stiefel, X, p, q, ::PolarInverseRetraction) +function inverse_retract_polar!(::Stiefel, X, p, q) A = p' * q H = -2 * one(p' * p) B = lyap(A, H) @@ -237,7 +218,7 @@ function inverse_retract!(::Stiefel, X, p, q, ::PolarInverseRetraction) X .-= p return X end -function inverse_retract!(M::Stiefel{n,k}, X, p, q, ::QRInverseRetraction) where {n,k} +function inverse_retract_qr!(M::Stiefel{n,k}, X, p, q) where {n,k} A = p' * q @boundscheck size(A) === (k, k) ElT = typeof(one(eltype(p)) * one(eltype(q))) @@ -332,7 +313,7 @@ the formula reads \operatorname{retr}_pX = \Bigl(I - \frac{1}{2}W_{p,X}\Bigr)^{-1}\Bigl(I + \frac{1}{2}W_{p,X}\Bigr)p. ```` -It is implemented as the case $m=1$ of the [`PadeRetraction`](@ref). +It is implemented as the case $m=1$ of the `PadeRetraction`. [^Zhu2017]: > X. Zhu: @@ -361,7 +342,7 @@ respectively. Then the PadΓ© approximation (of the matrix exponential $\exp(A)$) Defining further ````math W_{p,X} = \operatorname{P}_pXp^{\mathrm{H}} - pX^{\mathrm{H}}\operatorname{P_p} - \quad\text{where}  + \quad\text{where } \operatorname{P}_p = I - \frac{1}{2}pp^{\mathrm{H}} ```` the retraction reads @@ -379,7 +360,7 @@ retract(::Stiefel, ::Any, ::Any, ::PadeRetraction) @doc raw""" retract(M::Stiefel, p, X, ::PolarRetraction) -Compute the SVD-based retraction [`PolarRetraction`](@ref) on the +Compute the SVD-based retraction [`PolarRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.PolarRetraction) on the [`Stiefel`](@ref) manifold `M`. With $USV = p + X$ the retraction reads ````math @@ -391,7 +372,7 @@ retract(::Stiefel, ::Any, ::Any, ::PolarRetraction) @doc raw""" retract(M::Stiefel, p, X, ::QRRetraction) -Compute the QR-based retraction [`QRRetraction`](@ref) on the +Compute the QR-based retraction [`QRRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.QRRetraction) on the [`Stiefel`](@ref) manifold `M`. With $QR = p + X$ the retraction reads ````math @@ -415,7 +396,7 @@ retract(::Stiefel, ::Any, ::Any, ::QRRetraction) _qrfac_to_q(qrfac) = Matrix(qrfac.Q) _qrfac_to_q(qrfac::StaticArrays.QR) = qrfac.Q -function retract!(::Stiefel, q, p, X, ::PadeRetraction{m}) where {m} +function retract_pade!(::Stiefel, q, p, X, m) Pp = I - 1 // 2 * p * p' WpX = Pp * X * p' - p * X' * Pp pm = zeros(eltype(WpX), size(WpX)) @@ -436,11 +417,11 @@ function retract!(::Stiefel, q, p, X, ::PadeRetraction{m}) where {m} end return copyto!(q, (qm \ pm) * p) end -function retract!(::Stiefel, q, p, X, ::PolarRetraction) +function retract_polar!(::Stiefel, q, p, X) s = svd(p + X) return mul!(q, s.U, s.Vt) end -function retract!(::Stiefel, q, p, X, ::QRRetraction) +function retract_qr!(::Stiefel, q, p, X) qrfac = qr(p + X) d = diag(qrfac.R) D = Diagonal(sign.(sign.(d .+ 0.5))) @@ -455,8 +436,6 @@ i.e. `(n,k)`, which is the matrix dimensions. """ @generated representation_size(::Stiefel{n,k}) where {n,k} = (n, k) -Base.show(io::IO, ::CayleyRetraction) = print(io, "CayleyRetraction()") -Base.show(io::IO, ::PadeRetraction{m}) where {m} = print(io, "PadeRetraction($(m))") Base.show(io::IO, ::Stiefel{n,k,F}) where {n,k,F} = print(io, "Stiefel($(n), $(k), $(F))") """ @@ -486,7 +465,7 @@ end @doc raw""" vector_transport_direction(::Stiefel, p, X, d, ::DifferentiatedRetractionVectorTransport{CayleyRetraction}) -Compute the vector transport given by the differentiated retraction of the [`CayleyRetraction`](@ref), cf. [^Zhu2017] Equation (17). +Compute the vector transport given by the differentiated retraction of the [`CayleyRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.CayleyRetraction), cf. [^Zhu2017] Equation (17). The formula reads ````math @@ -496,12 +475,12 @@ The formula reads with ````math W_{p,X} = \operatorname{P}_pXp^{\mathrm{H}} - pX^{\mathrm{H}}\operatorname{P_p} - \quad\text{where}  + \quad\text{where } \operatorname{P}_p = I - \frac{1}{2}pp^{\mathrm{H}} ```` Since this is the differentiated retraction as a vector transport, the result will be in the -tangent space at $q=\operatorname{retr}_p(d)$ using the [`CayleyRetraction`](@ref). +tangent space at $q=\operatorname{retr}_p(d)$ using the [`CayleyRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.CayleyRetraction). """ vector_transport_direction( M::Stiefel, @@ -534,6 +513,7 @@ vector_transport_direction( ::Any, ::DifferentiatedRetractionVectorTransport{PolarRetraction}, ) + @doc raw""" vector_transport_direction(M::Stiefel, p, X, d, DifferentiatedRetractionVectorTransport{QRRetraction}) @@ -568,14 +548,7 @@ vector_transport_direction( ::DifferentiatedRetractionVectorTransport{QRRetraction}, ) -function vector_transport_direction!( - ::Stiefel, - Y, - p, - X, - d, - ::DifferentiatedRetractionVectorTransport{CayleyRetraction}, -) +function vector_transport_direction_diff!(::Stiefel, Y, p, X, d, ::CayleyRetraction) Pp = I - 1 // 2 * p * p' Wpd = Pp * d * p' - p * d' * Pp WpX = Pp * X * p' - p * X' * Pp @@ -583,27 +556,13 @@ function vector_transport_direction!( return copyto!(Y, (q1 \ WpX) * (q1 \ p)) end -function vector_transport_direction!( - M::Stiefel, - Y, - p, - X, - d, - ::DifferentiatedRetractionVectorTransport{PolarRetraction}, -) +function vector_transport_direction_diff!(M::Stiefel, Y, p, X, d, ::PolarRetraction) q = retract(M, p, d, PolarRetraction()) Iddsqrt = sqrt(I + d' * d) Ξ› = sylvester(Iddsqrt, Iddsqrt, -q' * X + X' * q) return copyto!(Y, q * Ξ› + (X - q * (q' * X)) / Iddsqrt) end -function vector_transport_direction!( - M::Stiefel, - Y, - p, - X, - d, - ::DifferentiatedRetractionVectorTransport{QRRetraction}, -) +function vector_transport_direction_diff!(M::Stiefel, Y, p, X, d, ::QRRetraction) q = retract(M, p, d, QRRetraction()) rf = UpperTriangular(qr(p + d).R) Xrf = X / rf @@ -683,27 +642,13 @@ projection it onto the tangent space at `q`. """ vector_transport_to(::Stiefel, ::Any, ::Any, ::Any, ::ProjectionTransport) -function vector_transport_to!( - M::Stiefel, - Y, - p, - X, - q, - ::DifferentiatedRetractionVectorTransport{PolarRetraction}, -) +function vector_transport_to_diff!(M::Stiefel, Y, p, X, q, ::PolarRetraction) d = inverse_retract(M, p, q, PolarInverseRetraction()) Iddsqrt = sqrt(I + d' * d) Ξ› = sylvester(Iddsqrt, Iddsqrt, -q' * X + X' * q) return copyto!(Y, q * Ξ› + (X - q * (q' * X)) / Iddsqrt) end -function vector_transport_to!( - M::Stiefel, - Y, - p, - X, - q, - ::DifferentiatedRetractionVectorTransport{QRRetraction}, -) +function vector_transport_to_diff!(M::Stiefel, Y, p, X, q, ::QRRetraction) d = inverse_retract(M, p, q, QRInverseRetraction()) rf = UpperTriangular(qr(p + d).R) Xrf = X / rf diff --git a/src/manifolds/StiefelEuclideanMetric.jl b/src/manifolds/StiefelEuclideanMetric.jl index cb5b807a0b..789147837a 100644 --- a/src/manifolds/StiefelEuclideanMetric.jl +++ b/src/manifolds/StiefelEuclideanMetric.jl @@ -1,8 +1,3 @@ -default_metric_dispatch(::Stiefel, ::EuclideanMetric) = Val(true) - -function decorator_transparent_dispatch(::typeof(distance), ::Stiefel, args...) - return Val(:intransparent) -end @doc raw""" exp(M::Stiefel, p, X) @@ -63,35 +58,31 @@ trangular entries of $a$ is set to $1$ its symmetric entry to $-1$ and we normal the factor $\frac{1}{\sqrt{2}}$ and for $b$ one can just use unit vectors reshaped to a matrix to obtain orthonormal set of parameters. """ -function get_basis( +get_basis(M::Stiefel{n,k,ℝ}, p, B::DefaultOrthonormalBasis{ℝ,TangentSpaceType}) where {n,k} + +function _get_basis( M::Stiefel{n,k,ℝ}, p, - B::DefaultOrthonormalBasis{ℝ,TangentSpaceType}, + B::DefaultOrthonormalBasis{ℝ,TangentSpaceType}; + kwargs..., ) where {n,k} - V = get_vectors(M, p, B) - return CachedBasis(B, V) + return CachedBasis(B, get_vectors(M, p, B)) end -function get_coordinates!( +function get_coordinates_orthonormal!( M::Stiefel{n,k,ℝ}, c, p, X, - B::DefaultOrthonormalBasis{ℝ,TangentSpaceType}, + N::RealNumbers, ) where {n,k} - V = get_vectors(M, p, B) + V = get_vectors(M, p, DefaultOrthonormalBasis(N)) c .= inner.(Ref(M), Ref(p), V, Ref(X)) return c end -function get_vector!( - M::Stiefel{n,k,ℝ}, - X, - p, - c, - B::DefaultOrthonormalBasis{ℝ,TangentSpaceType}, -) where {n,k} - V = get_vectors(M, p, B) +function get_vector_orthonormal!(M::Stiefel{n,k,ℝ}, X, p, c, N::RealNumbers) where {n,k} + V = get_vectors(M, p, DefaultOrthonormalBasis(N)) zero_vector!(M, X, p) length(c) < length(V) && error( "Coordinate vector too short. Excpected $(length(V)), but only got $(length(c)) entries.", diff --git a/src/manifolds/Symmetric.jl b/src/manifolds/Symmetric.jl index 1829c8570a..959e786cc3 100644 --- a/src/manifolds/Symmetric.jl +++ b/src/manifolds/Symmetric.jl @@ -1,7 +1,7 @@ @doc raw""" - SymmetricMatrices{n,𝔽} <: AbstractEmbeddedManifold{𝔽,TransparentIsometricEmbedding} + SymmetricMatrices{n,𝔽} <: AbstractDecoratorManifold{𝔽} -The [`AbstractManifold`](@ref) $ \operatorname{Sym}(n)$ consisting of the real- or complex-valued +The [`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) $ \operatorname{Sym}(n)$ consisting of the real- or complex-valued symmetric matrices of size $n Γ— n$, i.e. the set ````math @@ -21,12 +21,16 @@ which is also reflected in the [`manifold_dimension`](@ref manifold_dimension(:: Generate the manifold of $n Γ— n$ symmetric matrices. """ -struct SymmetricMatrices{n,𝔽} <: AbstractEmbeddedManifold{𝔽,TransparentIsometricEmbedding} end +struct SymmetricMatrices{n,𝔽} <: AbstractDecoratorManifold{𝔽} end function SymmetricMatrices(n::Int, field::AbstractNumbers=ℝ) return SymmetricMatrices{n,field}() end +function active_traits(f, ::SymmetricMatrices, args...) + return merge_traits(IsEmbeddedSubmanifold()) +end + function allocation_promotion_function( M::SymmetricMatrices{<:Any,β„‚}, ::typeof(get_vector), @@ -40,13 +44,11 @@ end Check whether `p` is a valid manifold point on the [`SymmetricMatrices`](@ref) `M`, i.e. whether `p` is a symmetric matrix of size `(n,n)` with values from the corresponding -[`AbstractNumbers`](@ref) `𝔽`. +[`AbstractNumbers`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#number-system) `𝔽`. The tolerance for the symmetry of `p` can be set using `kwargs...`. """ function check_point(M::SymmetricMatrices{n,𝔽}, p; kwargs...) where {n,𝔽} - mpv = invoke(check_point, Tuple{supertype(typeof(M)),typeof(p)}, M, p; kwargs...) - mpv === nothing || return mpv if !isapprox(norm(p - p'), 0.0; kwargs...) return DomainError( norm(p - p'), @@ -61,20 +63,11 @@ end Check whether `X` is a tangent vector to manifold point `p` on the [`SymmetricMatrices`](@ref) `M`, i.e. `X` has to be a symmetric matrix of size `(n,n)` -and its values have to be from the correct [`AbstractNumbers`](@ref). +and its values have to be from the correct [`AbstractNumbers`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#number-system). The tolerance for the symmetry of `X` can be set using `kwargs...`. """ function check_vector(M::SymmetricMatrices{n,𝔽}, p, X; kwargs...) where {n,𝔽} - mpv = invoke( - check_vector, - Tuple{supertype(typeof(M)),typeof(p),typeof(X)}, - M, - p, - X; - kwargs..., - ) - mpv === nothing || return mpv if !isapprox(norm(X - X'), 0.0; kwargs...) return DomainError( norm(X - X'), @@ -84,7 +77,8 @@ function check_vector(M::SymmetricMatrices{n,𝔽}, p, X; kwargs...) where {n, return nothing end -decorated_manifold(M::SymmetricMatrices{N,𝔽}) where {N,𝔽} = Euclidean(N, N; field=𝔽) +embed(::SymmetricMatrices, p) = p +embed(::SymmetricMatrices, p, X) = X function get_basis(M::SymmetricMatrices, p, B::DiagonalizingOrthonormalBasis) Ξ = get_basis(M, p, DefaultOrthonormalBasis()).data @@ -92,12 +86,12 @@ function get_basis(M::SymmetricMatrices, p, B::DiagonalizingOrthonormalBasis) return CachedBasis(B, ΞΊ, Ξ) end -function get_coordinates!( +function get_coordinates_orthonormal!( M::SymmetricMatrices{N,ℝ}, Y, p, X, - ::DefaultOrthonormalBasis{ℝ,TangentSpaceType}, + ::RealNumbers, ) where {N} dim = manifold_dimension(M) @assert size(Y) == (dim,) @@ -111,12 +105,12 @@ function get_coordinates!( end return Y end -function get_coordinates!( +function get_coordinates_orthonormal!( M::SymmetricMatrices{N,β„‚}, Y, p, X, - ::DefaultOrthonormalBasis{β„‚,TangentSpaceType}, + ::ComplexNumbers, ) where {N} dim = manifold_dimension(M) @assert size(Y) == (dim,) @@ -135,12 +129,14 @@ function get_coordinates!( return Y end -function get_vector!( +get_embedding(::SymmetricMatrices{N,𝔽}) where {N,𝔽} = Euclidean(N, N; field=𝔽) + +function get_vector_orthonormal!( M::SymmetricMatrices{N,ℝ}, Y, p, X, - ::DefaultOrthonormalBasis{ℝ,TangentSpaceType}, + ::RealNumbers, ) where {N} dim = manifold_dimension(M) @assert size(X) == (dim,) @@ -154,12 +150,12 @@ function get_vector!( end return Y end -function get_vector!( +function get_vector_orthonormal!( M::SymmetricMatrices{N,β„‚}, Y, p, X, - ::DefaultOrthonormalBasis{β„‚,TangentSpaceType}, + ::ComplexNumbers, ) where {N} dim = manifold_dimension(M) @assert size(X) == (dim,) diff --git a/src/manifolds/SymmetricPositiveDefinite.jl b/src/manifolds/SymmetricPositiveDefinite.jl index 9ccdda7cd3..e12e70dbca 100644 --- a/src/manifolds/SymmetricPositiveDefinite.jl +++ b/src/manifolds/SymmetricPositiveDefinite.jl @@ -1,5 +1,5 @@ @doc raw""" - SymmetricPositiveDefinite{N} <: AbstractEmbeddedManifold{ℝ,DefaultEmbeddingType} + SymmetricPositiveDefinite{N} <: AbstractDecoratorManifold{𝔽} The manifold of symmetric positive definite matrices, i.e. @@ -26,10 +26,14 @@ i.e. the set of symmetric matrices, generates the manifold $\mathcal P(n) \subset ℝ^{n Γ— n}$ """ -struct SymmetricPositiveDefinite{N} <: AbstractEmbeddedManifold{ℝ,DefaultEmbeddingType} end +struct SymmetricPositiveDefinite{N} <: AbstractDecoratorManifold{ℝ} end SymmetricPositiveDefinite(n::Int) = SymmetricPositiveDefinite{n}() +function active_traits(f, ::SymmetricPositiveDefinite, args...) + return merge_traits(IsEmbeddedManifold(), IsDefaultMetric(LinearAffineMetric())) +end + @doc raw""" check_point(M::SymmetricPositiveDefinite, p; kwargs...) @@ -38,8 +42,6 @@ of size `(N,N)`, symmetric and positive definite. The tolerance for the second to last test can be set using the `kwargs...`. """ function check_point(M::SymmetricPositiveDefinite{N}, p; kwargs...) where {N} - mpv = invoke(check_point, Tuple{supertype(typeof(M)),typeof(p)}, M, p; kwargs...) - mpv === nothing || return mpv if !isapprox(norm(p - transpose(p)), 0.0; kwargs...) return DomainError( norm(p - transpose(p)), @@ -65,15 +67,6 @@ Lie group. The tolerance for the last test can be set using the `kwargs...`. """ function check_vector(M::SymmetricPositiveDefinite{N}, p, X; kwargs...) where {N} - mpv = invoke( - check_vector, - Tuple{supertype(typeof(M)),typeof(p),typeof(X)}, - M, - p, - X; - kwargs..., - ) - mpv === nothing || return mpv if !isapprox(norm(X - transpose(X)), 0.0; kwargs...) return DomainError( X, @@ -83,10 +76,13 @@ function check_vector(M::SymmetricPositiveDefinite{N}, p, X; kwargs...) where {N return nothing end -function decorated_manifold(M::SymmetricPositiveDefinite) +function get_embedding(M::SymmetricPositiveDefinite) return Euclidean(representation_size(M)...; field=ℝ) end +embed(::SymmetricPositiveDefinite, p) = p +embed(::SymmetricPositiveDefinite, p, X) = X + @doc raw""" injectivity_radius(M::SymmetricPositiveDefinite[, p]) injectivity_radius(M::MetricManifold{SymmetricPositiveDefinite,LinearAffineMetric}[, p]) @@ -97,17 +93,9 @@ Since `M` is a Hadamard manifold with respect to the [`LinearAffineMetric`](@ref [`LogCholeskyMetric`](@ref), the injectivity radius is globally $∞$. """ injectivity_radius(::SymmetricPositiveDefinite) = Inf -injectivity_radius(::SymmetricPositiveDefinite, ::ExponentialRetraction) = Inf -injectivity_radius(::SymmetricPositiveDefinite, ::Any) = Inf -injectivity_radius(::SymmetricPositiveDefinite, ::Any, ::ExponentialRetraction) = Inf -eval( - quote - @invoke_maker 1 AbstractManifold injectivity_radius( - M::SymmetricPositiveDefinite, - rm::AbstractRetractionMethod, - ) - end, -) +injectivity_radius(::SymmetricPositiveDefinite, p) = Inf +injectivity_radius(::SymmetricPositiveDefinite, ::AbstractRetractionMethod) = Inf +injectivity_radius(::SymmetricPositiveDefinite, p, ::AbstractRetractionMethod) = Inf @doc raw""" manifold_dimension(M::SymmetricPositiveDefinite) @@ -136,21 +124,15 @@ Compute the Riemannian [`mean`](@ref mean(M::AbstractManifold, args...)) of `x` """ mean(::SymmetricPositiveDefinite, ::Any) -function Statistics.mean!( - M::SymmetricPositiveDefinite, - p, - x::AbstractVector, - w::AbstractVector; - kwargs..., -) - return mean!(M, p, x, w, GeodesicInterpolation(); kwargs...) +function default_estimation_method(::SymmetricPositiveDefinite, ::typeof(mean)) + return GeodesicInterpolation() end @doc raw""" project(M::SymmetricPositiveDefinite, p, X) project a matrix from the embedding onto the tangent space $T_p\mathcal P(n)$ of the -[SymmetricPositiveDefinite](@ref) matrices, i.e. the set of symmetric matrices. +[`SymmetricPositiveDefinite`](@ref) matrices, i.e. the set of symmetric matrices. """ project(::SymmetricPositiveDefinite, p, X) diff --git a/src/manifolds/SymmetricPositiveDefiniteLinearAffine.jl b/src/manifolds/SymmetricPositiveDefiniteLinearAffine.jl index 2ce111c834..ebaa54088f 100644 --- a/src/manifolds/SymmetricPositiveDefiniteLinearAffine.jl +++ b/src/manifolds/SymmetricPositiveDefiniteLinearAffine.jl @@ -53,8 +53,6 @@ function change_metric!(::SymmetricPositiveDefinite, Y, ::EuclideanMetric, p, X) return Y end -default_metric_dispatch(::SymmetricPositiveDefinite, ::LinearAffineMetric) = Val(true) - @doc raw""" distance(M::SymmetricPositiveDefinite, p, q) distance(M::MetricManifold{SymmetricPositiveDefinite,LinearAffineMetric}, p, q) @@ -110,8 +108,8 @@ function exp!(::SymmetricPositiveDefinite{N}, q, p, X) where {N} end @doc raw""" - [Ξ,ΞΊ] = get_basis(M::SymmetricPositiveDefinite, p, B::DiagonalizingOrthonormalBasis) - [Ξ,ΞΊ] = get_basis(M::MetricManifold{SymmetricPositiveDefinite{N},LinearAffineMetric}, p, B::DiagonalizingOrthonormalBasis) + [Ξ,ΞΊ] = get_basis_diagonalizing(M::SymmetricPositiveDefinite, p, B::DiagonalizingOrthonormalBasis) + [Ξ,ΞΊ] = get_basis_diagonalizing(M::MetricManifold{SymmetricPositiveDefinite{N},LinearAffineMetric}, p, B::DiagonalizingOrthonormalBasis) Return a orthonormal basis `Ξ` as a vector of tangent vectors (of length [`manifold_dimension`](@ref) of `M`) in the tangent space of `p` on the @@ -122,7 +120,7 @@ with eigenvalues `ΞΊ` and where the direction `B.frame_direction` ``V`` has curv The construction is based on an ONB for the symmetric matrices similar to [`get_basis(::SymmetricPositiveDefinite, p, ::DefaultOrthonormalBasis`](@ref get_basis(M::SymmetricPositiveDefinite,p,B::DefaultOrthonormalBasis{<:Any,ManifoldsBase.TangentSpaceType})) just that the ONB here is build from the eigen vectors of ``p^{\frac{1}{2}}Vp^{\frac{1}{2}}``. """ -function get_basis( +function get_basis_diagonalizing( ::SymmetricPositiveDefinite{N}, p, B::DiagonalizingOrthonormalBasis, @@ -177,10 +175,12 @@ We then form the ONB by \Xi_{i,j} = p^{\frac{1}{2}}\Delta_{i,j}p^{\frac{1}{2}},\qquad i=1,\ldots,n, j=i,\ldots,n. ``` """ -function get_basis( +get_basis(::SymmetricPositiveDefinite, p, B::DefaultOrthonormalBasis) + +function get_basis_orthonormal( M::SymmetricPositiveDefinite{N}, p, - B::DefaultOrthonormalBasis{<:Any,ManifoldsBase.TangentSpaceType}, + Ns::RealNumbers, ) where {N} e = eigen(Symmetric(p)) U = e.vectors @@ -198,7 +198,7 @@ function get_basis( @inbounds Ξ[k] .= s * pSqrt * Ξ[k] * pSqrt k += 1 end - return CachedBasis(B, Ξ) + return CachedBasis(DefaultOrthonormalBasis(Ns), Ξ) end @doc raw""" @@ -214,12 +214,12 @@ where $k$ is trhe linearized index of the $i=1,\ldots,n, j=i,\ldots,n$. """ get_coordinates(::SymmetricPositiveDefinite, c, p, X, ::DefaultOrthonormalBasis) -function get_coordinates!( +function get_coordinates_orthonormal!( M::SymmetricPositiveDefinite{N}, c, p, X, - ::DefaultOrthonormalBasis{ℝ,TangentSpaceType}, + ::RealNumbers, ) where {N} dim = manifold_dimension(M) @assert size(c) == (dim,) @@ -263,12 +263,12 @@ where $k$ is the linearized index of the $i=1,\ldots,n, j=i,\ldots,n$. """ get_vector(::SymmetricPositiveDefinite, X, p, c, ::DefaultOrthonormalBasis) -function get_vector!( +function get_vector_orthonormal!( ::SymmetricPositiveDefinite{N}, X, p, c, - ::DefaultOrthonormalBasis{ℝ,TangentSpaceType}, + ::RealNumbers, ) where {N} @assert size(c) == (div(N * (N + 1), 2),) @assert size(X) == (N, N) @@ -326,7 +326,7 @@ where $\operatorname{Log}$ denotes to the matrix logarithm. """ log(::SymmetricPositiveDefinite, ::Any...) -function log!(M::SymmetricPositiveDefinite{N}, X, p, q) where {N} +function log!(::SymmetricPositiveDefinite{N}, X, p, q) where {N} e = eigen(Symmetric(p)) U = e.vectors S = max.(e.values, floatmin(eltype(e.values))) @@ -342,8 +342,8 @@ function log!(M::SymmetricPositiveDefinite{N}, X, p, q) where {N} end @doc raw""" - vector_transport_to(M::SymmetricPositiveDefinite, p, X, q, ::ParallelTransport) - vector_transport_to(M::MetricManifold{SymmetricPositiveDefinite,LinearAffineMetric}, p, X, y, ::ParallelTransport) + parallel_transport_to(M::SymmetricPositiveDefinite, p, X, q) + parallel_transport_to(M::MetricManifold{SymmetricPositiveDefinite,LinearAffineMetric}, p, X, y) Compute the parallel transport of `X` from the tangent space at `p` to the tangent space at `q` on the [`SymmetricPositiveDefinite`](@ref) as a @@ -366,16 +366,9 @@ where $\operatorname{Exp}$ denotes the matrix exponential and `log` the logarithmic map on [`SymmetricPositiveDefinite`](@ref) (again with respect to the [`LinearAffineMetric`](@ref)). """ -vector_transport_to(::SymmetricPositiveDefinite, ::Any, ::Any, ::Any, ::ParallelTransport) +parallel_transport_to(::SymmetricPositiveDefinite, ::Any, ::Any, ::Any) -function vector_transport_to!( - M::SymmetricPositiveDefinite{N}, - Y, - p, - X, - q, - ::ParallelTransport, -) where {N} +function parallel_transport_to!(M::SymmetricPositiveDefinite{N}, Y, p, X, q) where {N} distance(M, p, q) < 2 * eps(eltype(p)) && copyto!(Y, X) e = eigen(Symmetric(p)) U = e.vectors diff --git a/src/manifolds/SymmetricPositiveDefiniteLogCholesky.jl b/src/manifolds/SymmetricPositiveDefiniteLogCholesky.jl index c8d9cb7451..0f3b186e73 100644 --- a/src/manifolds/SymmetricPositiveDefiniteLogCholesky.jl +++ b/src/manifolds/SymmetricPositiveDefiniteLogCholesky.jl @@ -40,7 +40,7 @@ $⌊\cdotβŒ‹$ denbotes the strictly lower triangular matrix of its argument, and $\lVert\cdot\rVert_{\mathrm{F}}$ the Frobenius norm. """ function distance( - M::MetricManifold{ℝ,SymmetricPositiveDefinite{N},LogCholeskyMetric}, + ::MetricManifold{ℝ,SymmetricPositiveDefinite{N},LogCholeskyMetric}, p, q, ) where {N} @@ -65,7 +65,7 @@ denotes the lower triangular matrix with the diagonal multiplied by $\frac{1}{2} exp(::MetricManifold{ℝ,SymmetricPositiveDefinite,LogCholeskyMetric}, ::Any...) function exp!( - M::MetricManifold{ℝ,SymmetricPositiveDefinite{N},LogCholeskyMetric}, + ::MetricManifold{ℝ,SymmetricPositiveDefinite{N},LogCholeskyMetric}, q, p, X, @@ -92,7 +92,7 @@ $a_z(W) = z (z^{-1}Wz^{-\mathrm{T}})_{\frac{1}{2}}$, and $(\cdot)_\frac{1}{2}$ denotes the lower triangular matrix with the diagonal multiplied by $\frac{1}{2}$ """ function inner( - M::MetricManifold{ℝ,SymmetricPositiveDefinite{N},LogCholeskyMetric}, + ::MetricManifold{ℝ,SymmetricPositiveDefinite{N},LogCholeskyMetric}, p, X, Y, @@ -117,7 +117,7 @@ of $q$ and the just mentioned logarithmic map is the one on [`CholeskySpace`](@r log(::MetricManifold{ℝ,SymmetricPositiveDefinite,LogCholeskyMetric}, ::Any...) function log!( - M::MetricManifold{ℝ,SymmetricPositiveDefinite{N},LogCholeskyMetric}, + ::MetricManifold{ℝ,SymmetricPositiveDefinite{N},LogCholeskyMetric}, X, p, q, @@ -149,24 +149,22 @@ transport on [`CholeskySpace`](@ref) from $x$ to $y$. The formula hear reads \mathcal P_{q←p}X = yV^{\mathrm{T}} + Vy^{\mathrm{T}}. ```` """ -vector_transport_to( +parallel_transport_to( ::MetricManifold{ℝ,SymmetricPositiveDefinite,LogCholeskyMetric}, ::Any, ::Any, ::Any, - ::ParallelTransport, ) -function vector_transport_to!( - M::MetricManifold{ℝ,SymmetricPositiveDefinite{N},LogCholeskyMetric}, +function parallel_transport_to!( + ::MetricManifold{ℝ,SymmetricPositiveDefinite{N},LogCholeskyMetric}, Y, p, X, q, - m::ParallelTransport, ) where {N} y = cholesky(q).L (x, W) = spd_to_cholesky(p, X) - vector_transport_to!(CholeskySpace{N}(), Y, x, W, y, m) + parallel_transport_to!(CholeskySpace{N}(), Y, x, W, y) return tangent_cholesky_to_tangent_spd!(y, Y) end diff --git a/src/manifolds/SymmetricPositiveDefiniteLogEuclidean.jl b/src/manifolds/SymmetricPositiveDefiniteLogEuclidean.jl index a0a1f7bc9a..3b1d93b096 100644 --- a/src/manifolds/SymmetricPositiveDefiniteLogEuclidean.jl +++ b/src/manifolds/SymmetricPositiveDefiniteLogEuclidean.jl @@ -21,7 +21,7 @@ where $\operatorname{Log}$ denotes the matrix logarithm and $\lVert\cdot\rVert_{\mathrm{F}}$ denotes the matrix Frobenius norm. """ function distance( - M::MetricManifold{ℝ,SymmetricPositiveDefinite{N},LogEuclideanMetric}, + ::MetricManifold{ℝ,SymmetricPositiveDefinite{N},LogEuclideanMetric}, p, q, ) where {N} diff --git a/src/manifolds/SymmetricPositiveSemidefiniteFixedRank.jl b/src/manifolds/SymmetricPositiveSemidefiniteFixedRank.jl index 08c8b14860..a2bc32a6ac 100644 --- a/src/manifolds/SymmetricPositiveSemidefiniteFixedRank.jl +++ b/src/manifolds/SymmetricPositiveSemidefiniteFixedRank.jl @@ -1,7 +1,7 @@ @doc raw""" - SymmetricPositiveSemidefiniteFixedRank{n,k,𝔽} <: AbstractEmbeddedManifold{𝔽,DefaultIsometricEmbeddingType} + SymmetricPositiveSemidefiniteFixedRank{n,k,𝔽} <: AbstractDecoratorManifold{𝔽} -The [`AbstractManifold`](@ref) $ \operatorname{SPS}_k(n)$ consisting of the real- or complex-valued +The [`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) $ \operatorname{SPS}_k(n)$ consisting of the real- or complex-valued symmetric positive semidefinite matrices of size $n Γ— n$ and rank $k$, i.e. the set ````math @@ -53,19 +53,22 @@ over the `field` of real numbers `ℝ` or complex numbers `β„‚`. > doi: [10.1137/18m1231389](https://doi.org/10.1137/18m1231389), > preprint: [sites.uclouvain.be/absil/2018.06](https://sites.uclouvain.be/absil/2018.06). """ -struct SymmetricPositiveSemidefiniteFixedRank{n,k,𝔽} <: - AbstractEmbeddedManifold{𝔽,DefaultIsometricEmbeddingType} end +struct SymmetricPositiveSemidefiniteFixedRank{n,k,𝔽} <: AbstractDecoratorManifold{𝔽} end function SymmetricPositiveSemidefiniteFixedRank(n::Int, k::Int, field::AbstractNumbers=ℝ) return SymmetricPositiveSemidefiniteFixedRank{n,k,field}() end +function active_traits(f, ::SymmetricPositiveSemidefiniteFixedRank, args...) + return merge_traits(IsIsometricEmbeddedManifold()) +end + @doc raw""" check_point(M::SymmetricPositiveSemidefiniteFixedRank{n,𝔽}, q; kwargs...) Check whether `q` is a valid manifold point on the [`SymmetricPositiveSemidefiniteFixedRank`](@ref) `M`, i.e. whether `p=q*q'` is a symmetric matrix of size `(n,n)` with values from the corresponding -[`AbstractNumbers`](@ref) `𝔽`. +[`AbstractNumbers`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#number-system) `𝔽`. The symmetry of `p` is not explicitly checked since by using `q` p is symmetric by construction. The tolerance for the symmetry of `p` can and the rank of `q*q'` be set using `kwargs...`. """ @@ -74,8 +77,6 @@ function check_point( q; kwargs..., ) where {n,k,𝔽} - mpv = invoke(check_point, Tuple{supertype(typeof(M)),typeof(q)}, M, q; kwargs...) - mpv === nothing || return mpv p = q * q' r = rank(p * p'; kwargs...) if r < k @@ -92,27 +93,13 @@ end Check whether `X` is a tangent vector to manifold point `p` on the [`SymmetricPositiveSemidefiniteFixedRank`](@ref) `M`, i.e. `X` has to be a symmetric matrix of size `(n,n)` -and its values have to be from the correct [`AbstractNumbers`](@ref). -The tolerance for the symmetry of `X` can be set using `kwargs...`. +and its values have to be from the correct [`AbstractNumbers`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#number-system). + +Due to the reduced representation this is fulfilled as soon as the matrix is of correct size. """ -function check_vector( - M::SymmetricPositiveSemidefiniteFixedRank{n,k,𝔽}, - q, - Y; - kwargs..., -) where {n,k,𝔽} - mpv = invoke( - check_vector, - Tuple{supertype(typeof(M)),typeof(q),typeof(Y)}, - M, - q, - Y; - kwargs..., - ) - return mpv -end +check_vector(M::SymmetricPositiveSemidefiniteFixedRank, q, Y; kwargs...) -function decorated_manifold(::SymmetricPositiveSemidefiniteFixedRank{N,K,𝔽}) where {N,K,𝔽} +function get_embedding(::SymmetricPositiveSemidefiniteFixedRank{N,K,𝔽}) where {N,K,𝔽} return Euclidean(N, K; field=𝔽) end @@ -241,14 +228,7 @@ vector_transport_to( ::ProjectionTransport, ) -function vector_transport_to!( - M::SymmetricPositiveSemidefiniteFixedRank, - Y, - p, - X, - q, - ::ProjectionTransport, -) +function vector_transport_to_project!(M::SymmetricPositiveSemidefiniteFixedRank, Y, p, X, q) project!(M, Y, q, X) return Y end diff --git a/src/manifolds/Symplectic.jl b/src/manifolds/Symplectic.jl index 44e73792c5..ec69706431 100644 --- a/src/manifolds/Symplectic.jl +++ b/src/manifolds/Symplectic.jl @@ -38,10 +38,14 @@ Generate the (real-valued) symplectic manifold of ``2n \times 2n`` symplectic ma The constructor for the [`Symplectic`](@ref) manifold accepts the even column/row embedding dimension ``2n`` for the real symplectic manifold, ``ℝ^{2n Γ— 2n}``. """ -struct Symplectic{n,𝔽} <: AbstractEmbeddedManifold{𝔽,DefaultIsometricEmbeddingType} end +struct Symplectic{n,𝔽} <: AbstractDecoratorManifold{𝔽} end + +function active_traits(f, ::Symplectic, args...) + return merge_traits(IsEmbeddedManifold(), IsDefaultMetric(RealSymplecticMetric())) +end function Symplectic(n::Int, field::AbstractNumbers=ℝ) - n % 2 == 0 || throw(ArgumentError("The dimensionality of the symplectic manifold + n % 2 == 0 || throw(ArgumentError("The dimension of the symplectic manifold embedding space must be even. Was odd, n % 2 == $(n % 2).")) return Symplectic{div(n, 2),field}() end @@ -73,8 +77,6 @@ as an inner product over the embedding space ``ℝ^{2n \times 2n}``, i.e. """ struct ExtendedSymplecticMetric <: AbstractMetric end -struct CayleyInverseRetraction <: AbstractInverseRetractionMethod end - @doc raw""" SymplecticMatrix{T} @@ -195,7 +197,7 @@ end check_point(M::Symplectic, p; kwargs...) Check whether `p` is a valid point on the [`Symplectic`](@ref) `M`=$\operatorname{Sp}(2n)$, -i.e. that it has the right [`AbstractNumbers`](@ref) type and $p^{+}p$ is (approximately) +i.e. that it has the right [`AbstractNumbers`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#number-system) type and $p^{+}p$ is (approximately) the identity, where $A^{+} = Q_{2n}^TA^TQ_{2n}$ is the symplectic inverse, with ````math Q_{2n} = @@ -207,11 +209,6 @@ Q_{2n} = The tolerance can be set with `kwargs...` (e.g. `atol = 1.0e-14`). """ function check_point(M::Symplectic{n,ℝ}, p; kwargs...) where {n,ℝ} - abstract_embedding_type = supertype(typeof(M)) - - mpv = invoke(check_point, Tuple{abstract_embedding_type,typeof(p)}, M, p; kwargs...) - mpv === nothing || return mpv - # Perform check that the matrix lives on the real symplectic manifold: expected_zero = norm(inv(M, p) * p - LinearAlgebra.I) if !isapprox(expected_zero, zero(eltype(p)); kwargs...) @@ -230,7 +227,7 @@ end check_vector(M::Symplectic, p, X; kwargs...) Checks whether `X` is a valid tangent vector at `p` on the [`Symplectic`](@ref) -`M`=``\operatorname{Sp}(2n)``, i.e. the [`AbstractNumbers`](@ref) fits and +`M`=``\operatorname{Sp}(2n)``, i.e. the [`AbstractNumbers`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#number-system) fits and it (approximately) holds that ``p^{T}Q_{2n}X + X^{T}Q_{2n}p = 0``, where ````math @@ -245,21 +242,8 @@ The tolerance can be set with `kwargs...` (e.g. `atol = 1.0e-14`). check_vector(::Symplectic, ::Any...) function check_vector(M::Symplectic{n}, p, X; kwargs...) where {n} - abstract_embedding_type = supertype(typeof(M)) - - mpv = invoke( - check_vector, - Tuple{abstract_embedding_type,typeof(p),typeof(X)}, - M, - p, - X; - kwargs..., - ) - mpv === nothing || return mpv - Q = SymplecticMatrix(p, X) tangent_requirement_norm = norm(X' * Q * p + p' * Q * X, 2) - if !isapprox(tangent_requirement_norm, 0.0; kwargs...) return DomainError( tangent_requirement_norm, @@ -272,10 +256,6 @@ function check_vector(M::Symplectic{n}, p, X; kwargs...) where {n} return nothing end -decorated_manifold(::Symplectic{n,ℝ}) where {n} = Euclidean(2n, 2n; field=ℝ) - -default_metric_dispatch(::Symplectic{n,ℝ}, ::RealSymplecticMetric) where {n,ℝ} = Val(true) - ManifoldsBase.default_inverse_retraction_method(::Symplectic) = CayleyInverseRetraction() ManifoldsBase.default_retraction_method(::Symplectic) = CayleyRetraction() @@ -326,6 +306,9 @@ function distance(M::Symplectic{n}, p, q) where {n} return norm(log(symplectic_inverse_times(M, p, q))) end +embed(::Symplectic, p) = p +embed(::Symplectic, p, X) = X + @doc raw""" exp(M::Symplectic, p, X) exp!(M::Symplectic, q, p, X) @@ -355,6 +338,8 @@ function exp!(M::Symplectic, q, p, X) return q end +get_embedding(::Symplectic{n,ℝ}) where {n} = Euclidean(2n, 2n; field=ℝ) + @doc raw""" gradient(M::Symplectic, f, p, backend::RiemannianProjectionBackend; extended_metric=true) @@ -514,7 +499,7 @@ function inv!(::Symplectic{n,ℝ}, A) where {n} end @doc raw""" - inverse_retract!(M::Symplectic, X, p, q, ::CayleyInverseRetraction) + inverse_retract(M::Symplectic, p, q, ::CayleyInverseRetraction) Compute the Cayley Inverse Retraction ``X = \mathcal{L}_p^{\operatorname{Sp}}(q)`` such that the Cayley Retraction from ``p`` along ``X`` lands at ``q``, i.e. @@ -547,7 +532,9 @@ If that is the case, the inverse cayley retration at ``p`` applied to ``q`` is > The real symplectic Stiefel and Grassmann manifolds: metrics, geodesics and applications > arXiv preprint arXiv:2108.12447, 2021 (https://arxiv.org/abs/2108.12447) """ -function inverse_retract!(M::Symplectic, X, p, q, ::CayleyInverseRetraction) +inverse_retract(::Symplectic, p, q, ::CayleyInverseRetraction) + +function inverse_retract_caley!(M::Symplectic, X, p, q) U_inv = lu(add_scaled_I!(symplectic_inverse_times(M, p, q), 1)) V_inv = lu(add_scaled_I!(symplectic_inverse_times(M, q, p), 1)) @@ -765,7 +752,9 @@ denotes the PadΓ© (1, 1) approximation to ``\operatorname{exp}(z)``. > Monatshefte f{\"u}r Mathematik, Springer > doi [10.1007/s00605-020-01369-9](https://doi.org/10.1007/s00605-020-01369-9) """ -function retract!(M::Symplectic, q, p, X, ::CayleyRetraction) +retract(M::Symplectic, p, X) + +function retract_caley!(M::Symplectic, q, p, X) p_star_X = symplectic_inverse_times(M, p, X) divisor = lu(2 * I - p_star_X) diff --git a/src/manifolds/SymplecticStiefel.jl b/src/manifolds/SymplecticStiefel.jl index 605e6e0374..6634b0f8d1 100644 --- a/src/manifolds/SymplecticStiefel.jl +++ b/src/manifolds/SymplecticStiefel.jl @@ -39,13 +39,15 @@ The constructor for the [`SymplecticStiefel`](@ref) manifold accepts the even co dimension ``2n`` and an even number of columns ``2k`` for the real symplectic Stiefel manifold with elements ``p \in ℝ^{2n Γ— 2k}``. """ -struct SymplecticStiefel{n,k,𝔽} <: AbstractEmbeddedManifold{𝔽,DefaultIsometricEmbeddingType} end +struct SymplecticStiefel{n,k,𝔽} <: AbstractDecoratorManifold{𝔽} end function SymplecticStiefel(two_n::Int, two_k::Int, field::AbstractNumbers=ℝ) return SymplecticStiefel{div(two_n, 2),div(two_k, 2),field}() end -decorated_manifold(::SymplecticStiefel{n,k,ℝ}) where {n,k} = Euclidean(2n, 2k; field=ℝ) +function active_traits(f, ::SymplecticStiefel, args...) + return merge_traits(IsEmbeddedManifold(), IsDefaultMetric(RealSymplecticMetric())) +end function ManifoldsBase.default_inverse_retraction_method(::SymplecticStiefel) return CayleyInverseRetraction() @@ -79,7 +81,7 @@ end Check whether `p` is a valid point on the [`SymplecticStiefel`](@ref), $\operatorname{SpSt}(2n, 2k)$ manifold. -That is, the point has the right [`AbstractNumbers`](@ref) type and $p^{+}p$ is +That is, the point has the right [`AbstractNumbers`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#number-system) type and $p^{+}p$ is (approximately) the identity, where for $A \in \mathbb{R}^{2n \times 2k}$, $A^{+} = Q_{2k}^TA^TQ_{2n}$ is the symplectic inverse, with @@ -93,11 +95,6 @@ Q_{2n} = The tolerance can be set with `kwargs...` (e.g. `atol = 1.0e-14`). """ function check_point(M::SymplecticStiefel{n,k}, p; kwargs...) where {n,k} - abstract_embedding_type = supertype(typeof(M)) - - mpv = invoke(check_point, Tuple{abstract_embedding_type,typeof(p)}, M, p; kwargs...) - mpv === nothing || return mpv - # Perform check that the matrix lives on the real symplectic manifold: expected_zero = norm(inv(M, p) * p - I) if !isapprox(expected_zero, 0; kwargs...) @@ -137,20 +134,7 @@ The tolerance can be set with `kwargs...` (e.g. `atol = 1.0e-14`). check_vector(::SymplecticStiefel, ::Any...) function check_vector(M::SymplecticStiefel{n,k,field}, p, X; kwargs...) where {n,k,field} - abstract_embedding_type = supertype(typeof(M)) - - mpv = invoke( - check_vector, - Tuple{abstract_embedding_type,typeof(p),typeof(X)}, - M, - p, - X; - kwargs..., - ) - mpv === nothing || return mpv - # From Bendokat-Zimmermann: T_pSpSt(2n, 2k) = \{p*H | H^{+} = -H \} - H = inv(M, p) * X # ∈ ℝ^{2k Γ— 2k}, should be Hamiltonian. H_star = inv(Symplectic(2k, field), H) hamiltonian_identity_norm = norm(H + H_star) @@ -292,6 +276,8 @@ function exp!(M::SymplecticStiefel{n,k}, q, p, X) where {n,k} return q end +get_embedding(::SymplecticStiefel{n,k,ℝ}) where {n,k} = Euclidean(2n, 2k; field=ℝ) + @doc raw""" gradient(::SymplecticStiefel, f, p, backend::RiemannianProjectionBackend) gradient!(::SymplecticStiefel, f, X, p, backend::RiemannianProjectionBackend) @@ -453,7 +439,7 @@ If that is the case, the inverse cayley retration at ``p`` applied to ``q`` is """ inverse_retract(::SymplecticStiefel, p, q, ::CayleyInverseRetraction) -function inverse_retract!(M::SymplecticStiefel, X, p, q, ::CayleyInverseRetraction) +function inverse_retract_caley!(M::SymplecticStiefel, X, p, q) U_inv = lu!(add_scaled_I!(symplectic_inverse_times(M, p, q), 1)) V_inv = lu!(add_scaled_I!(symplectic_inverse_times(M, q, p), 1)) @@ -592,7 +578,7 @@ It is this expression we compute inplace of `q`. """ retract(::SymplecticStiefel, p, X, ::CayleyRetraction) -function retract!(M::SymplecticStiefel, q, p, X, ::CayleyRetraction) +function retract_caley!(M::SymplecticStiefel, q, p, X) # Define intermediate matrices for later use: A = symplectic_inverse_times(M, p, X) diff --git a/src/manifolds/Torus.jl b/src/manifolds/Torus.jl index a766513326..6a5b934b78 100644 --- a/src/manifolds/Torus.jl +++ b/src/manifolds/Torus.jl @@ -4,7 +4,7 @@ The n-dimensional torus is the $n$-dimensional product of the [`Circle`](@ref). The [`Circle`](@ref) is stored internally within `M.manifold`, such that all functions of -[`AbstractPowerManifold`](@ref) can be used directly. +[`AbstractPowerManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/manifolds.html#ManifoldsBase.AbstractPowerManifold) can be used directly. """ struct Torus{N} <: AbstractPowerManifold{ℝ,Circle{ℝ},ArrayPowerRepresentation} manifold::Circle{ℝ} @@ -22,12 +22,6 @@ its entries is a valid point on the [`Circle`](@ref) and the length of `x` is `n """ check_point(::Torus, ::Any) function check_point(M::Torus{N}, p; kwargs...) where {N} - if length(p) != N - return DomainError( - length(p), - "The number of elements in `p` ($(length(p))) does not match the dimension of the torus ($(N)).", - ) - end return check_point(PowerManifold(M.manifold, N), p; kwargs...) end @doc raw""" @@ -38,19 +32,15 @@ This means, that `p` is valid, that `X` is of correct dimension and elementwise a tangent vector to the elements of `p` on the [`Circle`](@ref). """ function check_vector(M::Torus{N}, p, X; kwargs...) where {N} - if length(X) != N - return DomainError( - length(X), - "The number of elements in `X` ($(length(X))) does not match the dimension of the torus ($(N)).", - ) - end return check_vector(PowerManifold(M.manifold, N), p, X; kwargs...) end get_iterator(::Torus{N}) where {N} = 1:N -@generated manifold_dimension(::Torus{N}) where {N} = N +manifold_dimension(::Torus{N}) where {N} = N -@generated representation_size(::Torus{N}) where {N} = (N,) +power_dimensions(::Torus{N}) where {N} = (N,) + +representation_size(::Torus{N}) where {N} = (N,) Base.show(io::IO, ::Torus{N}) where {N} = print(io, "Torus($(N))") diff --git a/src/manifolds/Tucker.jl b/src/manifolds/Tucker.jl index 8e1accc125..1fa3993bdb 100644 --- a/src/manifolds/Tucker.jl +++ b/src/manifolds/Tucker.jl @@ -420,11 +420,11 @@ end @doc raw""" Base.foreach(f, M::Tucker, p::TuckerPoint, basis::AbstractBasis, indices=1:manifold_dimension(M)) -Let `basis` be and [`AbstractBasis`](@ref) at a point `p` on `M`. Suppose `f` is a function +Let `basis` be and [`AbstractBasis`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/bases.html#ManifoldsBase.AbstractBasis) at a point `p` on `M`. Suppose `f` is a function that takes an index and a vector as an argument. This function applies `f` to `i` and the `i`th basis vector sequentially for each `i` in `indices`. -Using a [`CachedBasis`](@ref) may speed up the computation. +Using a [`CachedBasis`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/bases.html#ManifoldsBase.CachedBasis) may speed up the computation. **NOTE**: The i'th basis vector is overwritten in each iteration. If any information about the vector is to be stored, `f` must make a copy. @@ -614,13 +614,7 @@ inverse_retract( ::ProjectionInverseRetraction, ) -function inverse_retract!( - β„³::Tucker, - X, - 𝔄::TuckerPoint, - 𝔅::TuckerPoint, - ::ProjectionInverseRetraction, -) +function inverse_retract_project!(β„³::Tucker, X, 𝔄::TuckerPoint, 𝔅::TuckerPoint) diffVector = embed(β„³, 𝔅) - embed(β„³, 𝔄) return project!(β„³, X, 𝔄, diffVector) end @@ -696,12 +690,11 @@ retraction produces a boundary point, which is outside the manifold. """ retract(::Tucker, ::Any, ::Any, ::PolarRetraction) -function retract!( +function retract_polar!( ::Tucker, q::TuckerPoint, p::TuckerPoint{T,D}, x::TuckerTVector, - ::PolarRetraction, ) where {T,D} U = p.hosvd.U V = x.UΜ‡ diff --git a/src/manifolds/VectorBundle.jl b/src/manifolds/VectorBundle.jl index 56dc51c8a4..93bc75e93d 100644 --- a/src/manifolds/VectorBundle.jl +++ b/src/manifolds/VectorBundle.jl @@ -73,14 +73,14 @@ const TangentSpaceAtPoint{M} = TangentSpaceAtPoint(M::AbstractManifold, p) Return an object of type [`VectorSpaceAtPoint`](@ref) representing tangent -space at `p` on the [`AbstractManifold`](@ref) `M`. +space at `p` on the [`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) `M`. """ TangentSpaceAtPoint(M::AbstractManifold, p) = VectorSpaceAtPoint(TangentBundleFibers(M), p) """ TangentSpace(M::AbstractManifold, p) -Return a [`TangentSpaceAtPoint`](@ref) representing tangent space at `p` on the [`AbstractManifold`](@ref) `M`. +Return a [`TangentSpaceAtPoint`](@ref) representing tangent space at `p` on the [`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) `M`. """ TangentSpace(M::AbstractManifold, p) = VectorSpaceAtPoint(TangentBundleFibers(M), p) @@ -117,7 +117,7 @@ end """ VectorBundle{𝔽,TVS<:VectorSpaceType,TM<:AbstractManifold{𝔽}} <: AbstractManifold{𝔽} -Vector bundle on a [`AbstractManifold`](@ref) `M` of type [`VectorSpaceType`](@ref). +Vector bundle on a [`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) `M` of type [`VectorSpaceType`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/bases.html#ManifoldsBase.VectorSpaceType). # Constructor @@ -286,7 +286,7 @@ function exp!(B::VectorBundle, q, p, X) return q end function exp!(M::TangentSpaceAtPoint, q, p, X) - copyto!(q, p + X) + copyto!(M.fiber.manifold, q, p + X) return q end @@ -318,25 +318,6 @@ function get_basis(M::VectorBundle, p, B::DiagonalizingOrthonormalBasis) return CachedBasis(B, VectorBundleBasisData(b1, b2)) end -for BT in [ - DefaultOrthonormalBasis, - DefaultOrthonormalBasis{<:Any,TangentSpaceType}, - ProjectedOrthonormalBasis{:gram_schmidt,ℝ}, - ProjectedOrthonormalBasis{:svd,ℝ}, -] - eval(quote - @invoke_maker 3 AbstractBasis get_basis(M::VectorBundle, p, B::$BT) - end) - eval( - quote - @invoke_maker 3 AbstractBasis{<:Any,TangentSpaceType} get_basis( - M::TangentSpaceAtPoint, - p, - B::$BT, - ) - end, - ) -end function get_basis(M::TangentBundleFibers, p, B::AbstractBasis{<:Any,TangentSpaceType}) return get_basis(M.manifold, p, B) end @@ -344,6 +325,32 @@ function get_basis(M::TangentSpaceAtPoint, p, B::AbstractBasis{<:Any,TangentSpac return get_basis(M.fiber.manifold, M.point, B) end +function get_coordinates(M::TangentBundleFibers, p, X, B::AbstractBasis) + return get_coordinates(M.manifold, p, X, B) +end + +function get_coordinates!(M::TangentBundleFibers, Y, p, X, B::AbstractBasis) + return get_coordinates!(M.manifold, Y, p, X, B) +end + +function get_coordinates(M::TangentSpaceAtPoint, p, X, B::AbstractBasis) + return get_coordinates(M.fiber.manifold, M.point, X, B) +end + +function get_coordinates!(M::TangentSpaceAtPoint, Y, p, X, B::AbstractBasis) + return get_coordinates!(M.fiber.manifold, Y, M.point, X, B) +end + +function get_coordinates(M::VectorBundle, p, X, B::AbstractBasis) + px, Vx = submanifold_components(M.manifold, p) + VXM, VXF = submanifold_components(M.manifold, X) + n = manifold_dimension(M.manifold) + return vcat( + get_coordinates(M.manifold, px, VXM, B), + get_coordinates(M.fiber, px, VXF, B), + ) +end + function get_coordinates!(M::VectorBundle, Y, p, X, B::AbstractBasis) px, Vx = submanifold_components(M.manifold, p) VXM, VXF = submanifold_components(M.manifold, X) @@ -352,86 +359,67 @@ function get_coordinates!(M::VectorBundle, Y, p, X, B::AbstractBasis) get_coordinates!(M.fiber, view(Y, (n + 1):length(Y)), px, VXF, B) return Y end -function get_coordinates!( + +function get_coordinates( M::VectorBundle, - Y, p, X, B::CachedBasis{𝔽,<:AbstractBasis{𝔽},<:VectorBundleBasisData}, ) where {𝔽} px, Vx = submanifold_components(M.manifold, p) VXM, VXF = submanifold_components(M.manifold, X) - n = manifold_dimension(M.manifold) - get_coordinates!(M.manifold, view(Y, 1:n), px, VXM, B.data.base_basis) - get_coordinates!(M.fiber, view(Y, (n + 1):length(Y)), px, VXF, B.data.vec_basis) - return Y -end -for BT in [ - DefaultBasis, - DefaultOrthogonalBasis, - DefaultOrthonormalBasis, - ProjectedOrthonormalBasis{:gram_schmidt,ℝ}, - ProjectedOrthonormalBasis{:svd,ℝ}, - VeeOrthogonalBasis, -] - eval( - quote - @invoke_maker 5 AbstractBasis get_coordinates!( - M::VectorBundle, - Y, - p, - X, - B::$BT, - ) - end, - ) - eval( - quote - @invoke_maker 5 AbstractBasis{<:Any,TangentSpaceType} get_coordinates!( - M::TangentSpaceAtPoint, - Y, - p, - X, - B::$BT, - ) - end, + return vcat( + get_coordinates(M.manifold, px, VXM, B.data.base_basis), + get_coordinates(M.fiber, px, VXF, B.data.vec_basis), ) end -function get_coordinates!(M::VectorBundle, Y, p, X, B::CachedBasis) - return error( - "get_coordinates! called on $M with an incorrect CachedBasis. Expected a CachedBasis with VectorBundleBasisData, given $B", - ) -end -function get_coordinates!(M::TangentSpaceAtPoint, Y, p, X, B::CachedBasis) - return get_coordinates!(M.fiber.manifold, Y, M.point, X, B) -end function get_coordinates!( - M::TangentBundleFibers, + M::VectorBundle, Y, p, X, - B::ManifoldsBase.all_uncached_bases{TangentSpaceType}, -) - return get_coordinates!(M.manifold, Y, p, X, B) + B::CachedBasis{𝔽,<:AbstractBasis{𝔽},<:VectorBundleBasisData}, +) where {𝔽} + px, Vx = submanifold_components(M.manifold, p) + VXM, VXF = submanifold_components(M.manifold, X) + n = manifold_dimension(M.manifold) + get_coordinates!(M.manifold, view(Y, 1:n), px, VXM, B.data.base_basis) + get_coordinates!(M.fiber, view(Y, (n + 1):length(Y)), px, VXF, B.data.vec_basis) + return Y end -function get_coordinates!( - M::TangentSpaceAtPoint, - Y, - p, - X, - B::ManifoldsBase.all_uncached_bases{TangentSpaceType}, -) - return get_coordinates!(M.fiber.manifold, Y, M.point, X, B) + +function get_vector(M::VectorBundle, p, X, B::AbstractBasis) + n = manifold_dimension(M.manifold) + xp1 = submanifold_component(p, Val(1)) + return ProductRepr( + get_vector(M.manifold, xp1, X[1:n], B), + get_vector(M.fiber, xp1, X[(n + 1):end], B), + ) end -function get_vector!(M::VectorBundle, Y, p, X, B::DefaultOrthonormalBasis) +function get_vector!(M::VectorBundle, Y, p, X, B::AbstractBasis) n = manifold_dimension(M.manifold) xp1 = submanifold_component(p, Val(1)) get_vector!(M.manifold, submanifold_component(Y, Val(1)), xp1, X[1:n], B) get_vector!(M.fiber, submanifold_component(Y, Val(2)), xp1, X[(n + 1):end], B) return Y end + +function get_vector( + M::VectorBundle, + p, + X, + B::CachedBasis{𝔽,<:AbstractBasis{𝔽},<:VectorBundleBasisData}, +) where {𝔽} + n = manifold_dimension(M.manifold) + xp1 = submanifold_component(p, Val(1)) + return ProductRepr( + get_vector(M.manifold, xp1, X[1:n], B.data.base_basis), + get_vector(M.fiber, xp1, X[(n + 1):end], B.data.vec_basis), + ) +end + function get_vector!( M::VectorBundle, Y, @@ -457,45 +445,18 @@ function get_vector!( ) return Y end -function get_vector!( - M::TangentBundleFibers, - Y, - p, - X, - B::ManifoldsBase.all_uncached_bases{TangentSpaceType}, -) - return get_vector!(M.manifold, Y, p, X, B) + +function get_vector(M::TangentBundleFibers, p, X, B::AbstractBasis) + return get_vector(M.manifold, p, X, B) end -function get_vector!( - M::TangentSpaceAtPoint, - Y, - p, - X, - B::ManifoldsBase.all_uncached_bases{TangentSpaceType}, -) - return get_vector!(M.fiber.manifold, Y, M.point, X, B) +function get_vector(M::TangentSpaceAtPoint, p, X, B::AbstractBasis) + return get_vector(M.fiber.manifold, M.point, X, B) end -for BT in [ - DefaultBasis, - DefaultOrthogonalBasis, - DefaultOrthonormalBasis, - ProjectedOrthonormalBasis{:gram_schmidt,ℝ}, - ProjectedOrthonormalBasis{:svd,ℝ}, - VeeOrthogonalBasis, -] - eval( - quote - @invoke_maker 5 AbstractBasis{<:Any,TangentSpaceType} get_vector!( - M::TangentSpaceAtPoint, - Y, - p, - X, - B::$BT, - ) - end, - ) + +function get_vector!(M::TangentBundleFibers, Y, p, X, B::AbstractBasis) + return get_vector!(M.manifold, Y, p, X, B) end -function get_vector!(M::TangentSpaceAtPoint, Y, p, X, B::CachedBasis) +function get_vector!(M::TangentSpaceAtPoint, Y, p, X, B::AbstractBasis) return get_vector!(M.fiber.manifold, Y, M.point, X, B) end @@ -516,7 +477,6 @@ function get_vectors( end return vs end - function get_vectors(M::VectorBundleFibers, p, B::CachedBasis) return get_vectors(M.manifold, p, B) end @@ -675,6 +635,19 @@ LinearAlgebra.norm(B::VectorBundleFibers, p, X) = sqrt(inner(B, p, X, X)) LinearAlgebra.norm(B::VectorBundleFibers{<:TangentSpaceType}, p, X) = norm(B.manifold, p, X) LinearAlgebra.norm(M::VectorSpaceAtPoint, p, X) = norm(M.fiber.manifold, M.point, X) +function parallel_transport_to!(M::VectorBundle, Y, p, X, q) + px, pVx = submanifold_components(M.manifold, p) + VXM, VXF = submanifold_components(M.manifold, X) + VYM, VYF = submanifold_components(M.manifold, Y) + qx, qVx = submanifold_components(M.manifold, q) + parallel_transport_to!(M.manifold, VYM, px, VXM, qx) + parallel_transport_to!(M.manifold, VYF, px, VXF, qx) + return Y +end +function parallel_transport_to!(M::TangentSpaceAtPoint, Y, p, X, q) + return copyto!(M.fiber.manifold, Y, p, X) +end + @doc raw""" project(B::VectorBundle, p) @@ -826,11 +799,6 @@ end function representation_size(B::VectorBundleFibers{<:TCoTSpaceType}) return representation_size(B.manifold) end -function representation_size(B::VectorBundle) - len_manifold = prod(representation_size(B.manifold)) - len_vs = prod(representation_size(B.fiber)) - return (len_manifold + len_vs,) -end function representation_size(B::TangentSpaceAtPoint) return representation_size(B.fiber.manifold) end @@ -937,6 +905,17 @@ Compute the vector transport the tangent vector `X`at `p` to `q` on the [`VectorBundle`](@ref) `M` using the [`VectorBundleVectorTransport`](@ref) `m`. """ vector_transport_to(::VectorBundle, ::Any, ::Any, ::Any, ::VectorBundleVectorTransport) + +function _vector_transport_to(M::VectorBundle, p, X, q, m::VectorBundleVectorTransport) + px, pVx = submanifold_components(M.manifold, p) + VXM, VXF = submanifold_components(M.manifold, X) + qx, qVx = submanifold_components(M.manifold, q) + return ProductRepr( + vector_transport_to(M.manifold, px, VXM, qx, m.method_point), + vector_transport_to(M.manifold, px, VXF, qx, m.method_vector), + ) +end + function vector_transport_to(M::VectorBundle, p, X, q) return vector_transport_to(M, p, X, q, M.vector_transport) end @@ -963,25 +942,15 @@ function vector_transport_to!( ) return vector_transport_to!(M, Y, p, X, q, VectorBundleVectorTransport(m, m)) end -for VT in ManifoldsBase.VECTOR_TRANSPORT_DISAMBIGUATION - eval( - quote - @invoke_maker 6 AbstractVectorTransportMethod vector_transport_to!( - M::TangentBundle, - Y, - p, - X, - q, - B::$VT, - ) - end, - ) -end -function vector_transport_to!(M::TangentSpaceAtPoint, Y, p, X, q) - return copyto!(Y, X) -end -function vector_transport_to!(M::TangentSpaceAtPoint, Y, p, X, q, ::ParallelTransport) - return copyto!(Y, X) +function vector_transport_to!( + M::TangentSpaceAtPoint, + Y, + p, + X, + q, + m::AbstractVectorTransportMethod, +) + return copyto!(M.fiber.manifold, Y, p, X) end """ diff --git a/src/nlsolve.jl b/src/nlsolve.jl index 957bac52a5..14e25d6eb5 100644 --- a/src/nlsolve.jl +++ b/src/nlsolve.jl @@ -1,41 +1,27 @@ @doc raw""" - inverse_retract(M, p, q method::NLsolveInverseRetraction; kwargs...) + inverse_retract(M, p, q method::NLSolveInverseRetraction; kwargs...) Approximate the inverse of the retraction specified by `method.retraction` from `p` with -respect to `q` on the [`AbstractManifold`](@ref) `M` using NLsolve. This inverse retraction is +respect to `q` on the [`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) `M` using NLsolve. This inverse retraction is not guaranteed to succeed and probably will not unless `q` is close to `p` and the initial guess `X0` is close. If the solver fails to converge, an [`OutOfInjectivityRadiusError`](@ref) is raised. -See [`NLsolveInverseRetraction`](@ref) for configurable parameters. +See [`NLSolveInverseRetraction`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/retractions.html#ManifoldsBase.NLSolveInverseRetraction) for configurable parameters. """ -inverse_retract(::AbstractManifold, p, q, ::NLsolveInverseRetraction; kwargs...) +inverse_retract(::AbstractManifold, p, q, ::NLSolveInverseRetraction; kwargs...) -function inverse_retract!( +function inverse_retract_nlsolve!( M::AbstractManifold, X, p, q, - method::NLsolveInverseRetraction; + m::NLSolveInverseRetraction; kwargs..., ) - X0 = method.X0 === nothing ? zero_vector(M, p) : method.X0 - res = _inverse_retract_nlsolve( - M, - p, - q, - method.retraction, - X0, - method.project_tangent, - method.project_point, - method.nlsolve_kwargs; - kwargs..., - ) - if !res.f_converged - @debug res - throw(OutOfInjectivityRadiusError()) - end + X0 = m.X0 === nothing ? zero_vector(M, p) : m.X0 + res = _inverse_retract_nlsolve(M, p, q, m; kwargs...) return copyto!(X, res.zero) end @@ -43,25 +29,23 @@ function _inverse_retract_nlsolve( M::AbstractManifold, p, q, - retraction, - X0, - project_tangent, - project_point, - nlsolve_kwargs; + m::NLSolveInverseRetraction; kwargs..., ) + X0 = m.X0 === nothing ? zero_vector(M, p) : m.X0 function f!(F, X) - project_tangent && project!(M, X, p, X) - retract!(M, F, p, project(M, p, X), retraction; kwargs...) - project_point && project!(M, q, q) + m.project_tangent && project!(M, X, p, X) + retract!(M, F, p, project(M, p, X), m.retraction; kwargs...) + m.project_point && project!(M, q, q) F .-= q return F end isdefined(Manifolds, :NLsolve) || - @warn "To use NLsolveInverseRetraction, NLsolve must be loaded using `using NLsolve`." - res = NLsolve.nlsolve(f!, X0; nlsolve_kwargs...) + @warn "To use NLSolveInverseRetraction, NLsolve must be loaded using `using NLsolve`." + res = NLsolve.nlsolve(f!, X0; m.nlsolve_kwargs...) + if !res.f_converged + @debug res + throw(OutOfInjectivityRadiusError()) + end return res end -function inverse_retract!(M::AbstractPowerManifold, X, q, p, m::NLsolveInverseRetraction) - return inverse_retract!(M, X, q, p, InversePowerRetraction(m)) -end diff --git a/src/product_representations.jl b/src/product_representations.jl index 6d7893ce38..f4ddc5410b 100644 --- a/src/product_representations.jl +++ b/src/product_representations.jl @@ -23,7 +23,7 @@ end Get the projected components of `p` on the submanifolds of `M`. The components are returned in a Tuple. """ submanifold_components(::Any...) -@inline submanifold_components(M::AbstractManifold, p) = submanifold_components(p) +@inline submanifold_components(::AbstractManifold, p) = submanifold_components(p) @inline submanifold_components(p) = p.parts @inline submanifold_components(p::ArrayPartition) = p.x diff --git a/src/statistics.jl b/src/statistics.jl index 785f4d8a53..05606ba48b 100644 --- a/src/statistics.jl +++ b/src/statistics.jl @@ -56,7 +56,7 @@ t_k &= \frac{w_k}{\sum_{i=1}^k w_i}\\ where $x_k$ are points, $w_k$ are weights, $ΞΌ_k$ is the $k$th estimate of the mean, and $Ξ³_x(y; t)$ is the point at time $t$ along the -[`shortest_geodesic`](@ref shortest_geodesic(::AbstractManifold, ::Any, ::Any, ::Real)) +[`shortest_geodesic`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/functions.html#ManifoldsBase.shortest_geodesic-Tuple{AbstractManifold,%20Any,%20Any}) between points $x,y ∈ \mathcal M$. The algorithm terminates when all $x_k$ have been considered. In the [`Euclidean`](@ref) case, this exactly computes the weighted mean. @@ -142,6 +142,35 @@ function Base.show(io::IO, method::GeodesicInterpolationWithinRadius) return print(io, "GeodesicInterpolationWithinRadius($(method.radius))") end +""" + default_estimation_method(M::AbstractManifold, f) + +Specify a default [`AbstractEstimationMethod`](@ref) for an `AbstractManifold` +for a function `f`, e.g. the `median` or the `mean`. + +Note that his function is decorated, so it can inherit from the embedding, for example for the +`IsEmbeddedSubmanifold` trait. +""" +default_estimation_method(M::AbstractManifold, f) + +for mf in [mean, median, cov, var, mean_and_std, mean_and_var] + @eval @trait_function default_estimation_method( + M::AbstractDecoratorManifold, + f::typeof($mf), + ) (no_empty,) + eval( + quote + function default_estimation_method( + ::TraitList{IsEmbeddedSubmanifold}, + M::AbstractDecoratorManifold, + f::typeof($mf), + ) + return default_estimation_method(get_embedding(M), f) + end + end, + ) +end + """ Statistics.cov( M::AbstractManifold, @@ -181,7 +210,7 @@ function Statistics.cov( tangent_space_covariance_estimator::CovarianceEstimator=SimpleCovariance(; corrected=true, ), - mean_estimation_method::AbstractEstimationMethod=GradientDescentEstimation(), + mean_estimation_method::AbstractEstimationMethod=default_estimation_method(M, cov), inverse_retraction_method::AbstractInverseRetractionMethod=default_inverse_retraction_method( M, ), @@ -197,11 +226,16 @@ function Statistics.cov( ) end +function default_estimation_method(::EmptyTrait, ::AbstractDecoratorManifold, ::typeof(cov)) + return GradientDescentEstimation() +end +default_estimation_method(::AbstractManifold, ::typeof(cov)) = GradientDescentEstimation() + @doc raw""" mean(M::AbstractManifold, x::AbstractVector[, w::AbstractWeights]; kwargs...) Compute the (optionally weighted) Riemannian center of mass also known as -Karcher mean of the vector `x` of points on the [`AbstractManifold`](@ref) `M`, defined +Karcher mean of the vector `x` of points on the [`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) `M`, defined as the point that satisfies the minimizer ````math \argmin_{y ∈ \mathcal M} \frac{1}{2 \sum_{i=1}^n w_i} \sum_{i=1}^n w_i\mathrm{d}_{\mathcal M}^2(y,x_i), @@ -213,7 +247,7 @@ In the general case, the [`GradientDescentEstimation`](@ref) is used to compute M::AbstractManifold, x::AbstractVector, [w::AbstractWeights,] - method::AbstractEstimationMethod; + method::AbstractEstimationMethod=default_estimation_method(M); kwargs..., ) @@ -226,8 +260,8 @@ Compute the mean using the specified `method`. method::GradientDescentEstimation; p0=x[1], stop_iter=100, - retraction::AbstractRetractionMethod = ExponentialRetraction(), - inverse_retraction::AbstractInverseRetractionMethod = LogarithmicInverseRetraction(), + retraction::AbstractRetractionMethod = default_retraction_method(M), + inverse_retraction::AbstractInverseRetractionMethod = default_retraction_method(M), kwargs..., ) @@ -267,23 +301,28 @@ mean(::AbstractManifold, ::Any...) function Statistics.mean( M::AbstractManifold, x::AbstractVector, - method::AbstractEstimationMethod...; + method::AbstractEstimationMethod=default_estimation_method(M, mean); kwargs..., ) y = allocate_result(M, mean, x[1]) - return mean!(M, y, x, method...; kwargs...) + return mean!(M, y, x, method; kwargs...) end function Statistics.mean( M::AbstractManifold, x::AbstractVector, w::AbstractVector, - method::AbstractEstimationMethod...; + method::AbstractEstimationMethod=default_estimation_method(M, mean); kwargs..., ) y = allocate_result(M, mean, x[1]) - return mean!(M, y, x, w, method...; kwargs...) + return mean!(M, y, x, w, method; kwargs...) end +function default_estimation_method(::EmptyTrait, ::AbstractManifold, ::typeof(mean)) + return GradientDescentEstimation() +end; +default_estimation_method(::AbstractManifold, ::typeof(mean)) = GradientDescentEstimation(); + @doc raw""" mean!(M::AbstractManifold, y, x::AbstractVector[, w::AbstractWeights]; kwargs...) mean!( @@ -298,24 +337,16 @@ end Compute the [`mean`](@ref mean(::AbstractManifold, args...)) in-place in `y`. """ mean!(::AbstractManifold, ::Any...) + function Statistics.mean!( M::AbstractManifold, y, x::AbstractVector, - method::AbstractEstimationMethod...; + method::AbstractEstimationMethod=default_estimation_method(M, mean); kwargs..., ) w = _unit_weights(length(x)) - return mean!(M, y, x, w, method...; kwargs...) -end -function Statistics.mean!( - M::AbstractManifold, - y, - x::AbstractVector, - w::AbstractVector; - kwargs..., -) - return mean!(M, y, x, w, GradientDescentEstimation(); kwargs...) + return mean!(M, y, x, w, method; kwargs...) end function Statistics.mean!( M::AbstractManifold, @@ -325,8 +356,10 @@ function Statistics.mean!( ::GradientDescentEstimation; p0=x[1], stop_iter=100, - retraction::AbstractRetractionMethod=ExponentialRetraction(), - inverse_retraction::AbstractInverseRetractionMethod=LogarithmicInverseRetraction(), + retraction::AbstractRetractionMethod=default_retraction_method(M), + inverse_retraction::AbstractInverseRetractionMethod=default_inverse_retraction_method( + M, + ), kwargs..., ) n = length(x) @@ -364,8 +397,8 @@ end [w::AbstractWeights,] method::GeodesicInterpolation; shuffle_rng=nothing, - retraction::AbstractRetractionMethod = ExponentialRetraction(), - inverse_retraction::AbstractInverseRetractionMethod = LogarithmicInverseRetraction(), + retraction::AbstractRetractionMethod = default_retraction_method(M), + inverse_retraction::AbstractInverseRetractionMethod = default_inverse_retraction_method(M), kwargs..., ) @@ -388,8 +421,10 @@ function Statistics.mean!( w::AbstractVector, ::GeodesicInterpolation; shuffle_rng::Union{AbstractRNG,Nothing}=nothing, - retraction::AbstractRetractionMethod=ExponentialRetraction(), - inverse_retraction::AbstractInverseRetractionMethod=LogarithmicInverseRetraction(), + retraction::AbstractRetractionMethod=default_retraction_method(M), + inverse_retraction::AbstractInverseRetractionMethod=default_inverse_retraction_method( + M, + ), kwargs..., ) n = length(x) @@ -469,8 +504,10 @@ function Statistics.mean!( ::CyclicProximalPointEstimation; p0=x[1], stop_iter=1000000, - retraction::AbstractRetractionMethod=ExponentialRetraction(), - inverse_retraction::AbstractInverseRetractionMethod=LogarithmicInverseRetraction(), + retraction::AbstractRetractionMethod=default_retraction_method(M), + inverse_retraction::AbstractInverseRetractionMethod=default_inverse_retraction_method( + M, + ), kwargs..., ) n = length(x) @@ -500,31 +537,6 @@ function Statistics.mean!( return q end -@decorator_transparent_signature Statistics.mean!( - M::AbstractDecoratorManifold, - y, - x::AbstractVector, - w::AbstractVector; - kwargs..., -) - -function decorator_transparent_dispatch( - ::typeof(mean!), - ::AbstractEmbeddedManifold, - args...; - kwargs..., -) - return Val(:parent) -end -function decorator_transparent_dispatch( - ::typeof(mean!), - ::AbstractEmbeddedManifold{𝔽,<:TransparentIsometricEmbedding}, - args...; - kwargs..., -) where {𝔽} - return Val(:transparent) -end - """ mean( M::AbstractManifold, @@ -556,7 +568,10 @@ function Statistics.mean!( x::AbstractVector, w::AbstractVector, ::ExtrinsicEstimation; - extrinsic_method::AbstractEstimationMethod=GeodesicInterpolation(), + extrinsic_method::AbstractEstimationMethod=default_estimation_method( + get_embedding(M), + mean, + ), kwargs..., ) embedded_x = map(p -> embed(M, p), x) @@ -576,7 +591,7 @@ end ) Compute the (optionally weighted) Riemannian median of the vector `x` of points on the -[`AbstractManifold`](@ref) `M`, defined as the point that satisfies the minimizer +[`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) `M`, defined as the point that satisfies the minimizer ````math \argmin_{y ∈ \mathcal M} \frac{1}{\sum_{i=1}^n w_i} \sum_{i=1}^n w_i\mathrm{d}_{\mathcal M}(y,x_i), ```` @@ -590,6 +605,17 @@ Compute the median using the specified `method`. """ Statistics.median(::AbstractManifold, ::Any...) +function default_estimation_method( + ::EmptyTrait, + ::AbstractDecoratorManifold, + ::typeof(median), +) + return CyclicProximalPointEstimation() +end +function default_estimation_method(::AbstractManifold, ::typeof(median)) + return CyclicProximalPointEstimation() +end + """ median( M::AbstractManifold, @@ -598,8 +624,8 @@ Statistics.median(::AbstractManifold, ::Any...) method::CyclicProximalPointEstimation; p0=x[1], stop_iter=1000000, - retraction::AbstractRetractionMethod = ExponentialRetraction(), - inverse_retraction::AbstractInverseRetractionMethod = LogarithmicInverseRetraction(), + retraction::AbstractRetractionMethod = default_retraction_method(M), + inverse_retraction::AbstractInverseRetractionMethod = default_inverse_retraction_method(M), kwargs..., ) @@ -663,8 +689,8 @@ Statistics.median( Ξ± = 1.0, p0=x[1], stop_iter=2000, - retraction::AbstractRetractionMethod = ExponentialRetraction(), - inverse_retraction::AbstractInverseRetractionMethod = LogarithmicInverseRetraction(), + retraction::AbstractRetractionMethod = default_retraction_method(M), + inverse_retraction::AbstractInverseRetractionMethod = default_inverse_retraction_method(M), kwargs..., ) @@ -719,21 +745,21 @@ Statistics.median( function Statistics.median( M::AbstractManifold, x::AbstractVector, - method::AbstractEstimationMethod...; + method::AbstractEstimationMethod=default_estimation_method(M, median); kwargs..., ) y = allocate_result(M, median, x[1]) - return median!(M, y, x, method...; kwargs...) + return median!(M, y, x, method; kwargs...) end function Statistics.median( M::AbstractManifold, x::AbstractVector, w::AbstractVector, - method::AbstractEstimationMethod...; + method::AbstractEstimationMethod=default_estimation_method(M, median); kwargs..., ) y = allocate_result(M, median, x[1]) - return median!(M, y, x, w, method...; kwargs...) + return median!(M, y, x, w, method; kwargs...) end @doc raw""" @@ -754,20 +780,11 @@ function Statistics.median!( M::AbstractManifold, q, x::AbstractVector, - method::AbstractEstimationMethod...; + method::AbstractEstimationMethod=default_estimation_method(M, median); kwargs..., ) w = _unit_weights(length(x)) - return median!(M, q, x, w, method...; kwargs...) -end -function Statistics.median!( - M::AbstractManifold, - y, - x::AbstractVector, - w::AbstractVector; - kwargs..., -) - return median!(M, y, x, w, CyclicProximalPointEstimation(); kwargs...) + return median!(M, q, x, w, method; kwargs...) end function Statistics.median!( M::AbstractManifold, @@ -777,8 +794,10 @@ function Statistics.median!( ::CyclicProximalPointEstimation; p0=x[1], stop_iter=1000000, - retraction::AbstractRetractionMethod=ExponentialRetraction(), - inverse_retraction::AbstractInverseRetractionMethod=LogarithmicInverseRetraction(), + retraction::AbstractRetractionMethod=default_retraction_method(M), + inverse_retraction::AbstractInverseRetractionMethod=default_inverse_retraction_method( + M, + ), kwargs..., ) n = length(x) @@ -814,7 +833,10 @@ function Statistics.median!( x::AbstractVector, w::AbstractVector, ::ExtrinsicEstimation; - extrinsic_method::AbstractEstimationMethod=CyclicProximalPointEstimation(), + extrinsic_method::AbstractEstimationMethod=default_estimation_method( + get_embedding(M), + median, + ), kwargs..., ) embedded_x = map(p -> embed(M, p), x) @@ -832,8 +854,10 @@ function Statistics.median!( p0=x[1], stop_iter=2000, Ξ±=1.0, - retraction::AbstractRetractionMethod=ExponentialRetraction(), - inverse_retraction::AbstractInverseRetractionMethod=LogarithmicInverseRetraction(), + retraction::AbstractRetractionMethod=default_retraction_method(M), + inverse_retraction::AbstractInverseRetractionMethod=default_inverse_retraction_method( + M, + ), kwargs..., ) n = length(x) @@ -866,37 +890,12 @@ function Statistics.median!( return q end -@decorator_transparent_signature Statistics.median!( - M::AbstractDecoratorManifold, - y, - x::AbstractVector, - w::AbstractVector; - kwargs..., -) - -function decorator_transparent_dispatch( - ::typeof(median!), - ::AbstractEmbeddedManifold, - args...; - kwargs..., -) - return Val(:parent) -end -function decorator_transparent_dispatch( - ::typeof(median!), - ::AbstractEmbeddedManifold{𝔽,<:TransparentIsometricEmbedding}, - args...; - kwargs..., -) where {𝔽} - return Val(:transparent) -end - @doc raw""" var(M, x, m=mean(M, x); corrected=true) var(M, x, w::AbstractWeights, m=mean(M, x, w); corrected=false) compute the (optionally weighted) variance of a `Vector` `x` of `n` data points -on the [`AbstractManifold`](@ref) `M`, i.e. +on the [`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) `M`, i.e. ````math \frac{1}{c} \sum_{i=1}^n w_i d_{\mathcal M}^2 (x_i,m), @@ -944,7 +943,7 @@ end std(M, x, w::AbstractWeights, m=mean(M, x, w); corrected=false, kwargs...) compute the optionally weighted standard deviation of a `Vector` `x` of `n` data -points on the [`AbstractManifold`](@ref) `M`, i.e. +points on the [`AbstractManifold`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.AbstractManifold) `M`, i.e. ````math \sqrt{\frac{1}{c} \sum_{i=1}^n w_i d_{\mathcal M}^2 (x_i,m)}, @@ -979,24 +978,34 @@ function StatsBase.mean_and_var( M::AbstractManifold, x::AbstractVector, w::AbstractWeights, - method::AbstractEstimationMethod...; + method::AbstractEstimationMethod=default_estimation_method(M, mean); corrected=false, kwargs..., ) - m = mean(M, x, w, method...; kwargs...) + m = mean(M, x, w, method; kwargs...) v = var(M, x, w, m; corrected=corrected) return m, v end function StatsBase.mean_and_var( M::AbstractManifold, x::AbstractVector, - method::AbstractEstimationMethod...; + method::AbstractEstimationMethod=default_estimation_method(M, mean_and_var); corrected=true, kwargs..., ) n = length(x) w = _unit_weights(n) - return mean_and_var(M, x, w, method...; corrected=corrected, kwargs...) + return mean_and_var(M, x, w, method; corrected=corrected, kwargs...) +end +function default_estimation_method( + ::EmptyTrait, + M::AbstractDecoratorManifold, + ::typeof(mean_and_var), +) + return default_estimation_method(M, mean) +end +function default_estimation_method(M::AbstractManifold, ::typeof(mean_and_var)) + return default_estimation_method(M, mean) end @doc raw""" @@ -1006,8 +1015,8 @@ end [w::AbstractWeights,] method::GeodesicInterpolation; shuffle_rng::Union{AbstractRNG,Nothing} = nothing, - retraction::AbstractRetractionMethod = ExponentialRetraction(), - inverse_retraction::AbstractInverseRetractionMethod = LogarithmicInverseRetraction(), + retraction::AbstractRetractionMethod = default_retraction_method(M), + inverse_retraction::AbstractInverseRetractionMethod = default_inverse_retraction_method(M), kwargs..., ) -> (mean, var) @@ -1032,8 +1041,10 @@ function StatsBase.mean_and_var( ::GeodesicInterpolation; shuffle_rng::Union{AbstractRNG,Nothing}=nothing, corrected=false, - retraction::AbstractRetractionMethod=ExponentialRetraction(), - inverse_retraction::AbstractInverseRetractionMethod=LogarithmicInverseRetraction(), + retraction::AbstractRetractionMethod=default_retraction_method(M), + inverse_retraction::AbstractInverseRetractionMethod=default_inverse_retraction_method( + M, + ), kwargs..., ) n = length(x) @@ -1137,6 +1148,9 @@ function StatsBase.mean_and_std(M::AbstractManifold, args...; kwargs...) m, v = mean_and_var(M, args...; kwargs...) return m, sqrt(v) end +function default_estimation_method(M::AbstractManifold, ::typeof(mean_and_std)) + return default_estimation_method(M, mean) +end """ moment(M::AbstractManifold, x::AbstractVector, k::Int[, w::AbstractWeights], m=mean(M, x[, w])) diff --git a/src/tests/tests_forwarddiff.jl b/src/tests/tests_forwarddiff.jl deleted file mode 100644 index 665233a79d..0000000000 --- a/src/tests/tests_forwarddiff.jl +++ /dev/null @@ -1,15 +0,0 @@ - -function test_forwarddiff(M::AbstractManifold, pts, tv) - return for (p, X) in zip(pts, tv) - exp_f(t) = distance(M, p, exp(M, p, t[1] * X)) - d12 = norm(M, p, X) - for t in 0.1:0.1:0.9 - Test.@test d12 β‰ˆ ForwardDiff.derivative(exp_f, t) - end - - retract_f(t) = distance(M, p, retract(M, p, t[1] * X)) - for t in 0.1:0.1:0.9 - Test.@test ForwardDiff.derivative(retract_f, t) β‰₯ 0 - end - end -end diff --git a/src/tests/tests_general.jl b/src/tests/tests_general.jl index fd2c96379a..db8d9b91f1 100644 --- a/src/tests/tests_general.jl +++ b/src/tests/tests_general.jl @@ -19,21 +19,21 @@ that lie on it (contained in `pts`). # Arguments - `basis_has_specialized_diagonalizing_get = false`: if true, assumes that - [`DiagonalizingOrthonormalBasis`](@ref) given in `basis_types` has + `DiagonalizingOrthonormalBasis` given in `basis_types` has [`get_coordinates`](@ref) and [`get_vector`](@ref) that work without caching. - `basis_types_to_from = ()`: basis types that will be tested based on [`get_coordinates`](@ref) and [`get_vector`](@ref). -- `basis_types_vecs = ()` : basis types that will be tested based on [`get_vectors`](@ref). +- `basis_types_vecs = ()` : basis types that will be tested based on `get_vectors` - `default_inverse_retraction_method = ManifoldsBase.LogarithmicInverseRetraction()`: - default method for inverse retractions ([`log`](@ref)). + default method for inverse retractions (`log`. - `default_retraction_method = ManifoldsBase.ExponentialRetraction()`: default method for - retractions ([`exp`](@ref)). + retractions (`exp`). - `exp_log_atol_multiplier = 0`: change absolute tolerance of exp/log tests (0 use default, i.e. deactivate atol and use rtol). - `exp_log_rtol_multiplier = 1`: change the relative tolerance of exp/log tests (1 use default). This is deactivated if the `exp_log_atol_multiplier` is nonzero. - `expected_dimension_type = Integer`: expected type of value returned by - [`manifold_dimension`](@ref). + `manifold_dimension`. - `inverse_retraction_methods = []`: inverse retraction methods that will be tested. - `is_mutating = true`: whether mutating variants of functions should be tested. - `is_point_atol_multiplier = 0`: determines atol of `is_point` checks. @@ -52,8 +52,6 @@ that lie on it (contained in `pts`). - `retraction_methods = []`: retraction methods that will be tested. - `test_atlases = []`: Vector or tuple of atlases that should be tested. - `test_exp_log = true`: if true, check that [`exp`](@ref) is the inverse of [`log`](@ref). -- `test_forward_diff = true`: if true, automatic differentiation using - ForwardDiff is tested. - `test_injectivity_radius = true`: whether implementation of [`injectivity_radius`](@ref) should be tested. - `test_inplace = false` : if true check if inplace variants work if they are activated, @@ -66,8 +64,6 @@ that lie on it (contained in `pts`). - `test_project_point = false`: test projections onto the manifold. - `test_project_tangent = false` : test projections on tangent spaces. - `test_representation_size = true` : test repersentation size of points/tvectprs. -- `test_reverse_diff = true`: if true, automatic differentiation using - ReverseDiff is tested. - `test_tangent_vector_broadcasting = true` : test boradcasting operators on TangentSpace. - `test_vector_spaces = true` : test Vector bundle of this manifold. - `test_default_vector_transport = false` : test the default vector transport (usually @@ -106,18 +102,20 @@ function test_manifold( retraction_rtol_multiplier=1, test_atlases=(), test_exp_log=true, - test_forward_diff=true, test_is_tangent=true, test_injectivity_radius=true, test_inplace=false, test_musical_isomorphisms=false, test_mutating_rand=false, + parallel_transport=false, + parallel_transport_to=parallel_transport, + parallel_transport_along=parallel_transport, + parallel_transport_direction=parallel_transport, test_project_point=false, test_project_tangent=false, test_rand_point=false, test_rand_tvector=false, test_representation_size=true, - test_reverse_diff=true, test_riesz_representer=false, test_tangent_vector_broadcasting=true, test_default_vector_transport=false, @@ -218,6 +216,13 @@ function test_manifold( Test.@test isapprox(M, pts[2], exp(M, pts[1], X1); atol=atolp1p2, rtol=rtolp1p2) Test.@test isapprox(M, pts[1], exp(M, pts[1], X1, 0); atol=atolp1p2, rtol=rtolp1p2) Test.@test isapprox(M, pts[2], exp(M, pts[1], X1, 1); atol=atolp1p2, rtol=rtolp1p2) + if is_mutating + q2 = allocate(pts[1]) + exp!(M, q2, pts[1], X1) + Test.@test isapprox(M, pts[2], q2; atol=atolp1p2, rtol=rtolp1p2) + exp!(M, q2, pts[1], X1, 0) + Test.@test isapprox(M, pts[1], q2; atol=atolp1p2, rtol=rtolp1p2) + end if VERSION >= v"1.5" && isa(M, Union{Grassmann,GeneralizedStiefel}) # TODO: investigate why this is so imprecise on newer Julia versions on CI Test.@test isapprox( @@ -291,6 +296,15 @@ function test_manifold( end end + parallel_transport && test_parallel_transport( + M, + pts; + along=parallel_transport_along, + to=parallel_transport_to, + direction=parallel_transport_direction, + mutating=is_mutating, + ) + Test.@testset "(inverse &) retraction tests" begin for (p, X) in zip(pts, tv) epsx = find_eps(p) @@ -652,14 +666,6 @@ function test_manifold( end end - test_forward_diff && Test.@testset "ForwardDiff support" begin - test_forwarddiff(M, pts, tv) - end - - test_reverse_diff && Test.@testset "ReverseDiff support" begin - test_reversediff(M, pts, tv) - end - test_musical_isomorphisms && Test.@testset "Musical isomorphisms" begin if default_inverse_retraction_method !== nothing tv_m = inverse_retract(M, pts[1], pts[2], default_inverse_retraction_method) @@ -808,3 +814,77 @@ function test_manifold( end return nothing end + +""" + test_parallel_transport(M,P; along=false, to=true, diretion=true) + +Generic tests for parallel transport on `M`given at least two pointsin `P`. + +The single functions to transport `along` (a curve), `to` (a point) or (towards a) `direction` +are sub-tests that can be activated by the keywords arguemnts + +!!! Note +Since the interface to specify curves is not yet provided, the along keyword does not have an effect yet +""" +function test_parallel_transport( + M::AbstractManifold, + P, + Ξ=inverse_retract.( + Ref(M), + P[1:(end - 1)], + P[2:end], + Ref(default_inverse_retraction_method(M)), + ); + along=false, + to=true, + direction=true, + mutating=true, +) + length(P) < 2 && + error("The Parallel Transport test set requires at least 2 points in P") + Test.@testset "Test Parallel Transport" begin + along && @warn "parallel transport along test not yet implemented" + Test.@testset "To (a point)" begin # even with to =false this displays no tests + if to + for i in 1:(length(P) - 1) + p = P[i] + q = P[i + 1] + X = Ξ[i] + Y1 = parallel_transport_to(M, p, X, q) + if mutating + Y2 = similar(X) + parallel_transport_to!(M, Y2, p, X, q) + # test that mutating and allocating to the same + Test.@test isapprox(M, q, Y1, Y2) + parallel_transport_to!(M, Y2, q, Y1, p) + # Test that transporting there and back again yields the identity + Test.@test isapprox(M, q, X, Y2) + parallel_transport_to!(M, Y1, q, Y1, p) + # Test that inplace does not have side effects + else + Y1 = parallel_transport_to(M, q, Y1, p) + end + Test.@test isapprox(M, q, X, Y1) + end + end + end + Test.@testset "(Tangent Vector) Direction" begin + if direction + for i in 1:(length(P) - 1) + p = P[i] + X = Ξ[i] + Y1 = parallel_transport_direction(M, p, X, X) + q = exp(M, p, X) + if mutating + Y2 = similar(X) + parallel_transport_direction!(M, Y2, p, X, X) + # test that mutating and allocating to the same + Test.@test isapprox(M, q, Y1, Y2) + end + # Test that Y is a tangent vector at q + Test.@test is_vector(M, p, Y1, true) + end + end + end + end +end diff --git a/src/tests/tests_group.jl b/src/tests/tests_group.jl index faa8fd8e36..4ce3cb70fb 100644 --- a/src/tests/tests_group.jl +++ b/src/tests/tests_group.jl @@ -306,7 +306,7 @@ function test_group( Test.@test log_lie!(G, X, Identity(G)) === X g = allocate(g_pts[1]) Test.@test exp_lie!(G, g, X) === g - Test.@test is_identity(G, g; atol=atol) + Test.@test is_identity(G, g; atol=atol) || "is_identity($G, $g; atol=$atol)" end end @@ -406,7 +406,7 @@ function test_group( ) end end - if invariant_metric_dispatch(G, RightAction()) === Val(true) + if has_invariant_metric(G, RightAction()) Test.@testset "right-invariant" begin Test.@test has_approx_invariant_metric( G, @@ -513,7 +513,7 @@ function test_action( test_switch_direction=true, ) G = base_group(A) - M = g_manifold(A) + M = group_manifold(A) e = Identity(G) Test.@testset "Basic action properties" begin diff --git a/src/tests/tests_reversediff.jl b/src/tests/tests_reversediff.jl deleted file mode 100644 index 04a71f822c..0000000000 --- a/src/tests/tests_reversediff.jl +++ /dev/null @@ -1,14 +0,0 @@ -function test_reversediff(M::AbstractManifold, pts, tv) - return for (p, X) in zip(pts, tv) - exp_f(t) = distance(M, p, exp(M, p, t[1] * X)) - d12 = norm(M, p, X) - for t in 0.1:0.1:0.9 - Test.@test d12 β‰ˆ ReverseDiff.gradient(exp_f, [t])[1] - end - - retract_f(t) = distance(M, p, retract(M, p, t[1] * X)) - for t in 0.1:0.1:0.9 - Test.@test ReverseDiff.gradient(retract_f, [t])[1] β‰₯ 0 - end - end -end diff --git a/src/utils.jl b/src/utils.jl index 3c248bb718..6bf64c37d9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -137,7 +137,7 @@ mul!_safe(Y, A, B) = (Y === A || Y === B) ? copyto!(Y, A * B) : mul!(Y, A, B) realify(X::AbstractMatrix{T𝔽}, 𝔽::AbstractNumbers) -> Y::AbstractMatrix{<:Real} Given a matrix $X ∈ 𝔽^{n Γ— n}$, compute $Y ∈ ℝ^{m Γ— m}$, where $m = n \operatorname{dim}_𝔽$, -and $\operatorname{dim}_𝔽$ is the [`real_dimension`](@ref) of the number field $𝔽$, using +and $\operatorname{dim}_𝔽$ is the [`real_dimension`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.real_dimension-Tuple{ManifoldsBase.AbstractNumbers}) of the number field $𝔽$, using the map $Ο• \colon X ↦ Y$, that preserves the matrix product, so that for all $C,D ∈ 𝔽^{n Γ— n}$, ````math @@ -185,7 +185,7 @@ end unrealify!(X::AbstractMatrix{T𝔽}, Y::AbstractMatrix{<:Real}, 𝔽::AbstractNumbers[, n]) Given a real matrix $Y ∈ ℝ^{m Γ— m}$, where $m = n \operatorname{dim}_𝔽$, and -$\operatorname{dim}_𝔽$ is the [`real_dimension`](@ref) of the number field $𝔽$, compute +$\operatorname{dim}_𝔽$ is the [`real_dimension`](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#ManifoldsBase.real_dimension-Tuple{ManifoldsBase.AbstractNumbers}) of the number field $𝔽$, compute in-place its equivalent matrix $X ∈ 𝔽^{n Γ— n}$. Note that this function does not check that $Y$ has a valid structure to be un-realified. diff --git a/test/ambiguities.jl b/test/ambiguities.jl index 9ab3122546..150e811ccd 100644 --- a/test/ambiguities.jl +++ b/test/ambiguities.jl @@ -1,10 +1,11 @@ @testset "Ambiguities" begin - if VERSION.prerelease == () && !Sys.iswindows() && VERSION < v"1.7.0" + if VERSION.prerelease == () && !Sys.iswindows() && VERSION < v"1.8.0" mbs = Test.detect_ambiguities(ManifoldsBase) # Interims solution until we follow what was proposed in # https://discourse.julialang.org/t/avoid-ambiguities-with-individual-number-element-identity/62465/2 fmbs = filter(x -> !any(has_type_in_signature.(x, Identity)), mbs) - FMBS_LIMIT = 22 + FMBS_LIMIT = 15 + println("Number of ManifoldsBase.jl ambiguities: $(length(fmbs))") @test length(fmbs) <= FMBS_LIMIT if length(fmbs) > FMBS_LIMIT for amb in fmbs @@ -16,7 +17,8 @@ # Interims solution until we follow what was proposed in # https://discourse.julialang.org/t/avoid-ambiguities-with-individual-number-element-identity/62465/2 fms = filter(x -> !any(has_type_in_signature.(x, Identity)), ms) - FMS_LIMIT = 23 + FMS_LIMIT = 17 + println("Number of Manifolds.jl ambiguities: $(length(fms))") if length(fms) > FMS_LIMIT for amb in fms println(amb) diff --git a/test/approx_inverse_retraction.jl b/test/approx_inverse_retraction.jl index 83fa282874..7d9b40ea12 100644 --- a/test/approx_inverse_retraction.jl +++ b/test/approx_inverse_retraction.jl @@ -6,19 +6,19 @@ include("utils.jl") Random.seed!(10) @testset "approximate inverse retractions" begin - @testset "NLsolveInverseRetraction" begin + @testset "NLSolveInverseRetraction" begin @testset "constructor" begin X = randn(3) - @test NLsolveInverseRetraction <: ApproximateInverseRetraction - m1 = NLsolveInverseRetraction(ExponentialRetraction()) + @test NLSolveInverseRetraction <: ApproximateInverseRetraction + m1 = NLSolveInverseRetraction(ExponentialRetraction()) @test m1.retraction === ExponentialRetraction() @test m1.X0 === nothing @test !m1.project_tangent @test !m1.project_point @test isempty(m1.nlsolve_kwargs) - m2 = NLsolveInverseRetraction( + m2 = NLSolveInverseRetraction( PolarRetraction(), [1.0, 2.0, 3.0]; project_tangent=true, @@ -36,7 +36,7 @@ Random.seed!(10) p = [1.0, 2.0, 3.0] q = [4.0, 5.0, 6.0] retr_method = ExponentialRetraction() - inv_retr_method = NLsolveInverseRetraction(retr_method) + inv_retr_method = NLSolveInverseRetraction(retr_method) X = inverse_retract(M, p, q, inv_retr_method) @test is_vector(M, p, X) @test X isa Vector{Float64} @@ -47,7 +47,7 @@ Random.seed!(10) p = [[1.0, 2.0], [3.0, 4.0]] q = [[5.0, 6.0], [7.0, 8.0]] retr_method = ExponentialRetraction() - inv_retr_method = NLsolveInverseRetraction(retr_method) + inv_retr_method = NLSolveInverseRetraction(retr_method) X = inverse_retract(M, p, q, inv_retr_method) @test is_vector(M, p, X) @test X isa Vector{Vector{Float64}} @@ -61,7 +61,7 @@ Random.seed!(10) # vector must be nonzero to converge X0 = randn(3) .* eps() inv_retr_method = - NLsolveInverseRetraction(ProjectionRetraction(), X0; project_point=true) + NLSolveInverseRetraction(ProjectionRetraction(), X0; project_point=true) X = inverse_retract(M, p, q, inv_retr_method) @test is_vector(M, p, X; atol=1e-9) @test X β‰ˆ X_exp atol = 1e-8 @@ -73,17 +73,18 @@ Random.seed!(10) ) end - @testset "Circle(β„‚)" begin - M = Circle(β„‚) - p = [1.0 * im] - X = [p[1] * im * (Ο€ / 4)] - q = exp(M, p, X) - X_exp = log(M, p, q) - inv_retr_method = - NLsolveInverseRetraction(ExponentialRetraction(); project_point=true) - X = inverse_retract(M, p, q, inv_retr_method) - @test is_vector(M, p, X; atol=1e-8) - @test X β‰ˆ X_exp - end + # Requires https://github.com/JuliaNLSolvers/NLSolversBase.jl/pull/141 + # @testset "Circle(β„‚)" begin + # M = Circle(β„‚) + # p = fill(1.0 * im) + # X = fill(p[1] * im * (Ο€ / 4)) + # q = exp(M, p, X) + # X_exp = log(M, p, q) + # inv_retr_method = + # NLSolveInverseRetraction(ExponentialRetraction(); project_point=true) + # X = inverse_retract(M, p, q, inv_retr_method) + # @test is_vector(M, p, X; atol=1e-8) + # @test X β‰ˆ X_exp + # end end end diff --git a/test/differentiation.jl b/test/differentiation.jl index 71950b3229..c8d3b558b6 100644 --- a/test/differentiation.jl +++ b/test/differentiation.jl @@ -1,5 +1,6 @@ include("utils.jl") using Manifolds: + default_differential_backend, _derivative, _derivative!, differential, @@ -9,7 +10,16 @@ using Manifolds: _gradient, _gradient!, _jacobian, - _jacobian! + _jacobian!, + set_default_differential_backend! + +# differentiation +using Manifolds: + AbstractDiffBackend, + AbstractRiemannianDiffBackend, + ExplicitEmbeddedBackend, + TangentDiffBackend, + RiemannianProjectionBackend import Manifolds: gradient @@ -18,14 +28,13 @@ function Manifolds.gradient(::AbstractManifold, f, p, ::TestRiemannianBackend) return collect(1.0:length(p)) end -using FiniteDifferences, FiniteDiff +using FiniteDifferences using LinearAlgebra: Diagonal, dot @testset "Differentiation backend" begin fd51 = Manifolds.FiniteDifferencesBackend() @testset "default_differential_backend" begin - #ForwardDiff is loaded first in utils. - @test default_differential_backend() === Manifolds.ForwardDiffBackend() + @test default_differential_backend() isa Manifolds.FiniteDifferencesBackend @test length(fd51.method.grid) == 5 # check method order @@ -35,57 +44,6 @@ using LinearAlgebra: Diagonal, dot @test default_differential_backend() == fd71 end - using ForwardDiff - - fwd_diff = Manifolds.ForwardDiffBackend() - @testset "ForwardDiff" begin - @test default_differential_backend() isa Manifolds.FiniteDifferencesBackend - - @test set_default_differential_backend!(fwd_diff) == fwd_diff - @test default_differential_backend() == fwd_diff - @test set_default_differential_backend!(fd51) isa Manifolds.FiniteDifferencesBackend - @test default_differential_backend() isa Manifolds.FiniteDifferencesBackend - - set_default_differential_backend!(fwd_diff) - @test default_differential_backend() == fwd_diff - set_default_differential_backend!(fd51) - end - - using FiniteDiff - - finite_diff = Manifolds.FiniteDiffBackend() - @testset "FiniteDiff" begin - @test default_differential_backend() isa Manifolds.FiniteDifferencesBackend - - @test set_default_differential_backend!(finite_diff) == finite_diff - @test default_differential_backend() == finite_diff - @test set_default_differential_backend!(fd51) isa Manifolds.FiniteDifferencesBackend - @test default_differential_backend() isa Manifolds.FiniteDifferencesBackend - - set_default_differential_backend!(finite_diff) - @test default_differential_backend() == finite_diff - set_default_differential_backend!(fd51) - end - - using ReverseDiff - - reverse_diff = Manifolds.ReverseDiffBackend() - @testset "ReverseDiff" begin - @test default_differential_backend() isa Manifolds.FiniteDifferencesBackend - - @test set_default_differential_backend!(reverse_diff) == reverse_diff - @test default_differential_backend() == reverse_diff - @test set_default_differential_backend!(fd51) isa Manifolds.FiniteDifferencesBackend - @test default_differential_backend() isa Manifolds.FiniteDifferencesBackend - - set_default_differential_backend!(reverse_diff) - @test default_differential_backend() == reverse_diff - set_default_differential_backend!(fd51) - end - - using Zygote - zygote_diff = Manifolds.ZygoteDiffBackend() - @testset "gradient" begin set_default_differential_backend!(fd51) r2 = Euclidean(2) @@ -97,45 +55,30 @@ using LinearAlgebra: Diagonal, dot return y end f2(x) = 3 * x[1] * x[2] + x[2]^3 + @test _jacobian(c1, 0.0) β‰ˆ [1.0; 0.0] - @testset "Inference" begin - X = [-1.0, -1.0] - @test (@inferred _derivative(c1, 0.0, Manifolds.ForwardDiffBackend())) β‰ˆ - [1.0, 0.0] - @test (@inferred _derivative!(c1, X, 0.0, Manifolds.ForwardDiffBackend())) === X - @test X β‰ˆ [1.0, 0.0] - - @test (@inferred _derivative(c1, 0.0, finite_diff)) β‰ˆ [1.0, 0.0] - @test (@inferred _gradient(f1, [1.0, -1.0], finite_diff)) β‰ˆ [1.0, -2.0] - end - - @testset for backend in [fd51, fwd_diff, finite_diff] + @testset for backend in [fd51] set_default_differential_backend!(backend) @test _derivative(c1, 0.0) β‰ˆ [1.0, 0.0] X = [-1.0, -1.0] @test _derivative!(c1, X, 0.0) === X @test isapprox(X, [1.0, 0.0]) end - @testset for backend in [fd51, fwd_diff, finite_diff, reverse_diff, zygote_diff] + @testset for backend in [fd51] set_default_differential_backend!(backend) X = [-1.0, -1.0] @test _gradient(f1, [1.0, -1.0]) β‰ˆ [1.0, -2.0] @test _gradient!(f1, X, [1.0, -1.0]) === X @test X β‰ˆ [1.0, -2.0] end - @testset for backend in [finite_diff] - set_default_differential_backend!(backend) - X = [-0.0 -0.0] - @test _jacobian(f1, [1.0, -1.0]) β‰ˆ [1.0 -2.0] - # The following seems not to worf for :central, but it does for forward - fdf = Manifolds.FiniteDiffBackend(Val(:forward)) - @test_broken _jacobian!(f1!, X, [1.0, -1.0], fdf) === X - @test_broken X β‰ˆ [1.0 -2.0] - end set_default_differential_backend!(Manifolds.NoneDiffBackend()) - @testset for backend in [fd51, Manifolds.ForwardDiffBackend()] + @testset for backend in [fd51] @test _derivative(c1, 0.0, backend) β‰ˆ [1.0, 0.0] @test _gradient(f1, [1.0, -1.0], backend) β‰ˆ [1.0, -2.0] + @test _jacobian(c1, 0.0, backend) β‰ˆ [1.0; 0.0] + jac = [NaN; NaN] + _jacobian!(c1, jac, 0.0, backend) + @test jac β‰ˆ [1.0; 0.0] end set_default_differential_backend!(fd51) @@ -147,17 +90,18 @@ rb_onb_default = TangentDiffBackend( Manifolds.ExponentialRetraction(), Manifolds.LogarithmicInverseRetraction(), DefaultOrthonormalBasis(), + DefaultOrthonormalBasis(), ) rb_onb_fd51 = TangentDiffBackend(Manifolds.FiniteDifferencesBackend()) -rb_onb_fwd_diff = TangentDiffBackend(Manifolds.ForwardDiffBackend()) - -rb_onb_finite_diff = TangentDiffBackend(Manifolds.FiniteDiffBackend()) - rb_onb_default2 = TangentDiffBackend( default_differential_backend(); - basis=CachedBasis( + basis_arg=CachedBasis( + DefaultOrthonormalBasis(), + [[0.0, -1.0, 0.0], [sqrt(2) / 2, 0.0, -sqrt(2) / 2]], + ), + basis_val=CachedBasis( DefaultOrthonormalBasis(), [[0.0, -1.0, 0.0], [sqrt(2) / 2, 0.0, -sqrt(2) / 2]], ), @@ -177,7 +121,7 @@ rb_proj = Manifolds.RiemannianProjectionBackend(default_differential_backend()) differential!(s2, c1, X, Ο€ / 4, rb_onb_default) @test isapprox(s2, c1(Ο€ / 4), X, Xval) - @testset for backend in [rb_onb_fd51, rb_onb_fwd_diff, rb_onb_finite_diff] + @testset for backend in [rb_onb_fd51] @test isapprox(s2, c1(Ο€ / 4), differential(s2, c1, Ο€ / 4, backend), Xval) X = similar(p) differential!(s2, c1, X, Ο€ / 4, backend) @@ -209,6 +153,36 @@ end @test X == [1.0, 2.0, 3.0] end +@testset "Riemannian Jacobians" begin + s2 = Sphere(2) + f1(p) = p + + q = [sqrt(2) / 2, 0, sqrt(2) / 2] + X = similar(q) + @test isapprox( + s2, + q, + Manifolds.jacobian(s2, s2, f1, q, rb_onb_default), + [1.0 0.0; 0.0 1.0], + ) + + q2 = [1.0, 0.0, 0.0] + f2(X) = [0.0 0.0 0.0; 0.0 2.0 -1.0; 0.0 -3.0 1.0] * X + Tq2s2 = TangentSpaceAtPoint(s2, q2) + @test isapprox( + Manifolds.jacobian(Tq2s2, Tq2s2, f2, zero_vector(s2, q2), rb_onb_default), + [2.0 -1.0; -3.0 1.0], + ) + + q3 = [0.0, 1.0, 0.0] + f3(X) = [0.0 2.0 1.0; 0.0 0.0 0.0; 0.0 5.0 1.0] * X + Tq3s2 = TangentSpaceAtPoint(s2, q3) + @test isapprox( + Manifolds.jacobian(Tq2s2, Tq3s2, f3, zero_vector(s2, q2), rb_onb_default), + [-2.0 -1.0; 5.0 1.0], + ) +end + @testset "EmbeddedBackend" begin A = [1 0 0; 0 2 0; 0 0 3.0] p = 1 / sqrt(2.0) .* [1.0, 1.0, 0.0] diff --git a/test/groups/circle_group.jl b/test/groups/circle_group.jl index a6413edf60..23e5b1f7ad 100644 --- a/test/groups/circle_group.jl +++ b/test/groups/circle_group.jl @@ -1,34 +1,31 @@ include("../utils.jl") include("group_utils.jl") -using Manifolds: invariant_metric_dispatch, default_metric_dispatch - @testset "Circle group" begin G = CircleGroup() @test repr(G) == "CircleGroup()" @test base_manifold(G) === Circle{β„‚}() - @test (@inferred invariant_metric_dispatch(G, LeftAction())) === Val(true) - @test (@inferred invariant_metric_dispatch(G, RightAction())) === Val(true) - @test (@inferred Manifolds.biinvariant_metric_dispatch(G)) === Val(true) - @test (@inferred default_metric_dispatch(MetricManifold(G, EuclideanMetric()))) === - Val(true) @test has_invariant_metric(G, LeftAction()) @test has_invariant_metric(G, RightAction()) @test has_biinvariant_metric(G) @test is_default_metric(MetricManifold(G, EuclideanMetric())) - + @test is_group_manifold(G) @testset "identity overloads" begin ig = Identity(G) @test inv(G, ig) === ig - q = [1.0 * im] - X = [Complex(0.5)] + q = fill(1.0 * im) + X = fill(Complex(0.5)) @test translate_diff(G, ig, q, X) === X @test identity_element(G) === 1.0 @test identity_element(G, 1.0f0) === 1.0f0 - @test identity_element(G, [1.0f0]) == [1.0f0] + @test identity_element(G, fill(1.0f0)) == fill(1.0f0) + @test !is_point(G, Identity(AdditionOperation())) + ef = Identity(AdditionOperation()) + @test_throws DomainError is_point(G, ef, true) + @test_throws DomainError is_vector(G, ef, X, true; check_base_point=true) end @testset "scalar points" begin @@ -49,11 +46,11 @@ using Manifolds: invariant_metric_dispatch, default_metric_dispatch ) end - @testset "vector points" begin - pts = [[1.0 + 0.0im], [0.0 + 1.0im], [(1.0 + 1.0im) / √2]] - Xpts = [[0.0 + 0.5im], [0.0 - 1.5im]] - @test compose(G, pts[2], pts[1]) β‰ˆ pts[2] .* pts[1] - @test translate_diff(G, pts[2], pts[1], Xpts[1]) β‰ˆ pts[2] .* Xpts[1] + @testset "array points" begin + pts = [fill(1.0 + 0.0im), fill(0.0 + 1.0im), fill((1.0 + 1.0im) / √2)] + Xpts = [fill(0.0 + 0.5im), fill(0.0 - 1.5im)] + @test compose(G, pts[2], pts[1]) β‰ˆ fill(pts[2] .* pts[1]) + @test translate_diff(G, pts[2], pts[1], Xpts[1]) β‰ˆ fill(pts[2] .* Xpts[1]) test_group( G, pts, @@ -68,20 +65,18 @@ using Manifolds: invariant_metric_dispatch, default_metric_dispatch end @testset "Group forwards to decorated" begin - pts = [[1.0 + 0.0im], [0.0 + 1.0im], [(1.0 + 1.0im) / √2]] + pts = [1.0 + 0.0im, 0.0 + 1.0im, (1.0 + 1.0im) / √2] test_manifold( G, pts, - basis_types_to_from=(Manifolds.VeeOrthogonalBasis(), DefaultOrthonormalBasis()), - test_forward_diff=false, - test_reverse_diff=false, test_vector_spaces=false, test_project_tangent=true, test_musical_isomorphisms=false, test_default_vector_transport=true, - is_mutating=true, + is_mutating=false, exp_log_atol_multiplier=2.0, is_tangent_atol_multiplier=2.0, + mid_point12=nothing, ) end end @@ -92,11 +87,6 @@ end @test base_manifold(G) === Circle{ℝ}() - @test (@inferred invariant_metric_dispatch(G, LeftAction())) === Val(true) - @test (@inferred invariant_metric_dispatch(G, RightAction())) === Val(true) - @test (@inferred Manifolds.biinvariant_metric_dispatch(G)) === Val(true) - @test (@inferred default_metric_dispatch(MetricManifold(G, EuclideanMetric()))) === - Val(true) @test has_invariant_metric(G, LeftAction()) @test has_invariant_metric(G, RightAction()) @test has_biinvariant_metric(G) @@ -105,16 +95,16 @@ end @testset "identity overloads" begin ig = Identity(G) @test inv(G, ig) === ig - q = [0.0] - X = [0.5] + q = fill(0.0) + X = fill(0.5) @test translate_diff(G, ig, q, X) === X @test identity_element(G) === 0.0 @test identity_element(G, 1.0f0) === 0.0f0 - @test identity_element(G, [0.0f0]) == [0.0f0] + @test identity_element(G, fill(0.0f0)) == fill(0.0f0) end - @testset "scalar points" begin + @testset "points" begin pts = [1.0, 0.5, -3.0] Xpts = [-2.0, 0.5, 2.0] @test compose(G, pts[2], pts[1]) β‰ˆ pts[2] + pts[1] @@ -132,10 +122,10 @@ end ) end - @testset "vector points" begin - pts = [[1.0], [0.5], [-3.0]] - Xpts = [[-2.0], [0.5], [2.0]] - @test compose(G, pts[2], pts[1]) β‰ˆ pts[2] .+ pts[1] + @testset "array points" begin + pts = [fill(1.0), fill(0.5), fill(-3.0)] + Xpts = [fill(-2.0), fill(0.5), fill(2.0)] + @test compose(G, pts[2], pts[1]) β‰ˆ fill(pts[2] .+ pts[1]) @test translate_diff(G, pts[2], pts[1], Xpts[1]) β‰ˆ Xpts[1] test_group( G, @@ -150,21 +140,19 @@ end ) end - @testset "Group forwards to decorated" begin - pts = [[1.0], [0.5], [-3.0]] + @testset "Group forwards" begin + pts = [1.0, 0.5, -3.0] test_manifold( G, pts, - basis_types_to_from=(Manifolds.VeeOrthogonalBasis(), DefaultOrthonormalBasis()), - test_forward_diff=false, - test_reverse_diff=false, test_vector_spaces=false, test_project_tangent=true, test_musical_isomorphisms=false, test_default_vector_transport=true, - is_mutating=true, + is_mutating=false, exp_log_atol_multiplier=2.0, is_tangent_atol_multiplier=2.0, + mid_point12=nothing, ) end end diff --git a/test/groups/connections.jl b/test/groups/connections.jl index 5f60cebf5a..2b7a9ee09f 100644 --- a/test/groups/connections.jl +++ b/test/groups/connections.jl @@ -29,13 +29,22 @@ using Manifolds: connection end @testset "Parallel transport" begin + Y = similar(X) @test isapprox(SO3, q, X, vector_transport_to(SO3minus, SO3e, X, q)) + @test isapprox(SO3, q, X, vector_transport_to!(SO3minus, Y, SO3e, X, q)) @test isapprox(SO3, q, q * X / q, vector_transport_to(SO3plus, SO3e, X, q)) + @test isapprox(SO3, q, q * X / q, vector_transport_to!(SO3plus, Y, SO3e, X, q)) @test isapprox( SO3, q, vector_transport_to(SO3, e, X, q), vector_transport_to(SO3zero, SO3e, X, q), ) + @test isapprox( + SO3, + q, + vector_transport_to(SO3, e, X, q), + vector_transport_to!(SO3zero, Y, SO3e, X, q), + ) end end diff --git a/test/groups/general_linear.jl b/test/groups/general_linear.jl index e79e8678a9..7f35058d28 100644 --- a/test/groups/general_linear.jl +++ b/test/groups/general_linear.jl @@ -8,12 +8,12 @@ using NLsolve @test G === GeneralLinear(3, ℝ) @test repr(G) == "GeneralLinear(3, ℝ)" @test base_manifold(G) === GeneralLinear(3) - @test decorated_manifold(G) === Euclidean(3, 3) + @test get_embedding(G) === Euclidean(3, 3) @test number_system(G) === ℝ @test manifold_dimension(G) == 9 @test representation_size(G) == (3, 3) Gc = GeneralLinear(2, β„‚) - @test decorated_manifold(Gc) === Euclidean(2, 2; field=β„‚) + @test get_embedding(Gc) === Euclidean(2, 2; field=β„‚) @test repr(Gc) == "GeneralLinear(2, β„‚)" @test number_system(Gc) == β„‚ @test manifold_dimension(Gc) == 8 @@ -23,25 +23,6 @@ using NLsolve @test number_system(Gh) == ℍ @test manifold_dimension(Gh) == 4 * 16 @test representation_size(Gh) == (4, 4) - - @test (@inferred invariant_metric_dispatch(G, LeftAction())) === Val(true) - @test (@inferred invariant_metric_dispatch(G, RightAction())) === Val(false) - @test is_default_metric( - MetricManifold(G, InvariantMetric(EuclideanMetric(), LeftAction())), - ) === true - @test @inferred(Manifolds.default_metric_dispatch(G, EuclideanMetric())) === - Val(true) - @test @inferred( - Manifolds.default_metric_dispatch( - G, - InvariantMetric(EuclideanMetric(), LeftAction()), - ) - ) === Val(true) - @test @inferred( - Manifolds.default_metric_dispatch( - MetricManifold(G, InvariantMetric(EuclideanMetric(), LeftAction())), - ) - ) === Val(true) @test Manifolds.allocation_promotion_function(Gc, exp!, (1,)) === complex q = identity_element(G) @@ -80,14 +61,14 @@ using NLsolve @testset "Real" begin G = GeneralLinear(3) - @test_throws DomainError is_point(G, randn(2, 3), true) - @test_throws DomainError is_point(G, randn(2, 2), true) - @test_throws DomainError is_point(G, randn(ComplexF64, 3, 3), true) + @test_throws ManifoldDomainError is_point(G, randn(2, 3), true) + @test_throws ManifoldDomainError is_point(G, randn(2, 2), true) + @test_throws ManifoldDomainError is_point(G, randn(ComplexF64, 3, 3), true) @test_throws DomainError is_point(G, zeros(3, 3), true) @test_throws DomainError is_point(G, Float64[0 0 0; 0 1 1; 1 1 1], true) @test is_point(G, Float64[0 0 1; 0 1 1; 1 1 1], true) @test is_point(G, Identity(G), true) - @test_throws DomainError is_vector( + @test_throws ManifoldDomainError is_vector( G, Float64[0 1 1; 0 1 1; 1 0 0], randn(3, 3), @@ -131,8 +112,6 @@ using NLsolve test_manifold( G, gpts; - test_reverse_diff=false, - test_forward_diff=false, test_project_point=true, test_injectivity_radius=false, test_project_tangent=true, @@ -156,14 +135,14 @@ using NLsolve @testset "Complex" begin G = GeneralLinear(2, β„‚) - @test_throws DomainError is_point(G, randn(ComplexF64, 2, 3), true) - @test_throws DomainError is_point(G, randn(ComplexF64, 3, 3), true) + @test_throws ManifoldDomainError is_point(G, randn(ComplexF64, 2, 3), true) + @test_throws ManifoldDomainError is_point(G, randn(ComplexF64, 3, 3), true) @test_throws DomainError is_point(G, zeros(2, 2), true) @test_throws DomainError is_point(G, ComplexF64[1 im; 1 im], true) @test is_point(G, ComplexF64[1 1; im 1], true) @test is_point(G, Identity(G), true) - @test_throws DomainError is_point(G, Float64[0 0 0; 0 1 1; 1 1 1], true) - @test_throws DomainError is_vector( + @test_throws ManifoldDomainError is_point(G, Float64[0 0 0; 0 1 1; 1 1 1], true) + @test_throws ManifoldDomainError is_vector( G, ComplexF64[im im; im im], randn(ComplexF64, 2, 2), @@ -196,8 +175,6 @@ using NLsolve test_manifold( G, gpts; - test_reverse_diff=false, - test_forward_diff=false, test_project_point=true, test_injectivity_radius=false, test_project_tangent=true, diff --git a/test/groups/group_operation_action.jl b/test/groups/group_operation_action.jl index 9a92b2718e..492a910051 100644 --- a/test/groups/group_operation_action.jl +++ b/test/groups/group_operation_action.jl @@ -9,7 +9,7 @@ include("group_utils.jl") types = [Matrix{Float64}] - @test g_manifold(A_left) === G + @test group_manifold(A_left) === G @test base_group(A_left) == G @test repr(A_left) == "GroupOperationAction($(repr(G)), LeftAction())" @test repr(A_right) == "GroupOperationAction($(repr(G)), RightAction())" @@ -58,7 +58,7 @@ include("group_utils.jl") hat(M, p, [0.5, 0.5, 0.5]), ] - @test g_manifold(A_left) === G + @test group_manifold(A_left) === G @test base_group(A_left) == G @test repr(A_left) == "GroupOperationAction($(repr(G)), LeftAction())" @test repr(A_right) == "GroupOperationAction($(repr(G)), RightAction())" diff --git a/test/groups/group_utils.jl b/test/groups/group_utils.jl index 5f23904177..2a307ab276 100644 --- a/test/groups/group_utils.jl +++ b/test/groups/group_utils.jl @@ -2,13 +2,30 @@ struct NotImplementedOperation <: AbstractGroupOperation end struct NotImplementedManifold <: AbstractManifold{ℝ} end -struct NotImplementedGroupDecorator{M} <: - AbstractDecoratorManifold{ℝ,TransparentGroupDecoratorType} +struct NotImplementedGroupDecorator{𝔽,M<:AbstractManifold{𝔽}} <: + AbstractDecoratorManifold{𝔽} manifold::M end +function active_traits(f, M::NotImplementedGroupDecorator, args...) + return merge_traits(active_traits(f, M.manifold, args...), IsExplicitDecorator()) +end + +function Manifolds.decorated_manifold(M::NotImplementedGroupDecorator) + return M.manifold +end -struct DefaultTransparencyGroup{M,A<:AbstractGroupOperation} <: - AbstractGroupManifold{ℝ,A,DefaultGroupDecoratorType} +struct DefaultTransparencyGroup{𝔽,M<:AbstractManifold{𝔽},A<:AbstractGroupOperation} <: + AbstractDecoratorManifold{𝔽} manifold::M op::A end +function active_traits(f, M::DefaultTransparencyGroup, args...) + return merge_traits( + Manifolds.IsGroupManifold(M.op), + active_traits(f, M.manifold, args...), + ) +end + +function Manifolds.decorated_manifold(M::DefaultTransparencyGroup) + return M.manifold +end diff --git a/test/groups/groups_general.jl b/test/groups/groups_general.jl index f7ad6e3d2e..1e4aaa6e38 100644 --- a/test/groups/groups_general.jl +++ b/test/groups/groups_general.jl @@ -1,14 +1,10 @@ using StaticArrays: identity_perm -using Manifolds: decorator_transparent_dispatch using Base: decode_overlong include("../utils.jl") include("group_utils.jl") @testset "General group tests" begin - @test length(methods(has_biinvariant_metric)) == 1 - @test length(methods(has_invariant_metric)) == 1 - @test length(methods(has_biinvariant_metric)) == 1 @testset "Not implemented operation" begin G = GroupManifold(NotImplementedManifold(), NotImplementedOperation()) @test repr(G) == @@ -18,57 +14,56 @@ include("group_utils.jl") eg = Identity(G) @test repr(eg) === "Identity(NotImplementedOperation)" @test number_eltype(eg) == Bool + @test !is_group_manifold(NotImplementedManifold(), NotImplementedOperation()) @test is_identity(G, eg) # identity transparent @test_throws MethodError identity_element(G) # but for a NotImplOp there is no concrete id. @test isapprox(G, eg, eg) - @test_throws MethodError is_identity(G, 1) # same rror as before i.e. dispatch isapprox works - @test length(methods(is_group_decorator)) == 1 + @test !isapprox(G, Identity(AdditionOperation()), eg) + @test !isapprox(G, Identity(AdditionOperation()), eg) + @test !isapprox( + G, + Identity(AdditionOperation()), + Identity(MultiplicationOperation()), + ) + @test_throws DomainError is_point(G, Identity(AdditionOperation()), true) + @test is_point(G, eg) + @test_throws MethodError is_identity(G, 1) # same error as before i.e. dispatch isapprox works + @test Manifolds.check_size(G, eg) === nothing + @test Manifolds.check_size( + Manifolds.EmptyTrait(), + MetricManifold(NotImplementedManifold(), EuclideanMetric()), + eg, + ) isa DomainError @test Identity(NotImplementedOperation()) === eg @test Identity(NotImplementedOperation) === eg @test !is_point(G, Identity(AdditionOperation())) @test !isapprox(G, eg, Identity(AdditionOperation())) + @test !isapprox(G, Identity(AdditionOperation()), eg) - @test Manifolds.is_group_decorator(G) - @test Manifolds.decorator_group_dispatch(G) === Val{true}() - @test Manifolds.default_decorator_dispatch(G) === Val{false}() - @test !Manifolds.is_group_decorator(NotImplementedManifold()) - @test Manifolds.decorator_group_dispatch(NotImplementedManifold()) === Val{false}() - - @test Manifolds.decorator_transparent_dispatch(compose, G, p, p, p) === - Val{:intransparent}() - @test Manifolds.decorator_transparent_dispatch(compose!, G, p, p, p) === - Val{:intransparent}() - @test Manifolds.decorator_transparent_dispatch(exp_lie, G, p, p) === - Val{:intransparent}() - @test Manifolds.decorator_transparent_dispatch(log_lie, G, p, p) === - Val{:intransparent}() - @test Manifolds.decorator_transparent_dispatch(translate_diff!, G, X, p, p, X) === - Val{:intransparent}() - @test base_group(G) === G @test NotImplementedOperation(NotImplementedManifold()) === G @test (NotImplementedOperation())(NotImplementedManifold()) === G - @test_throws ErrorException base_group( - MetricManifold(Euclidean(3), EuclideanMetric()), - ) @test_throws ErrorException hat(Rotations(3), eg, [1, 2, 3]) + @test_throws ErrorException hat!(Rotations(3), randn(3, 3), eg, [1, 2, 3]) # If you force it, you get a not that readable MethodError @test_throws MethodError hat( GroupManifold(Rotations(3), NotImplementedOperation()), eg, [1, 2, 3], ) + @test_throws ErrorException vee(Rotations(3), eg, [1, 2, 3]) + @test_throws ErrorException vee!(Rotations(3), randn(3), eg, [1, 2, 3]) @test_throws MethodError vee( GroupManifold(Rotations(3), NotImplementedOperation()), eg, [1, 2, 3], ) - @test_throws ErrorException inv!(G, p, p) + @test_throws MethodError inv!(G, p, p) @test_throws MethodError inv!(G, p, eg) - @test_throws ErrorException inv(G, p) + @test_throws MethodError inv(G, p) # no function defined to return the identity array representation @test_throws MethodError copyto!(G, p, eg) @@ -90,50 +85,32 @@ include("group_utils.jl") @test_throws MethodError translate!(G, p, p, p, LeftAction()) @test_throws MethodError translate!(G, p, p, p, RightAction()) - @test_throws ErrorException inverse_translate(G, p, p) - @test_throws ErrorException inverse_translate(G, p, p, LeftAction()) - @test_throws ErrorException inverse_translate(G, p, p, RightAction()) - @test_throws ErrorException inverse_translate!(G, p, p, p) - @test_throws ErrorException inverse_translate!(G, p, p, p, LeftAction()) - @test_throws ErrorException inverse_translate!(G, p, p, p, RightAction()) + @test_throws MethodError inverse_translate(G, p, p) + @test_throws MethodError inverse_translate(G, p, p, LeftAction()) + @test_throws MethodError inverse_translate(G, p, p, RightAction()) + @test_throws MethodError inverse_translate!(G, p, p, p) + @test_throws MethodError inverse_translate!(G, p, p, p, LeftAction()) + @test_throws MethodError inverse_translate!(G, p, p, p, RightAction()) - @test_throws ErrorException translate_diff(G, p, p, X) - @test_throws ErrorException translate_diff(G, p, p, X, LeftAction()) - @test_throws ErrorException translate_diff(G, p, p, X, RightAction()) - @test_throws ErrorException translate_diff!(G, X, p, p, X) - @test_throws ErrorException translate_diff!(G, X, p, p, X, LeftAction()) - @test_throws ErrorException translate_diff!(G, X, p, p, X, RightAction()) + @test_throws MethodError translate_diff(G, p, p, X) + @test_throws MethodError translate_diff(G, p, p, X, LeftAction()) + @test_throws MethodError translate_diff(G, p, p, X, RightAction()) + @test_throws MethodError translate_diff!(G, X, p, p, X) + @test_throws MethodError translate_diff!(G, X, p, p, X, LeftAction()) + @test_throws MethodError translate_diff!(G, X, p, p, X, RightAction()) - @test_throws ErrorException inverse_translate_diff(G, p, p, X) - @test_throws ErrorException inverse_translate_diff(G, p, p, X, LeftAction()) - @test_throws ErrorException inverse_translate_diff(G, p, p, X, RightAction()) - @test_throws ErrorException inverse_translate_diff!(G, X, p, p, X) - @test_throws ErrorException inverse_translate_diff!(G, X, p, p, X, LeftAction()) - @test_throws ErrorException inverse_translate_diff!(G, X, p, p, X, RightAction()) + @test_throws MethodError inverse_translate_diff(G, p, p, X) + @test_throws MethodError inverse_translate_diff(G, p, p, X, LeftAction()) + @test_throws MethodError inverse_translate_diff(G, p, p, X, RightAction()) + @test_throws MethodError inverse_translate_diff!(G, X, p, p, X) + @test_throws MethodError inverse_translate_diff!(G, X, p, p, X, LeftAction()) + @test_throws MethodError inverse_translate_diff!(G, X, p, p, X, RightAction()) - @test_throws ErrorException exp_lie(G, X) - @test_throws ErrorException exp_lie!(G, p, X) + @test_throws MethodError exp_lie(G, X) + @test_throws MethodError exp_lie!(G, p, X) # no transparency error, but _log_lie missing @test_throws MethodError log_lie(G, p) @test_throws MethodError log_lie!(G, X, p) - - for f in [translate, translate!] - @test Manifolds.decorator_transparent_dispatch(f, G) === Val{:intransparent}() - end - for f in - [get_vector, get_coordinates, inverse_translate_diff!, inverse_translate_diff] - @test Manifolds.decorator_transparent_dispatch(f, G) === Val{:transparent}() - end - for f in [exp_lie!, exp_lie, log_lie, log_lie!] - @test Manifolds.decorator_transparent_dispatch(f, G, p, p) === - Val{:intransparent}() - end - @test Manifolds.decorator_transparent_dispatch(isapprox, G, eg, p) === - Val{:transparent}() - @test Manifolds.decorator_transparent_dispatch(isapprox, G, p, eg) === - Val{:transparent}() - @test Manifolds.decorator_transparent_dispatch(isapprox, G, eg, eg) === - Val{:transparent}() end @testset "Action direction" begin @@ -259,7 +236,7 @@ include("group_utils.jl") @test y β‰ˆ p X = [1.0 2.0; 3.0 4.0] @test exp_lie!(G, y, X) === y - @test_throws ErrorException exp_lie!(G, y, :a) + @test_throws MethodError exp_lie!(G, y, :a) @test y β‰ˆ exp(X) Y = allocate(X) @test log_lie!(G, Y, y) === Y @@ -283,26 +260,6 @@ include("group_utils.jl") @test e_mul * e_add === e_add @test mul!(e_mul, e_mul, e_mul) === e_mul end - - @testset "Transparency tests" begin - G = DefaultTransparencyGroup(Euclidean(3), AdditionOperation()) - p = ones(3) - q = 2 * p - X = zeros(3) - Y = similar(X) - for f in - [vector_transport_along!, vector_transport_direction!, vector_transport_to!] - @test ManifoldsBase.decorator_transparent_dispatch( - f, - G, - Y, - p, - X, - q, - ParallelTransport(), - ) == Val(:intransparent) - end - end end struct NotImplementedAction <: AbstractGroupAction{LeftAction} end @@ -314,8 +271,6 @@ struct NotImplementedAction <: AbstractGroupAction{LeftAction} end a = [1.0, 2.0] X = [1.0, 2.0] - @test_throws ErrorException base_group(A) - @test_throws ErrorException g_manifold(A) @test_throws ErrorException apply(A, a, p) @test_throws ErrorException apply!(A, p, a, p) @test_throws ErrorException inverse_apply(A, a, p) diff --git a/test/groups/metric.jl b/test/groups/metric.jl index f435a670f0..5dfd0888aa 100644 --- a/test/groups/metric.jl +++ b/test/groups/metric.jl @@ -2,20 +2,24 @@ include("../utils.jl") include("group_utils.jl") using OrdinaryDiffEq -import Manifolds: invariant_metric_dispatch, default_metric_dispatch, local_metric +import Manifolds: local_metric struct TestInvariantMetricBase <: AbstractMetric end -function local_metric( - ::MetricManifold{𝔽,<:AbstractManifold,TestInvariantMetricBase}, - ::Identity, - ::DefaultOrthonormalBasis, +function active_traits( + f, + M::MetricManifold{𝔽,<:AbstractManifold,TestInvariantMetricBase}, + args..., ) where {𝔽} - return Diagonal([1.0, 2.0, 3.0]) + return merge_traits( + HasLeftInvariantMetric(), + IsMetricManifold(), + active_traits(f, M.manifold, args...), + ) end function local_metric( - ::MetricManifold{𝔽,<:AbstractManifold,<:InvariantMetric{TestInvariantMetricBase}}, - p, + ::MetricManifold{𝔽,<:AbstractManifold,TestInvariantMetricBase}, + ::Identity, ::DefaultOrthonormalBasis, ) where {𝔽} return Diagonal([1.0, 2.0, 3.0]) @@ -23,13 +27,17 @@ end struct TestBiInvariantMetricBase <: AbstractMetric end -function invariant_metric_dispatch( - ::MetricManifold{𝔽,<:AbstractManifold,<:InvariantMetric{TestBiInvariantMetricBase}}, - ::ActionDirection, +function active_traits( + f, + M::MetricManifold{𝔽,<:AbstractManifold,TestBiInvariantMetricBase}, + args..., ) where {𝔽} - return Val(true) + return merge_traits( + HasBiinvariantMetric(), + IsMetricManifold(), + active_traits(f, M.manifold, args...), + ) end - function local_metric( ::MetricManifold{𝔽,<:AbstractManifold,<:TestBiInvariantMetricBase}, ::Identity, @@ -42,60 +50,12 @@ struct TestInvariantMetricManifold <: AbstractManifold{ℝ} end struct TestDefaultInvariantMetricManifold <: AbstractManifold{ℝ} end -function default_metric_dispatch( - ::MetricManifold{ - ℝ, - TestDefaultInvariantMetricManifold, - RightInvariantMetric{TestInvariantMetricBase}, - }, -) - return Val(true) +function ManifoldsBase.active_traits(f, ::TestDefaultInvariantMetricManifold, args...) + return merge_traits(HasRightInvariantMetric()) end -invariant_metric_dispatch(::TestDefaultInvariantMetricManifold, ::RightAction) = Val(true) - @testset "Invariant metrics" begin base_metric = TestInvariantMetricBase() - metric = InvariantMetric(base_metric) - lmetric = LeftInvariantMetric(base_metric) - rmetric = RightInvariantMetric(base_metric) - - @test InvariantMetric(base_metric) === InvariantMetric(base_metric, LeftAction()) - @test lmetric === InvariantMetric(base_metric, LeftAction()) - @test rmetric === InvariantMetric(base_metric, RightAction()) - @test sprint(show, lmetric) == "LeftInvariantMetric(TestInvariantMetricBase())" - @test sprint(show, rmetric) == "RightInvariantMetric(TestInvariantMetricBase())" - - @test direction(lmetric) === LeftAction() - @test direction(rmetric) === RightAction() - - G = MetricManifold(TestInvariantMetricManifold(), lmetric) - @test (@inferred invariant_metric_dispatch(G, LeftAction())) === Val(true) - @test (@inferred invariant_metric_dispatch(G, RightAction())) === Val(false) - - G = MetricManifold(TestInvariantMetricManifold(), rmetric) - @test (@inferred invariant_metric_dispatch(G, RightAction())) === Val(true) - @test (@inferred invariant_metric_dispatch(G, LeftAction())) === Val(false) - - @test Manifolds.invariant_metric_dispatch( - TestInvariantMetricManifold(), - RightAction(), - ) === Val{false}() - @test Manifolds.invariant_metric_dispatch( - TestInvariantMetricManifold(), - LeftAction(), - ) === Val{false}() - - G = MetricManifold( - TestDefaultInvariantMetricManifold(), - LeftInvariantMetric(TestInvariantMetricBase()), - ) - @test !is_default_metric(G) - G = MetricManifold( - TestDefaultInvariantMetricManifold(), - RightInvariantMetric(TestInvariantMetricBase()), - ) - @test is_default_metric(G) e = Matrix{Float64}(I, 3, 3) @testset "inner/norm" begin @@ -109,14 +69,9 @@ invariant_metric_dispatch(::TestDefaultInvariantMetricManifold, ::RightAction) = X = hat(SO3, Identity(SO3), fX.data) Y = hat(SO3, Identity(SO3), fY.data) - G = MetricManifold(SO3, lmetric) + G = MetricManifold(SO3, base_metric) @test inner(G, p, fX, fY) β‰ˆ dot(fX.data, Diagonal([1.0, 2.0, 3.0]) * fY.data) @test norm(G, p, fX) β‰ˆ sqrt(inner(G, p, fX, fX)) - - G = MetricManifold(SO3, rmetric) - @test_broken inner(G, p, fX, fY) β‰ˆ - dot(p * X * p', Diagonal([1.0, 2.0, 3.0]) * p * Y * p') - @test_broken norm(G, p, fX) β‰ˆ sqrt(inner(G, p, fX, fX)) end @testset "log/exp bi-invariant" begin @@ -126,21 +81,32 @@ invariant_metric_dispatch(::TestDefaultInvariantMetricManifold, ::RightAction) = p = exp(hat(SO3, pe, [1.0, 2.0, 3.0])) q = exp(hat(SO3, pe, [3.0, 4.0, 1.0])) X = hat(SO3, e, [2.0, 3.0, 4.0]) + Y = similar(X) + p2 = similar(p) - G = MetricManifold(SO3, InvariantMetric(TestBiInvariantMetricBase(), LeftAction())) + G = MetricManifold(SO3, TestBiInvariantMetricBase()) @test isapprox(SO3, exp(G, p, X), exp(SO3, p, X)) + exp!(G, p2, p, X) + @test isapprox(SO3, p2, exp(SO3, p, X)) @test isapprox(SO3, p, log(G, p, q), log(SO3, p, q); atol=1e-6) + log!(G, Y, p, q) + @test isapprox(SO3, p, Y, log(SO3, p, q); atol=1e-6) - G = MetricManifold(SO3, InvariantMetric(TestBiInvariantMetricBase(), RightAction())) + G = MetricManifold(SO3, TestBiInvariantMetricBase()) @test isapprox(SO3, exp(G, p, X), exp(SO3, p, X)) @test isapprox(SO3, p, log(G, p, q), log(SO3, p, q); atol=1e-6) + + @test is_group_manifold(G) + @test is_group_manifold(G, MultiplicationOperation()) + @test !isapprox(G, e, Identity(AdditionOperation())) + @test has_biinvariant_metric(G) + @test !has_biinvariant_metric(Sphere(2)) end - @testset "exp Ο„-invariant" begin - T3 = TranslationGroup(3) - p = [1.0, 2.0, 3.0] - X = [3.0, 5.0, 6.0] - @test_broken isapprox(T3, exp(MetricManifold(T3, lmetric), p, X), p .+ X) - @test_broken isapprox(T3, exp(MetricManifold(T3, rmetric), p, X), p .+ X) + @testset "invariant metric direction" begin + @test direction(HasRightInvariantMetric()) === RightAction() + @test direction(HasLeftInvariantMetric()) === LeftAction() + @test direction(HasRightInvariantMetric) === RightAction() + @test direction(HasLeftInvariantMetric) === LeftAction() end end diff --git a/test/groups/product_group.jl b/test/groups/product_group.jl index 6c7257264a..5da40d2e46 100644 --- a/test/groups/product_group.jl +++ b/test/groups/product_group.jl @@ -15,9 +15,6 @@ include("group_utils.jl") @test sprint(show, G) == "ProductGroup($(SOn), $(Tn))" @test sprint(show, "text/plain", G) == "ProductGroup with 2 subgroups:\n $(SOn)\n $(Tn)" x = Matrix{Float64}(I, 3, 3) - for f in [exp_lie!, log_lie!] - @test Manifolds.decorator_transparent_dispatch(f, G, x, x) === Val{:transparent}() - end t = Vector{Float64}.([1:2, 2:3, 3:4]) Ο‰ = [[1.0, 2.0, 3.0], [3.0, 2.0, 1.0], [1.0, 3.0, 2.0]] tuple_pts = [(exp(Rn, x, hat(Rn, x, Ο‰i)), ti) for (Ο‰i, ti) in zip(Ο‰, t)] diff --git a/test/groups/rotation_action.jl b/test/groups/rotation_action.jl index 8ef5649f5b..cb23130987 100644 --- a/test/groups/rotation_action.jl +++ b/test/groups/rotation_action.jl @@ -16,7 +16,7 @@ include("group_utils.jl") types_m = [Vector{Float64}] - @test g_manifold(A_left) == Euclidean(2) + @test group_manifold(A_left) == Euclidean(2) @test base_group(A_left) == G @test isa(A_left, AbstractGroupAction{LeftAction}) @test base_manifold(G) == M @@ -64,7 +64,7 @@ end types_m = [Vector{Float64}] - @test g_manifold(A) == Euclidean(3) + @test group_manifold(A) == Euclidean(3) @test base_group(A) == G @test isa(A, AbstractGroupAction{LeftAction}) @test base_manifold(G) == M diff --git a/test/groups/special_euclidean.jl b/test/groups/special_euclidean.jl index 237e5c532b..3ff00aa737 100644 --- a/test/groups/special_euclidean.jl +++ b/test/groups/special_euclidean.jl @@ -102,7 +102,7 @@ Random.seed!(10) X2[1:n, 1:n] .= X[1:n, 1:n] X2[1:n, end] .= X[1:n, end] X2[end, end] = X[end, end] - @test_throws CompositeManifoldError is_vector(G, p, X2, true) + @test_throws DomainError is_vector(G, p, X2, true) p[n + 1, n + 1] = 0.1 @test_throws DomainError is_point(G, p, true) p2 = zeros(n + 2, n + 2) @@ -110,7 +110,7 @@ Random.seed!(10) p2[1:n, 1:n] .= p[1:n, 1:n] p2[1:n, end] .= p[1:n, end] p2[end, end] = p[end, end] - @test_throws CompositeManifoldError is_point(G, p2, true) + @test_throws DomainError is_point(G, p2, true) # exp/log_lie for ProductGroup on arrays X = copy(G, p, X_pts[1]) p3 = exp_lie(G, X) @@ -133,18 +133,26 @@ Random.seed!(10) e = Identity(G) Xe = log_lie(G, p) Xc = vee(G, e, Xe) - @test_throws ErrorException vee(M, e, Xe) + @test_throws MethodError vee(M, e, Xe) w = similar(Xc) vee!(G, w, e, Xe) @test isapprox(Xc, w) - @test_throws ErrorException vee!(M, w, e, Xe) + @test_throws MethodError vee!(M, w, e, Xe) + + w = similar(Xc) + vee!(G, w, identity_element(G), Xe) + @test isapprox(Xc, w) Ye = hat(G, e, Xc) - @test_throws ErrorException hat(M, e, Xc) + @test_throws MethodError hat(M, e, Xc) isapprox(G, e, Xe, Ye) Ye2 = copy(G, p, X) hat!(G, Ye2, e, Xc) - @test_throws ErrorException hat!(M, Ye, e, Xc) + @test_throws MethodError hat!(M, Ye, e, Xc) + @test isapprox(G, e, Ye, Ye2) + + Ye2 = copy(G, p, X) + hat!(G, Ye2, identity_element(G), Xc) @test isapprox(G, e, Ye, Ye2) end end diff --git a/test/groups/special_linear.jl b/test/groups/special_linear.jl index 75f9223776..7e30722dab 100644 --- a/test/groups/special_linear.jl +++ b/test/groups/special_linear.jl @@ -8,12 +8,12 @@ using NLsolve @test G === SpecialLinear(3, ℝ) @test repr(G) == "SpecialLinear(3, ℝ)" @test base_manifold(G) === SpecialLinear(3) - @test decorated_manifold(G) == GeneralLinear(3) + @test get_embedding(G) == GeneralLinear(3) @test number_system(G) === ℝ @test manifold_dimension(G) == 8 @test representation_size(G) == (3, 3) Gc = SpecialLinear(2, β„‚) - @test decorated_manifold(Gc) == GeneralLinear(2, β„‚) + @test get_embedding(Gc) == GeneralLinear(2, β„‚) @test repr(Gc) == "SpecialLinear(2, β„‚)" @test number_system(Gc) == β„‚ @test manifold_dimension(Gc) == 6 @@ -24,38 +24,20 @@ using NLsolve @test manifold_dimension(Gh) == 4 * 15 @test representation_size(Gh) == (4, 4) - @test (@inferred invariant_metric_dispatch(G, LeftAction())) === Val(true) - @test (@inferred invariant_metric_dispatch(G, RightAction())) === Val(false) - @test is_default_metric( - MetricManifold(G, InvariantMetric(EuclideanMetric(), LeftAction())), - ) === true - @test @inferred(Manifolds.default_metric_dispatch(G, EuclideanMetric())) === - Val(true) - @test @inferred( - Manifolds.default_metric_dispatch( - G, - InvariantMetric(EuclideanMetric(), LeftAction()), - ) - ) === Val(true) - @test @inferred( - Manifolds.default_metric_dispatch( - MetricManifold(G, InvariantMetric(EuclideanMetric(), LeftAction())), - ) - ) === Val(true) @test Manifolds.allocation_promotion_function(Gc, exp!, (1,)) === complex end @testset "Real" begin G = SpecialLinear(3) - @test_throws DomainError is_point(G, randn(2, 3), true) - @test_throws DomainError is_point(G, Float64[2 1; 1 1], true) - @test_throws DomainError is_point(G, [1 0 im; im 0 0; 0 -1 0], true) - @test_throws DomainError is_point(G, zeros(3, 3), true) + @test_throws ManifoldDomainError is_point(G, randn(2, 3), true) + @test_throws ManifoldDomainError is_point(G, Float64[2 1; 1 1], true) + @test_throws ManifoldDomainError is_point(G, [1 0 im; im 0 0; 0 -1 0], true) + @test_throws ManifoldDomainError is_point(G, zeros(3, 3), true) @test_throws DomainError is_point(G, Float64[1 3 3; 1 1 2; 1 2 3], true) @test is_point(G, Float64[1 1 1; 2 2 1; 2 3 3], true) @test is_point(G, Identity(G), true) - @test_throws DomainError is_vector( + @test_throws ManifoldDomainError is_vector( G, Float64[2 3 2; 3 1 2; 1 1 1], randn(3, 3), @@ -108,8 +90,6 @@ using NLsolve test_manifold( G, gpts; - test_reverse_diff=false, - test_forward_diff=false, test_injectivity_radius=false, test_project_point=true, test_project_tangent=true, @@ -146,13 +126,17 @@ using NLsolve @testset "Complex" begin G = SpecialLinear(2, β„‚) - @test_throws DomainError is_point(G, randn(ComplexF64, 2, 3), true) + @test_throws ManifoldDomainError is_point(G, randn(ComplexF64, 2, 3), true) @test_throws DomainError is_point(G, randn(2, 2), true) - @test_throws DomainError is_point(G, ComplexF64[1 0 im; im 0 0; 0 -1 0], true) + @test_throws ManifoldDomainError is_point( + G, + ComplexF64[1 0 im; im 0 0; 0 -1 0], + true, + ) @test_throws DomainError is_point(G, ComplexF64[1 im; im 1], true) @test is_point(G, ComplexF64[im 1; -2 im], true) @test is_point(G, Identity(G), true) - @test_throws DomainError is_vector( + @test_throws ManifoldDomainError is_vector( G, ComplexF64[-1+im -1; -im 1], ComplexF64[1-im 1+im; 1 -1+im], @@ -199,8 +183,6 @@ using NLsolve test_manifold( G, gpts; - test_reverse_diff=false, - test_forward_diff=false, test_injectivity_radius=false, test_project_point=true, test_project_tangent=true, diff --git a/test/groups/special_orthogonal.jl b/test/groups/special_orthogonal.jl index 348655c75a..bc537c260c 100644 --- a/test/groups/special_orthogonal.jl +++ b/test/groups/special_orthogonal.jl @@ -7,20 +7,7 @@ include("group_utils.jl") M = base_manifold(G) @test M === Rotations(3) x = Matrix(I, 3, 3) - - @test (@inferred invariant_metric_dispatch(G, LeftAction())) === Val(true) - @test (@inferred invariant_metric_dispatch(G, RightAction())) === Val(true) - @test (@inferred Manifolds.biinvariant_metric_dispatch(G)) === Val(true) - @test is_default_metric(MetricManifold(G, EuclideanMetric())) === true - @test is_default_metric( - MetricManifold(G, InvariantMetric(EuclideanMetric(), LeftAction())), - ) === true - @test is_default_metric( - MetricManifold(G, InvariantMetric(EuclideanMetric(), RightAction())), - ) === true - @test Manifolds.default_metric_dispatch(G, EuclideanMetric()) === Val{true}() - @test Manifolds.default_metric_dispatch(MetricManifold(G, EuclideanMetric())) === - Val{true}() + @test is_default_metric(MetricManifold(G, EuclideanMetric())) types = [Matrix{Float64}] Ο‰ = [[1.0, 2.0, 3.0], [3.0, 2.0, 1.0], [1.0, 3.0, 2.0]] @@ -60,9 +47,6 @@ include("group_utils.jl") @testset "Decorator forwards to group" begin DM = NotImplementedGroupDecorator(G) - @test (@inferred Manifolds.decorator_group_dispatch(DM)) === Val(true) - @test Manifolds.is_group_decorator(DM) - @test base_group(DM) === G test_group(DM, pts, vpts, vpts; test_diff=true) end @@ -84,7 +68,6 @@ include("group_utils.jl") test_manifold( G, pts; - test_reverse_diff=false, test_injectivity_radius=false, test_project_tangent=true, test_musical_isomorphisms=false, diff --git a/test/groups/translation_action.jl b/test/groups/translation_action.jl index 29cf9a275d..3996707295 100644 --- a/test/groups/translation_action.jl +++ b/test/groups/translation_action.jl @@ -15,7 +15,7 @@ include("group_utils.jl") types_m = [Matrix{Float64}] - @test g_manifold(A) == M + @test group_manifold(A) == M @test base_group(A) == G @test base_manifold(G) == M diff --git a/test/groups/translation_group.jl b/test/groups/translation_group.jl index 7bf87c28e2..f67b331076 100644 --- a/test/groups/translation_group.jl +++ b/test/groups/translation_group.jl @@ -7,12 +7,12 @@ include("group_utils.jl") @test repr(G) == "TranslationGroup(2, 3; field = ℝ)" @test repr(TranslationGroup(2, 3; field=β„‚)) == "TranslationGroup(2, 3; field = β„‚)" - @test (@inferred invariant_metric_dispatch(G, LeftAction())) === Val(true) - @test (@inferred invariant_metric_dispatch(G, RightAction())) === Val(true) - @test (@inferred Manifolds.biinvariant_metric_dispatch(G)) === Val(true) + @test has_invariant_metric(G, LeftAction()) + @test has_invariant_metric(G, RightAction()) + @test has_biinvariant_metric(G) @test is_default_metric(MetricManifold(G, EuclideanMetric())) === true - @test Manifolds.default_metric_dispatch(MetricManifold(G, EuclideanMetric())) === - Val{true}() + @test is_group_manifold(G) + @test !is_group_manifold(G.manifold) types = [Matrix{Float64}] @test base_manifold(G) === Euclidean(2, 3) @test log_lie(G, Identity(G)) == zeros(2, 3) # log_lie with Identity on Addition group. diff --git a/test/groups/validation_group.jl b/test/groups/validation_group.jl index eaacfb369b..ba35e1e695 100644 --- a/test/groups/validation_group.jl +++ b/test/groups/validation_group.jl @@ -4,9 +4,7 @@ include("../utils.jl") G = SpecialOrthogonal(3) M = Rotations(3) AG = ValidationManifold(G) - @test base_group(AG) === G - @test (@inferred Manifolds.decorator_group_dispatch(AG)) === Val(true) - @test Manifolds.is_group_decorator(AG) + @test is_group_manifold(AG) eg = Matrix{Float64}(I, 3, 3) Ο‰ = [[1.0, 2.0, 3.0], [3.0, 2.0, 1.0]] diff --git a/test/manifolds/centered_matrices.jl b/test/manifolds/centered_matrices.jl index d3ad47bded..c4ffba7e6b 100644 --- a/test/manifolds/centered_matrices.jl +++ b/test/manifolds/centered_matrices.jl @@ -11,16 +11,15 @@ include("../utils.jl") @testset "Real Centered Matrices Basics" begin @test repr(M) == "CenteredMatrices(3, 2, ℝ)" @test representation_size(M) == (3, 2) - @test base_manifold(M) === M @test typeof(get_embedding(M)) === Euclidean{Tuple{3,2},ℝ} @test check_point(M, A) === nothing - @test_throws DomainError is_point(M, B, true) - @test_throws DomainError is_point(M, C, true) + @test_throws ManifoldDomainError is_point(M, B, true) + @test_throws ManifoldDomainError is_point(M, C, true) @test_throws DomainError is_point(M, D, true) @test check_vector(M, A, A) === nothing @test_throws DomainError is_vector(M, A, D, true) - @test_throws DomainError is_vector(M, D, A, true) - @test_throws DomainError is_vector(M, A, B, true) + @test_throws ManifoldDomainError is_vector(M, D, A, true) + @test_throws ManifoldDomainError is_vector(M, A, B, true) @test manifold_dimension(M) == 4 @test A == project!(M, A, A) @test A == project(M, A, A) @@ -35,7 +34,6 @@ include("../utils.jl") M, [A, E, F], test_injectivity_radius=false, - test_reverse_diff=false, test_project_tangent=true, test_musical_isomorphisms=true, test_default_vector_transport=true, @@ -52,7 +50,6 @@ include("../utils.jl") M_complex, [C, G, H], test_injectivity_radius=false, - test_reverse_diff=false, test_project_tangent=true, test_musical_isomorphisms=true, test_default_vector_transport=true, diff --git a/test/manifolds/cholesky_space.jl b/test/manifolds/cholesky_space.jl index e161da3d56..6f04a81d33 100644 --- a/test/manifolds/cholesky_space.jl +++ b/test/manifolds/cholesky_space.jl @@ -28,8 +28,6 @@ include("../utils.jl") SchildsLadderTransport(), PoleLadderTransport(), ], - test_forward_diff=false, - test_reverse_diff=false, test_vee_hat=false, exp_log_atol_multiplier=8.0, test_inplace=true, diff --git a/test/manifolds/circle.jl b/test/manifolds/circle.jl index aa0c6a8496..dfd77e3b81 100644 --- a/test/manifolds/circle.jl +++ b/test/manifolds/circle.jl @@ -9,9 +9,14 @@ using Manifolds: TFVector, CoTFVector @test representation_size(M) == () @test manifold_dimension(M) == 1 @test !is_point(M, 9.0) + @test !is_point(M, zeros(3, 3)) @test_throws DomainError is_point(M, 9.0, true) + @test_throws DomainError is_point(M, zeros(3, 3), true) @test !is_vector(M, 9.0, 0.0) + @test !is_vector(M, zeros(3, 3), zeros(3, 3)) @test_throws DomainError is_vector(M, 9.0, 0.0, true) + @test_throws DomainError is_vector(M, zeros(3, 3), zeros(3, 3), true) + @test_throws DomainError is_vector(M, 0.0, zeros(3, 3), true) @test is_vector(M, 0.0, 0.0) @test get_coordinates(M, Ref(0.0), Ref(2.0), DefaultOrthonormalBasis())[] β‰ˆ 2.0 @test get_coordinates( @@ -41,26 +46,23 @@ using Manifolds: TFVector, CoTFVector y = [0.0] get_coordinates!(M, y, Ref(0.0), Ref(2.0), DiagonalizingOrthonormalBasis(Ref(1.0))) @test y β‰ˆ [2.0] - @test get_vector(M, Ref(0.0), Ref(2.0), DefaultOrthonormalBasis())[] β‰ˆ 2.0 - @test get_vector(M, Ref(0.0), Ref(2.0), DiagonalizingOrthonormalBasis(Ref(1.0)))[] β‰ˆ + @test get_vector(M, Ref(0.0), [2.0], DefaultOrthonormalBasis())[] β‰ˆ 2.0 + @test get_vector(M, [0.0], [2.0], DefaultOrthonormalBasis())[] β‰ˆ 2.0 + @test get_vector(M, Ref(0.0), [2.0], DiagonalizingOrthonormalBasis(Ref(1.0)))[] β‰ˆ 2.0 - @test get_vector(M, Ref(0.0), Ref(-2.0), DiagonalizingOrthonormalBasis(Ref(1.0)))[] β‰ˆ + @test get_vector(M, Ref(0.0), [-2.0], DiagonalizingOrthonormalBasis(Ref(1.0)))[] β‰ˆ -2.0 - @test get_vector(M, Ref(0.0), Ref(2.0), DiagonalizingOrthonormalBasis(Ref(-1.0)))[] β‰ˆ + @test get_vector(M, Ref(0.0), [2.0], DiagonalizingOrthonormalBasis(Ref(-1.0)))[] β‰ˆ -2.0 - @test get_vector( - M, - Ref(0.0), - Ref(-2.0), - DiagonalizingOrthonormalBasis(Ref(-1.0)), - )[] β‰ˆ 2.0 + @test get_vector(M, Ref(0.0), [-2.0], DiagonalizingOrthonormalBasis(Ref(-1.0)))[] β‰ˆ + 2.0 @test number_of_coordinates(M, DiagonalizingOrthonormalBasis(Ref(-1.0))) == 1 @test number_of_coordinates(M, DefaultOrthonormalBasis()) == 1 rrcv = Manifolds.RieszRepresenterCotangentVector(M, 0.0, 1.0) @test flat(M, 0.0, 1.0) == rrcv @test sharp(M, 0.0, rrcv) == 1.0 B_cot = Manifolds.dual_basis(M, 0.0, DefaultOrthonormalBasis()) - @test get_coordinates(M, 0.0, rrcv, B_cot) β‰ˆ 1.0 + @test get_coordinates(M, 0.0, rrcv, B_cot) β‰ˆ @SVector [1.0] @test get_vector(M, 0.0, 1.0, B_cot) isa Manifolds.RieszRepresenterCotangentVector a = fill(NaN) get_coordinates!(M, a, 0.0, rrcv, B_cot) @@ -77,26 +79,26 @@ using Manifolds: TFVector, CoTFVector @test mean(M, [-Ο€ / 2, 0.0, Ο€]) β‰ˆ -Ο€ / 2 @test mean(M, [-Ο€ / 2, 0.0, Ο€], [1.0, 1.0, 1.0]) == -Ο€ / 2 z = project(M, 1.5 * Ο€) - z2 = [0.0] + z2 = fill(0.0) project!(M, z2, 1.5 * Ο€) @test z2[1] == z @test project(M, z) == z @test project(M, 1.0, 2.0) == 2.0 end TEST_STATIC_SIZED && @testset "Real Circle and static sized arrays" begin - v = MVector(0.0) - x = SVector(0.0) - log!(M, v, x, SVector(Ο€ / 4)) - @test norm(M, x, v) β‰ˆ Ο€ / 4 - @test is_vector(M, x, v) - @test is_vector(M, [], v) + X = @MArray fill(0.0) + p = @SArray fill(0.0) + log!(M, X, p, @SArray fill(Ο€ / 4)) + @test norm(M, p, X) β‰ˆ Ο€ / 4 + @test is_vector(M, p, X) + @test is_vector(M, [], X) @test project(M, 1.0) == 1.0 - x = MVector(0.0) - project!(M, x, x) - @test x == MVector(0.0) - x .+= 2 * Ο€ - project!(M, x, x) - @test x == MVector(0.0) + p = @MArray fill(0.0) + project!(M, p, p) + @test p == @MArray fill(0.0) + p .+= 2 * Ο€ + project!(M, p, p) + @test p == @MArray fill(0.0) @test project(M, 0.0, 1.0) == 1.0 end types = [Float64] @@ -114,8 +116,6 @@ using Manifolds: TFVector, CoTFVector test_manifold( M, pts, - test_forward_diff=false, - test_reverse_diff=false, test_vector_spaces=false, test_project_tangent=true, test_musical_isomorphisms=true, @@ -125,12 +125,10 @@ using Manifolds: TFVector, CoTFVector test_rand_point=true, test_rand_tvector=true, ) - ptsS = SVector.(pts) + ptsS = map(p -> (@SArray fill(p)), pts) test_manifold( M, ptsS, - test_forward_diff=false, - test_reverse_diff=false, test_project_tangent=true, test_musical_isomorphisms=true, test_default_vector_transport=true, @@ -147,6 +145,25 @@ using Manifolds: TFVector, CoTFVector ) end end + @testset "Mutating Rand for real Circle" begin + p = fill(NaN) + X = fill(NaN) + rand!(M, p) + @test is_point(M, p) + rand!(M, X; vector_at=p) + @test is_vector(M, p, X) + + rng = MersenneTwister() + rand!(rng, M, p) + @test is_point(M, p) + rand!(rng, M, X; vector_at=p) + @test is_vector(M, p, X) + end + @testset "Test sym_rem" begin + p = 4.0 # not a point + p = sym_rem(p) # modulo to a point + @test is_point(M, p) + end Mc = Circle(β„‚) @testset "Complex Circle Basics" begin @test repr(Mc) == "Circle(β„‚)" @@ -165,22 +182,22 @@ using Manifolds: TFVector, CoTFVector @test sharp(Mc, 0.0 + 0.0im, rrcv) == 1.0im @test norm(Mc, 1.0, log(Mc, 1.0, -1.0)) β‰ˆ Ο€ @test is_vector(Mc, 1.0, log(Mc, 1.0, -1.0)) - v = MVector(0.0 + 0.0im) - x = SVector(1.0 + 0.0im) - log!(Mc, v, x, SVector(-1.0 + 0.0im)) - @test norm(Mc, SVector(1.0), v) β‰ˆ Ο€ - @test is_vector(Mc, x, v) + X = @MArray fill(0.0 + 0.0im) + p = @SArray fill(1.0 + 0.0im) + log!(Mc, X, p, @SArray fill(-1.0 + 0.0im)) + @test norm(Mc, (@SArray fill(1.0)), X) β‰ˆ Ο€ + @test is_vector(Mc, p, X) @test project(Mc, 1.0) == 1.0 - project(Mc, 1 / sqrt(2.0) + 1 / sqrt(2.0) * im) == - 1 / sqrt(2.0) + 1 / sqrt(2.0) * im - x = MVector(1.0 + 0.0im) - project!(Mc, x, x) - @test x == MVector(1.0 + 0.0im) - x .*= 2 - project!(Mc, x, x) - @test x == MVector(1.0 + 0.0im) + @test project(Mc, 1 / sqrt(2.0) + 1 / sqrt(2.0) * im) β‰ˆ + 1 / sqrt(2.0) + 1 / sqrt(2.0) * im + p = @MArray fill(1.0 + 0.0im) + project!(Mc, p, p) + @test p == @MArray fill(1.0 + 0.0im) + p .*= 2 + project!(Mc, p, p) + @test p == @MArray fill(1.0 + 0.0im) - angles = map(x -> exp(x * im), [-Ο€ / 2, 0.0, Ο€]) + angles = map(pp -> exp(pp * im), [-Ο€ / 2, 0.0, Ο€]) @test mean(Mc, angles) β‰ˆ exp(-Ο€ * im / 2) @test mean(Mc, angles, [1.0, 1.0, 1.0]) β‰ˆ exp(-Ο€ * im / 2) @test_throws ErrorException mean(Mc, [-1.0 + 0im, 1.0 + 0im]) @@ -196,8 +213,6 @@ using Manifolds: TFVector, CoTFVector test_manifold( Mc, pts, - test_forward_diff=false, - test_reverse_diff=false, test_vector_spaces=false, test_project_tangent=true, test_musical_isomorphisms=true, @@ -207,12 +222,10 @@ using Manifolds: TFVector, CoTFVector exp_log_atol_multiplier=2.0, is_tangent_atol_multiplier=2.0, ) - ptsS = SVector.(pts) + ptsS = map(p -> (@SArray fill(p)), pts) test_manifold( Mc, ptsS, - test_forward_diff=false, - test_reverse_diff=false, test_project_tangent=true, test_musical_isomorphisms=true, test_default_vector_transport=true, diff --git a/test/manifolds/elliptope.jl b/test/manifolds/elliptope.jl index f0488b5dbb..de9a9c50f2 100644 --- a/test/manifolds/elliptope.jl +++ b/test/manifolds/elliptope.jl @@ -40,8 +40,6 @@ include("../utils.jl") M, pts, test_injectivity_radius=false, - test_reverse_diff=false, - test_forward_diff=false, test_project_tangent=true, test_project_point=true, test_exp_log=false, diff --git a/test/manifolds/essential_manifold.jl b/test/manifolds/essential_manifold.jl index 048a8024ed..3c961f3045 100644 --- a/test/manifolds/essential_manifold.jl +++ b/test/manifolds/essential_manifold.jl @@ -20,6 +20,7 @@ include("../utils.jl") np2 = [nr, nr] np3 = [r1, r2, r3] @test !is_point(M, r1) + # first two components of r1 are not rotations @test_throws DomainError is_point(M, r1, true) @test_throws DomainError is_point(M, np3, true) @test is_point(M, p1) @@ -42,8 +43,6 @@ include("../utils.jl") test_manifold( M, [p1, p2, p3], - test_forward_diff=false, - test_reverse_diff=false, test_vector_spaces=true, test_project_point=true, projection_atol_multiplier=10, @@ -55,14 +54,13 @@ include("../utils.jl") mid_point12=nothing, exp_log_atol_multiplier=4, test_inplace=true, + parallel_transport=true, ) end @testset "Unsigned Essential" begin test_manifold( EssentialManifold(false), [p1, p2, p3], - test_forward_diff=false, - test_reverse_diff=false, test_vector_spaces=true, test_project_point=true, projection_atol_multiplier=10, @@ -73,6 +71,7 @@ include("../utils.jl") test_exp_log=true, mid_point12=nothing, exp_log_atol_multiplier=4, + parallel_transport=true, ) end end diff --git a/test/manifolds/euclidean.jl b/test/manifolds/euclidean.jl index 71f35b1ddc..8817e0a013 100644 --- a/test/manifolds/euclidean.jl +++ b/test/manifolds/euclidean.jl @@ -10,9 +10,6 @@ using Manifolds: induced_basis @test repr(Ec) == "Euclidean(3; field = β„‚)" @test repr(Euclidean(2, 3; field=ℍ)) == "Euclidean(2, 3; field = ℍ)" @test Manifolds.allocation_promotion_function(Ec, get_vector, ()) === complex - @test is_default_metric(EM) - @test is_default_metric(E, Manifolds.EuclideanMetric()) - @test Manifolds.default_metric_dispatch(E, Manifolds.EuclideanMetric()) === Val{true}() p = zeros(3) A = Manifolds.RetractionAtlas() B = induced_basis(EM, A, p, TangentSpace) @@ -28,6 +25,26 @@ using Manifolds: induced_basis @test Y == X @test embed(E, p, X) == X + # temp: explicit test for induced basis + B = induced_basis(E, RetractionAtlas(), 0, ManifoldsBase.TangentSpaceType()) + @test get_coordinates(E, p, X, B) == X + get_coordinates!(E, Y, p, X, B) + @test Y == X + @test get_vector(E, p, Y, B) == X + Y2 = similar(X) + get_vector!(E, Y2, p, Y, B) + @test Y2 == X + + Y = parallel_transport_along(E, p, X, [p]) + @test Y == X + parallel_transport_along!(E, Y, p, X, [p]) + @test Y == X + + Y = vector_transport_along(E, p, X, [p]) + @test Y == X + vector_transport_along!(E, Y, p, X, [p]) + @test Y == X + # real manifold does not allow complex values @test_throws DomainError is_point(Ec, [:a, :b, :b], true) @test_throws DomainError is_point(E, [1.0, 1.0im, 0.0], true) @@ -67,6 +84,7 @@ using Manifolds: induced_basis DefaultOrthonormalBasis(), DefaultOrthonormalBasis(β„‚), DiagonalizingOrthonormalBasis([1.0, 2.0, 3.0]), + DiagonalizingOrthonormalBasis([1.0, 2.0, 3.0], β„‚), ) else () @@ -81,7 +99,6 @@ using Manifolds: induced_basis test_manifold( M, pts, - test_reverse_diff=isa(T, Vector), test_project_point=true, test_project_tangent=true, test_musical_isomorphisms=true, @@ -122,11 +139,11 @@ using Manifolds: induced_basis test_manifold( Ec, pts, - test_reverse_diff=isa(T, Vector), test_project_tangent=true, test_musical_isomorphisms=true, test_default_vector_transport=true, test_vee_hat=false, + parallel_transport=true, ) end end @@ -146,8 +163,6 @@ using Manifolds: induced_basis test_manifold( M, pts, - test_forward_diff=false, - test_reverse_diff=false, test_vector_spaces=false, test_project_tangent=true, test_musical_isomorphisms=true, diff --git a/test/manifolds/fixed_rank.jl b/test/manifolds/fixed_rank.jl index 7286c1c137..14b527549b 100644 --- a/test/manifolds/fixed_rank.jl +++ b/test/manifolds/fixed_rank.jl @@ -4,7 +4,8 @@ include("../utils.jl") M = FixedRankMatrices(3, 2, 2) M2 = FixedRankMatrices(3, 2, 1) Mc = FixedRankMatrices(3, 2, 2, β„‚) - p = SVDMPoint([1.0 0.0; 0.0 1.0; 0.0 0.0]) + pE = [1.0 0.0; 0.0 1.0; 0.0 0.0] + p = SVDMPoint(pE) p2 = SVDMPoint([1.0 0.0; 0.0 1.0; 0.0 0.0], 1) X = UMVTVector([0.0 0.0; 0.0 0.0; 1.0 1.0], [1.0 0.0; 0.0 1.0], zeros(2, 2)) @test repr(M) == "FixedRankMatrices(3, 2, 2, ℝ)" @@ -52,6 +53,8 @@ include("../utils.jl") @test !is_point(M, SVDMPoint([1.0 0.0; 0.0 0.0], 2)) @test_throws DomainError is_point(M, SVDMPoint([1.0 0.0; 0.0 0.0], 2), true) @test is_point(M2, p2) + @test_throws DomainError is_point(M2, [1.0 0.0; 0.0 1.0; 0.0 0.0], true) + @test Manifolds.check_point(M2, [1.0 0.0; 0.0 1.0; 0.0 0.0]) isa DomainError @test !is_vector( M, @@ -59,7 +62,12 @@ include("../utils.jl") UMVTVector(zeros(2, 1), zeros(1, 2), zeros(2, 2)), ) @test !is_vector(M, SVDMPoint([1.0 0.0; 0.0 0.0], 2), X) - @test_throws DomainError is_vector(M, SVDMPoint([1.0 0.0; 0.0 0.0], 2), X, true) + @test_throws ManifoldDomainError is_vector( + M, + SVDMPoint([1.0 0.0; 0.0 0.0], 2), + X, + true, + ) @test !is_vector(M, p, UMVTVector(p.U, X.M, p.Vt, 2)) @test_throws DomainError is_vector(M, p, UMVTVector(p.U, X.M, p.Vt, 2), true) @test !is_vector(M, p, UMVTVector(X.U, X.M, p.Vt, 2)) @@ -67,6 +75,12 @@ include("../utils.jl") @test is_point(M, p) @test is_vector(M, p, X) + + q = embed(M, p) + @test pE == q + q2 = similar(q) + embed!(M, q2, p) + @test q == q2 end types = [[Matrix{Float64}, Vector{Float64}, Matrix{Float64}]] TEST_FLOAT32 && push!(types, [Matrix{Float32}, Vector{Float32}, Matrix{Float32}]) @@ -109,7 +123,7 @@ include("../utils.jl") @test y2 == y3 @test is_point(M, p) - xM = p.U * Diagonal(p.S) * p.Vt + xM = embed(M, p) @test is_point(M, xM) @test !is_point(M, xM[1:2, :]) @test_throws DomainError is_point(M, xM[1:2, :], true) @@ -165,7 +179,7 @@ include("../utils.jl") wb = w .+ X .* 2 @test wb isa UMVTVector @test wb == w + X * 2 - wb .= 2 .* w .+ X + @test (wb .= 2 .* w .+ X) == 2 * w + X @test wb == 2 * w + X wb .= w @test wb == w @@ -173,6 +187,9 @@ include("../utils.jl") N = get_embedding(M) B = embed(M, p, X) @test isapprox(N, p, B, p.U * X.M * p.Vt + X.U * p.Vt + p.U * X.Vt) + BB = similar(B) + embed!(M, BB, p, X) + @test isapprox(M, p, B, BB) v2 = project(M, p, B) @test isapprox(M, p, X, v2) end @@ -190,11 +207,9 @@ include("../utils.jl") default_retraction_method=PolarRetraction(), test_is_tangent=false, test_default_vector_transport=false, - test_forward_diff=false, - test_reverse_diff=false, test_vector_spaces=false, test_vee_hat=false, - test_tangent_vector_broadcasting=false, #broadcast not so easy for 3 matrix type + test_tangent_vector_broadcasting=true, projection_atol_multiplier=15, retraction_methods=[PolarRetraction()], vector_transport_methods=[ProjectionTransport()], diff --git a/test/manifolds/generalized_grassmann.jl b/test/manifolds/generalized_grassmann.jl index b8c6905321..94aaf044f0 100644 --- a/test/manifolds/generalized_grassmann.jl +++ b/test/manifolds/generalized_grassmann.jl @@ -5,17 +5,30 @@ include("../utils.jl") B = [1.0 0.0 0.0; 0.0 4.0 0.0; 0.0 0.0 1.0] M = GeneralizedGrassmann(3, 2, B) p = [1.0 0.0; 0.0 0.5; 0.0 0.0] + X = zeros(3, 2) + X[1, :] .= 1.0 @testset "Basics" begin @test repr(M) == "GeneralizedGrassmann(3, 2, [1.0 0.0 0.0; 0.0 4.0 0.0; 0.0 0.0 1.0], ℝ)" @test representation_size(M) == (3, 2) @test manifold_dimension(M) == 2 @test base_manifold(M) === M - @test_throws DomainError is_point(M, [1.0, 0.0, 0.0, 0.0], true) - @test_throws DomainError is_point(M, 1im * [1.0 0.0; 0.0 1.0; 0.0 0.0], true) + @test_throws ManifoldDomainError is_point(M, [1.0, 0.0, 0.0, 0.0], true) + @test_throws ManifoldDomainError is_point( + M, + 1im * [1.0 0.0; 0.0 1.0; 0.0 0.0], + true, + ) + @test_throws ManifoldDomainError is_point(M, 2 * p, true) @test !is_vector(M, p, [0.0, 0.0, 1.0, 0.0]) - @test_throws DomainError is_vector(M, p, [0.0, 0.0, 1.0, 0.0], true) - @test_throws DomainError is_vector(M, p, 1 * im * zero_vector(M, p), true) + @test_throws ManifoldDomainError is_vector(M, p, [0.0, 0.0, 1.0, 0.0], true) + @test_throws ManifoldDomainError is_vector( + M, + p, + 1 * im * zero_vector(M, p), + true, + ) + @test_throws ManifoldDomainError is_vector(M, p, X, true) @test injectivity_radius(M) == Ο€ / 2 @test injectivity_radius(M, ExponentialRetraction()) == Ο€ / 2 @test injectivity_radius(M, p) == Ο€ / 2 @@ -73,9 +86,9 @@ include("../utils.jl") @testset "Type $T" for T in types pts = convert.(T, [p, q, r]) @test !is_point(M, 2 * p) - @test_throws DomainError !is_point(M, 2 * r, true) + @test_throws ManifoldDomainError !is_point(M, 2 * r, true) @test !is_vector(M, p, q) - @test_throws DomainError is_vector(M, p, q, true) + @test_throws ManifoldDomainError is_vector(M, p, q, true) test_manifold( M, pts, @@ -86,8 +99,6 @@ include("../utils.jl") test_is_tangent=true, test_project_tangent=true, test_default_vector_transport=false, - test_forward_diff=false, - test_reverse_diff=false, projection_atol_multiplier=15.0, retraction_atol_multiplier=10.0, is_tangent_atol_multiplier=4 * 10.0^2, diff --git a/test/manifolds/generalized_stiefel.jl b/test/manifolds/generalized_stiefel.jl index 26b99f3c1c..ffd7a1a4f2 100644 --- a/test/manifolds/generalized_stiefel.jl +++ b/test/manifolds/generalized_stiefel.jl @@ -4,7 +4,9 @@ include("../utils.jl") @testset "Real" begin B = [1.0 0.0 0.0; 0.0 4.0 0.0; 0.0 0.0 1.0] M = GeneralizedStiefel(3, 2, B) - x = [1.0 0.0; 0.0 0.5; 0.0 0.0] + p = [1.0 0.0; 0.0 0.5; 0.0 0.0] + X = zeros(3, 2) + X[1, :] .= 1 @testset "Basics" begin @test repr(M) == "GeneralizedStiefel(3, 2, [1.0 0.0 0.0; 0.0 4.0 0.0; 0.0 0.0 1.0], ℝ)" @@ -12,28 +14,41 @@ include("../utils.jl") @test manifold_dimension(M) == 3 @test base_manifold(M) === M @test_throws DomainError is_point(M, [1.0, 0.0, 0.0, 0.0], true) - @test_throws DomainError is_point(M, 1im * [1.0 0.0; 0.0 1.0; 0.0 0.0], true) - @test !is_vector(M, x, [0.0, 0.0, 1.0, 0.0]) - @test_throws DomainError is_vector(M, x, [0.0, 0.0, 1.0, 0.0], true) - @test_throws DomainError is_vector(M, x, 1 * im * zero_vector(M, x), true) + @test_throws ManifoldDomainError is_point( + M, + 1im * [1.0 0.0; 0.0 1.0; 0.0 0.0], + true, + ) + @test_throws DomainError is_point(M, 2 * p, true) + @test !is_vector(M, p, [0.0, 0.0, 1.0, 0.0]) + @test_throws DomainError is_vector(M, p, [0.0, 0.0, 1.0, 0.0], true) + @test_throws ManifoldDomainError is_vector( + M, + p, + 1 * im * zero_vector(M, p), + true, + ) + @test_throws DomainError is_vector(M, p, X, true) + @test default_retraction_method(M) == ProjectionRetraction() end @testset "Embedding and Projection" begin - y = similar(x) - z = embed(M, x) - @test z == x - embed!(M, y, x) + @test get_embedding(GeneralizedStiefel(3, 2)) == Euclidean(3, 2) + y = similar(p) + z = embed(M, p) + @test z == p + embed!(M, y, p) @test y == z a = [1.0 0.0; 0.0 2.0; 0.0 0.0] @test !is_point(M, a) b = similar(a) c = project(M, a) - @test c == x + @test c == p project!(M, b, a) - @test b == x + @test b == p X = [0.0 0.0; 0.0 0.0; -1.0 1.0] Y = similar(X) - Z = embed(M, x, X) - embed!(M, Y, x, X) + Z = embed(M, p, X) + embed!(M, Y, p, X) @test Y == X @test Z == X end @@ -42,28 +57,28 @@ include("../utils.jl") TEST_STATIC_SIZED && push!(types, MMatrix{3,2,Float64,6}) X = [0.0 0.0; 0.0 0.0; 1.0 1.0] Y = [0.0 0.0; 0.0 0.0; -1.0 1.0] - @test inner(M, x, X, Y) == 0 - y = retract(M, x, X) - z = retract(M, x, Y) + @test inner(M, p, X, Y) == 0 + y = retract(M, p, X) + z = retract(M, p, Y) @test is_point(M, y) @test is_point(M, z) - a = project(M, x + X) - b = retract(M, x, X) - c = retract(M, x, X, ProjectionRetraction()) - d = retract(M, x, X, PolarRetraction()) + a = project(M, p + X) + b = retract(M, p, X) + c = retract(M, p, X, ProjectionRetraction()) + d = retract(M, p, X, PolarRetraction()) @test a == b @test c == d @test b == c e = similar(a) - retract!(M, e, x, X) + retract!(M, e, p, X) @test e == a - @test vector_transport_to(M, x, X, y, ProjectionTransport()) == project(M, y, X) + @test vector_transport_to(M, p, X, y, ProjectionTransport()) == project(M, y, X) @testset "Type $T" for T in types - pts = convert.(T, [x, y, z]) - @test !is_point(M, 2 * x) - @test_throws DomainError !is_point(M, 2 * x, true) - @test !is_vector(M, x, y) - @test_throws DomainError is_vector(M, x, y, true) + pts = convert.(T, [p, y, z]) + @test !is_point(M, 2 * p) + @test_throws DomainError !is_point(M, 2 * p, true) + @test !is_vector(M, p, y) + @test_throws DomainError is_vector(M, p, y, true) test_manifold( M, pts, @@ -74,8 +89,6 @@ include("../utils.jl") test_is_tangent=true, test_project_tangent=true, test_default_vector_transport=false, - test_forward_diff=false, - test_reverse_diff=false, projection_atol_multiplier=15.0, retraction_atol_multiplier=10.0, is_tangent_atol_multiplier=4 * 10.0^2, diff --git a/test/manifolds/graph.jl b/test/manifolds/graph.jl index 15e7bbbd16..d0db4ca5f7 100644 --- a/test/manifolds/graph.jl +++ b/test/manifolds/graph.jl @@ -24,12 +24,7 @@ include("../utils.jl") @test incident_log(N, x) == [x[2] - x[1], x[1] - x[2] + x[3] - x[2], x[2] - x[3]] pts = [x, y, z] - test_manifold( - N, - pts; - test_representation_size=false, - test_reverse_diff=false, #VERSION > v"1.2", - ) + test_manifold(N, pts; test_representation_size=false) @test sprint(show, "text/plain", N) == """ GraphManifold Graph: @@ -51,7 +46,6 @@ include("../utils.jl") NE, [x[1:2], y[1:2], z[1:2]]; test_representation_size=false, - test_reverse_diff=false, #VERSION > v"1.2", test_inplace=true, ) @test sprint(show, "text/plain", NE) == """ diff --git a/test/manifolds/grassmann.jl b/test/manifolds/grassmann.jl index ac2166bf1a..3ef1f4a421 100644 --- a/test/manifolds/grassmann.jl +++ b/test/manifolds/grassmann.jl @@ -9,28 +9,32 @@ include("../utils.jl") @test manifold_dimension(M) == 2 @test !is_point(M, [1.0, 0.0, 0.0, 0.0]) @test !is_vector(M, [1.0 0.0; 0.0 1.0; 0.0 0.0], [0.0, 0.0, 1.0, 0.0]) - @test_throws DomainError is_point(M, [2.0 0.0; 0.0 1.0; 0.0 0.0], true) - @test_throws DomainError is_vector( + @test_throws ManifoldDomainError is_point(M, [2.0 0.0; 0.0 1.0; 0.0 0.0], true) + @test_throws ManifoldDomainError is_vector( M, [2.0 0.0; 0.0 1.0; 0.0 0.0], zeros(3, 2), true, ) - @test_throws DomainError is_vector( + @test_throws ManifoldDomainError is_vector( M, [1.0 0.0; 0.0 1.0; 0.0 0.0], ones(3, 2), true, ) @test is_point(M, [1.0 0.0; 0.0 1.0; 0.0 0.0], true) - @test_throws DomainError is_point(M, 1im * [1.0 0.0; 0.0 1.0; 0.0 0.0], true) + @test_throws ManifoldDomainError is_point( + M, + 1im * [1.0 0.0; 0.0 1.0; 0.0 0.0], + true, + ) @test is_vector( M, [1.0 0.0; 0.0 1.0; 0.0 0.0], zero_vector(M, [1.0 0.0; 0.0 1.0; 0.0 0.0]), true, ) - @test_throws DomainError is_vector( + @test_throws ManifoldDomainError is_vector( M, [1.0 0.0; 0.0 1.0; 0.0 0.0], 1im * zero_vector(M, [1.0 0.0; 0.0 1.0; 0.0 0.0]), @@ -51,12 +55,12 @@ include("../utils.jl") TEST_FLOAT32 && push!(types, Matrix{Float32}) basis_types = (ProjectedOrthonormalBasis(:gram_schmidt),) @testset "Type $T" for T in types - x = [1.0 0.0; 0.0 1.0; 0.0 0.0] - v = [0.0 0.0; 0.0 0.0; 0.0 1.0] - y = exp(M, x, v) - w = [0.0 1.0; -1.0 0.0; 1.0 0.0] - z = exp(M, x, w) - pts = convert.(T, [x, y, z]) + p1 = [1.0 0.0; 0.0 1.0; 0.0 0.0] + X = [0.0 0.0; 0.0 0.0; 0.0 1.0] + p2 = exp(M, p1, X) + Y = [0.0 1.0; -1.0 0.0; 1.0 0.0] + p3 = exp(M, p1, Y) + pts = convert.(T, [p1, p2, p3]) test_manifold( M, pts, @@ -65,8 +69,6 @@ include("../utils.jl") test_project_tangent=true, test_default_vector_transport=false, point_distributions=[Manifolds.uniform_distribution(M, pts[1])], - test_forward_diff=false, - test_reverse_diff=false, test_vee_hat=false, test_rand_point=true, test_rand_tvector=true, @@ -82,15 +84,15 @@ include("../utils.jl") ) @testset "inner/norm" begin - v1 = inverse_retract(M, pts[1], pts[2], PolarInverseRetraction()) - v2 = inverse_retract(M, pts[1], pts[3], PolarInverseRetraction()) + X1 = inverse_retract(M, pts[1], pts[2], PolarInverseRetraction()) + X2 = inverse_retract(M, pts[1], pts[3], PolarInverseRetraction()) - @test real(inner(M, pts[1], v1, v2)) β‰ˆ real(inner(M, pts[1], v2, v1)) - @test imag(inner(M, pts[1], v1, v2)) β‰ˆ -imag(inner(M, pts[1], v2, v1)) - @test imag(inner(M, pts[1], v1, v1)) β‰ˆ 0 + @test real(inner(M, pts[1], X1, X2)) β‰ˆ real(inner(M, pts[1], X2, X1)) + @test imag(inner(M, pts[1], X1, X2)) β‰ˆ -imag(inner(M, pts[1], X2, X1)) + @test imag(inner(M, pts[1], X1, X1)) β‰ˆ 0 - @test norm(M, pts[1], v1) isa Real - @test norm(M, pts[1], v1) β‰ˆ sqrt(inner(M, pts[1], v1, v1)) + @test norm(M, pts[1], X1) isa Real + @test norm(M, pts[1], X1) β‰ˆ sqrt(inner(M, pts[1], X1, X1)) end end @@ -104,14 +106,15 @@ include("../utils.jl") end @testset "vector transport" begin - x = [1.0 0.0; 0.0 1.0; 0.0 0.0] - v = [0.0 0.0; 0.0 0.0; 0.0 1.0] - y = exp(M, x, v) - @test vector_transport_to(M, x, v, y, ProjectionTransport()) == project(M, y, v) + p1 = [1.0 0.0; 0.0 1.0; 0.0 0.0] + X = [0.0 0.0; 0.0 0.0; 0.0 1.0] + p2 = exp(M, p1, X) + @test vector_transport_to(M, p1, X, p2, ProjectionTransport()) == + project(M, p2, X) @test is_vector( M, - y, - vector_transport_to(M, x, v, y, ProjectionTransport()), + p2, + vector_transport_to(M, p1, X, p2, ProjectionTransport()), true; atol=10^-15, ) @@ -127,14 +130,14 @@ include("../utils.jl") @test !is_point(M, [1.0, 0.0, 0.0, 0.0]) @test !is_vector(M, [1.0 0.0; 0.0 1.0; 0.0 0.0], [0.0, 0.0, 1.0, 0.0]) @test Manifolds.allocation_promotion_function(M, exp!, (1,)) == complex - @test_throws DomainError is_point(M, [2.0 0.0; 0.0 1.0; 0.0 0.0], true) - @test_throws DomainError is_vector( + @test_throws ManifoldDomainError is_point(M, [2.0 0.0; 0.0 1.0; 0.0 0.0], true) + @test_throws ManifoldDomainError is_vector( M, [2.0 0.0; 0.0 1.0; 0.0 0.0], zeros(3, 2), true, ) - @test_throws DomainError is_vector( + @test_throws ManifoldDomainError is_vector( M, [1.0 0.0; 0.0 1.0; 0.0 0.0], ones(3, 2), @@ -150,12 +153,12 @@ include("../utils.jl") end types = [Matrix{ComplexF64}] @testset "Type $T" for T in types - x = [0.5+0.5im 0.5+0.5im; 0.5+0.5im -0.5-0.5im; 0.0 0.0] - v = [0.0 0.0; 0.0 0.0; 0.0 1.0] - y = exp(M, x, v) - w = [0.0 1.0; -1.0 0.0; 1.0 0.0] - z = exp(M, x, w) - pts = convert.(T, [x, y, z]) + p1 = [0.5+0.5im 0.5+0.5im; 0.5+0.5im -0.5-0.5im; 0.0 0.0] + X = [0.0 0.0; 0.0 0.0; 0.0 1.0] + p2 = exp(M, p1, X) + Y = [0.0 1.0; -1.0 0.0; 1.0 0.0] + p3 = exp(M, p1, Y) + pts = convert.(T, [p1, p2, p3]) test_manifold( M, pts, @@ -163,8 +166,6 @@ include("../utils.jl") test_injectivity_radius=false, test_project_tangent=true, test_default_vector_transport=false, - test_forward_diff=false, - test_reverse_diff=false, test_vee_hat=false, retraction_methods=[PolarRetraction(), QRRetraction()], inverse_retraction_methods=[ @@ -178,15 +179,15 @@ include("../utils.jl") ) @testset "inner/norm" begin - v1 = inverse_retract(M, pts[1], pts[2], PolarInverseRetraction()) - v2 = inverse_retract(M, pts[1], pts[3], PolarInverseRetraction()) + X1 = inverse_retract(M, pts[1], pts[2], PolarInverseRetraction()) + X2 = inverse_retract(M, pts[1], pts[3], PolarInverseRetraction()) - @test real(inner(M, pts[1], v1, v2)) β‰ˆ real(inner(M, pts[1], v2, v1)) - @test imag(inner(M, pts[1], v1, v2)) β‰ˆ -imag(inner(M, pts[1], v2, v1)) - @test imag(inner(M, pts[1], v1, v1)) β‰ˆ 0 + @test real(inner(M, pts[1], X1, X2)) β‰ˆ real(inner(M, pts[1], X2, X1)) + @test imag(inner(M, pts[1], X1, X2)) β‰ˆ -imag(inner(M, pts[1], X2, X1)) + @test isapprox(imag(inner(M, pts[1], X1, X1)), 0; atol=1e-30) - @test norm(M, pts[1], v1) isa Real - @test norm(M, pts[1], v1) β‰ˆ sqrt(inner(M, pts[1], v1, v1)) + @test norm(M, pts[1], X1) isa Real + @test norm(M, pts[1], X1) β‰ˆ sqrt(inner(M, pts[1], X1, X1)) end end end @@ -196,7 +197,7 @@ include("../utils.jl") p = reshape([im, 0.0, 0.0], 3, 1) @test is_point(G, p) X = reshape([-0.5; 0.5; 0], 3, 1) - @test_throws DomainError is_vector(G, p, X, true) + @test_throws ManifoldDomainError is_vector(G, p, X, true) Y = project(G, p, X) @test is_vector(G, p, Y) end diff --git a/test/manifolds/hyperbolic.jl b/test/manifolds/hyperbolic.jl index dfff1d690d..b2c10a6664 100644 --- a/test/manifolds/hyperbolic.jl +++ b/test/manifolds/hyperbolic.jl @@ -18,12 +18,15 @@ include("../utils.jl") @test_throws DomainError is_point(M, [2.0, 0.0, 0.0], true) @test !is_point(M, [2.0, 0.0, 0.0]) @test !is_vector(M, [1.0, 0.0, 0.0], [1.0, 0.0, 0.0]) - @test Manifolds.default_metric_dispatch(M, MinkowskiMetric()) === Val{true}() - @test_throws DomainError is_vector(M, [1.0, 0.0, 0.0], [1.0, 0.0, 0.0], true) + @test_throws ManifoldDomainError is_vector( + M, + [1.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + true, + ) @test !is_vector(M, [0.0, 0.0, 1.0], [1.0, 0.0, 1.0]) @test_throws DomainError is_vector(M, [0.0, 0.0, 1.0], [1.0, 0.0, 1.0], true) @test is_default_metric(M, MinkowskiMetric()) - @test Manifolds.default_metric_dispatch(M, MinkowskiMetric()) === Val{true}() @test manifold_dimension(M) == 2 for (P, T) in zip( @@ -68,10 +71,15 @@ include("../utils.jl") @test convert(AbstractVector, pB) == p # convert back yields again p @test convert(HyperboloidPoint, pB).value == pH.value @test_throws DomainError is_point(M, PoincareBallPoint([0.9, 0.0, 0.0]), true) - @test_throws DomainError is_point(M, PoincareBallPoint([1.0, 0.0]), true) + @test_throws DomainError is_point(M, PoincareBallPoint([1.1, 0.0]), true) @test is_vector(M, pB, PoincareBallTVector([2.0, 2.0])) - + @test_throws DomainError is_vector( + M, + pB, + PoincareBallTVector([2.0, 2.0, 3.0]), + true, + ) pS = convert(PoincareHalfSpacePoint, p) pS2 = convert(PoincareHalfSpacePoint, pB) pS3 = convert(PoincareHalfSpacePoint, pH) @@ -104,7 +112,6 @@ include("../utils.jl") @test isapprox(M, p1, X1, X2) # Test broadcast @test 2 .* X1 == T(2 .* X1.value) - @test 2 .* p1 == P(2 .* p1.value) @test copy(X1) == X1 @test copy(X1) !== X1 X1s = similar(X1) @@ -166,8 +173,6 @@ include("../utils.jl") exp_log_atol_multiplier=10.0, retraction_methods=(ExponentialRetraction(),), test_vee_hat=false, - test_forward_diff=is_plain_array, - test_reverse_diff=is_plain_array, test_tangent_vector_broadcasting=is_plain_array, test_vector_spaces=is_plain_array, test_inplace=true, @@ -190,13 +195,45 @@ include("../utils.jl") p2 = HyperboloidPoint(p) X2 = HyperboloidTVector(X) q2 = HyperboloidPoint(similar(p)) - @test embed(M, p2).value == p2.value + q3 = similar(p) + @test embed(M, p2) == p2.value embed!(M, q2, p2) + embed!(M, q3, p2) @test q2.value == p2.value - @test embed(M, p2, X2).value == X2.value + @test q3 == p2.value + @test embed(M, p2, X2) == X2.value Y2 = HyperboloidTVector(similar(X)) + Y3 = similar(X) embed!(M, Y2, p2, X2) @test Y2.value == X2.value + embed!(M, Y3, p2, X2) + @test Y3 == X2.value + # check embed for PoincareBall + p4 = convert(PoincareBallPoint, p) + X4 = convert(PoincareBallTVector, p, X) + q4 = embed(M, p4) + @test isapprox(q4, zeros(2)) + q4b = similar(q4) + embed!(M, q4b, p4) + @test q4b == q4 + Y4 = embed(M, p4, X4) + @test Y4 == X4.value + Y4b = similar(Y4) + embed!(M, Y4b, p4, X4) + @test Y4 == Y4b + # check embed for PoincareHalfSpace + p5 = convert(PoincareHalfSpacePoint, p) + X5 = convert(PoincareHalfSpaceTVector, p, X) + q5 = embed(M, p5) + @test isapprox(q5, [0.0; 1.0]) + q5b = similar(q5) + embed!(M, q5b, p5) + @test q5b == q5 + Y5 = embed(M, p5, X5) + @test Y5 == X5.value + Y5b = similar(Y5) + embed!(M, Y5b, p5, X5) + @test Y5 == Y5b end @testset "Hyperbolic mean test" begin pts = [ @@ -229,6 +266,8 @@ include("../utils.jl") @test is_vector(M, p, X) c = get_coordinates(M, p, X, B) @test c β‰ˆ [0.5, 1.0] + c2 = similar(c) + get_coordinates!(M, c2, p, X, DefaultOrthonormalBasis()) B2 = DiagonalizingOrthonormalBasis(X) V2 = get_vectors(M, p, get_basis(M, p, B2)) @test V2[1] β‰ˆ X ./ norm(M, p, X) diff --git a/test/manifolds/lorentz.jl b/test/manifolds/lorentz.jl index d84fe66d13..59e0362818 100644 --- a/test/manifolds/lorentz.jl +++ b/test/manifolds/lorentz.jl @@ -1,7 +1,16 @@ include("../utils.jl") -@testset "Minkowski Metric" begin - N = Euclidean(3) - M = MetricManifold(N, MinkowskiMetric()) - @test local_metric(M, zeros(3)) == Diagonal([1.0, 1.0, -1.0]) +@testset "Lorentz Manifold" begin + M = Lorentz(3) + @testset "Minkowski Metric" begin + N = MetricManifold(Euclidean(3), MinkowskiMetric()) + @test N == M + @test local_metric(M, zeros(3)) == Diagonal([1.0, 1.0, -1.0]) + # check minkowski metric is called + p = zeros(3) + X = [1.0, 2.0, 3.0] + Y = [2.0, 3.0, 4.0] + @test minkowski_metric(X, Y) == -4 + @test inner(M, p, X, Y) == minkowski_metric(X, Y) + end end diff --git a/test/manifolds/multinomial_doubly_stochastic.jl b/test/manifolds/multinomial_doubly_stochastic.jl index a5886c72a0..5e2c4efcca 100644 --- a/test/manifolds/multinomial_doubly_stochastic.jl +++ b/test/manifolds/multinomial_doubly_stochastic.jl @@ -9,14 +9,14 @@ include("../utils.jl") @test is_point(M, p) @test is_vector(M, p, X) pf1 = [0.1 0.9 0.1; 0.1 0.9 0.1; 0.1 0.1 0.9] #not sum 1 - @test_throws CompositeManifoldError is_point(M, pf1, true) + @test_throws ManifoldDomainError is_point(M, pf1, true) pf2r = [0.1 0.9 0.1; 0.8 0.05 0.15; 0.1 0.05 0.75] @test_throws DomainError is_point(M, pf2r, true) - @test_throws CompositeManifoldError is_point(M, pf2r', true) + @test_throws ManifoldDomainError is_point(M, pf2r', true) pf3 = [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0] # contains nonpositive entries - @test_throws CompositeManifoldError is_point(M, pf3, true) + @test_throws ManifoldDomainError is_point(M, pf3, true) Xf2c = [-0.1 0.0 0.1; -0.2 0.1 0.1; 0.2 -0.1 -0.1] #nonzero columns - @test_throws CompositeManifoldError is_vector(M, p, Xf2c, true) + @test_throws ManifoldDomainError is_vector(M, p, Xf2c, true) @test_throws DomainError is_vector(M, p, Xf2c', true) @test representation_size(M) == (3, 3) pE = similar(p) @@ -50,8 +50,6 @@ include("../utils.jl") M, pts, test_injectivity_radius=false, - test_reverse_diff=false, - test_forward_diff=false, test_project_tangent=true, test_exp_log=false, test_default_vector_transport=true, diff --git a/test/manifolds/multinomial_matrices.jl b/test/manifolds/multinomial_matrices.jl index d578f2a956..25d7f191c6 100644 --- a/test/manifolds/multinomial_matrices.jl +++ b/test/manifolds/multinomial_matrices.jl @@ -35,8 +35,6 @@ include("../utils.jl") M, [x, y, z], test_injectivity_radius=false, - test_forward_diff=false, - test_reverse_diff=false, test_vector_spaces=true, test_project_tangent=false, test_musical_isomorphisms=true, diff --git a/test/manifolds/multinomial_symmetric.jl b/test/manifolds/multinomial_symmetric.jl index 8f0e7a6f82..cfbee347e0 100644 --- a/test/manifolds/multinomial_symmetric.jl +++ b/test/manifolds/multinomial_symmetric.jl @@ -9,15 +9,15 @@ include("../utils.jl") @test is_point(M, p) @test is_vector(M, p, X) pf1 = [0.1 0.9 0.1; 0.1 0.9 0.1; 0.1 0.1 0.9] #not symmetric - @test_throws CompositeManifoldError is_point(M, pf1, true) + @test_throws ManifoldDomainError is_point(M, pf1, true) pf2 = [0.8 0.1 0.1; 0.1 0.8 0.1; 0.1 0.1 0.9] # cols do not sum to 1 - @test_throws ComponentManifoldError is_point(M, pf2, true) + @test_throws ManifoldDomainError is_point(M, pf2, true) pf3 = [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0] # contains nonpositive entries - @test_throws CompositeManifoldError is_point(M, pf3, true) + @test_throws ManifoldDomainError is_point(M, pf3, true) Xf1 = [0.0 1.0 -1.0; 0.0 0.0 0.0; 0.0 0.0 0.0] # not symmetric - @test_throws CompositeManifoldError is_vector(M, p, Xf1, true) + @test_throws ManifoldDomainError is_vector(M, p, Xf1, true) Xf2 = [0.0 -1.0 0.0; -1.0 0.0 0.0; 0.0 0.0 0.0] # nonzero sums - @test_throws CompositeManifoldError is_vector(M, p, Xf2, true) + @test_throws ManifoldDomainError is_vector(M, p, Xf2, true) @test representation_size(M) == (3, 3) pE = similar(p) embed!(M, pE, p) @@ -55,8 +55,6 @@ include("../utils.jl") M, pts, test_injectivity_radius=false, - test_reverse_diff=false, - test_forward_diff=false, test_project_tangent=true, test_exp_log=false, test_default_vector_transport=true, diff --git a/test/manifolds/oblique.jl b/test/manifolds/oblique.jl index 1b66a94977..db59ba1c72 100644 --- a/test/manifolds/oblique.jl +++ b/test/manifolds/oblique.jl @@ -31,12 +31,10 @@ include("../utils.jl") y = [1.0 0.0 0.0; 1/sqrt(2) 1/sqrt(2) 0.0]' z = [1/sqrt(2) 1/sqrt(2) 0.0; 1.0 0.0 0.0]' basis_types = (DefaultOrthonormalBasis(),) - transports = [ParallelTransport(), PowerVectorTransport(ParallelTransport())] + transports = [ParallelTransport()] test_manifold( M, [x, y, z], - test_forward_diff=false, - test_reverse_diff=false, test_vector_spaces=true, test_project_tangent=false, test_musical_isomorphisms=true, diff --git a/test/manifolds/positive_numbers.jl b/test/manifolds/positive_numbers.jl index 3e4b00ee1a..91c9e56f0e 100644 --- a/test/manifolds/positive_numbers.jl +++ b/test/manifolds/positive_numbers.jl @@ -38,8 +38,6 @@ include("../utils.jl") test_manifold( M, pts, - test_forward_diff=false, - test_reverse_diff=false, test_vector_spaces=false, test_project_tangent=true, test_musical_isomorphisms=true, @@ -56,8 +54,6 @@ include("../utils.jl") test_manifold( M2, pts2, - test_forward_diff=false, - test_reverse_diff=false, test_vector_spaces=false, test_project_tangent=true, test_musical_isomorphisms=true, diff --git a/test/manifolds/power_manifold.jl b/test/manifolds/power_manifold.jl index 346567113b..545c879c2d 100644 --- a/test/manifolds/power_manifold.jl +++ b/test/manifolds/power_manifold.jl @@ -1,7 +1,6 @@ include("../utils.jl") using HybridArrays, Random -using Manifolds: default_metric_dispatch using StaticArrays: Dynamic Random.seed!(42) @@ -55,8 +54,6 @@ end @test Ms^(5,) === Ms1 @test Mr^(5, 7) === Mr2 - @test is_default_metric(Ms1, PowerMetric()) - @test default_metric_dispatch(Ms1, PowerMetric()) === Val{true}() types_s1 = [Array{Float64,2}, HybridArray{Tuple{3,Dynamic()},Float64,2}] types_s2 = [Array{Float64,3}, HybridArray{Tuple{3,Dynamic(),Dynamic()},Float64,3}] @@ -68,9 +65,8 @@ end types_r2 = [Array{Float64,4}, HybridArray{Tuple{3,3,Dynamic(),Dynamic()},Float64,4}] types_rn2 = [Matrix{Matrix{Float64}}] - retraction_methods = [Manifolds.PowerRetraction(ManifoldsBase.ExponentialRetraction())] - inverse_retraction_methods = - [Manifolds.InversePowerRetraction(ManifoldsBase.LogarithmicInverseRetraction())] + retraction_methods = [ManifoldsBase.ExponentialRetraction()] + inverse_retraction_methods = [ManifoldsBase.LogarithmicInverseRetraction()] sphere_dist = Manifolds.uniform_distribution(Ms, @SVector [1.0, 0.0, 0.0]) power_s1_pt_dist = @@ -171,7 +167,7 @@ end end @testset "power vector transport" begin - m = PowerVectorTransport(ParallelTransport()) + m = ParallelTransport() p = repeat([1.0, 0.0, 0.0], 1, 5) q = repeat([0.0, 1.0, 0.0], 1, 5) X = log(Ms1, p, q) @@ -195,16 +191,12 @@ end test_manifold( Ms1, pts1; - test_reverse_diff=true, test_musical_isomorphisms=true, test_injectivity_radius=false, test_default_vector_transport=true, test_project_point=true, test_project_tangent=true, vector_transport_methods=[ - PowerVectorTransport(ParallelTransport()), - PowerVectorTransport(SchildsLadderTransport()), - PowerVectorTransport(PoleLadderTransport()), ParallelTransport(), SchildsLadderTransport(), PoleLadderTransport(), @@ -231,7 +223,6 @@ end test_manifold( Ms2, pts2; - test_reverse_diff=true, test_musical_isomorphisms=true, test_injectivity_radius=false, test_vee_hat=true, @@ -254,7 +245,6 @@ end test_manifold( Mr1, pts1; - test_reverse_diff=false, test_injectivity_radius=false, test_musical_isomorphisms=true, test_vee_hat=true, @@ -278,7 +268,6 @@ end test_manifold( Mrn1, pts1; - test_reverse_diff=false, test_injectivity_radius=false, test_musical_isomorphisms=true, test_vee_hat=true, @@ -303,7 +292,6 @@ end test_manifold( Mr2, pts2; - test_reverse_diff=false, test_injectivity_radius=false, test_musical_isomorphisms=true, test_vee_hat=true, @@ -325,7 +313,6 @@ end test_manifold( Mrn2, pts2; - test_reverse_diff=false, test_injectivity_radius=false, test_musical_isomorphisms=true, test_vee_hat=true, @@ -349,8 +336,6 @@ end test_manifold( MT, pts_t; - test_reverse_diff=false, - test_forward_diff=false, test_injectivity_radius=false, test_musical_isomorphisms=true, test_vee_hat=true, @@ -393,8 +378,6 @@ end test_manifold( MT, pts_t; - test_reverse_diff=false, - test_forward_diff=false, test_injectivity_radius=false, test_musical_isomorphisms=true, retraction_methods=retraction_methods, diff --git a/test/manifolds/probability_simplex.jl b/test/manifolds/probability_simplex.jl index 5d2e09dc70..ec8ed7e19e 100644 --- a/test/manifolds/probability_simplex.jl +++ b/test/manifolds/probability_simplex.jl @@ -9,17 +9,15 @@ include("../utils.jl") Y = [-0.1, 0.05, 0.05] @test is_point(M, p) @test_throws DomainError is_point(M, p .+ 1, true) - @test_throws DomainError is_point(M, [0], true) + @test_throws ManifoldDomainError is_point(M, [0], true) @test_throws DomainError is_point(M, -ones(3), true) @test manifold_dimension(M) == 2 @test is_vector(M, p, X) @test is_vector(M, p, Y) - @test_throws DomainError is_vector(M, p .+ 1, X, true) - @test_throws DomainError is_vector(M, p, zeros(4), true) + @test_throws ManifoldDomainError is_vector(M, p .+ 1, X, true) + @test_throws ManifoldDomainError is_vector(M, p, zeros(4), true) @test_throws DomainError is_vector(M, p, Y .+ 1, true) - @test Manifolds.default_metric_dispatch(M, Manifolds.FisherRaoMetric()) === Val{true}() - @test injectivity_radius(M, p) == injectivity_radius(M, p, ExponentialRetraction()) @test injectivity_radius(M, p, SoftmaxRetraction()) == injectivity_radius(M, p) @test injectivity_radius(M, ExponentialRetraction()) == 0 @@ -54,8 +52,6 @@ include("../utils.jl") test_project_tangent=true, test_musical_isomorphisms=true, test_vee_hat=false, - test_forward_diff=false, - test_reverse_diff=false, is_tangent_atol_multiplier=5.0, inverse_retraction_methods=[SoftmaxInverseRetraction()], retraction_methods=[SoftmaxRetraction()], diff --git a/test/manifolds/product_manifold.jl b/test/manifolds/product_manifold.jl index 3b32d2ec5b..91fce44f61 100644 --- a/test/manifolds/product_manifold.jl +++ b/test/manifolds/product_manifold.jl @@ -15,6 +15,11 @@ using RecursiveArrayTools: ArrayPartition @test Mse[1] == M1 @test Mse[2] == M2 @test injectivity_radius(Mse) β‰ˆ Ο€ + @test injectivity_radius( + Mse, + ProductRetraction(ExponentialRetraction(), ExponentialRetraction()), + ) β‰ˆ Ο€ + @test injectivity_radius(Mse, ExponentialRetraction()) β‰ˆ Ο€ @test injectivity_radius( Mse, ProductRepr([0.0, 1.0, 0.0], [0.0, 0.0]), @@ -26,12 +31,16 @@ using RecursiveArrayTools: ArrayPartition ExponentialRetraction(), ) β‰ˆ Ο€ @test is_default_metric(Mse, ProductMetric()) - @test Manifolds.default_metric_dispatch(Mse, ProductMetric()) === Val{true}() @test Manifolds.number_of_components(Mse) == 2 # test that arrays are not points @test_throws DomainError is_point(Mse, [1, 2], true) + @test check_point(Mse, [1, 2]) isa DomainError @test_throws DomainError is_vector(Mse, 1, [1, 2], true; check_base_point=false) + @test check_vector(Mse, 1, [1, 2]; check_base_point=false) isa DomainError + #default fallbacks for check_size, Product not working with Arrays + @test Manifolds.check_size(Mse, zeros(2)) isa DomainError + @test Manifolds.check_size(Mse, zeros(2), zeros(3)) isa DomainError types = [Vector{Float64}] TEST_FLOAT32 && push!(types, Vector{Float32}) TEST_STATIC_SIZED && push!(types, MVector{5,Float64}) @@ -261,8 +270,6 @@ using RecursiveArrayTools: ArrayPartition Mser, pts, test_injectivity_radius=false, - test_forward_diff=false, - test_reverse_diff=false, is_tangent_atol_multiplier=1, exp_log_atol_multiplier=1, test_inplace=true, @@ -315,12 +322,50 @@ using RecursiveArrayTools: ArrayPartition @test isapprox(Mse, q, Y, Z) end end + @testset "Parallel transport" begin + p = ProductRepr([1.0, 0.0, 0.0], [0.0, 0.0]) + q = ProductRepr([0.0, 1.0, 0.0], [2.0, 0.0]) + X = log(Mse, p, q) + # to + Y = parallel_transport_to(Mse, p, X, q) + Z1 = parallel_transport_to( + Mse.manifolds[1], + submanifold_component.([p, X, q], Ref(1))..., + ) + Z2 = parallel_transport_to( + Mse.manifolds[2], + submanifold_component.([p, X, q], Ref(2))..., + ) + Z = ProductRepr(Z1, Z2) + @test isapprox(Mse, q, Y, Z) + Ym = allocate(Y) + parallel_transport_to!(Mse, Ym, p, X, q) + @test isapprox(Mse, q, Y, Z) + + # direction + Y = parallel_transport_direction(Mse, p, X, X) + Z1 = parallel_transport_direction( + Mse.manifolds[1], + submanifold_component.([p, X, X], Ref(1))..., + ) + Z2 = parallel_transport_direction( + Mse.manifolds[2], + submanifold_component.([p, X, X], Ref(2))..., + ) + Z = ProductRepr(Z1, Z2) + @test isapprox(Mse, q, Y, Z) + Ym = allocate(Y) + parallel_transport_direction!(Mse, Ym, p, X, X) + @test isapprox(Mse, q, Ym, Z) + end @testset "ProductRepr" begin @test (@inferred convert( ProductRepr{Tuple{T,Float64,T} where T}, ProductRepr(9, 10, 11), )) == ProductRepr(9, 10.0, 11) + @test (@inferred convert(ProductRepr, ProductRepr(9, 10, 11))) === + ProductRepr(9, 10, 11) p = ProductRepr([1.0, 0.0, 0.0], [0.0, 0.0]) @test p == ProductRepr([1.0, 0.0, 0.0], [0.0, 0.0]) @@ -330,6 +375,8 @@ using RecursiveArrayTools: ArrayPartition @test submanifold_component(p, Val(1)) === p.parts[1] @test submanifold_components(Mse, p) === p.parts @test submanifold_components(p) === p.parts + @test allocate(p, Int, 10) isa Vector{Int} + @test length(allocate(p, Int, 10)) == 10 end @testset "ArrayPartition" begin @@ -378,6 +425,13 @@ using RecursiveArrayTools: ArrayPartition @test injectivity_radius(Mse, pts[1], ExponentialRetraction()) β‰ˆ Ο€ @test injectivity_radius(Mse, ExponentialRetraction()) β‰ˆ Ο€ + @test ManifoldsBase.allocate_coordinates( + Mse, + pts[1], + Float64, + number_of_coordinates(Mse, DefaultOrthogonalBasis()), + ) isa Vector{Float64} + test_manifold( Mse, pts; @@ -391,8 +445,6 @@ using RecursiveArrayTools: ArrayPartition test_musical_isomorphisms=true, musical_isomorphism_bases=[DefaultOrthonormalBasis()], test_tangent_vector_broadcasting=true, - test_forward_diff=true, - test_reverse_diff=true, test_project_tangent=true, test_project_point=true, test_mutating_rand=false, @@ -467,12 +519,20 @@ using RecursiveArrayTools: ArrayPartition @testset "Basis-related errors" begin a = ProductRepr([1.0, 0.0, 0.0], [0.0, 0.0]) - @test_throws ErrorException get_vector!( + B = CachedBasis(DefaultOrthonormalBasis(), ProductBasisData(([],))) + @test_throws AssertionError get_vector!( + Mse, + a, + ProductRepr([1.0, 0.0, 0.0], [0.0, 0.0]), + [1.0, 2.0, 3.0, 4.0, 5.0], # this is one element too long, hence assertionerror + B, + ) + @test_throws MethodError get_vector!( Mse, a, ProductRepr([1.0, 0.0, 0.0], [0.0, 0.0]), - [1.0, 2.0, 3.0, 4.0, 5.0], - CachedBasis(DefaultOrthonormalBasis(), []), + [1.0, 2.0, 3.0, 4.0], + B, # empty elements yield a submanifold MethodError ) end diff --git a/test/manifolds/projective_space.jl b/test/manifolds/projective_space.jl index a6aa8547bb..304a412e44 100644 --- a/test/manifolds/projective_space.jl +++ b/test/manifolds/projective_space.jl @@ -48,8 +48,6 @@ include("../utils.jl") tvector_distributions=[ Manifolds.normal_tvector_distribution(M, pts[1], 1.0), ], - test_forward_diff=false, - test_reverse_diff=false, basis_types_vecs=( DiagonalizingOrthonormalBasis([0.0, 1.0, 2.0]), basis_types..., @@ -151,8 +149,6 @@ include("../utils.jl") SchildsLadderTransport(), PoleLadderTransport(), ], - test_forward_diff=false, - test_reverse_diff=false, basis_types_to_from=(DefaultOrthonormalBasis(),), test_vee_hat=false, retraction_methods=[ @@ -211,8 +207,8 @@ include("../utils.jl") ) @test !is_vector( M, - Quaternion[1.0 + 0im, 0.0, 0.0], - Quaternion[-0.5im, 0.0, 0.0], + Quaternion.([1.0 + 0im, 0.0, 0.0]), + Quaternion.([-0.5im, 0.0, 0.0]), ) @test_throws DomainError is_vector( M, @@ -222,8 +218,8 @@ include("../utils.jl") ) @test_throws DomainError is_vector( M, - Quaternion[1.0 + 0im, 0.0, 0.0], - Quaternion[-0.5im, 0.0, 0.0], + Quaternion.([1.0 + 0im, 0.0, 0.0]), + Quaternion.([-0.5im, 0.0, 0.0]), true, ) @test injectivity_radius(M) == Ο€ / 2 @@ -261,8 +257,6 @@ include("../utils.jl") SchildsLadderTransport(), PoleLadderTransport(), ], - test_forward_diff=false, - test_reverse_diff=false, basis_types_to_from=(DefaultOrthonormalBasis(),), test_vee_hat=false, retraction_methods=[ diff --git a/test/manifolds/rotations.jl b/test/manifolds/rotations.jl index b2c53da74c..779720aa36 100644 --- a/test/manifolds/rotations.jl +++ b/test/manifolds/rotations.jl @@ -11,6 +11,7 @@ include("../utils.jl") Ο€ * sqrt(2.0) @test injectivity_radius(M, PolarRetraction()) β‰ˆ Ο€ / sqrt(2) @test injectivity_radius(M, [1.0 0.0; 0.0 1.0], PolarRetraction()) β‰ˆ Ο€ / sqrt(2) + @test get_embedding(M) == Euclidean(2, 2) types = [Matrix{Float64}] TEST_FLOAT32 && push!(types, Matrix{Float32}) TEST_STATIC_SIZED && push!(types, MMatrix{2,2,Float64,4}) @@ -42,7 +43,6 @@ include("../utils.jl") test_manifold( M, pts; - test_reverse_diff=false, test_injectivity_radius=false, test_project_tangent=true, test_musical_isomorphisms=true, @@ -106,8 +106,6 @@ include("../utils.jl") test_manifold( SOn, pts; - test_forward_diff=n == 3, - test_reverse_diff=false, test_injectivity_radius=false, test_musical_isomorphisms=true, test_mutating_rand=true, @@ -157,8 +155,6 @@ include("../utils.jl") X = Manifolds.hat(SOn, Matrix(1.0I, n, n), Xf) p = exp(X) @test p β‰ˆ exp(SOn, one(p), X) - @test ForwardDiff.derivative(t -> exp(SOn, one(p), t * X), 0) β‰ˆ - X p2 = exp(log(SOn, one(p), p)) @test isapprox(p, p2; atol=1e-6) end diff --git a/test/manifolds/skewhermitian.jl b/test/manifolds/skewhermitian.jl index 116759d5bd..8fcbc743a1 100644 --- a/test/manifolds/skewhermitian.jl +++ b/test/manifolds/skewhermitian.jl @@ -25,13 +25,13 @@ end @test typeof(get_embedding(M)) === Euclidean{Tuple{3,3},ℝ} @test check_point(M, B_skewsym) === nothing @test_throws DomainError is_point(M, A, true) - @test_throws DomainError is_point(M, C, true) + @test_throws ManifoldDomainError is_point(M, C, true) @test_throws DomainError is_point(M, D, true) @test check_vector(M, B_skewsym, B_skewsym) === nothing @test_throws DomainError is_vector(M, B_skewsym, A, true) - @test_throws DomainError is_vector(M, A, B_skewsym, true) + @test_throws ManifoldDomainError is_vector(M, A, B_skewsym, true) @test_throws DomainError is_vector(M, B_skewsym, D, true) - @test_throws DomainError is_vector( + @test_throws ManifoldDomainError is_vector( M, B_skewsym, 1 * im * zero_vector(M, B_skewsym), @@ -58,7 +58,6 @@ end M, pts, test_injectivity_radius=false, - test_reverse_diff=isa(T, Vector), test_project_tangent=true, test_musical_isomorphisms=true, test_default_vector_transport=true, @@ -88,7 +87,6 @@ end M_complex, pts_complex, test_injectivity_radius=false, - test_reverse_diff=isa(T, Vector), test_project_tangent=true, test_musical_isomorphisms=true, test_default_vector_transport=true, diff --git a/test/manifolds/spectrahedron.jl b/test/manifolds/spectrahedron.jl index 5c3d351791..70ccb0e517 100644 --- a/test/manifolds/spectrahedron.jl +++ b/test/manifolds/spectrahedron.jl @@ -43,8 +43,6 @@ include("../utils.jl") M, pts, test_injectivity_radius=false, - test_reverse_diff=false, - test_forward_diff=false, test_project_tangent=true, test_exp_log=false, test_default_vector_transport=true, diff --git a/test/manifolds/sphere.jl b/test/manifolds/sphere.jl index 591140f976..33f70d68ba 100644 --- a/test/manifolds/sphere.jl +++ b/test/manifolds/sphere.jl @@ -13,6 +13,8 @@ using ManifoldsBase: TFVector @test injectivity_radius(M, ExponentialRetraction()) == Ο€ @test injectivity_radius(M, ProjectionRetraction()) == Ο€ / 2 @test base_manifold(M) === M + @test is_default_metric(M, EuclideanMetric()) + @test !is_default_metric(M, LinearAffineMetric()) @test !is_point(M, [1.0, 0.0, 0.0, 0.0]) @test !is_vector(M, [1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]) @test_throws DomainError is_point(M, [2.0, 0.0, 0.0], true) @@ -37,7 +39,6 @@ using ManifoldsBase: TFVector test_manifold( M, pts, - test_reverse_diff=isa(T, Vector), test_project_tangent=true, test_musical_isomorphisms=true, test_default_vector_transport=true, diff --git a/test/manifolds/sphere_symmetric_matrices.jl b/test/manifolds/sphere_symmetric_matrices.jl index 02d99ce939..e3960b83ee 100644 --- a/test/manifolds/sphere_symmetric_matrices.jl +++ b/test/manifolds/sphere_symmetric_matrices.jl @@ -16,16 +16,16 @@ include("../utils.jl") @test base_manifold(M) === M @test typeof(get_embedding(M)) === ArraySphere{Tuple{3,3},ℝ} @test check_point(M, A) === nothing - @test_throws DomainError is_point(M, B, true) - @test_throws DomainError is_point(M, C, true) + @test_throws ManifoldDomainError is_point(M, B, true) + @test_throws ManifoldDomainError is_point(M, C, true) @test_throws DomainError is_point(M, D, true) - @test_throws DomainError is_point(M, E, true) + @test_throws ManifoldDomainError is_point(M, E, true) @test check_vector(M, A, zeros(3, 3)) === nothing - @test_throws DomainError is_vector(M, A, B, true) - @test_throws DomainError is_vector(M, A, C, true) - @test_throws DomainError is_vector(M, A, D, true) - @test_throws DomainError is_vector(M, D, A, true) - @test_throws DomainError is_vector(M, A, E, true) + @test_throws ManifoldDomainError is_vector(M, A, B, true) + @test_throws ManifoldDomainError is_vector(M, A, C, true) + @test_throws ManifoldDomainError is_vector(M, A, D, true) + @test_throws ManifoldDomainError is_vector(M, D, A, true) + @test_throws ManifoldDomainError is_vector(M, A, E, true) @test_throws DomainError is_vector(M, J, K, true) @test manifold_dimension(M) == 5 A2 = similar(A) @@ -41,8 +41,6 @@ include("../utils.jl") M, [A, F, G], test_injectivity_radius=false, - test_forward_diff=false, - test_reverse_diff=false, test_vector_spaces=true, test_project_tangent=true, test_musical_isomorphisms=true, @@ -64,8 +62,6 @@ include("../utils.jl") M_complex, [C, H, I], test_injectivity_radius=false, - test_forward_diff=false, - test_reverse_diff=false, test_vector_spaces=true, test_project_tangent=true, test_musical_isomorphisms=true, diff --git a/test/manifolds/stiefel.jl b/test/manifolds/stiefel.jl index c009aca1a9..f5c596763f 100644 --- a/test/manifolds/stiefel.jl +++ b/test/manifolds/stiefel.jl @@ -1,7 +1,5 @@ include("../utils.jl") -using Manifolds: default_metric_dispatch - @testset "Stiefel" begin @testset "Real" begin M = Stiefel(3, 2) @@ -9,14 +7,23 @@ using Manifolds: default_metric_dispatch @testset "Basics" begin @test repr(M) == "Stiefel(3, 2, ℝ)" x = [1.0 0.0; 0.0 1.0; 0.0 0.0] - @test (@inferred default_metric_dispatch(M2)) === Val(true) + @test is_default_metric(M, EuclideanMetric()) @test representation_size(M) == (3, 2) @test manifold_dimension(M) == 3 base_manifold(M) === M - @test_throws DomainError is_point(M, [1.0, 0.0, 0.0, 0.0], true) - @test_throws DomainError is_point(M, 1im * [1.0 0.0; 0.0 1.0; 0.0 0.0], true) + @test_throws ManifoldDomainError is_point(M, [1.0, 0.0, 0.0, 0.0], true) + @test_throws ManifoldDomainError is_point( + M, + 1im * [1.0 0.0; 0.0 1.0; 0.0 0.0], + true, + ) @test !is_vector(M, x, [0.0, 0.0, 1.0, 0.0]) - @test_throws DomainError is_vector(M, x, 1 * im * zero_vector(M, x), true) + @test_throws ManifoldDomainError is_vector( + M, + x, + 1 * im * zero_vector(M, x), + true, + ) end @testset "Embedding and Projection" begin x = [1.0 0.0; 0.0 1.0; 0.0 0.0] @@ -87,7 +94,7 @@ using Manifolds: default_metric_dispatch x = [1.0 0.0; 0.0 1.0; 0.0 0.0] y = exp(M, x, [0.0 0.0; 0.0 0.0; 1.0 1.0]) z = exp(M, x, [0.0 0.0; 0.0 0.0; -1.0 1.0]) - @test_throws ErrorException distance(M, x, y) + @test_throws MethodError distance(M, x, y) @test isapprox( M, retract( @@ -104,7 +111,7 @@ using Manifolds: default_metric_dispatch @test !is_point(M, 2 * x) @test_throws DomainError !is_point(M, 2 * x, true) @test !is_vector(M, 2 * x, v) - @test_throws DomainError !is_vector(M, 2 * x, v, true) + @test_throws ManifoldDomainError !is_vector(M, 2 * x, v, true) @test !is_vector(M, x, y) @test_throws DomainError is_vector(M, x, y, true) test_manifold( @@ -119,8 +126,6 @@ using Manifolds: default_metric_dispatch test_project_tangent=true, test_default_vector_transport=false, point_distributions=[Manifolds.uniform_distribution(M, pts[1])], - test_forward_diff=false, - test_reverse_diff=false, test_vee_hat=false, projection_atol_multiplier=15.0, retraction_atol_multiplier=10.0, @@ -136,8 +141,8 @@ using Manifolds: default_metric_dispatch QRInverseRetraction(), ], vector_transport_methods=[ - DifferentiatedRetractionVectorTransport{PolarRetraction}(), - DifferentiatedRetractionVectorTransport{QRRetraction}(), + DifferentiatedRetractionVectorTransport(PolarRetraction()), + DifferentiatedRetractionVectorTransport(QRRetraction()), ProjectionTransport(), ], vector_transport_retractions=[ @@ -201,7 +206,7 @@ using Manifolds: default_metric_dispatch @test !is_point(M, 2 * x) @test_throws DomainError !is_point(M, 2 * x, true) @test !is_vector(M, 2 * x, v) - @test_throws DomainError !is_vector(M, 2 * x, v, true) + @test_throws ManifoldDomainError !is_vector(M, 2 * x, v, true) @test !is_vector(M, x, y) @test_throws DomainError is_vector(M, x, y, true) test_manifold( @@ -213,8 +218,6 @@ using Manifolds: default_metric_dispatch test_is_tangent=true, test_project_tangent=true, test_default_vector_transport=false, - test_forward_diff=false, - test_reverse_diff=false, test_vee_hat=false, projection_atol_multiplier=15.0, retraction_atol_multiplier=10.0, @@ -225,8 +228,8 @@ using Manifolds: default_metric_dispatch QRInverseRetraction(), ], vector_transport_methods=[ - DifferentiatedRetractionVectorTransport{PolarRetraction}(), - DifferentiatedRetractionVectorTransport{QRRetraction}(), + DifferentiatedRetractionVectorTransport(PolarRetraction()), + DifferentiatedRetractionVectorTransport(QRRetraction()), ProjectionTransport(), ], vector_transport_retractions=[ @@ -289,7 +292,7 @@ using Manifolds: default_metric_dispatch p, X, X, - DifferentiatedRetractionVectorTransport{CayleyRetraction}(), + DifferentiatedRetractionVectorTransport(CayleyRetraction()), ) @test is_vector(M, q1, Y2; atol=10^-15) r2 = PadeRetraction(2) diff --git a/test/manifolds/symmetric.jl b/test/manifolds/symmetric.jl index 26b7332644..abf0300477 100644 --- a/test/manifolds/symmetric.jl +++ b/test/manifolds/symmetric.jl @@ -19,13 +19,18 @@ include("../utils.jl") @test typeof(get_embedding(M)) === Euclidean{Tuple{3,3},ℝ} @test check_point(M, B_sym) === nothing @test_throws DomainError is_point(M, A, true) - @test_throws DomainError is_point(M, C, true) - @test_throws DomainError is_point(M, D, true) + @test_throws ManifoldDomainError is_point(M, C, true) + @test_throws ManifoldDomainError is_point(M, D, true) #embedding changes type @test check_vector(M, B_sym, B_sym) === nothing @test_throws DomainError is_vector(M, B_sym, A, true) - @test_throws DomainError is_vector(M, A, B_sym, true) - @test_throws DomainError is_vector(M, B_sym, D, true) - @test_throws DomainError is_vector(M, B_sym, 1 * im * zero_vector(M, B_sym), true) + @test_throws ManifoldDomainError is_vector(M, A, B_sym, true) + @test_throws ManifoldDomainError is_vector(M, B_sym, D, true) + @test_throws ManifoldDomainError is_vector( + M, + B_sym, + 1 * im * zero_vector(M, B_sym), + true, + ) @test manifold_dimension(M) == 6 @test manifold_dimension(M_complex) == 9 @test A_sym2 == project!(M, A_sym, A_sym) @@ -48,7 +53,6 @@ include("../utils.jl") M, pts, test_injectivity_radius=false, - test_reverse_diff=isa(T, Vector), test_project_tangent=true, test_musical_isomorphisms=true, test_default_vector_transport=true, @@ -64,7 +68,6 @@ include("../utils.jl") M_complex, pts, test_injectivity_radius=false, - test_reverse_diff=isa(T, Vector), test_project_tangent=true, test_musical_isomorphisms=true, test_default_vector_transport=true, diff --git a/test/manifolds/symmetric_positive_definite.jl b/test/manifolds/symmetric_positive_definite.jl index 8b7e918d13..d5c9563ce7 100644 --- a/test/manifolds/symmetric_positive_definite.jl +++ b/test/manifolds/symmetric_positive_definite.jl @@ -1,7 +1,5 @@ include("../utils.jl") -using Manifolds: default_metric_dispatch - @testset "Symmetric Positive Definite Matrices" begin M1 = SymmetricPositiveDefinite(3) @test repr(M1) == "SymmetricPositiveDefinite(3)" @@ -9,17 +7,6 @@ using Manifolds: default_metric_dispatch M3 = MetricManifold(SymmetricPositiveDefinite(3), Manifolds.LogCholeskyMetric()) M4 = MetricManifold(SymmetricPositiveDefinite(3), Manifolds.LogEuclideanMetric()) - @test (@inferred default_metric_dispatch(M2)) === Val(true) - @test (@inferred default_metric_dispatch(M1, Manifolds.LinearAffineMetric())) === - Val(true) - @test (@inferred default_metric_dispatch(M1, Manifolds.LogCholeskyMetric())) === - Val(false) - @test (@inferred default_metric_dispatch(M3)) === Val(false) - @test is_default_metric(M2) - @test is_default_metric(M1, Manifolds.LinearAffineMetric()) - @test !is_default_metric(M1, Manifolds.LogCholeskyMetric()) - @test !is_default_metric(M3) - @test injectivity_radius(M1) == Inf @test injectivity_radius(M1, one(zeros(3, 3))) == Inf @test injectivity_radius(M1, ExponentialRetraction()) == Inf @@ -62,8 +49,6 @@ using Manifolds: default_metric_dispatch test_default_vector_transport=true, vector_transport_methods=typeof(M) == SymmetricPositiveDefinite{3} ? [ParallelTransport()] : [], - test_forward_diff=false, - test_reverse_diff=false, test_vee_hat=M === M2, exp_log_atol_multiplier=exp_log_atol_multiplier, basis_types_vecs=basis_types, diff --git a/test/manifolds/symmetric_positive_semidefinite_fixed_rank.jl b/test/manifolds/symmetric_positive_semidefinite_fixed_rank.jl index 3dcccf2dd2..5f56398ca5 100644 --- a/test/manifolds/symmetric_positive_semidefinite_fixed_rank.jl +++ b/test/manifolds/symmetric_positive_semidefinite_fixed_rank.jl @@ -29,8 +29,6 @@ include("../utils.jl") pts, exp_log_atol_multiplier=5, is_tangent_atol_multiplier=5, - test_forward_diff=false, - test_reverse_diff=false, test_project_tangent=true, test_inplace=true, ) diff --git a/test/manifolds/symplectic.jl b/test/manifolds/symplectic.jl index 56b1f4cd97..a1d9dd9743 100644 --- a/test/manifolds/symplectic.jl +++ b/test/manifolds/symplectic.jl @@ -74,7 +74,6 @@ include("../utils.jl") @test repr(Sp_2) == "Symplectic{$(2), ℝ}()" @test representation_size(Sp_2) == (2, 2) @test base_manifold(Sp_2) === Sp_2 - @test (@inferred Manifolds.default_metric_dispatch(Metr_Sp_2)) === Val(true) @test is_point(Sp_2, p_2) @test_throws DomainError is_point(Sp_2, p_2 + I, true) @@ -140,7 +139,7 @@ include("../utils.jl") @test isapprox( distance(Sp_2, p_2, q_2), approximate_p_q_geodesic_distance; - atol=1.0e-16, + atol=1e-14, ) # Project tangent vector into (T_pSp)^{\perp}: @@ -200,8 +199,6 @@ include("../utils.jl") is_point_atol_multiplier=1.0e8, is_tangent_atol_multiplier=1.0e6, retraction_atol_multiplier=1.0e4, - test_reverse_diff=false, - test_forward_diff=false, test_project_tangent=true, test_injectivity_radius=false, test_exp_log=false, @@ -222,8 +219,6 @@ include("../utils.jl") is_point_atol_multiplier=1.0e8, is_tangent_atol_multiplier=1.0e6, retraction_atol_multiplier=1.0e4, - test_reverse_diff=false, - test_forward_diff=false, test_project_tangent=true, test_injectivity_radius=false, test_exp_log=false, @@ -244,8 +239,6 @@ include("../utils.jl") is_point_atol_multiplier=1.0e7, is_tangent_atol_multiplier=1.0e6, retraction_atol_multiplier=1.0e4, - test_reverse_diff=false, - test_forward_diff=false, test_project_tangent=false, # Cannot solve 'sylvester' for MMatrix-type. test_injectivity_radius=false, test_exp_log=false, @@ -260,32 +253,32 @@ include("../utils.jl") analytical_grad_f(p) = (1 / 2) * (p * Q_grad * p * Q_grad + p * p') p_grad = convert(Array{Float64}, points[1]) - ad_diff = RiemannianProjectionBackend(Manifolds.ForwardDiffBackend()) + fd_diff = RiemannianProjectionBackend(Manifolds.FiniteDifferencesBackend()) @test isapprox( - Manifolds.gradient(Sp_6, test_f, p_grad, ad_diff), + Manifolds.gradient(Sp_6, test_f, p_grad, fd_diff), analytical_grad_f(p_grad); - atol=1.0e-16, + atol=1.0e-9, ) @test isapprox( - Manifolds.gradient(Sp_6, test_f, p_grad, ad_diff; extended_metric=false), + Manifolds.gradient(Sp_6, test_f, p_grad, fd_diff; extended_metric=false), analytical_grad_f(p_grad); - atol=1.0e-12, + atol=1.0e-9, ) grad_f_p = similar(p_grad) - Manifolds.gradient!(Sp_6, test_f, grad_f_p, p_grad, ad_diff) - @test isapprox(grad_f_p, analytical_grad_f(p_grad); atol=1.0e-16) + Manifolds.gradient!(Sp_6, test_f, grad_f_p, p_grad, fd_diff) + @test isapprox(grad_f_p, analytical_grad_f(p_grad); atol=1.0e-9) Manifolds.gradient!( Sp_6, test_f, grad_f_p, p_grad, - ad_diff; + fd_diff; extended_metric=false, ) - @test isapprox(grad_f_p, analytical_grad_f(p_grad); atol=1.0e-12) + @test isapprox(grad_f_p, analytical_grad_f(p_grad); atol=1.0e-9) end end diff --git a/test/manifolds/symplecticstiefel.jl b/test/manifolds/symplecticstiefel.jl index b07663a399..9b4b5bb922 100644 --- a/test/manifolds/symplecticstiefel.jl +++ b/test/manifolds/symplecticstiefel.jl @@ -246,8 +246,6 @@ end is_point_atol_multiplier=1.0e4, is_tangent_atol_multiplier=1.0e3, retraction_atol_multiplier=1.0e1, - test_reverse_diff=false, - test_forward_diff=false, test_project_tangent=(type != MMatrix{6,4,Float64,24}), test_injectivity_radius=false, test_exp_log=false, @@ -266,8 +264,6 @@ end is_point_atol_multiplier=1.0e11, is_tangent_atol_multiplier=1.0e2, retraction_atol_multiplier=1.0e4, - test_reverse_diff=false, - test_forward_diff=false, test_project_tangent=(type != MMatrix{6,4,Float64,24}), test_injectivity_radius=false, test_exp_log=false, @@ -290,17 +286,17 @@ end return Q_grad * p * (euc_grad_f') * Q_grad * p + euc_grad_f * p' * p end p_grad = convert(Array{Float64}, points[1]) - ad_diff = RiemannianProjectionBackend(Manifolds.ForwardDiffBackend()) + fd_diff = RiemannianProjectionBackend(Manifolds.FiniteDifferencesBackend()) @test isapprox( - Manifolds.gradient(SpSt_6_4, test_f, p_grad, ad_diff), + Manifolds.gradient(SpSt_6_4, test_f, p_grad, fd_diff), analytical_grad_f(p_grad); - atol=1.0e-16, + atol=1.0e-9, ) grad_f_p = similar(p_grad) - Manifolds.gradient!(SpSt_6_4, test_f, grad_f_p, p_grad, ad_diff) - @test isapprox(grad_f_p, analytical_grad_f(p_grad); atol=1.0e-16) + Manifolds.gradient!(SpSt_6_4, test_f, grad_f_p, p_grad, fd_diff) + @test isapprox(grad_f_p, analytical_grad_f(p_grad); atol=1.0e-9) end end end diff --git a/test/manifolds/torus.jl b/test/manifolds/torus.jl index 0158320ffd..6b9bc0ac66 100644 --- a/test/manifolds/torus.jl +++ b/test/manifolds/torus.jl @@ -14,6 +14,7 @@ include("../utils.jl") @test_throws DomainError is_point(M, 9.0, true) @test !is_point(M, [9.0; 9.0]) @test_throws CompositeManifoldError is_point(M, [9.0 9.0], true) + @test_throws CompositeManifoldError is_point(M, [9.0, 9.0], true) @test !is_vector(M, [9.0; 9.0], 0.0) @test_throws DomainError is_vector(M, 9.0, 0.0, true) # point false and checked @test !is_vector(M, [9.0; 9.0], [0.0; 0.0]) @@ -26,8 +27,6 @@ include("../utils.jl") test_manifold( M, [x, y, z], - test_forward_diff=false, - test_reverse_diff=false, test_vector_spaces=true, test_project_tangent=false, test_musical_isomorphisms=true, diff --git a/test/manifolds/tucker.jl b/test/manifolds/tucker.jl index 3d8d8308a7..926d97ada5 100644 --- a/test/manifolds/tucker.jl +++ b/test/manifolds/tucker.jl @@ -172,8 +172,6 @@ include("../utils.jl") test_is_tangent=false, test_project_tangent=false, test_default_vector_transport=false, - test_forward_diff=false, - test_reverse_diff=false, test_vector_spaces=false, test_vee_hat=false, test_tangent_vector_broadcasting=true, diff --git a/test/manifolds/vector_bundle.jl b/test/manifolds/vector_bundle.jl index 95ca7d7c22..a19a90f891 100644 --- a/test/manifolds/vector_bundle.jl +++ b/test/manifolds/vector_bundle.jl @@ -29,7 +29,7 @@ struct TestVectorSpaceType <: VectorSpaceType end @test sprint(show, TB) == "TangentBundle(Sphere(2, ℝ))" @test base_manifold(TB) == M @test manifold_dimension(TB) == 2 * manifold_dimension(M) - @test representation_size(TB) == (6,) + @test representation_size(TB) === nothing CTB = CotangentBundle(M) @test sprint(show, CTB) == "CotangentBundle(Sphere(2, ℝ))" @test sprint(show, VectorBundle(TestVectorSpaceType(), M)) == @@ -58,8 +58,6 @@ struct TestVectorSpaceType <: VectorSpaceType end TB, pts_tb, test_injectivity_radius=false, - test_reverse_diff=isa(T, Vector), - test_forward_diff=isa(T, Vector), test_tangent_vector_broadcasting=false, test_vee_hat=true, test_project_tangent=true, @@ -69,6 +67,7 @@ struct TestVectorSpaceType <: VectorSpaceType end basis_types_vecs=basis_types, projection_atol_multiplier=4, test_inplace=true, + test_representation_size=false, test_rand_point=true, test_rand_tvector=true, ) @@ -87,8 +86,6 @@ struct TestVectorSpaceType <: VectorSpaceType end TpM, pts_TpM, test_injectivity_radius=true, - test_reverse_diff=isa(T, Vector), - test_forward_diff=isa(T, Vector), test_tangent_vector_broadcasting=true, test_vee_hat=false, test_project_tangent=true, @@ -158,14 +155,6 @@ struct TestVectorSpaceType <: VectorSpaceType end @test_throws ErrorException Manifolds.project!(vbf, [1, 2, 3], [1, 2, 3], [1, 2, 3]) @test_throws ErrorException zero_vector!(vbf, [1, 2, 3], [1, 2, 3]) @test_throws MethodError vector_space_dimension(vbf) - a = fill(0.0, 6) - @test_throws ErrorException get_coordinates!( - TangentBundle(M), - a, - ProductRepr([1.0, 0.0, 0.0], [0.0, 0.0, 0.0]), - ProductRepr([1.0, 0.0, 0.0], [0.0, 0.0, 1.0]), - CachedBasis(DefaultOrthonormalBasis(), []), - ) end @testset "log and exp on tangent bundle for power and product manifolds" begin @@ -186,7 +175,7 @@ struct TestVectorSpaceType <: VectorSpaceType end ) @test isapprox(N2, p2_2, exp(N2, p1_2, log(N2, p1_2, p2_2))) - ppt = PowerVectorTransport(ParallelTransport()) + ppt = ParallelTransport() tbvt = Manifolds.VectorBundleVectorTransport(ppt, ppt) @test TangentBundle(M, tbvt).vector_transport === tbvt @test CotangentBundle(M, tbvt).vector_transport === tbvt diff --git a/test/metric.jl b/test/metric.jl index 86bce8d56d..dea4c2e77d 100644 --- a/test/metric.jl +++ b/test/metric.jl @@ -1,14 +1,24 @@ -using FiniteDifferences, ForwardDiff +using FiniteDifferences using LinearAlgebra: I using StatsBase: AbstractWeights, pweights -import Manifolds: mean!, median!, InducedBasis, induced_basis, get_chart_index, connection - +using ManifoldsBase: TraitList +import ManifoldsBase: default_retraction_method +import Manifolds: solve_exp_ode +using Manifolds: + FiniteDifferencesBackend, + InducedBasis, + connection, + get_chart_index, + induced_basis, + mean!, + median! include("utils.jl") struct TestEuclidean{N} <: AbstractManifold{ℝ} end struct TestEuclideanMetric <: AbstractMetric end struct TestScaledEuclideanMetric <: AbstractMetric end struct TestRetraction <: AbstractRetractionMethod end +struct TestConnection <: AbstractAffineConnection end ManifoldsBase.default_retraction_method(::TestEuclidean) = TestRetraction() function ManifoldsBase.default_retraction_method( @@ -19,6 +29,7 @@ end Manifolds.manifold_dimension(::TestEuclidean{N}) where {N} = N function Manifolds.local_metric( + ::TraitList{<:IsMetricManifold}, M::MetricManifold{ℝ,<:TestEuclidean,<:TestEuclideanMetric}, ::Any, ::InducedBasis, @@ -26,6 +37,7 @@ function Manifolds.local_metric( return Diagonal(1.0:manifold_dimension(M)) end function Manifolds.local_metric( + ::TraitList{IsMetricManifold}, M::MetricManifold{ℝ,<:TestEuclidean,<:TestEuclideanMetric}, ::Any, ::T, @@ -33,48 +45,64 @@ function Manifolds.local_metric( return Diagonal(1.0:manifold_dimension(M)) end function Manifolds.local_metric( + ::TraitList{IsMetricManifold}, M::MetricManifold{ℝ,<:TestEuclidean,<:TestScaledEuclideanMetric}, ::Any, ::T, ) where {T<:ManifoldsBase.AbstractOrthogonalBasis} return 2 .* Diagonal(1.0:manifold_dimension(M)) end -function Manifolds.get_coordinates!( +function Manifolds.get_coordinates_orthogonal( + M::MetricManifold{ℝ,<:TestEuclidean,<:TestEuclideanMetric}, + ::Any, + X, + ::ManifoldsBase.AbstractNumbers, +) + return 1 ./ [1.0:manifold_dimension(M)...] .* X +end +function Manifolds.get_coordinates_orthogonal!( M::MetricManifold{ℝ,<:TestEuclidean,<:TestEuclideanMetric}, c, ::Any, X, - ::DefaultOrthogonalBasis{ℝ,<:ManifoldsBase.TangentSpaceType}, + ::ManifoldsBase.AbstractNumbers, ) c .= 1 ./ [1.0:manifold_dimension(M)...] .* X return c end -function Manifolds.get_vector!( +function Manifolds.get_vector_orthonormal!( M::MetricManifold{ℝ,<:TestEuclidean,<:TestEuclideanMetric}, X, ::Any, c, - ::DefaultOrthogonalBasis{ℝ,<:ManifoldsBase.TangentSpaceType}, + ::ManifoldsBase.AbstractNumbers, ) X .= [1.0:manifold_dimension(M)...] .* c return X end -function Manifolds.get_coordinates!( +function Manifolds.get_coordinates_orthogonal!( M::MetricManifold{ℝ,<:TestEuclidean,<:TestScaledEuclideanMetric}, c, ::Any, X, - ::DefaultOrthogonalBasis, ) c .= 1 ./ (2 .* [1.0:manifold_dimension(M)...]) .* X return c end -function Manifolds.get_vector!( +function Manifolds.get_vector_orthogonal!( + M::MetricManifold{ℝ,<:TestEuclidean,<:TestScaledEuclideanMetric}, + ::Any, + c, + ::ManifoldsBase.AbstractNumbers, +) + return 2 .* [1.0:manifold_dimension(M)...] .* c +end +function Manifolds.get_vector_orthogonal!( M::MetricManifold{ℝ,<:TestEuclidean,<:TestScaledEuclideanMetric}, X, ::Any, c, - ::DefaultOrthogonalBasis, + ::ManifoldsBase.AbstractNumbers, ) X .= 2 .* [1.0:manifold_dimension(M)...] .* c return X @@ -113,10 +141,11 @@ Manifolds.project!(::BaseManifold, q, p) = (q .= p) Manifolds.injectivity_radius(::BaseManifold) = Inf Manifolds.injectivity_radius(::BaseManifold, ::Any) = Inf Manifolds.injectivity_radius(::BaseManifold, ::AbstractRetractionMethod) = Inf -Manifolds.injectivity_radius(::BaseManifold, ::ExponentialRetraction) = Inf +Manifolds._injectivity_radius(::BaseManifold, ::ExponentialRetraction) = Inf Manifolds.injectivity_radius(::BaseManifold, ::Any, ::AbstractRetractionMethod) = Inf -Manifolds.injectivity_radius(::BaseManifold, ::Any, ::ExponentialRetraction) = Inf +Manifolds._injectivity_radius(::BaseManifold, ::Any, ::ExponentialRetraction) = Inf function Manifolds.local_metric( + ::TraitList{<:IsMetricManifold}, ::MetricManifold{ℝ,BaseManifold{N},BaseManifoldMetric{N}}, p, ::InducedBasis, @@ -131,7 +160,7 @@ function Manifolds.exp!( ) where {N} return exp!(base_manifold(M), q, p, X) end -function Manifolds.vector_transport_to!(::BaseManifold, Y, p, X, q, ::ParallelTransport) +function Manifolds.parallel_transport_to!(::BaseManifold, Y, p, X, q) return (Y .= X) end function Manifolds.get_basis( @@ -141,61 +170,31 @@ function Manifolds.get_basis( ) where {N} return CachedBasis(B, [(Matrix{eltype(p)}(I, N, N)[:, i]) for i in 1:N]) end -function Manifolds.get_coordinates!( +function Manifolds.get_coordinates_orthonormal!( ::BaseManifold, Y, p, X, - ::DefaultOrthonormalBasis{<:Any,ManifoldsBase.TangentSpaceType}, + ::ManifoldsBase.AbstractNumbers, ) return Y .= X end -function Manifolds.get_vector!( +function Manifolds.get_vector_orthonormal!( ::BaseManifold, Y, p, X, - ::DefaultOrthonormalBasis{<:Any,ManifoldsBase.TangentSpaceType}, + ::ManifoldsBase.AbstractNumbers, ) return Y .= X end -Manifolds.default_metric_dispatch(::BaseManifold, ::DefaultBaseManifoldMetric) = Val(true) +Manifolds.is_default_metric(::BaseManifold, ::DefaultBaseManifoldMetric) = true function Manifolds.projected_distribution(M::BaseManifold, d) return ProjectedPointDistribution(M, d, project!, rand(d)) end function Manifolds.projected_distribution(M::BaseManifold, d, p) return ProjectedPointDistribution(M, d, project!, p) end -function Manifolds.mean!(::BaseManifold, y, x::AbstractVector, w::AbstractVector; kwargs...) - return fill!(y, 1) -end -function Manifolds.median!( - ::BaseManifold, - y, - x::AbstractVector, - w::AbstractVector; - kwargs..., -) - return fill!(y, 2) -end -function Manifolds.mean!( - ::MetricManifold{ℝ,BaseManifold{N},BaseManifoldMetric{N}}, - y, - x::AbstractVector, - w::AbstractVector; - kwargs..., -) where {N} - return fill!(y, 3) -end -function Manifolds.median!( - ::MetricManifold{ℝ,BaseManifold{N},BaseManifoldMetric{N}}, - y, - x::AbstractVector, - w::AbstractVector; - kwargs..., -) where {N} - return fill!(y, 4) -end function Manifolds.flat!( ::BaseManifold, @@ -215,25 +214,65 @@ function Manifolds.sharp!( v.data .= w.data ./ 2 return v end - +function solve_exp_ode( + ::ConnectionManifold{ℝ,TestEuclidean{N},TestConnection}, + p, + X; + kwargs..., +) where {N} + return X +end +function Manifolds.vector_transport_along!( + M::BaseManifold, + Y, + p, + X, + c::AbstractVector, + m::AbstractVectorTransportMethod=default_vector_transport_method(M), +) + Y .= c + return Y +end @testset "Metrics" begin # some tests failed due to insufficient accuracy for a particularly bad RNG state Random.seed!(42) @testset "Metric Basics" begin - #one for MetricManifold, one for AbstractManifold & Metric - @test length(methods(is_default_metric)) == 2 + @test repr(MetricManifold(Euclidean(3), EuclideanMetric())) === + "MetricManifold(Euclidean(3; field = ℝ), EuclideanMetric())" + @test repr(IsDefaultMetric(EuclideanMetric())) === + "IsDefaultMetric(EuclideanMetric())" + end + @testset "Connection Trait" begin + M = ConnectionManifold(Euclidean(3), LeviCivitaConnection()) + @test is_default_connection(M) + @test decorated_manifold(M) == Euclidean(3) + @test is_default_connection(Euclidean(3), LeviCivitaConnection()) + @test !is_default_connection(TestEuclidean{3}(), LeviCivitaConnection()) + c = IsDefaultConnection(LeviCivitaConnection()) + @test ManifoldsBase.parent_trait(c) == Manifolds.IsConnectionManifold() end @testset "solve_exp_ode error message" begin E = TestEuclidean{3}() g = TestEuclideanMetric() M = MetricManifold(E, g) - + default_retraction_method(::TestEuclidean) = TestRetraction() p = [1.0, 2.0, 3.0] X = [2.0, 3.0, 4.0] + q = similar(X) @test_throws ErrorException exp(M, p, X) + @test_throws ErrorException exp!(M, q, p, X) + + N = ConnectionManifold(E, LeviCivitaConnection()) + @test_throws ErrorException exp(N, p, X) + @test_throws ErrorException exp!(N, q, p, X) + using OrdinaryDiffEq - exp(M, p, X) + @test is_point(M, exp(M, p, X)) + + # a small trick to check that retract_exp_ode! returns the right value on ConnectionManifolds + N2 = ConnectionManifold(E, TestConnection()) + @test exp(N2, p, X) == X end @testset "Local Metric Error message" begin M = MetricManifold(BaseManifold{2}(), NotImplementedMetric()) @@ -306,17 +345,17 @@ end @test gaussian_curvature(M, p, B_chart_p; backend=fdm) β‰ˆ 0 atol = 1e-6 @test einstein_tensor(M, p, B_chart_p; backend=fdm) β‰ˆ zeros(n, n) atol = 1e-6 - fwd_diff = Manifolds.ForwardDiffBackend() - @test christoffel_symbols_first(M, p, B_chart_p; backend=fwd_diff) β‰ˆ + fd_diff = Manifolds.FiniteDifferencesBackend() + @test christoffel_symbols_first(M, p, B_chart_p; backend=fd_diff) β‰ˆ zeros(n, n, n) atol = 1e-6 - @test christoffel_symbols_second(M, p, B_chart_p; backend=fwd_diff) β‰ˆ + @test christoffel_symbols_second(M, p, B_chart_p; backend=fd_diff) β‰ˆ zeros(n, n, n) atol = 1e-6 - @test riemann_tensor(M, p, B_chart_p; backend=fwd_diff) β‰ˆ zeros(n, n, n, n) atol = + @test riemann_tensor(M, p, B_chart_p; backend=fd_diff) β‰ˆ zeros(n, n, n, n) atol = 1e-6 - @test ricci_tensor(M, p, B_chart_p; backend=fwd_diff) β‰ˆ zeros(n, n) atol = 1e-6 - @test ricci_curvature(M, p, B_chart_p; backend=fwd_diff) β‰ˆ 0 atol = 1e-6 - @test gaussian_curvature(M, p, B_chart_p; backend=fwd_diff) β‰ˆ 0 atol = 1e-6 - @test einstein_tensor(M, p, B_chart_p; backend=fwd_diff) β‰ˆ zeros(n, n) atol = + @test ricci_tensor(M, p, B_chart_p; backend=fd_diff) β‰ˆ zeros(n, n) atol = 1e-6 + @test ricci_curvature(M, p, B_chart_p; backend=fd_diff) β‰ˆ 0 atol = 1e-6 + @test gaussian_curvature(M, p, B_chart_p; backend=fd_diff) β‰ˆ 0 atol = 1e-6 + @test einstein_tensor(M, p, B_chart_p; backend=fd_diff) β‰ˆ zeros(n, n) atol = 1e-6 end end @@ -419,18 +458,9 @@ end MM2 = MetricManifold(M, g2) A = Manifolds.get_default_atlas(M) - @test (@inferred Manifolds.default_metric_dispatch(MM)) === - (@inferred Manifolds.default_metric_dispatch(base_manifold(MM), metric(MM))) - @test (@inferred Manifolds.default_metric_dispatch(MM2)) === - (@inferred Manifolds.default_metric_dispatch(base_manifold(MM2), metric(MM2))) - @test (@inferred Manifolds.default_metric_dispatch(MM2)) === Val(true) @test is_default_metric(MM) == is_default_metric(base_manifold(MM), metric(MM)) @test is_default_metric(MM2) == is_default_metric(base_manifold(MM2), metric(MM2)) @test is_default_metric(MM2) - @test Manifolds.default_decorator_dispatch(MM) === - Manifolds.default_metric_dispatch(MM) - @test Manifolds.default_decorator_dispatch(MM2) === - Manifolds.default_metric_dispatch(MM2) @test convert(typeof(MM2), M) == MM2 @test_throws ErrorException convert(typeof(MM), M) @@ -464,15 +494,14 @@ end @test exp!(MM, q, p, X) === exp!(M, q, p, X) @test retract!(MM, q, p, X) === retract!(M, q, p, X) @test retract!(MM, q, p, X, 1) === retract!(M, q, p, X, 1) + @test project!(MM, Y, p, X) === project!(M, Y, p, X) + @test project!(MM, q, p) === project!(M, q, p) # without a definition for the metric from the embedding, no projection possible - @test_throws ErrorException log!(MM, Y, p, q) === project!(M, Y, p, q) - @test_throws ErrorException project!(MM, Y, p, X) === project!(M, Y, p, X) - @test_throws ErrorException project!(MM, q, p) === project!(M, q, p) - @test_throws ErrorException vector_transport_to!(MM, Y, p, X, q) === - vector_transport_to!(M, Y, p, X, q) + @test_throws MethodError log!(MM, Y, p, q) === project!(M, Y, p, q) + @test_throws MethodError vector_transport_to!(MM, Y, p, X, q) === + vector_transport_to!(M, Y, p, X, q) # without DiffEq, these error - # @test_throws ErrorException exp(MM,x, X, 1:3) - # @test_throws ErrorException exp!(MM, q, p, X) + @test_throws MethodError exp(MM, p, X, 1:3) # these always fall back anyways. @test zero_vector!(MM, X, p) === zero_vector!(M, X, p) @@ -509,7 +538,17 @@ end @test project!(MM2, q, p) === project!(M, q, p) @test project!(MM2, Y, p, X) === project!(M, Y, p, X) + @test parallel_transport_to(MM2, p, X, q) == parallel_transport_to(M, q, X, p) + @test parallel_transport_to!(MM2, Y, p, X, q) == + parallel_transport_to!(M, Y, q, X, p) + @test project!(MM2, Y, p, X) === project!(M, Y, p, X) @test vector_transport_to!(MM2, Y, p, X, q) == vector_transport_to!(M, Y, p, X, q) + c = 2 * ones(3) + m = ParallelTransport() + @test vector_transport_along(MM2, p, X, c, m) == + vector_transport_along(M, p, X, c, m) + @test vector_transport_along!(MM2, Y, p, X, c, m) == + vector_transport_along!(M, Y, p, X, c, m) @test zero_vector!(MM2, X, p) === zero_vector!(M, X, p) @test injectivity_radius(MM2, p) === injectivity_radius(M, p) @test injectivity_radius(MM2) === injectivity_radius(M) @@ -533,7 +572,7 @@ end @test isapprox(a.distribution.ΞΌ, b.distribution.ΞΌ) @test get_basis(M, p, DefaultOrthonormalBasis()).data == get_basis(MM2, p, DefaultOrthonormalBasis()).data - @test_throws ErrorException get_basis(MM, p, DefaultOrthonormalBasis()) + @test_throws MethodError get_basis(MM, p, DefaultOrthonormalBasis()) fX = ManifoldsBase.TFVector(X, B_p) fY = ManifoldsBase.TFVector(Y, B_p) @@ -573,73 +612,17 @@ end fX2 = allocate(fX) sharp!(MM, fX2, p, cofX2) @test isapprox(fX2.data, fX.data) - end - psample = [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]] - Y = pweights([0.5, 0.5]) - # test despatch with results from above - @test mean(M, psample, Y) β‰ˆ ones(3) - @test mean(MM2, psample, Y) β‰ˆ ones(3) - @test mean(MM, psample, Y) β‰ˆ 3 .* ones(3) + cofX3a = flat(MM2, p, fX) + cofX3b = allocate(cofX3a) + flat!(MM2, cofX3b, p, fX) + @test isapprox(cofX3a.data, cofX3b.data) - @test median(M, psample, Y) β‰ˆ 2 .* ones(3) - @test median(MM2, psample, Y) β‰ˆ 2 * ones(3) - @test median(MM, psample, Y) β‰ˆ 4 .* ones(3) - end - - @testset "Metric decorator dispatches" begin - M = BaseManifold{3}() - g = BaseManifoldMetric{3}() - MM = MetricManifold(M, g) - x = [1, 2, 3] - # nonmutating always go to parent for allocation - for f in [exp, flat, inverse_retract, log, mean, median, project] - @test Manifolds.decorator_transparent_dispatch(f, MM) === Val{:parent}() - end - for f in [sharp, retract, get_vector, get_coordinates] - @test Manifolds.decorator_transparent_dispatch(f, MM) === Val{:parent}() - end - for f in [vector_transport_along, vector_transport_direction, vector_transport_to] - @test Manifolds.decorator_transparent_dispatch(f, MM) === Val{:parent}() - end - for f in [get_basis, inner] - @test Manifolds.decorator_transparent_dispatch(f, MM) === Val{:intransparent}() - end - for f in [get_coordinates!, get_vector!] - @test Manifolds.decorator_transparent_dispatch(f, MM) === Val{:intransparent}() - end - - # mirroring ones are mostly intransparent despite for a few cases - e.g. dispatch/default last variables - for f in [exp!, flat!, inverse_retract!, log!, mean!, median!] - @test Manifolds.decorator_transparent_dispatch(f, MM) === Val{:intransparent}() - end - for f in [norm, project!, sharp!, retract!] - @test Manifolds.decorator_transparent_dispatch(f, MM) === Val{:intransparent}() - end - for f in [vector_transport_along!, vector_transport_to!] - @test Manifolds.decorator_transparent_dispatch(f, MM) === Val{:intransparent}() + fX3a = sharp(MM2, p, cofX) + fX3b = allocate(fX3a) + sharp!(MM2, fX3b, p, cofX) + @test isapprox(fX3a.data, fX3b.data) end - @test Manifolds.decorator_transparent_dispatch(vector_transport_direction!, MM) === - Val{:parent}() - - @test Manifolds.decorator_transparent_dispatch(exp!, MM, x, x, x, x) === - Val{:parent}() - @test Manifolds.decorator_transparent_dispatch( - inverse_retract!, - MM, - x, - x, - x, - LogarithmicInverseRetraction(), - ) === Val{:parent}() - @test Manifolds.decorator_transparent_dispatch( - retract!, - MM, - x, - x, - x, - ExponentialRetraction(), - ) === Val{:parent}() end @testset "change metric and representer" begin diff --git a/test/runtests.jl b/test/runtests.jl index 680510c669..46c4d02941 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,176 +1,188 @@ include("utils.jl") + +TEST_GROUP = get(ENV, "MANIFOLDS_TEST_GROUP", "all") + @info "Manifolds.jl Test settings:\n\n" * "Testing Float32: $(TEST_FLOAT32)\n" * "Testing Double64: $(TEST_DOUBLE64)\n" * "Testing Static: $(TEST_STATIC_SIZED)\n\n" * + "Test group: $(TEST_GROUP)\n\n" * "Check test/utils.jl if you wish to change these settings." @testset "Manifolds.jl" begin - include_test("differentiation.jl") - - include_test("ambiguities.jl") - - @testset "utils test" begin - Random.seed!(42) - @testset "usinc_from_cos" begin - @test Manifolds.usinc_from_cos(-1) == 0 - @test Manifolds.usinc_from_cos(-1.0) == 0.0 - end - @testset "log_safe!" begin - n = 8 - Q = qr(randn(n, n)).Q - A1 = Matrix(Hermitian(Q * Diagonal(rand(n)) * Q')) - @test exp(Manifolds.log_safe!(similar(A1), A1)) β‰ˆ A1 atol = 1e-6 - A1_fail = Matrix(Hermitian(Q * Diagonal([-1; rand(n - 1)]) * Q')) - @test_throws DomainError Manifolds.log_safe!(similar(A1_fail), A1_fail) - - T = triu!(randn(n, n)) - T[diagind(T)] .= rand.() - @test exp(Manifolds.log_safe!(similar(T), T)) β‰ˆ T atol = 1e-6 - T_fail = copy(T) - T_fail[1] = -1 - @test_throws DomainError Manifolds.log_safe!(similar(T_fail), T_fail) - - A2 = Q * T * Q' - @test exp(Manifolds.log_safe!(similar(A2), A2)) β‰ˆ A2 atol = 1e-6 - A2_fail = Q * T_fail * Q' - @test_throws DomainError Manifolds.log_safe!(similar(A2_fail), A2_fail) - - A3 = exp(SizedMatrix{n,n}(randn(n, n))) - @test A3 isa SizedMatrix - @test exp(Manifolds.log_safe!(similar(A3), A3)) β‰ˆ A3 atol = 1e-6 - @test exp(Manifolds.log_safe(A3)) β‰ˆ A3 atol = 1e-6 - - A3_fail = Float64[1 2; 3 1] - @test_throws DomainError Manifolds.log_safe!(similar(A3_fail), A3_fail) - - A4 = randn(ComplexF64, n, n) - @test exp(Manifolds.log_safe!(similar(A4), A4)) β‰ˆ A4 atol = 1e-6 - end - @testset "isnormal" begin - @test !Manifolds.isnormal([1.0 2.0; 3.0 4.0]) - @test !Manifolds.isnormal(complex.(reshape(1:4, 2, 2), reshape(5:8, 2, 2))) - - # diagonal - @test Manifolds.isnormal(diagm(randn(5))) - @test Manifolds.isnormal(diagm(randn(ComplexF64, 5))) - @test Manifolds.isnormal(Diagonal(randn(5))) - @test Manifolds.isnormal(Diagonal(randn(ComplexF64, 5))) - - # symmetric/hermitian - @test Manifolds.isnormal(Symmetric(randn(3, 3))) - @test Manifolds.isnormal(Hermitian(randn(3, 3))) - @test Manifolds.isnormal(Hermitian(randn(ComplexF64, 3, 3))) - x = Matrix(Symmetric(randn(3, 3))) - x[3, 1] += eps() - @test !Manifolds.isnormal(x) - @test Manifolds.isnormal(x; atol=sqrt(eps())) - - # skew-symmetric/skew-hermitian - skew(x) = x - x' - @test Manifolds.isnormal(skew(randn(3, 3))) - @test Manifolds.isnormal(skew(randn(ComplexF64, 3, 3))) - - # orthogonal/unitary - @test Manifolds.isnormal(Matrix(qr(randn(3, 3)).Q); atol=sqrt(eps())) - @test Manifolds.isnormal( - Matrix(qr(randn(ComplexF64, 3, 3)).Q); - atol=sqrt(eps()), - ) - end - @testset "realify/unrealify!" begin - # round trip real - x = randn(3, 3) - @test Manifolds.realify(x, ℝ) === x - @test Manifolds.unrealify!(similar(x), x, ℝ) == x - - # round trip complex - x2 = randn(ComplexF64, 3, 3) - x2r = Manifolds.realify(x2, β„‚) - @test eltype(x2r) <: Real - @test size(x2r) == (6, 6) - x2c = Manifolds.unrealify!(similar(x2), x2r, β„‚) - @test x2c β‰ˆ x2 - - # matrix multiplication is preserved - x3 = randn(ComplexF64, 3, 3) - x3r = Manifolds.realify(x3, β„‚) - @test x2 * x3 β‰ˆ Manifolds.unrealify!(similar(x2), x2r * x3r, β„‚) - end - @testset "allocation" begin - @test allocate([1 2; 3 4], Float64, Size(3, 3)) isa Matrix{Float64} - @test allocate(SA[1 2; 3 4], Float64, Size(3, 3)) isa MMatrix{3,3,Float64} - @test allocate(SA[1 2; 3 4], Size(3, 3)) isa MMatrix{3,3,Int} - end - @testset "eigen_safe" begin - @test Manifolds.eigen_safe(SA[1.0 0.0; 0.0 1.0]) isa - Eigen{Float64,Float64,<:SizedMatrix{2,2},<:SizedVector{2}} + if TEST_GROUP ∈ ["all", "test_manifolds"] + include_test("differentiation.jl") + + include_test("ambiguities.jl") + + @testset "utils test" begin + Random.seed!(42) + @testset "usinc_from_cos" begin + @test Manifolds.usinc_from_cos(-1) == 0 + @test Manifolds.usinc_from_cos(-1.0) == 0.0 + end + @testset "log_safe!" begin + n = 8 + Q = qr(randn(n, n)).Q + A1 = Matrix(Hermitian(Q * Diagonal(rand(n)) * Q')) + @test exp(Manifolds.log_safe!(similar(A1), A1)) β‰ˆ A1 atol = 1e-6 + A1_fail = Matrix(Hermitian(Q * Diagonal([-1; rand(n - 1)]) * Q')) + @test_throws DomainError Manifolds.log_safe!(similar(A1_fail), A1_fail) + + T = triu!(randn(n, n)) + T[diagind(T)] .= rand.() + @test exp(Manifolds.log_safe!(similar(T), T)) β‰ˆ T atol = 1e-6 + T_fail = copy(T) + T_fail[1] = -1 + @test_throws DomainError Manifolds.log_safe!(similar(T_fail), T_fail) + + A2 = Q * T * Q' + @test exp(Manifolds.log_safe!(similar(A2), A2)) β‰ˆ A2 atol = 1e-6 + A2_fail = Q * T_fail * Q' + @test_throws DomainError Manifolds.log_safe!(similar(A2_fail), A2_fail) + + A3 = exp(SizedMatrix{n,n}(randn(n, n))) + @test A3 isa SizedMatrix + @test exp(Manifolds.log_safe!(similar(A3), A3)) β‰ˆ A3 atol = 1e-6 + @test exp(Manifolds.log_safe(A3)) β‰ˆ A3 atol = 1e-6 + + A3_fail = Float64[1 2; 3 1] + @test_throws DomainError Manifolds.log_safe!(similar(A3_fail), A3_fail) + + A4 = randn(ComplexF64, n, n) + @test exp(Manifolds.log_safe!(similar(A4), A4)) β‰ˆ A4 atol = 1e-6 + end + @testset "isnormal" begin + @test !Manifolds.isnormal([1.0 2.0; 3.0 4.0]) + @test !Manifolds.isnormal(complex.(reshape(1:4, 2, 2), reshape(5:8, 2, 2))) + + # diagonal + @test Manifolds.isnormal(diagm(randn(5))) + @test Manifolds.isnormal(diagm(randn(ComplexF64, 5))) + @test Manifolds.isnormal(Diagonal(randn(5))) + @test Manifolds.isnormal(Diagonal(randn(ComplexF64, 5))) + + # symmetric/hermitian + @test Manifolds.isnormal(Symmetric(randn(3, 3))) + @test Manifolds.isnormal(Hermitian(randn(3, 3))) + @test Manifolds.isnormal(Hermitian(randn(ComplexF64, 3, 3))) + x = Matrix(Symmetric(randn(3, 3))) + x[3, 1] += eps() + @test !Manifolds.isnormal(x) + @test Manifolds.isnormal(x; atol=sqrt(eps())) + + # skew-symmetric/skew-hermitian + skew(x) = x - x' + @test Manifolds.isnormal(skew(randn(3, 3))) + @test Manifolds.isnormal(skew(randn(ComplexF64, 3, 3))) + + # orthogonal/unitary + @test Manifolds.isnormal(Matrix(qr(randn(3, 3)).Q); atol=sqrt(eps())) + @test Manifolds.isnormal( + Matrix(qr(randn(ComplexF64, 3, 3)).Q); + atol=sqrt(eps()), + ) + end + @testset "realify/unrealify!" begin + # round trip real + x = randn(3, 3) + @test Manifolds.realify(x, ℝ) === x + @test Manifolds.unrealify!(similar(x), x, ℝ) == x + + # round trip complex + x2 = randn(ComplexF64, 3, 3) + x2r = Manifolds.realify(x2, β„‚) + @test eltype(x2r) <: Real + @test size(x2r) == (6, 6) + x2c = Manifolds.unrealify!(similar(x2), x2r, β„‚) + @test x2c β‰ˆ x2 + + # matrix multiplication is preserved + x3 = randn(ComplexF64, 3, 3) + x3r = Manifolds.realify(x3, β„‚) + @test x2 * x3 β‰ˆ Manifolds.unrealify!(similar(x2), x2r * x3r, β„‚) + end + @testset "allocation" begin + @test allocate([1 2; 3 4], Float64, Size(3, 3)) isa Matrix{Float64} + @test allocate(SA[1 2; 3 4], Float64, Size(3, 3)) isa MMatrix{3,3,Float64} + @test allocate(SA[1 2; 3 4], Size(3, 3)) isa MMatrix{3,3,Int} + end + @testset "eigen_safe" begin + @test Manifolds.eigen_safe(SA[1.0 0.0; 0.0 1.0]) isa + Eigen{Float64,Float64,<:SizedMatrix{2,2},<:SizedVector{2}} + end end + + @test Manifolds.is_metric_function(flat) + @test Manifolds.is_metric_function(sharp) + + include_test("groups/group_utils.jl") + include_test("notation.jl") + # starting with tests of simple manifolds + include_test("manifolds/centered_matrices.jl") + include_test("manifolds/circle.jl") + include_test("manifolds/cholesky_space.jl") + include_test("manifolds/elliptope.jl") + include_test("manifolds/euclidean.jl") + include_test("manifolds/fixed_rank.jl") + include_test("manifolds/generalized_grassmann.jl") + include_test("manifolds/generalized_stiefel.jl") + include_test("manifolds/grassmann.jl") + include_test("manifolds/hyperbolic.jl") + include_test("manifolds/lorentz.jl") + include_test("manifolds/multinomial_doubly_stochastic.jl") + include_test("manifolds/multinomial_symmetric.jl") + include_test("manifolds/positive_numbers.jl") + include_test("manifolds/probability_simplex.jl") + include_test("manifolds/projective_space.jl") + include_test("manifolds/rotations.jl") + include_test("manifolds/skewhermitian.jl") + include_test("manifolds/spectrahedron.jl") + include_test("manifolds/sphere.jl") + include_test("manifolds/sphere_symmetric_matrices.jl") + include_test("manifolds/stiefel.jl") + include_test("manifolds/symmetric.jl") + include_test("manifolds/symmetric_positive_definite.jl") + include_test("manifolds/symmetric_positive_semidefinite_fixed_rank.jl") + include_test("manifolds/symplectic.jl") + include_test("manifolds/symplecticstiefel.jl") + include_test("manifolds/tucker.jl") + + include_test("manifolds/essential_manifold.jl") + include_test("manifolds/multinomial_matrices.jl") + include_test("manifolds/oblique.jl") + include_test("manifolds/torus.jl") + + #meta manifolds + include_test("manifolds/product_manifold.jl") + include_test("manifolds/power_manifold.jl") + include_test("manifolds/vector_bundle.jl") + include_test("manifolds/graph.jl") + + include_test("metric.jl") + include_test("statistics.jl") + include_test("approx_inverse_retraction.jl") end - include_test("groups/group_utils.jl") - include_test("notation.jl") - # starting with tests of simple manifolds - include_test("manifolds/centered_matrices.jl") - include_test("manifolds/circle.jl") - include_test("manifolds/cholesky_space.jl") - include_test("manifolds/elliptope.jl") - include_test("manifolds/euclidean.jl") - include_test("manifolds/fixed_rank.jl") - include_test("manifolds/generalized_grassmann.jl") - include_test("manifolds/generalized_stiefel.jl") - include_test("manifolds/grassmann.jl") - include_test("manifolds/hyperbolic.jl") - include_test("manifolds/lorentz.jl") - include_test("manifolds/multinomial_doubly_stochastic.jl") - include_test("manifolds/multinomial_symmetric.jl") - include_test("manifolds/positive_numbers.jl") - include_test("manifolds/probability_simplex.jl") - include_test("manifolds/projective_space.jl") - include_test("manifolds/rotations.jl") - include_test("manifolds/skewhermitian.jl") - include_test("manifolds/spectrahedron.jl") - include_test("manifolds/sphere.jl") - include_test("manifolds/sphere_symmetric_matrices.jl") - include_test("manifolds/stiefel.jl") - include_test("manifolds/symmetric.jl") - include_test("manifolds/symmetric_positive_definite.jl") - include_test("manifolds/symmetric_positive_semidefinite_fixed_rank.jl") - include_test("manifolds/symplectic.jl") - include_test("manifolds/symplecticstiefel.jl") - include_test("manifolds/tucker.jl") - - include_test("manifolds/essential_manifold.jl") - include_test("manifolds/multinomial_matrices.jl") - include_test("manifolds/oblique.jl") - include_test("manifolds/torus.jl") - - #meta manifolds - include_test("manifolds/product_manifold.jl") - include_test("manifolds/power_manifold.jl") - include_test("manifolds/vector_bundle.jl") - include_test("manifolds/graph.jl") - - include_test("metric.jl") - include_test("statistics.jl") - include_test("approx_inverse_retraction.jl") - - # Lie groups and actions - include_test("groups/groups_general.jl") - include_test("groups/validation_group.jl") - include_test("groups/circle_group.jl") - include_test("groups/translation_group.jl") - include_test("groups/general_linear.jl") - include_test("groups/special_linear.jl") - include_test("groups/special_orthogonal.jl") - include_test("groups/product_group.jl") - include_test("groups/semidirect_product_group.jl") - include_test("groups/special_euclidean.jl") - include_test("groups/group_operation_action.jl") - include_test("groups/rotation_action.jl") - include_test("groups/translation_action.jl") - include_test("groups/connections.jl") - include_test("groups/metric.jl") - - include_test("recipes.jl") + if TEST_GROUP ∈ ["test_lie_groups", "all"] + # Lie groups and actions + include_test("groups/groups_general.jl") + include_test("groups/validation_group.jl") + include_test("groups/circle_group.jl") + include_test("groups/translation_group.jl") + include_test("groups/general_linear.jl") + include_test("groups/special_linear.jl") + include_test("groups/special_orthogonal.jl") + include_test("groups/product_group.jl") + include_test("groups/semidirect_product_group.jl") + include_test("groups/special_euclidean.jl") + include_test("groups/group_operation_action.jl") + include_test("groups/rotation_action.jl") + include_test("groups/translation_action.jl") + include_test("groups/connections.jl") + include_test("groups/metric.jl") + end + if TEST_GROUP ∈ ["all", "test_manifolds"] + include_test("recipes.jl") + end end diff --git a/test/statistics.jl b/test/statistics.jl index e4d1bfc531..42d988f515 100644 --- a/test/statistics.jl +++ b/test/statistics.jl @@ -2,7 +2,15 @@ include("utils.jl") using StatsBase: AbstractWeights, pweights using Random: GLOBAL_RNG, seed! import ManifoldsBase: - manifold_dimension, exp!, log!, inner, zero_vector!, decorated_manifold, base_manifold + active_traits, + manifold_dimension, + exp!, + log!, + inner, + zero_vector!, + decorated_manifold, + base_manifold, + get_embedding using Manifolds: AbstractEstimationMethod, CyclicProximalPointEstimation, @@ -10,7 +18,7 @@ using Manifolds: GeodesicInterpolationWithinRadius, GradientDescentEstimation, WeiszfeldEstimation -import Manifolds: mean!, median!, var, mean_and_var +import Manifolds: mean, mean!, median, median!, var, mean_and_var, default_estimation_method struct TestStatsSphere{N} <: AbstractManifold{ℝ} end TestStatsSphere(N) = TestStatsSphere{N}() @@ -44,19 +52,28 @@ function zero_vector!(::TestStatsEuclidean{N}, v, x; kwargs...) where {N} return zero_vector!(Euclidean(N), v, x; kwargs...) end -struct TestStatsNotImplementedEmbeddedManifold <: - AbstractEmbeddedManifold{ℝ,TransparentIsometricEmbedding} end +struct TestStatsNotImplementedEmbeddedManifold <: AbstractDecoratorManifold{ℝ} end +function active_traits(f, ::TestStatsNotImplementedEmbeddedManifold, args...) + return merge_traits(IsEmbeddedSubmanifold()) +end decorated_manifold(::TestStatsNotImplementedEmbeddedManifold) = Sphere(2) +get_embedding(::TestStatsNotImplementedEmbeddedManifold) = Sphere(2) base_manifold(::TestStatsNotImplementedEmbeddedManifold) = Sphere(2) -struct TestStatsNotImplementedEmbeddedManifold2 <: - AbstractEmbeddedManifold{ℝ,DefaultIsometricEmbeddingType} end +struct TestStatsNotImplementedEmbeddedManifold2 <: AbstractDecoratorManifold{ℝ} end +function active_traits(f, ::TestStatsNotImplementedEmbeddedManifold2, args...) + return merge_traits(IsIsometricEmbeddedManifold()) +end decorated_manifold(::TestStatsNotImplementedEmbeddedManifold2) = Sphere(2) +get_embedding(::TestStatsNotImplementedEmbeddedManifold2) = Sphere(2) base_manifold(::TestStatsNotImplementedEmbeddedManifold2) = Sphere(2) -struct TestStatsNotImplementedEmbeddedManifold3 <: - AbstractEmbeddedManifold{ℝ,DefaultEmbeddingType} end +struct TestStatsNotImplementedEmbeddedManifold3 <: AbstractDecoratorManifold{ℝ} end +function active_traits(f, ::TestStatsNotImplementedEmbeddedManifold3, args...) + return merge_traits(IsEmbeddedManifold()) +end decorated_manifold(::TestStatsNotImplementedEmbeddedManifold3) = Sphere(2) +get_embedding(::TestStatsNotImplementedEmbeddedManifold3) = Sphere(2) base_manifold(::TestStatsNotImplementedEmbeddedManifold3) = Sphere(2) function test_mean(M, x, yexp=nothing, method...; kwargs...) @@ -86,6 +103,14 @@ function test_mean(M, x, yexp=nothing, method...; kwargs...) @test isapprox(M, mean_and_std(M, x, w; kwargs...)[1], y; atol=10^-7) end @test_throws DimensionMismatch mean(M, x, pweights(ones(n + 1)); kwargs...) + @test_throws DimensionMismatch mean!( + M, + y, + x, + pweights(ones(n + 1)), + Manifolds.default_estimation_method(M, mean); + kwargs..., + ) end return nothing end @@ -168,6 +193,13 @@ function test_var(M, x, vexp=nothing; kwargs...) var(M, x, w, m; kwargs...) * n / (n - 1) end @test_throws DimensionMismatch var(M, x, pweights(ones(n + 1)); kwargs...) + @test_throws DimensionMismatch mean_and_var( + M, + x, + pweights(ones(n + 1)), + GeodesicInterpolation(); + kwargs..., + ) end return nothing end @@ -263,8 +295,6 @@ function test_moments(M, x) end struct TestStatsOverload1 <: AbstractManifold{ℝ} end -struct TestStatsOverload2 <: AbstractManifold{ℝ} end -struct TestStatsOverload3 <: AbstractManifold{ℝ} end struct TestStatsMethod1 <: AbstractEstimationMethod end function mean!( @@ -276,38 +306,25 @@ function mean!( ) return fill!(y, 3) end -mean!(::TestStatsOverload2, y, ::AbstractVector, ::AbstractWeights) = fill!(y, 4) -function mean!( - ::TestStatsOverload2, - y, +function mean( + ::TestStatsOverload1, ::AbstractVector, ::AbstractWeights, ::GradientDescentEstimation, ) - return fill!(y, 3) -end -function mean!( - ::TestStatsOverload3, - y, - ::AbstractVector, - ::AbstractWeights, - ::TestStatsMethod1=TestStatsMethod1(), -) - return fill!(y, 5) + return fill(3, 1) end -function median!( +function median( ::TestStatsOverload1, - y, ::AbstractVector, ::AbstractWeights, ::CyclicProximalPointEstimation, ) - return fill!(y, 3) + return fill(3, 1) end -median!(::TestStatsOverload2, y, ::AbstractVector, ::AbstractWeights) = fill!(y, 4) function median!( - ::TestStatsOverload2, + ::TestStatsOverload1, y, ::AbstractVector, ::AbstractWeights, @@ -315,15 +332,6 @@ function median!( ) return fill!(y, 3) end -function median!( - ::TestStatsOverload3, - y, - ::AbstractVector, - ::AbstractWeights, - ::TestStatsMethod1=TestStatsMethod1(), -) - return fill!(y, 5) -end function var(::TestStatsOverload1, ::AbstractVector, ::AbstractWeights, m; corrected=false) return 4 + 5 * corrected @@ -354,18 +362,14 @@ end x = [[0.0]] @testset "mean" begin M = TestStatsOverload1() + y = similar(x[1]) @test mean(M, x) == [3.0] + @test mean!(M, y, x) == [3.0] @test mean(M, x, w) == [3.0] @test mean(M, x, w, GradientDescentEstimation()) == [3.0] + @test mean!(M, y, x, w, GradientDescentEstimation()) == [3.0] @test mean(M, x, GradientDescentEstimation()) == [3.0] - M = TestStatsOverload2() - @test mean(M, x) == [4.0] - @test mean(M, x, w) == [4.0] - @test mean(M, x, w, GradientDescentEstimation()) == [3.0] - @test mean(M, x, GradientDescentEstimation()) == [3.0] - M = TestStatsOverload3() - @test mean(M, x) == [5.0] - @test mean(M, x, w) == [5.0] + @test mean!(M, y, x, GradientDescentEstimation()) == [3.0] end @testset "median" begin @@ -374,21 +378,13 @@ end @test median(M, x, w) == [3.0] @test median(M, x, w, CyclicProximalPointEstimation()) == [3.0] @test median(M, x, CyclicProximalPointEstimation()) == [3.0] - M = TestStatsOverload2() - @test median(M, x) == [4.0] - @test median(M, x, w) == [4.0] - @test median(M, x, w, CyclicProximalPointEstimation()) == [3.0] - @test median(M, x, CyclicProximalPointEstimation()) == [3.0] - M = TestStatsOverload3() - @test median(M, x) == [5.0] - @test median(M, x, w) == [5.0] end @testset "var" begin M = TestStatsOverload1() - @test mean_and_var(M, x) == ([4.0], 9) + @test mean_and_var(M, x) == ([3.0], 9) @test mean_and_var(M, x, w) == ([4.0], 4) - @test mean_and_std(M, x) == ([4.0], 3.0) + @test mean_and_std(M, x) == ([3.0], 3.0) @test mean_and_std(M, x, w) == ([4.0], 2.0) @test var(M, x) == 9 @test var(M, x, 2) == 9 @@ -399,6 +395,8 @@ end @test std(M, x, w) == 2.0 @test std(M, x, w, 2) == 2.0 + @test Manifolds.default_estimation_method(M, mean_and_std) == + Manifolds.default_estimation_method(M, mean) @test mean_and_var(M, x, TestStatsMethod1()) == ([5.0], 16) @test mean_and_var(M, x, w, TestStatsMethod1()) == ([5.0], 9) @test mean_and_std(M, x, TestStatsMethod1()) == ([5.0], 4.0) @@ -407,24 +405,24 @@ end end @testset "decorator dispatch" begin + # equality tests are intentional to ensure correct dispatch + # (both calls eventually use the same method) ps = [normalize([1, 0, 0] .+ 0.1 .* randn(3)) for _ in 1:3] M1 = TestStatsNotImplementedEmbeddedManifold() @test mean!(M1, similar(ps[1]), ps) == mean!(Sphere(2), similar(ps[1]), ps) @test mean(M1, ps) == mean(Sphere(2), ps) - @test median!(M1, similar(ps[1]), ps) == median!(Sphere(2), similar(ps[1]), ps) - @test median(M1, ps) == median(Sphere(2), ps) M2 = TestStatsNotImplementedEmbeddedManifold2() - @test_throws ErrorException mean(M2, ps) - @test_throws ErrorException mean!(M2, similar(ps[1]), ps) - @test_throws ErrorException median(M2, ps) - @test_throws ErrorException median!(M2, similar(ps[1]), ps) + @test_throws MethodError mean(M2, ps) + @test_throws MethodError mean!(M2, similar(ps[1]), ps) + @test_throws MethodError median(M2, ps) + @test_throws MethodError median!(M2, similar(ps[1]), ps) M3 = TestStatsNotImplementedEmbeddedManifold3() - @test_throws ErrorException mean(M3, ps) - @test_throws ErrorException mean!(M3, similar(ps[1]), ps) - @test_throws ErrorException median(M3, ps) - @test_throws ErrorException median!(M3, similar(ps[1]), ps) + @test_throws MethodError mean(M3, ps) + @test_throws MethodError mean!(M3, similar(ps[1]), ps) + @test_throws MethodError median(M3, ps) + @test_throws MethodError median!(M3, similar(ps[1]), ps) end @testset "TestStatsSphere" begin @@ -527,7 +525,7 @@ end end @testset "vector" begin - x = [[1.0], [2.0], [3.0], [4.0]] + x = [fill(1.0), fill(2.0), fill(3.0), fill(4.0)] vx = vcat(x...) w = pweights(ones(length(x)) / length(x)) @test mean(M, x) β‰ˆ mean(x) @@ -795,6 +793,11 @@ end @test isapprox(S, m, mg) end + @testset "Covariance Default" begin + @test default_estimation_method(TestStatsSphere{2}(), cov) == + GradientDescentEstimation() + end + @testset "Covariance matrix, Euclidean" begin rng = MersenneTwister(47) M = Euclidean(3) diff --git a/test/utils.jl b/test/utils.jl index 4a952a592a..c281049f53 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -5,14 +5,13 @@ const TEST_STATIC_SIZED = false using Manifolds using ManifoldsBase using ManifoldsBase: number_of_coordinates +import ManifoldsBase: active_traits, merge_traits using LinearAlgebra using Distributions using DoubleFloats -using ForwardDiff using Quaternions using Random -using ReverseDiff using StaticArrays using Statistics using StatsBase @@ -43,6 +42,6 @@ end function has_type_in_signature(sig, T::Type) return any(map(Base.unwrap_unionall(sig.sig).parameters) do x xw = Base.rewrap_unionall(x, sig.sig) - return xw <: T + return (xw isa Type ? xw : xw.T) <: T end) end