Skip to content

WIP: get rid of A*_mul_B! #74

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
80 changes: 52 additions & 28 deletions src/LinearMaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,33 +72,15 @@ function Base.:(*)(A::LinearMap, x::AbstractVector)
size(A, 2) == length(x) || throw(DimensionMismatch("mul!"))
return @inbounds mul!(similar(x, promote_type(eltype(A), eltype(x)), size(A, 1)), A, x)
end
function LinearAlgebra.mul!(y::AbstractVector, A::LinearMap, x::AbstractVector, α::Number=true, β::Number=false)
Base.@propagate_inbounds function LinearAlgebra.mul!(y::AbstractVector, A::LinearMap, x::AbstractVector,
α::Number=true, β::Number=false)
@boundscheck check_dim_mul(y, A, x)
if isone(α)
iszero(β) && (A_mul_B!(y, A, x); return y)
isone(β) && (y .+= A * x; return y)
# β != 0, 1
rmul!(y, β)
y .+= A * x
return y
elseif iszero(α)
iszero(β) && (fill!(y, zero(eltype(y))); return y)
isone(β) && return y
# β != 0, 1
rmul!(y, β)
return y
else # α != 0, 1
iszero(β) && (A_mul_B!(y, A, x); rmul!(y, α); return y)
isone(β) && (y .+= rmul!(A * x, α); return y)
# β != 0, 1
rmul!(y, β)
y .+= rmul!(A * x, α)
return y
end
_muladd!(MulStyle(A), y, A, x, α, β)
end
# the following is of interest in, e.g., subspace-iteration methods
Base.@propagate_inbounds function LinearAlgebra.mul!(Y::AbstractMatrix, A::LinearMap, X::AbstractMatrix, α::Number=true, β::Number=false)
@boundscheck check_dim_mul(Y, A, X)
function LinearAlgebra.mul!(Y::AbstractMatrix, A::LinearMap, X::AbstractMatrix, α::Number=true, β::Number=false)
(size(Y, 1) == size(A, 1) && size(X, 1) == size(A, 2) && size(Y, 2) == size(X, 2)) || throw(DimensionMismatch("mul!"))
# TODO: reuse intermediate z if necessary
@inbounds @views for i = 1:size(X, 2)
mul!(Y[:, i], A, X[:, i], α, β)
end
Expand All @@ -109,13 +91,55 @@ Base.@propagate_inbounds function LinearAlgebra.mul!(Y::AbstractMatrix, A::Linea
return Y
end

A_mul_B!(y::AbstractVector, A::AbstractMatrix, x::AbstractVector) = mul!(y, A, x)
At_mul_B!(y::AbstractVector, A::AbstractMatrix, x::AbstractVector) = mul!(y, transpose(A), x)
Ac_mul_B!(y::AbstractVector, A::AbstractMatrix, x::AbstractVector) = mul!(y, adjoint(A), x)
# multiplication helper functions
@inline _muladd!(::FiveArg, y, A::MapOrMatrix, x, α, β) = mul!(y, A, x, α, β)
@inline _muladd!(::ThreeArg, y, A::MapOrMatrix, x, α, β) =
iszero(β) ? muladd!(y, A, x, α, β, nothing) : muladd!(y, A, x, α, β, similar(y))
@inline _muladd!(::FiveArg, y, A::MapOrMatrix, x, α, β, _) = mul!(y, A, x, α, β)
@inline _muladd!(::ThreeArg, y, A::MapOrMatrix, x, α, β, z) = muladd!(y, A, x, α, β, z)

"""
muladd!(y, A, x, α, β, z)

Compute 5-arg multiplication for `LinearMap`s that do not have a 5-arg `mul!`,
but only a 3-arg `mul!`. For storing the intermediate `A*x*α`, a vector (or
matrix, as appropriate) `z` is (already) provided.
"""
@inline function muladd!(y, A::MapOrMatrix, x, α, β, z)
if iszero(α)
iszero(β) && (fill!(y, zero(eltype(y))); return y)
isone(β) && return y
# β != 0, 1
rmul!(y, β)
return y
else
if iszero(β)
mul!(y, A, x)
!isone(α) && rmul!(y, α)
return y
else
mul!(z, A, x)
if isone(α)
if isone(β)
y .+= z
else
y .= y .* β .+ z
end
else
if isone(β)
y .+= z .* α
else
y .= y .* β .+ z .* α
end
end
return y
end
end
end

include("transpose.jl") # transposing linear maps
include("wrappedmap.jl") # wrap a matrix of linear map in a new type, thereby allowing to alter its properties
include("uniformscalingmap.jl") # the uniform scaling map, to be able to make linear combinations of LinearMap objects and multiples of I
include("transpose.jl") # transposing linear maps
include("linearcombination.jl") # defining linear combinations of linear maps
include("composition.jl") # composition of linear maps
include("functionmap.jl") # using a function as linear map
Expand Down
39 changes: 13 additions & 26 deletions src/blockmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -344,32 +344,17 @@ end
# multiplication with vectors & matrices
############

Base.@propagate_inbounds A_mul_B!(y::AbstractVector, A::BlockMap, x::AbstractVector) =
mul!(y, A, x)

Base.@propagate_inbounds A_mul_B!(y::AbstractVector, A::TransposeMap{<:Any,<:BlockMap}, x::AbstractVector) =
mul!(y, A, x)

Base.@propagate_inbounds At_mul_B!(y::AbstractVector, A::BlockMap, x::AbstractVector) =
mul!(y, transpose(A), x)

Base.@propagate_inbounds A_mul_B!(y::AbstractVector, A::AdjointMap{<:Any,<:BlockMap}, x::AbstractVector) =
mul!(y, A, x)

Base.@propagate_inbounds Ac_mul_B!(y::AbstractVector, A::BlockMap, x::AbstractVector) =
mul!(y, adjoint(A), x)

for Atype in (AbstractVector, AbstractMatrix)
@eval Base.@propagate_inbounds function LinearAlgebra.mul!(y::$Atype, A::BlockMap, x::$Atype,
α::Number=true, β::Number=false)
α::Number=true, β::Number=false)
require_one_based_indexing(y, x)
@boundscheck check_dim_mul(y, A, x)
return _blockmul!(y, A, x, α, β)
end

