diff --git a/src/triangular.jl b/src/triangular.jl index a7037e7e..97b00901 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -144,6 +144,8 @@ const LowerOrUnitLowerTriangular{T,S<:AbstractMatrix{T}} = Union{LowerTriangular const UpperOrLowerTriangular{T,S<:AbstractMatrix{T}} = Union{UpperOrUnitUpperTriangular{T,S}, LowerOrUnitLowerTriangular{T,S}} const UnitUpperOrUnitLowerTriangular{T,S<:AbstractMatrix{T}} = Union{UnitUpperTriangular{T,S}, UnitLowerTriangular{T,S}} +const UpperOrLowerTriangularStrided{T,S<:StridedMatrix{T}} = UpperOrLowerTriangular{T,S} + uppertriangular(M) = UpperTriangular(M) lowertriangular(M) = LowerTriangular(M) @@ -1116,11 +1118,20 @@ for (TA, TB) in ((:AbstractTriangular, :AbstractMatrix), if isone(alpha) && iszero(beta) return _trimul!(C, A, B) else - return generic_matmatmul!(C, 'N', 'N', A, B, alpha, beta) + return generic_matmatmul_NN!(C, A, B, alpha, beta) end end end +generic_matmatmul_NN!(C, A, B, alpha, beta) = generic_matmatmul!(C, 'N', 'N', A, B, alpha, beta) +# Optimization for strided matrices, where we know that _generic_matmatmul! will be taken +for (TA, TB) in ((:UpperOrLowerTriangularStrided, :StridedMatrix), + (:StridedMatrix, :UpperOrLowerTriangularStrided), + (:UpperOrLowerTriangularStrided, :UpperOrLowerTriangularStrided) + ) + @eval generic_matmatmul_NN!(C, A::$TA, B::$TB, alpha, beta) = _generic_matmatmul!(C, A, B, alpha, beta) +end + ldiv!(C::AbstractVecOrMat, A::AbstractTriangular, B::AbstractVecOrMat) = _ldiv!(C, A, B) # generic fallback for AbstractTriangular, directs to 2-arg [l/r]div! _ldiv!(C::AbstractVecOrMat, A::AbstractTriangular, B::AbstractVecOrMat) =