diff --git a/src/bases_dirac.jl b/src/bases_dirac.jl index 75d7918..54ece3e 100644 --- a/src/bases_dirac.jl +++ b/src/bases_dirac.jl @@ -38,8 +38,8 @@ function Base.IteratorSize(::Type{<:DiracBasis{T,S}}) where {T,S} end function Base.size(db::DiracBasis) - @assert Base.haslength(object(db)) - return size(object(db)) + # Some object such as `PermGroup` do not define `size` + return (length(db),) end function Base.length(db::DiracBasis) @@ -53,6 +53,7 @@ Base.iterate(db::DiracBasis, st) = iterate(object(db), st) Base.in(g, db::DiracBasis) = g in object(db) Base.haskey(db::DiracBasis, g) = in(g, db) +Base.firstindex(db::DiracBasis) = Base.first(object(db)) function Base.getindex(db::DiracBasis{T}, x::T) where {T} @assert x in object(db) return x diff --git a/src/bases_fixed.jl b/src/bases_fixed.jl index 5273279..76ac973 100644 --- a/src/bases_fixed.jl +++ b/src/bases_fixed.jl @@ -75,6 +75,8 @@ end FixedBasis(basis::AbstractBasis{T}; n::Integer) where {T} = FixedBasis{T,typeof(n)}(basis; n) Base.in(x, b::FixedBasis) = haskey(b.relts, x) +Base.firstindex(b::FixedBasis) = firstindex(b.elts) +Base.lastindex(b::FixedBasis) = lastindex(b.elts) Base.getindex(b::FixedBasis{T}, x::T) where {T} = b.relts[x] Base.getindex(b::FixedBasis, i::Integer) = b.elts[i] @@ -138,6 +140,8 @@ end Base.in(x::T, b::SubBasis{T}) where T = !isnothing(get(b, x, nothing)) Base.haskey(b::SubBasis, i::Integer) = i in eachindex(b.keys) +Base.firstindex(b::SubBasis) = firstindex(b.keys) +Base.lastindex(b::SubBasis) = lastindex(b.keys) Base.getindex(b::SubBasis, i::Integer) = parent(b)[b.keys[i]] function Base.getindex(b::SubBasis{T,I}, x::T) where {T,I} i = get(b, x, nothing) diff --git a/test/quadratic_form.jl b/test/quadratic_form.jl index 85f1b0c..8afcb69 100644 --- a/test/quadratic_form.jl +++ b/test/quadratic_form.jl @@ -110,7 +110,9 @@ end mstr = ChebyMStruct(implicit) mt = SA.MTable(implicit, mstr, (0, 0)) sub = SA.SubBasis(implicit, 1:3) + test_vector_interface(sub) fixed = SA.FixedBasis(implicit; n = 3) + test_vector_interface(fixed) a = ChebyPoly(2) b = ChebyPoly(3) expected = SA.SparseCoefficients((ChebyPoly(1), ChebyPoly(5)), (1 // 2, 1 // 2)) diff --git a/test/runtests.jl b/test/runtests.jl index 07f31bd..b27be3c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,6 +28,23 @@ for file in readdir(joinpath(@__DIR__, "..", "examples")) end end +# Test that `basis` implements the vector interface +# returns the same as `vector` +function test_vector_interface(basis, vector = collect(basis)) + @test length(basis) == length(vector) + @test size(basis) == size(vector) + @test eltype(basis) == eltype(vector) + @test eltype(basis) == eltype(typeof(basis)) + @test last(basis) == last(vector) + i = lastindex(basis) + @test i == lastindex(vector) + @test getindex(basis, i) == last(vector) + @test first(basis) == first(vector) + i = firstindex(basis) + @test i == firstindex(vector) + @test getindex(basis, i) == first(vector) +end + @testset "StarAlgebras" begin include("basic.jl") # proof of concept