diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..ba39cc53 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +Manifest.toml diff --git a/src/LinearMaps.jl b/src/LinearMaps.jl index bd030de8..45de6023 100644 --- a/src/LinearMaps.jl +++ b/src/LinearMaps.jl @@ -130,7 +130,7 @@ julia> A(x) """ function Base.:(*)(A::LinearMap, x::AbstractVector) check_dim_mul(A, x) - T = promote_type(eltype(A), eltype(x)) + T = typeof(oneunit(eltype(A)) * oneunit(eltype(x))) y = similar(x, T, axes(A)[1]) return @inbounds mul!(y, A, x) end @@ -308,25 +308,23 @@ function _generic_map_mul!(Y, A, X::AbstractMatrix, α, β) end return Y end -function _generic_map_mul!(Y, A, s::Number) - T = promote_type(eltype(A), typeof(s)) +function _generic_map_mul!(Y, A, s::S) where {S <: Number} ax2 = axes(A)[2] - xi = zeros(T, ax2) + xi = zeros(S, ax2) @inbounds for (i, Yi) in zip(ax2, eachcol(Y)) xi[i] = s mul!(Yi, A, xi) - xi[i] = zero(T) + xi[i] = zero(S) end return Y end -function _generic_map_mul!(Y, A, s::Number, α, β) - T = promote_type(eltype(A), typeof(s)) +function _generic_map_mul!(Y, A, s::S, α, β) where {S <: Number} ax2 = axes(A)[2] - xi = zeros(T, ax2) + xi = zeros(S, ax2) @inbounds for (i, Yi) in zip(ax2, eachcol(Y)) xi[i] = s mul!(Yi, A, xi, α, β) - xi[i] = zero(T) + xi[i] = zero(S) end return Y end diff --git a/src/composition.jl b/src/composition.jl index fe594c17..4717815d 100644 --- a/src/composition.jl +++ b/src/composition.jl @@ -1,3 +1,9 @@ +# appropriate type for product of two (possibly Unitful) quantities: +_multype(a, b) = typeof(oneunit(eltype(a)) * oneunit(eltype(b))) +# this variant is needed because reinterpret changes length of complex vectors +#_multype(a, b, iscomplex::Bool) = +# typeof((1+0im) * oneunit(eltype(a)) * oneunit(eltype(b))) + struct CompositeMap{T, As<:LinearMapTupleOrVector} <: LinearMap{T} maps::As # stored in order of application to vector function CompositeMap{T, As}(maps::As) where {T, As} @@ -5,17 +11,19 @@ struct CompositeMap{T, As<:LinearMapTupleOrVector} <: LinearMap{T} for n in 2:N check_dim_mul(maps[n], maps[n-1]) end - for TA in Base.Iterators.map(eltype, maps) - promote_type(T, TA) == T || - error("eltype $TA cannot be promoted to $T in CompositeMap constructor") - end + Tprod = typeof(*(map(oneunit∘eltype, maps)...)) # handles units + promote_type(T, Tprod) == T || + error("eltype $Tprod and $T incompatible in CompositeMap constructor") new{T, As}(maps) end end + CompositeMap{T}(maps::As) where {T, As<:LinearMapTupleOrVector} = CompositeMap{T, As}(maps) -Base.mapreduce(::typeof(identity), ::typeof(Base.mul_prod), maps::LinearMapTupleOrVector) = - CompositeMap{promote_type(map(eltype, maps)...)}(reverse(maps)) +function Base.mapreduce(::typeof(identity), ::typeof(Base.mul_prod), maps::LinearMapTupleOrVector) + Tprod = typeof(*(map(oneunit∘eltype, maps)...)) # handles units + return CompositeMap{Tprod}(reverse(maps)) +end Base.mapreduce(::typeof(identity), ::typeof(Base.mul_prod), maps::AbstractVector{<:LinearMap{T}}) where {T} = CompositeMap{T}(reverse(maps)) @@ -80,11 +88,11 @@ end # scalar multiplication and division (non-commutative case) function Base.:(*)(α::Number, A::LinearMap) - T = promote_type(typeof(α), eltype(A)) + T = _multype(α, A) return CompositeMap{T}(_combine(A, UniformScalingMap(α, size(A, 1)))) end function Base.:(*)(α::Number, A::CompositeMap) - T = promote_type(typeof(α), eltype(A)) + T = _multype(α, A) Alast = last(A.maps) if Alast isa UniformScalingMap return CompositeMap{T}(_combine(_front(A.maps), α * Alast)) @@ -94,15 +102,15 @@ function Base.:(*)(α::Number, A::CompositeMap) end # needed for disambiguation function Base.:(*)(α::RealOrComplex, A::CompositeMap{<:RealOrComplex}) - T = Base.promote_op(*, typeof(α), eltype(A)) + T = _multype(α, A) return ScaledMap{T}(α, A) end function Base.:(*)(A::LinearMap, α::Number) - T = promote_type(typeof(α), eltype(A)) + T = _multype(A, α) return CompositeMap{T}(_combine(UniformScalingMap(α, size(A, 2)), A)) end function Base.:(*)(A::CompositeMap, α::Number) - T = promote_type(typeof(α), eltype(A)) + T = _multype(A, α) Afirst = first(A.maps) if Afirst isa UniformScalingMap return CompositeMap{T}(_combine(Afirst * α, _tail(A.maps))) @@ -112,7 +120,7 @@ function Base.:(*)(A::CompositeMap, α::Number) end # needed for disambiguation function Base.:(*)(A::CompositeMap{<:RealOrComplex}, α::RealOrComplex) - T = Base.promote_op(*, typeof(α), eltype(A)) + T = _multype(A, α) return ScaledMap{T}(α, A) end @@ -137,19 +145,19 @@ julia> LinearMap(ones(Int, 3, 3)) * CS * I * rand(3, 3); ``` """ function Base.:(*)(A₁::LinearMap, A₂::LinearMap) - T = promote_type(eltype(A₁), eltype(A₂)) + T = _multype(A₁, A₂) return CompositeMap{T}(_combine(A₂, A₁)) end function Base.:(*)(A₁::LinearMap, A₂::CompositeMap) - T = promote_type(eltype(A₁), eltype(A₂)) + T = _multype(A₁, A₂) return CompositeMap{T}(_combine(A₂.maps, A₁)) end function Base.:(*)(A₁::CompositeMap, A₂::LinearMap) - T = promote_type(eltype(A₁), eltype(A₂)) + T = _multype(A₁, A₂) return CompositeMap{T}(_combine(A₂, A₁.maps)) end function Base.:(*)(A₁::CompositeMap, A₂::CompositeMap) - T = promote_type(eltype(A₁), eltype(A₂)) + T = _multype(A₁, A₂) return CompositeMap{T}(_combine(A₂.maps, A₁.maps)) end # needed for disambiguation @@ -217,6 +225,11 @@ function _compositemulN!(y, A::CompositeMap, x, src = nothing, dst = nothing) N = length(A.maps) # ≥ 3 + # caution: be careful if y is complex but intermediate products are not + # source = reinterpret(_multype(A.maps[1], x, !isreal(y)), source) # for units + # resize!(source, size(A.maps[1],1)) # trick due to complex case + # todo: build reinterpret into _unsafe_mul! instead? + # only necessary if either source or map has Number type instead of Real|Complex n = n0 = firstindex(A.maps) source = isnothing(src) ? convert(AbstractArray, A.maps[n] * x) : @@ -227,9 +240,12 @@ function _compositemulN!(y, A::CompositeMap, x, _unsafe_mul!(dst, A.maps[n], source) dest, source = source, dest # alternate dest and source for n in (n0+2):N-1 - dest = _resize(dest, (size(A.maps[n], 1), size(x)[2:end]...)) + # dest = _resize(dest, (size(A.maps[n], 1), size(x)[2:end]...)) + # dest = reinterpret(_multype(A.maps[n], source), dest) + dest = similar([], _multype(A.maps[n], source), (size(A.maps[n], 1), size(source)[2:end]...)) _unsafe_mul!(dest, A.maps[n], source) - dest, source = source, dest # alternate dest and source + # dest, source = source, dest # alternate dest and source + source = dest end _unsafe_mul!(y, last(A.maps), source) return y diff --git a/src/scaledmap.jl b/src/scaledmap.jl index 549bf7a8..d6f82eb9 100644 --- a/src/scaledmap.jl +++ b/src/scaledmap.jl @@ -7,14 +7,16 @@ struct ScaledMap{T, S<:RealOrComplex, L<:LinearMap} <: LinearMap{T} λ::S lmap::L function ScaledMap{T}(λ::S, A::L) where {T, S <: RealOrComplex, L <: LinearMap{<:RealOrComplex}} - @assert Base.promote_op(*, S, eltype(A)) == T "target type $T cannot hold products of $S and $(eltype(A)) objects" + Tprod = typeof(oneunit(S) * oneunit(eltype(A))) + promote_type(T, Tprod) == T || + error("target type $T vs product of $S and $(eltype(A))") new{T,S,L}(λ, A) end end # constructor ScaledMap(λ::RealOrComplex, lmap::LinearMap{<:RealOrComplex}) = - ScaledMap{Base.promote_op(*, typeof(λ), eltype(lmap))}(λ, lmap) + ScaledMap{typeof(oneunit(λ) * oneunit(eltype(lmap)))}(λ, lmap) # basic methods Base.size(A::ScaledMap) = size(A.lmap) diff --git a/test/Project.toml b/test/Project.toml index c3704248..9155a6df 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -11,6 +11,7 @@ Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [compat] Aqua = "0.5, 0.6" diff --git a/test/runtests.jl b/test/runtests.jl index bde3fd94..fc473236 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -42,6 +42,8 @@ include("inversemap.jl") include("rrules.jl") +include("units.jl") + include("khatrirao.jl") include("trace.jl") diff --git a/test/scaledmap.jl b/test/scaledmap.jl index 8a35f27e..bd06e157 100644 --- a/test/scaledmap.jl +++ b/test/scaledmap.jl @@ -41,7 +41,7 @@ using Test, LinearMaps, LinearAlgebra # complex case β = π + 2π * im C = @inferred β * A - @test_throws AssertionError LinearMaps.ScaledMap{Float64}(β, A) + @test_throws ErrorException LinearMaps.ScaledMap{Float64}(β, A) @inferred conj(β) * A' # needed in left-mul T = ComplexF64 xc = rand(T, N) diff --git a/test/units.jl b/test/units.jl new file mode 100644 index 00000000..f2c8119f --- /dev/null +++ b/test/units.jl @@ -0,0 +1,26 @@ +# test/units + +using Test: @test, @testset, @inferred, @test_throws +using LinearMaps: LinearMap +using Unitful: m, s, g + +@testset "units" begin + A = rand(4,3) * 1m + B = rand(3,2) * 1s + C = A * B + D = 1f0g * C + + Ma = @inferred LinearMap(A) + Mb = @inferred LinearMap(B) + Mc = Ma * Mb + Md = 1f0g * Mc + + @test Matrix(Ma) == A + @test Matrix(Mc) == C + @test Matrix(Md) == D + + x = randn(2) + @test B * x == Mb * x + @test C * x ≈ Mc * x + @test D * x ≈ Md * x +end