Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 41 additions & 11 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ SimpleOrSpecialOrAdjMat{T, M} = Union{M,
AdjOrTranspMat{T, <:SpecialMat{T,<:M}},
SpecialMat{T,<:AdjOrTranspMat{T,<:M}}}

const SpecialMatrices = (LowerTriangular, UpperTriangular,
UnitLowerTriangular, UnitUpperTriangular,
Symmetric, Hermitian)

# unwraps matrix A from Adjoint/Transpose transform
unwrap_trans(A::AbstractMatrix) = A
unwrap_trans(A::Union{Adjoint, Transpose}) = unwrap_trans(parent(A))
Expand Down Expand Up @@ -219,17 +223,45 @@ end
# sparse * sparse overloads, have to be more specific than
# the ones in SparseArrays.jl to avoid ambiguity

(*)(A::SparseMat{T}, B::SparseMat{T}) where T =
spmatmul_sparse(A, B)
for Amat in (nothing, SpecialMatrices...), Bmat in (nothing, SpecialMatrices...)
Atype = !isnothing(Amat) ? :($Amat{T,S}) : :S
tAtype = !isnothing(Amat) ? :($Amat{T, <:AdjOrTranspMat{T, S}}) : nothing
Btype = !isnothing(Bmat) ? :($Bmat{T,S}) : :S
tBtype = !isnothing(Bmat) ? :($Bmat{T, <:AdjOrTranspMat{T, S}}) : nothing

@eval (*)(A::$Atype, B::$Btype) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)

@eval (*)(A::AdjOrTranspMat{T, $Atype}, B::$Btype) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)

(*)(A::AdjOrTranspMat{T, S}, B::S) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)
@eval (*)(A::$Atype, B::AdjOrTranspMat{T, $Btype}) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)

(*)(A::S, B::AdjOrTranspMat{T, S}) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)
@eval (*)(A::AdjOrTranspMat{T, $Atype}, B::AdjOrTranspMat{T, $Btype}) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)

if tAtype !== nothing
@eval (*)(A::$tAtype, B::$Btype) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)

@eval (*)(A::$tAtype, B::AdjOrTranspMat{T, $Btype}) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)
end

(*)(A::AdjOrTranspMat{T, S}, B::AdjOrTranspMat{T, S}) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)
if tBtype !== nothing
@eval (*)(A::$Atype, B::$tBtype) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)

@eval (*)(A::AdjOrTranspMat{T, $Atype}, B::$tBtype) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)
end

if tAtype !== nothing && tBtype !== nothing
@eval (*)(A::$tAtype, B::$tBtype) where {T <: BlasFloat, S <: SparseMat{T}} =
spmatmul_sparse(A, B)
end
end

if VERSION < v"1.11" # in 1.11 these wrappers are already defined in LinearAlgebra

Expand All @@ -245,9 +277,7 @@ function (\)(A::Union{S, AdjOrTranspMat{T, S}}, B::StridedMatrix{T}) where {T <:
return ldiv!(C, A, B)
end

for mat in (LowerTriangular, UpperTriangular,
UnitLowerTriangular, UnitUpperTriangular,
Symmetric, Hermitian)
for mat in SpecialMatrices

@eval function (\)(A::Union{$mat{T, S}, AdjOrTranspMat{T, $mat{T, S}}, $mat{T, <:AdjOrTranspMat{T, S}}},
x::StridedVector{T}
Expand Down
4 changes: 2 additions & 2 deletions test/test_BLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -728,8 +728,8 @@ end
n = rand(10:50)
spf = 0.1 + 0.8 * rand()

spA = convert_to_Aclass(sparserandn(SPMT{T, IT}, n, n, spf))
spB = convert_to_Bclass(sparserandn(SPMT{T, IT}, n, n, spf))
spA = Aclass(convert_to_Aclass(sparserandn(SPMT{T, IT}, n, n, spf)))
spB = Bclass(convert_to_Bclass(sparserandn(SPMT{T, IT}, n, n, spf)))
A = convert(Matrix, spA)
B = convert(Matrix, spB)

Expand Down
Loading