for (maptype, transform) in ((:(TransposeMap{<:Any,<:BlockMap}), :transpose), (:(AdjointMap{<:Any,<:BlockMap}), :adjoint))
@eval Base.@propagate_inbounds function LinearAlgebra.mul!(y::$Atype, wrapA::$maptype, x::$Atype,
α::Number=true, β::Number=false)
α::Number=true, β::Number=false)
require_one_based_indexing(y, x)
@boundscheck check_dim_mul(y, wrapA, x)
return _transblockmul!(y, wrapA.lmap, x, α, β, $transform)
Expand Down Expand Up @@ -457,15 +442,6 @@ LinearAlgebra.transpose(A::BlockDiagonalMap{T}) where {T} = BlockDiagonalMap{T}(

Base.:(==)(A::BlockDiagonalMap, B::BlockDiagonalMap) = (eltype(A) == eltype(B) && A.maps == B.maps)

Base.@propagate_inbounds A_mul_B!(y::AbstractVector, A::BlockDiagonalMap, x::AbstractVector) =
mul!(y, A, x, true, false)

Base.@propagate_inbounds At_mul_B!(y::AbstractVector, A::BlockDiagonalMap, x::AbstractVector) =
mul!(y, transpose(A), x, true, false)

Base.@propagate_inbounds Ac_mul_B!(y::AbstractVector, A::BlockDiagonalMap, x::AbstractVector) =
mul!(y, adjoint(A), x, true, false)

for Atype in (AbstractVector, AbstractMatrix)
@eval Base.@propagate_inbounds function LinearAlgebra.mul!(y::$Atype, A::BlockDiagonalMap, x::$Atype,
α::Number=true, β::Number=false)
Expand All @@ -483,3 +459,14 @@ end
end
return y
end

Base.@propagate_inbounds function muladd!(y, A::BlockDiagonalMap, x, α, β, z)
require_one_based_indexing(y, x, z)
@boundscheck check_dim_mul(y, A, x)
@boundscheck size(y) == size(z)
maps, yinds, xinds = A.maps, A.rowranges, A.colranges
@views @inbounds for i in 1:length(maps)
_muladd!(MulStyle(maps[i]), selectdim(y, 1, yinds[i]), maps[i], selectdim(x, 1, xinds[i]), α, β, selectdim(z, 1, yinds[i]))
end
return y
end
37 changes: 26 additions & 11 deletions src/composition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ struct CompositeMap{T, As<:Tuple{Vararg{LinearMap}}} <: LinearMap{T}
end
CompositeMap{T}(maps::As) where {T, As<:Tuple{Vararg{LinearMap}}} = CompositeMap{T, As}(maps)

# treat CompositeMaps as FiveArg's because the 5-arg mul! is optimized
MulStyle(::CompositeMap) = FiveArg()

# basic methods
Base.size(A::CompositeMap) = (size(A.maps[end], 1), size(A.maps[1], 2))
Base.isreal(A::CompositeMap) = all(isreal, A.maps) # sufficient but not necessary
Expand Down Expand Up @@ -120,17 +123,21 @@ LinearAlgebra.adjoint(A::CompositeMap{T}) where {T} = CompositeMap{T}(map(adjo
Base.:(==)(A::CompositeMap, B::CompositeMap) = (eltype(A) == eltype(B) && A.maps == B.maps)

# multiplication with vectors
function A_mul_B!(y::AbstractVector, A::CompositeMap, x::AbstractVector)
Base.@propagate_inbounds function LinearAlgebra.mul!(y::AbstractVector, A::CompositeMap, x::AbstractVector,
α::Number=true, β::Number=false)
# no size checking, will be done by individual maps
N = length(A.maps)
A1 = first(A.maps)
if N==1
A_mul_B!(y, A.maps[1], x)
mul!(y, A1, x, α, β)
else
dest = similar(y, size(A.maps[1], 1))
A_mul_B!(dest, A.maps[1], x)
_muladd!(MulStyle(A1), dest, A1, x, α, false)
source = dest
if N>2
dest = similar(y, size(A.maps[2], 1))
if N>2 || (MulStyle(A.maps[2]) === ThreeArg() && !iszero(β))
# we need a second intermediate vector if there are more than two maps
# or if the second out of two is a ThreeArg and we need to do muladd
dest = Array{T}(undef, size(A.maps[2], 1))
end
for n in 2:N-1
try
Expand All @@ -142,14 +149,22 @@ function A_mul_B!(y::AbstractVector, A::CompositeMap, x::AbstractVector)
rethrow(err)
end
end
A_mul_B!(dest, A.maps[n], source)
mul!(dest, A.maps[n], source)
dest, source = source, dest # alternate dest and source
end
A_mul_B!(y, A.maps[N], source)
Alast = last(A.maps)
if (MulStyle(Alast) === ThreeArg() && !iszero(β))
try
resize!(dest, size(Alast, 1))
catch err
if err == ErrorException("cannot resize array with shared data")
dest = Array{T}(undef, size(Alast, 1))
else
rethrow(err)
end
end
end
_muladd!(MulStyle(Alast), y, Alast, source, true, β, dest)
end
return y
end

At_mul_B!(y::AbstractVector, A::CompositeMap, x::AbstractVector) = A_mul_B!(y, transpose(A), x)

Ac_mul_B!(y::AbstractVector, A::CompositeMap, x::AbstractVector) = A_mul_B!(y, adjoint(A), x)
18 changes: 10 additions & 8 deletions src/functionmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,16 @@ function Base.:(*)(A::TransposeMap{<:Any,<:FunctionMap}, x::AbstractVector)
end
end

function A_mul_B!(y::AbstractVector, A::FunctionMap, x::AbstractVector)
(length(x) == A.N && length(y) == A.M) || throw(DimensionMismatch("A_mul_B!"))
Base.@propagate_inbounds function LinearAlgebra.mul!(y::AbstractVector, A::FunctionMap, x::AbstractVector)
@boundscheck check_dim_mul(y, A, x)
ismutating(A) ? A.f(y, x) : copyto!(y, A.f(x))
return y
end

function At_mul_B!(y::AbstractVector, A::FunctionMap, x::AbstractVector)
(issymmetric(A) || (isreal(A) && ishermitian(A))) && return A_mul_B!(y, A, x)
(length(x) == A.M && length(y) == A.N) || throw(DimensionMismatch("At_mul_B!"))
Base.@propagate_inbounds function LinearAlgebra.mul!(y::AbstractVector, transA::TransposeMap{<:Any,<:FunctionMap}, x::AbstractVector)
A = transA.lmap
issymmetric(A) && return mul!(y, A, x)
@boundscheck check_dim_mul(y, transA, x)
if A.fc !== nothing
if !isreal(A)
x = conj(x)
Expand All @@ -135,9 +136,10 @@ function At_mul_B!(y::AbstractVector, A::FunctionMap, x::AbstractVector)
end
end

function Ac_mul_B!(y::AbstractVector, A::FunctionMap, x::AbstractVector)
ishermitian(A) && return A_mul_B!(y, A, x)
(length(x) == A.M && length(y) == A.N) || throw(DimensionMismatch("Ac_mul_B!"))
Base.@propagate_inbounds function LinearAlgebra.mul!(y::AbstractVector, adjA::AdjointMap{<:Any,<:FunctionMap}, x::AbstractVector)
A = adjA.lmap
ishermitian(A) && return mul!(y, A, x)
@boundscheck check_dim_mul(y, adjA, x)
if A.fc !== nothing
ismutating(A) ? A.fc(y, x) : copyto!(y, A.fc(x))
return y
Expand Down
56 changes: 31 additions & 25 deletions src/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ Base.:(==)(A::KroneckerMap, B::KroneckerMap) = (eltype(A) == eltype(B) && A.maps
temp2 = similar(y, nb)
@views @inbounds for i in 1:ma
v[i] = one(T)
A_mul_B!(temp1, At, v)
A_mul_B!(temp2, X, temp1)
A_mul_B!(y[((i-1)*mb+1):i*mb], B, temp2)
mul!(temp1, At, v)
mul!(temp2, X, temp1)
mul!(y[((i-1)*mb+1):i*mb], B, temp2)
v[i] = zero(T)
end
return y
Expand All @@ -131,15 +131,15 @@ end
# multiplication with vectors
#################

Base.@propagate_inbounds function A_mul_B!(y::AbstractVector, L::KroneckerMap{T,<:NTuple{2,LinearMap}}, x::AbstractVector) where {T}
Base.@propagate_inbounds function LinearAlgebra.mul!(y::AbstractVector, L::KroneckerMap{T,<:NTuple{2,LinearMap}}, x::AbstractVector) where {T}
require_one_based_indexing(y)
@boundscheck check_dim_mul(y, L, x)
A, B = L.maps
X = LinearMap(reshape(x, (size(B, 2), size(A, 2))); issymmetric=false, ishermitian=false, isposdef=false)
_kronmul!(y, B, X, transpose(A), T)
return y
end
Base.@propagate_inbounds function A_mul_B!(y::AbstractVector, L::KroneckerMap{T}, x::AbstractVector) where {T}
Base.@propagate_inbounds function LinearAlgebra.mul!(y::AbstractVector, L::KroneckerMap{T}, x::AbstractVector) where {T}
require_one_based_indexing(y)
@boundscheck check_dim_mul(y, L, x)
A = first(L.maps)
Expand All @@ -150,35 +150,29 @@ Base.@propagate_inbounds function A_mul_B!(y::AbstractVector, L::KroneckerMap{T}
end
# mixed-product rule, prefer the right if possible
# (A₁ ⊗ A₂ ⊗ ... ⊗ Aᵣ) * (B₁ ⊗ B₂ ⊗ ... ⊗ Bᵣ) = (A₁B₁) ⊗ (A₂B₂) ⊗ ... ⊗ (AᵣBᵣ)
Base.@propagate_inbounds function A_mul_B!(y::AbstractVector, L::CompositeMap{<:Any,<:Tuple{KroneckerMap,KroneckerMap}}, x::AbstractVector)
Base.@propagate_inbounds function LinearAlgebra.mul!(y::AbstractVector, L::CompositeMap{<:Any,<:Tuple{KroneckerMap,KroneckerMap}}, x::AbstractVector)
B, A = L.maps
if length(A.maps) == length(B.maps) && all(M -> check_dim_mul(M[1], M[2]), zip(A.maps, B.maps))
A_mul_B!(y, kron(map(*, A.maps, B.maps)...), x)
mul!(y, kron(map(*, A.maps, B.maps)...), x)
else
A_mul_B!(y, LinearMap(A)*B, x)
mul!(y, LinearMap(A)*B, x)
end
end
# mixed-product rule, prefer the right if possible
# (A₁ ⊗ B₁)*(A₂⊗B₂)*...*(Aᵣ⊗Bᵣ) = (A₁*A₂*...*Aᵣ) ⊗ (B₁*B₂*...*Bᵣ)
Base.@propagate_inbounds function A_mul_B!(y::AbstractVector, L::CompositeMap{T,<:Tuple{Vararg{KroneckerMap{<:Any,<:Tuple{LinearMap,LinearMap}}}}}, x::AbstractVector) where {T}
Base.@propagate_inbounds function LinearAlgebra.mul!(y::AbstractVector, L::CompositeMap{T,<:Tuple{Vararg{KroneckerMap{<:Any,<:Tuple{LinearMap,LinearMap}}}}}, x::AbstractVector) where {T}
As = map(AB -> AB.maps[1], L.maps)
Bs = map(AB -> AB.maps[2], L.maps)
As1, As2 = Base.front(As), Base.tail(As)
Bs1, Bs2 = Base.front(Bs), Base.tail(Bs)
apply = all(A -> check_dim_mul(A...), zip(As1, As2)) && all(A -> check_dim_mul(A...), zip(Bs1, Bs2))
if apply
A_mul_B!(y, kron(prod(As), prod(Bs)), x)
mul!(y, kron(prod(As), prod(Bs)), x)
else
A_mul_B!(y, CompositeMap{T}(map(LinearMap, L.maps)), x)
mul!(y, CompositeMap{T}(map(LinearMap, L.maps)), x)
end
end

Base.@propagate_inbounds At_mul_B!(y::AbstractVector, A::KroneckerMap, x::AbstractVector) =
A_mul_B!(y, transpose(A), x)

Base.@propagate_inbounds Ac_mul_B!(y::AbstractVector, A::KroneckerMap, x::AbstractVector) =
A_mul_B!(y, adjoint(A), x)

###############
# KroneckerSumMap
###############
Expand Down Expand Up @@ -249,6 +243,8 @@ where `A` can be a square `AbstractMatrix` or a `LinearMap`.

Base.:(^)(A::MapOrMatrix, ::KronSumPower{p}) where {p} = kronsum(ntuple(n -> convert_to_lmaps_(A), Val(p))...)

MulStyle(A::KronSumPower) = MulStyle(Matrix{Float64}(undef, 0, 0), A.maps...)

Base.size(A::KroneckerSumMap, i) = prod(size.(A.maps, i))
Base.size(A::KroneckerSumMap) = (size(A, 1), size(A, 2))

Expand All @@ -261,20 +257,30 @@ LinearAlgebra.transpose(A::KroneckerSumMap{T}) where {T} = KroneckerSumMap{T}(ma

Base.:(==)(A::KroneckerSumMap, B::KroneckerSumMap) = (eltype(A) == eltype(B) && A.maps == B.maps)

Base.@propagate_inbounds function A_mul_B!(y::AbstractVector, L::KroneckerSumMap, x::AbstractVector)
Base.@propagate_inbounds function LinearAlgebra.mul!(y::AbstractVector, L::KroneckerSumMap, x::AbstractVector,
α::Number=true, β::Number=false)
@boundscheck check_dim_mul(y, L, x)
A, B = L.maps
ma, na = size(A)
mb, nb = size(B)
X = reshape(x, (nb, na))
Y = reshape(y, (nb, na))
mul!(Y, X, convert(AbstractMatrix, transpose(A)))
mul!(Y, B, X, true, true)
_muladd!(MulStyle(X), Y, X, convert(AbstractMatrix, transpose(A)), true, β)
_muladd!(MulStyle(B), Y, B, X, α, true)
return y
end

Base.@propagate_inbounds At_mul_B!(y::AbstractVector, A::KroneckerSumMap, x::AbstractVector) =
A_mul_B!(y, transpose(A), x)

Base.@propagate_inbounds Ac_mul_B!(y::AbstractVector, A::KroneckerSumMap, x::AbstractVector) =
A_mul_B!(y, adjoint(A), x)
Base.@propagate_inbounds function muladd!(y::AbstractVector, L::KroneckerSumMap, x::AbstractVector,
α, β, z::AbstractVector)
@boundscheck check_dim_mul(y, L, x)
@boundscheck size(y) == size(z)
A, B = L.maps
ma, na = size(A)
mb, nb = size(B)
X = reshape(x, (nb, na))
Y = reshape(y, (nb, na))
Z = reshape(z, (nb, na))
_muladd!(MulStyle(X), Y, X, convert(AbstractMatrix, transpose(A)), true, β, Z)
_muladd!(MulStyle(B), Y, B, X, α, true, Z)
return y
end
Loading