diff --git a/Project.toml b/Project.toml index cbb0c6dc..aa4b761f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FillArrays" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.10.1" +version = "0.11.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index 8bcea244..2f87aa3c 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -1,3 +1,10 @@ +const FillVector{F,A} = Fill{F,1,A} +const FillMatrix{F,A} = Fill{F,2,A} +const OnesVector{F,A} = Ones{F,1,A} +const OnesMatrix{F,A} = Ones{F,2,A} +const ZerosVector{F,A} = Zeros{F,1,A} +const ZerosMatrix{F,A} = Zeros{F,2,A} + ## vec vec(a::Ones{T}) where T = Ones{T}(length(a)) @@ -77,8 +84,10 @@ end *(a::Zeros{<:Any,1}, b::AbstractMatrix) = mult_zeros(a, b) *(a::Zeros{<:Any,2}, b::AbstractMatrix) = mult_zeros(a, b) +*(a::Zeros{<:Any,2}, b::AbstractTriangular) = mult_zeros(a, b) *(a::AbstractMatrix, b::Zeros{<:Any,1}) = mult_zeros(a, b) *(a::AbstractMatrix, b::Zeros{<:Any,2}) = mult_zeros(a, b) +*(a::AbstractTriangular, b::Zeros{<:Any,2}) = mult_zeros(a, b) *(a::Zeros{<:Any,1}, b::AbstractVector) = mult_zeros(a, b) *(a::Zeros{<:Any,2}, b::AbstractVector) = mult_zeros(a, b) *(a::AbstractVector, b::Zeros{<:Any,2}) = mult_zeros(a, b) @@ -87,36 +96,65 @@ end *(a::Zeros{<:Any,2}, b::Diagonal) = mult_zeros(a, b) *(a::Diagonal, b::Zeros{<:Any,1}) = mult_zeros(a, b) *(a::Diagonal, b::Zeros{<:Any,2}) = mult_zeros(a, b) -function *(a::Diagonal, b::AbstractFill{<:Any,2}) + +function *(a::Diagonal, b::AbstractFill{T,2}) where T size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))")) a.diag .* b # use special broadcast end -function *(a::AbstractFill{<:Any,2}, b::Diagonal) +function *(a::AbstractFill{T,2}, b::Diagonal) where T size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))")) a .* permutedims(b.diag) # use special broadcast end -*(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(conj.(parent(a)); dims=1) .* b.value, size(parent(a), 2)) -*(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(parent(a); dims=1) .* b.value, size(parent(a), 2)) -*(a::StridedMatrix{T}, b::Fill{T, 1}) where T = reshape(sum(a; dims=2) .* b.value, size(a, 1)) - -function *(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T - fB = similar(parent(a), size(b, 1), size(b, 2)) - fill!(fB, b.value) - return a*fB +function mult_sum2(x::AbstractMatrix, f::AbstractFill{T,2}) where T + axes(x, 2) ≠ axes(f, 1) && + throw(DimensionMismatch("Incompatible matrix multiplication dimensions")) + m = size(f, 2) + repeat(sum(x, dims=2) * getindex_value(f), 1, m) end -function *(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T - fB = similar(parent(a), size(b, 1), size(b, 2)) - fill!(fB, b.value) - return a*fB +function mult_sum1(f::AbstractFill{T,2}, x::AbstractMatrix) where T + axes(f, 2) ≠ axes(x, 1) && + throw(DimensionMismatch("Incompatible matrix multiplication dimensions")) + m = size(f, 1) + repeat(sum(x, dims=1) * getindex_value(f), m, 1) end -function *(a::StridedMatrix{T}, b::Fill{T, 2}) where T - fB = similar(a, size(b, 1), size(b, 2)) - fill!(fB, b.value) - return a*fB -end +*(x::AbstractMatrix, y::AbstractFill{<:Any,2}) = mult_sum2(x, y) +*(x::AbstractTriangular, y::AbstractFill{<:Any,2}) = mult_sum2(x, y) +*(x::AbstractFill{<:Any,2}, y::AbstractMatrix) = mult_sum1(x, y) +*(x::AbstractFill{<:Any,2}, y::AbstractTriangular) = mult_sum1(x, y) + + +### These methods are faster for small n ############# +# function *(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T +# fB = similar(parent(a), size(b, 1), size(b, 2)) +# fill!(fB, b.value) +# return a*fB +# end + +# function *(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T +# fB = similar(parent(a), size(b, 1), size(b, 2)) +# fill!(fB, b.value) +# return a*fB +# end + +# function *(a::StridedMatrix{T}, b::Fill{T, 2}) where T +# fB = similar(a, size(b, 1), size(b, 2)) +# fill!(fB, b.value) +# return a*fB +# end + +## Matrix-Vector multiplication + +*(a::Adjoint{T, <:StridedMatrix{T}}, b::AbstractFill{T, 1}) where T = + reshape(sum(conj.(parent(a)); dims=1) .* getindex_value(b), size(parent(a), 2)) +*(a::Transpose{T, <:StridedMatrix{T}}, b::AbstractFill{T, 1}) where T = + reshape(sum(parent(a); dims=1) .* getindex_value(b), size(parent(a), 2)) +*(a::StridedMatrix{T}, b::AbstractFill{T, 1}) where T = + reshape(sum(a; dims=2) .* getindex_value(b), size(a, 1)) + + function _adjvec_mul_zeros(a::Adjoint{T}, b::Zeros{S, 1}) where {T, S} la, lb = length(a), length(b) if la ≠ lb diff --git a/test/runtests.jl b/test/runtests.jl index 74f01d69..59e99c15 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1028,6 +1028,31 @@ end @test E*(1:5) ≡ 1.0:5.0 @test (1:5)'E == (1.0:5)' @test E*E ≡ E + + # Adjoint / Transpose / Triangular / Symmetric / Hermitian + for x in [transpose(rand(2, 2)), + adjoint(rand(2,2)), + UpperTriangular(rand(2,2)), + Symmetric(rand(2,2)), + Hermitian(rand(2,2))] + @test x * Ones(2, 2) isa Matrix + @test Ones(2, 2) * x isa Matrix + @test x * Ones(2) isa Vector + @test Ones(2)' * x isa Adjoint{Float64,<:Vector} + @test x * Zeros(2, 2) isa Zeros + @test Zeros(2, 2) * x isa Zeros + @test x * Zeros(2) isa Vector + @test x * Zeros(2) isa Zeros + @test Zeros(2)' * x isa Adjoint{Float64,<:Zeros} + @test x * Fill(1., 2, 2) isa Matrix + @test Fill(1., 2, 2) * x isa Matrix + @test x * Fill(1.,2) isa Vector + @test Fill(1.,2)' * x isa Adjoint{Float64,<:Vector} + @test x * Ones(2,2) == x * Fill(1., 2, 2) == x * ones(2,2) + @test Ones(2,2) * x == Fill(1., 2, 2) * x == Ones(2,2) * x + @test x * Ones(2) == x * Fill(1.,2) == x * ones(2) + @test Ones(2)'x == Fill(1.,2)'x == ones(2)'x + end end @testset "count" begin @@ -1154,4 +1179,4 @@ end @test FillArrays.getindex_value(transpose(a)) == FillArrays.unique_value(transpose(a)) == 2.0 @test convert(Fill, transpose(a)) ≡ Fill(2.0,1,5) end -end \ No newline at end of file +end