diff --git a/src/sparse/array.jl b/src/sparse/array.jl index cfe7874c1..72122c32f 100644 --- a/src/sparse/array.jl +++ b/src/sparse/array.jl @@ -406,22 +406,35 @@ ROCSparseMatrixCSR{T}(Mat::Adjoint{Tv, <:SparseMatrixCSC}) where {T, Tv} = ROCSparseMatrixCSR{T}( ROCVector{Cint}(parent(Mat).colptr), ROCVector{Cint}(parent(Mat).rowval), ROCVector{T}(conj.(parent(Mat).nzval)), size(Mat)) +ROCSparseMatrixCSC{T}(Mat::Union{Transpose{Tv, <:SparseMatrixCSC}, Adjoint{Tv, <:SparseMatrixCSC}}) where {T, Tv} = ROCSparseMatrixCSC(ROCSparseMatrixCSR{T}(Mat)) ROCSparseMatrixCSR{T}(Mat::SparseMatrixCSC) where {T} = ROCSparseMatrixCSR(ROCSparseMatrixCSC{T}(Mat)) ROCSparseMatrixBSR{T}(Mat::SparseMatrixCSC, blockdim) where {T} = ROCSparseMatrixBSR(ROCSparseMatrixCSR{T}(Mat), blockdim) ROCSparseMatrixCOO{T}(Mat::SparseMatrixCSC) where {T} = ROCSparseMatrixCOO(ROCSparseMatrixCSR{T}(Mat)) +ROCSparseMatrixCOO{T}(Mat::Transpose{Tv, <:SparseMatrixCSC}) where {T, Tv} = ROCSparseMatrixCOO{T}(ROCSparseMatrixCSR{T}(Mat)) +ROCSparseMatrixCOO{T}(Mat::Adjoint{Tv, <:SparseMatrixCSC}) where {T, Tv} = ROCSparseMatrixCOO{T}(ROCSparseMatrixCSR{T}(Mat)) + # untyped variants ROCSparseVector(x::AbstractSparseArray{T}) where {T} = ROCSparseVector{T}(x) ROCSparseMatrixCSC(x::AbstractSparseArray{T}) where {T} = ROCSparseMatrixCSC{T}(x) ROCSparseMatrixCSR(x::AbstractSparseArray{T}) where {T} = ROCSparseMatrixCSR{T}(x) ROCSparseMatrixBSR(x::AbstractSparseArray{T}, blockdim) where {T} = ROCSparseMatrixBSR{T}(x, blockdim) ROCSparseMatrixCOO(x::AbstractSparseArray{T}) where {T} = ROCSparseMatrixCOO{T}(x) + +# adjoint / transpose ROCSparseMatrixCSR(x::Transpose{T}) where {T} = ROCSparseMatrixCSR{T}(x) ROCSparseMatrixCSR(x::Adjoint{T}) where {T} = ROCSparseMatrixCSR{T}(x) ROCSparseMatrixCSC(x::Transpose{T}) where {T} = ROCSparseMatrixCSC{T}(x) ROCSparseMatrixCSC(x::Adjoint{T}) where {T} = ROCSparseMatrixCSC{T}(x) - -# TODO adjoint / transpose: GPUArrays._sptranspose +ROCSparseMatrixCOO(x::Transpose{T}) where {T} = ROCSparseMatrixCOO{T}(x) +ROCSparseMatrixCOO(x::Adjoint{T}) where {T} = ROCSparseMatrixCOO{T}(x) + +ROCSparseMatrixCSR(x::Transpose{T,<:Union{ROCSparseMatrixCSC, ROCSparseMatrixCSR, ROCSparseMatrixCOO}}) where {T} = ROCSparseMatrixCSR(_sptranspose(parent(x))) +ROCSparseMatrixCSC(x::Transpose{T,<:Union{ROCSparseMatrixCSC, ROCSparseMatrixCSR, ROCSparseMatrixCOO}}) where {T} = ROCSparseMatrixCSC(_sptranspose(parent(x))) +ROCSparseMatrixCOO(x::Transpose{T,<:Union{ROCSparseMatrixCSC, ROCSparseMatrixCSR, ROCSparseMatrixCOO}}) where {T} = ROCSparseMatrixCOO(_sptranspose(parent(x))) +ROCSparseMatrixCSR(x::Adjoint{T,<:Union{ROCSparseMatrixCSC, ROCSparseMatrixCSR, ROCSparseMatrixCOO}}) where {T} = ROCSparseMatrixCSR(_spadjoint(parent(x))) +ROCSparseMatrixCSC(x::Adjoint{T,<:Union{ROCSparseMatrixCSC, ROCSparseMatrixCSR, ROCSparseMatrixCOO}}) where {T} = ROCSparseMatrixCSC(_spadjoint(parent(x))) +ROCSparseMatrixCOO(x::Adjoint{T,<:Union{ROCSparseMatrixCSC, ROCSparseMatrixCSR, ROCSparseMatrixCOO}}) where {T} = ROCSparseMatrixCOO(_spadjoint(parent(x))) # gpu to cpu SparseVector(x::ROCSparseVector) = SparseVector(length(x), Array(SparseArrays.nonzeroinds(x)), Array(SparseArrays.nonzeros(x))) diff --git a/src/sparse/interfaces.jl b/src/sparse/interfaces.jl index fe77ab835..6b4839ea9 100644 --- a/src/sparse/interfaces.jl +++ b/src/sparse/interfaces.jl @@ -1,5 +1,30 @@ # interfacing with other packages +function _spadjoint(A::ROCSparseMatrixCSR) + Aᴴ = ROCSparseMatrixCSC(A.rowPtr, A.colVal, conj(A.nzVal), reverse(size(A))) + ROCSparseMatrixCSR(Aᴴ) +end +function _sptranspose(A::ROCSparseMatrixCSR) + Aᵀ = ROCSparseMatrixCSC(A.rowPtr, A.colVal, A.nzVal, reverse(size(A))) + ROCSparseMatrixCSR(Aᵀ) +end +function _spadjoint(A::ROCSparseMatrixCSC) + Aᴴ = ROCSparseMatrixCSR(A.colPtr, A.rowVal, conj(A.nzVal), reverse(size(A))) + ROCSparseMatrixCSC(Aᴴ) +end +function _sptranspose(A::ROCSparseMatrixCSC) + Aᵀ = ROCSparseMatrixCSR(A.colPtr, A.rowVal, A.nzVal, reverse(size(A))) + ROCSparseMatrixCSC(Aᵀ) +end +function _spadjoint(A::ROCSparseMatrixCOO) + # we use sparse instead of ROCSparseMatrixCOO because we want to sort the matrix. + sparse(A.colInd, A.rowInd, conj(A.nzVal), reverse(size(A))..., fmt = :coo) +end +function _sptranspose(A::ROCSparseMatrixCOO) + # we use sparse instead of ROCSparseMatrixCOO because we want to sort the matrix. + sparse(A.colInd, A.rowInd, A.nzVal, reverse(size(A))..., fmt = :coo) +end + function mv_wrapper( transa::SparseChar, alpha::Number, A::ROCSparseMatrix, X::DenseROCVector{T}, beta::Number, Y::ROCVector{T}, @@ -28,6 +53,15 @@ LinearAlgebra.dot(x::DenseROCVector{T}, y::ROCSparseVector{T}) where {T <: BlasR LinearAlgebra.dot(x::ROCSparseVector{T}, y::DenseROCVector{T}) where {T <: BlasComplex} = vv!('C', x, y, 'O') LinearAlgebra.dot(x::DenseROCVector{T}, y::ROCSparseVector{T}) where {T <: BlasComplex} = conj(dot(y,x)) +adjtrans_wrappers = ((identity, identity), + (M -> :(Transpose{T, <:$M}), M -> :(_sptranspose(parent($M)))), + (M -> :(Adjoint{T, <:$M}), M -> :(_spadjoint(parent($M))))) + +op_wrappers = ((identity, T -> 'N', identity), + (T -> :(Transpose{T, <:$T}), T -> 'T', A -> :(parent($A))), + (T -> :(Adjoint{T, <:$T}), T -> T <: Real ? 'T' : 'C', A -> :(parent($A))), + (T -> :(HermOrSym{T, <:$T}), T -> 'N', A -> :(parent($A)))) + # legacy methods with final MulAddMul argument LinearAlgebra.generic_matvecmul!(C::ROCVector{T}, tA::AbstractChar, A::ROCSparseMatrix{T}, B::DenseROCVector{T}, _add::MulAddMul) where T <: BlasFloat = LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta) @@ -52,6 +86,17 @@ function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::ROCSparseM mm_wrapper(tA, tB, alpha, A, B, beta, C) end +for (wrapa, transa, unwrapa) in op_wrappers + TypeA = wrapa(:(ROCSparseMatrix{T})) + + @eval function LinearAlgebra.:(*)(A::$TypeA, x::ROCSparseVector{T}) where {T <: Union{Float16, ComplexF16, BlasFloat}} + m, n = size(A) + length(x) == n || throw(DimensionMismatch()) + y = ROCVector{T}(undef, m) + mul!(y, A, x, true, false) + end +end + # legacy methods with final MulAddMul argument LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMatrix{T}, B::ROCSparseMatrixCSC{T}, _add::MulAddMul) where T <: BlasFloat = LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) @@ -76,74 +121,166 @@ function LinearAlgebra.generic_matmatmul!(C::ROCMatrix{T}, tA, tB, A::DenseROCMa mm!(tA, tB, alpha, A, B, beta, C, 'O') end -Base.:(+)(A::ROCSparseMatrixCSR, B::ROCSparseMatrixCSR) = geam(one(eltype(A)), A, one(eltype(A)), B, 'O') -Base.:(-)(A::ROCSparseMatrixCSR, B::ROCSparseMatrixCSR) = geam(one(eltype(A)), A, -one(eltype(A)), B, 'O') +for (wrapa, transa, unwrapa) in op_wrappers + TypeA = wrapa(:(DenseROCMatrix{T})) -Base.:(+)(A::ROCSparseMatrixCSR, B::Adjoint{T,<:ROCSparseMatrixCSR}) where T = A + Transpose(conj(B.parent)) -Base.:(-)(A::ROCSparseMatrixCSR, B::Adjoint{T,<:ROCSparseMatrixCSR}) where T = A - Transpose(conj(B.parent)) -Base.:(+)(A::Adjoint{T,<:ROCSparseMatrixCSR}, B::ROCSparseMatrixCSR) where T = Transpose(conj(A.parent)) + B -Base.:(-)(A::Adjoint{T,<:ROCSparseMatrixCSR}, B::ROCSparseMatrixCSR) where T = Transpose(conj(A.parent)) - B -Base.:(+)(A::Adjoint{T,<:ROCSparseMatrixCSR}, B::Adjoint{T,<:ROCSparseMatrixCSR}) where T = - Transpose(conj(A.parent)) + B -Base.:(-)(A::Adjoint{T,<:ROCSparseMatrixCSR}, B::Adjoint{T,<:ROCSparseMatrixCSR}) where T = - Transpose(conj(A.parent)) - B + @eval function Base.:(*)(A::$TypeA, x::ROCSparseVector{T}) where {T <: BlasFloat} + m, n = size(A) + length(x) == n || throw(DimensionMismatch()) + y = ROCVector{T}(undef, m) + mul!(y, A, x, true, false) + end -function Base.:(+)(A::ROCSparseMatrixCSR, B::Transpose{T,<:ROCSparseMatrixCSR}) where T - cscB = ROCSparseMatrixCSC(B.parent) - transB = ROCSparseMatrixCSR(cscB.colPtr, cscB.rowVal, cscB.nzVal, size(cscB)) - return geam(one(T), A, one(T), transB, 'O') + for (wrapb, transb, unwrapb) in op_wrappers + for SparseMatrixType in (:(ROCSparseMatrixCSC{T}), :(ROCSparseMatrixCSR{T}), :(ROCSparseMatrixCOO{T})) + TypeB = wrapb(SparseMatrixType) + + @eval function Base.:(*)(A::$TypeA, B::$TypeB) where {T <: Union{Float16, ComplexF16, BlasFloat}} + m, n = size(A) + k, p = size(B) + n == k || throw(DimensionMismatch()) + C = ROCMatrix{T}(undef, m, p) + mul!(C, A, B, true, false) + end + end + end end -function Base.:(-)(A::ROCSparseMatrixCSR, B::Transpose{T,<:ROCSparseMatrixCSR}) where T - cscB = ROCSparseMatrixCSC(B.parent) - transB = ROCSparseMatrixCSR(cscB.colPtr, cscB.rowVal, cscB.nzVal, size(cscB)) - return geam(one(T), A, -one(T), transB, 'O') -end +# legacy methods with final MulAddMul argument +LinearAlgebra.generic_matmatmul!(C::ROCSparseMatrixCSC{T}, tA, tB, A::ROCSparseMatrixCSC{T}, B::ROCSparseMatrixCSC{T}, _add::MulAddMul) where {T <: BlasFloat} = + LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) +LinearAlgebra.generic_matmatmul!(C::ROCSparseMatrixCSR{T}, tA, tB, A::ROCSparseMatrixCSR{T}, B::ROCSparseMatrixCSR{T}, _add::MulAddMul) where {T <: BlasFloat} = + LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) +LinearAlgebra.generic_matmatmul!(C::ROCSparseMatrixCOO{T}, tA, tB, A::ROCSparseMatrixCOO{T}, B::ROCSparseMatrixCOO{T}, _add::MulAddMul) where {T <: BlasFloat} = + LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) -function Base.:(+)(A::Transpose{T,<:ROCSparseMatrixCSR}, B::ROCSparseMatrixCSR) where T - cscA = ROCSparseMatrixCSC(A.parent) - transA = ROCSparseMatrixCSR(cscA.colPtr, cscA.rowVal, cscA.nzVal, size(cscA)) - geam(one(T), transA, one(T), B, 'O') +function LinearAlgebra.generic_matmatmul!(C::ROCSparseMatrixCSC{T}, tA, tB, A::ROCSparseMatrixCSC{T}, B::ROCSparseMatrixCSC{T}, alpha::Number, beta::Number) where {T <: BlasFloat} + tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA + tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB + gemm!(tA, tB, alpha, A, B, beta, C, 'O') end - -function Base.:(-)(A::Transpose{T,<:ROCSparseMatrixCSR}, B::ROCSparseMatrixCSR) where T - cscA = ROCSparseMatrixCSC(A.parent) - transA = ROCSparseMatrixCSR(cscA.colPtr, cscA.rowVal, cscA.nzVal, size(cscA)) - geam(one(T), transA, -one(T), B, 'O') +function LinearAlgebra.generic_matmatmul!(C::ROCSparseMatrixCSR{T}, tA, tB, A::ROCSparseMatrixCSR{T}, B::ROCSparseMatrixCSR{T}, alpha::Number, beta::Number) where {T <: BlasFloat} + tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA + tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB + gemm!(tA, tB, alpha, A, B, beta, C, 'O') end - -function Base.:(+)(A::Transpose{T,<:ROCSparseMatrixCSR}, B::Transpose{T,<:ROCSparseMatrixCSR}) where T - C = geam(one(T), A.parent, one(T), B.parent, 'O') - cscC = ROCSparseMatrixCSC(C) - return ROCSparseMatrixCSR(cscC.colPtr, cscC.rowVal, cscC.nzVal, size(cscC)) +function LinearAlgebra.generic_matmatmul!(C::ROCSparseMatrixCOO{T}, tA, tB, A::ROCSparseMatrixCOO{T}, B::ROCSparseMatrixCOO{T}, alpha::Number, beta::Number) where {T <: BlasFloat} + tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA + tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB + A_csr = ROCSparseMatrixCSR(A) + B_csr = ROCSparseMatrixCSR(B) + C_csr = ROCSparseMatrixCSR(C) + LinearAlgebra.generic_matmatmul!(C_csr, tA, tB, A_csr, B_csr, alpha, beta) + copyto!(C, ROCSparseMatrixCOO(C_csr)) + return C end -function Base.:(-)(A::Transpose{T,<:ROCSparseMatrixCSR}, B::Transpose{T,<:ROCSparseMatrixCSR}) where T - C = geam(one(T), A.parent, -one(T), B.parent, 'O') - cscC = ROCSparseMatrixCSC(C) - return ROCSparseMatrixCSR(cscC.colPtr, cscC.rowVal, cscC.nzVal, size(cscC)) +for SparseMatrixType in (:ROCSparseMatrixCSC, :ROCSparseMatrixCSR) + @eval function LinearAlgebra.:(*)(A::$SparseMatrixType{T}, B::$SparseMatrixType{T}) where {T <: BlasFloat} + gemm('N', 'N', one(T), A, B, 'O') + end end -function Base.:(+)(A::ROCSparseMatrixCSR, B::ROCSparseMatrix) - csrB = ROCSparseMatrixCSR(B) - return geam(one(eltype(A)), A, one(eltype(A)), csrB, 'O') +function LinearAlgebra.:(*)(A::ROCSparseMatrixCOO{T}, B::ROCSparseMatrixCOO{T}) where {T <: BlasFloat} + A_csr = ROCSparseMatrixCSR(A) + B_csr = ROCSparseMatrixCSR(B) + ROCSparseMatrixCOO(A_csr * B_csr) end -function Base.:(-)(A::ROCSparseMatrixCSR, B::ROCSparseMatrix) - csrB = ROCSparseMatrixCSR(B) - return geam(one(eltype(A)), A, -one(eltype(A)), csrB, 'O') +for (wrapa, unwrapa) in adjtrans_wrappers, (wrapb, unwrapb) in adjtrans_wrappers + for SparseMatrixType in (:(ROCSparseMatrixCSC{T}), :(ROCSparseMatrixCSR{T}), :(ROCSparseMatrixCOO{T})) + TypeA = wrapa(SparseMatrixType) + TypeB = wrapb(SparseMatrixType) + wrapa == identity && wrapb == identity && continue + @eval Base.:(*)(A::$TypeA, B::$TypeB) where {T <: BlasFloat} = $(unwrapa(:A)) * $(unwrapb(:B)) + end end -function Base.:(+)(A::ROCSparseMatrix, B::ROCSparseMatrixCSR) - csrA = ROCSparseMatrixCSR(A) - return geam(one(eltype(A)), csrA, one(eltype(A)), B, 'O') -end +for op in (:(+), :(-)) + for (wrapa, unwrapa) in adjtrans_wrappers, (wrapb, unwrapb) in adjtrans_wrappers + for SparseMatrixType in (:(ROCSparseMatrixCSC{T}), :(ROCSparseMatrixCSR{T})) + TypeA = wrapa(SparseMatrixType) + TypeB = wrapb(SparseMatrixType) + @eval Base.$op(A::$TypeA, B::$TypeB) where {T <: BlasFloat} = geam(one(T), $(unwrapa(:A)), $(op)(one(T)), $(unwrapb(:B)), 'O') + end + end -function Base.:(-)(A::ROCSparseMatrix, B::ROCSparseMatrixCSR) - csrA = ROCSparseMatrixCSR(A) - return geam(one(eltype(A)), csrA, -one(eltype(A)), B, 'O') + @eval begin + Base.$op(A::ROCSparseVector{T}, B::ROCSparseVector{T}) where {T <: BlasFloat} = axpby(one(T), A, $(op)(one(T)), B, 'O') + Base.$op(A::Union{ROCSparseMatrixCOO{T}, Transpose{T,<:ROCSparseMatrixCOO}, Adjoint{T,<:ROCSparseMatrixCOO}}, + B::Union{ROCSparseMatrixCOO{T}, Transpose{T,<:ROCSparseMatrixCOO}, Adjoint{T,<:ROCSparseMatrixCOO}}) where {T <: BlasFloat} = + ROCSparseMatrixCOO($(op)(ROCSparseMatrixCSR(A), ROCSparseMatrixCSR(B))) + end end +# Base.:(+)(A::ROCSparseMatrixCSR, B::ROCSparseMatrixCSR) = geam(one(eltype(A)), A, one(eltype(A)), B, 'O') +# Base.:(-)(A::ROCSparseMatrixCSR, B::ROCSparseMatrixCSR) = geam(one(eltype(A)), A, -one(eltype(A)), B, 'O') + +# Base.:(+)(A::ROCSparseMatrixCSR, B::Adjoint{T,<:ROCSparseMatrixCSR}) where T = A + Transpose(conj(B.parent)) +# Base.:(-)(A::ROCSparseMatrixCSR, B::Adjoint{T,<:ROCSparseMatrixCSR}) where T = A - Transpose(conj(B.parent)) +# Base.:(+)(A::Adjoint{T,<:ROCSparseMatrixCSR}, B::ROCSparseMatrixCSR) where T = Transpose(conj(A.parent)) + B +# Base.:(-)(A::Adjoint{T,<:ROCSparseMatrixCSR}, B::ROCSparseMatrixCSR) where T = Transpose(conj(A.parent)) - B +# Base.:(+)(A::Adjoint{T,<:ROCSparseMatrixCSR}, B::Adjoint{T,<:ROCSparseMatrixCSR}) where T = +# Transpose(conj(A.parent)) + B +# Base.:(-)(A::Adjoint{T,<:ROCSparseMatrixCSR}, B::Adjoint{T,<:ROCSparseMatrixCSR}) where T = +# Transpose(conj(A.parent)) - B + +# function Base.:(+)(A::ROCSparseMatrixCSR, B::Transpose{T,<:ROCSparseMatrixCSR}) where T +# cscB = ROCSparseMatrixCSC(B.parent) +# transB = ROCSparseMatrixCSR(cscB.colPtr, cscB.rowVal, cscB.nzVal, size(cscB)) +# return geam(one(T), A, one(T), transB, 'O') +# end + +# function Base.:(-)(A::ROCSparseMatrixCSR, B::Transpose{T,<:ROCSparseMatrixCSR}) where T +# cscB = ROCSparseMatrixCSC(B.parent) +# transB = ROCSparseMatrixCSR(cscB.colPtr, cscB.rowVal, cscB.nzVal, size(cscB)) +# return geam(one(T), A, -one(T), transB, 'O') +# end + +# function Base.:(+)(A::Transpose{T,<:ROCSparseMatrixCSR}, B::ROCSparseMatrixCSR) where T +# cscA = ROCSparseMatrixCSC(A.parent) +# transA = ROCSparseMatrixCSR(cscA.colPtr, cscA.rowVal, cscA.nzVal, size(cscA)) +# geam(one(T), transA, one(T), B, 'O') +# end + +# function Base.:(-)(A::Transpose{T,<:ROCSparseMatrixCSR}, B::ROCSparseMatrixCSR) where T +# cscA = ROCSparseMatrixCSC(A.parent) +# transA = ROCSparseMatrixCSR(cscA.colPtr, cscA.rowVal, cscA.nzVal, size(cscA)) +# geam(one(T), transA, -one(T), B, 'O') +# end + +# function Base.:(+)(A::Transpose{T,<:ROCSparseMatrixCSR}, B::Transpose{T,<:ROCSparseMatrixCSR}) where T +# C = geam(one(T), A.parent, one(T), B.parent, 'O') +# cscC = ROCSparseMatrixCSC(C) +# return ROCSparseMatrixCSR(cscC.colPtr, cscC.rowVal, cscC.nzVal, size(cscC)) +# end + +# function Base.:(-)(A::Transpose{T,<:ROCSparseMatrixCSR}, B::Transpose{T,<:ROCSparseMatrixCSR}) where T +# C = geam(one(T), A.parent, -one(T), B.parent, 'O') +# cscC = ROCSparseMatrixCSC(C) +# return ROCSparseMatrixCSR(cscC.colPtr, cscC.rowVal, cscC.nzVal, size(cscC)) +# end + +# function Base.:(+)(A::ROCSparseMatrixCSR, B::ROCSparseMatrix) +# csrB = ROCSparseMatrixCSR(B) +# return geam(one(eltype(A)), A, one(eltype(A)), csrB, 'O') +# end + +# function Base.:(-)(A::ROCSparseMatrixCSR, B::ROCSparseMatrix) +# csrB = ROCSparseMatrixCSR(B) +# return geam(one(eltype(A)), A, -one(eltype(A)), csrB, 'O') +# end + +# function Base.:(+)(A::ROCSparseMatrix, B::ROCSparseMatrixCSR) +# csrA = ROCSparseMatrixCSR(A) +# return geam(one(eltype(A)), csrA, one(eltype(A)), B, 'O') +# end + +# function Base.:(-)(A::ROCSparseMatrix, B::ROCSparseMatrixCSR) +# csrA = ROCSparseMatrixCSR(A) +# return geam(one(eltype(A)), csrA, -one(eltype(A)), B, 'O') +# end + # triangular for SparseMatrixType in (:ROCSparseMatrixBSR,) @@ -223,24 +360,73 @@ function _sparse_identity( ROCSparseMatrixCSC{Tv,Ti}(colPtr, rowVal, nzVal, dims) end -# TODO COO +function _sparse_identity( + ::Type{<:ROCSparseMatrixCOO{Tv,Ti}}, I::UniformScaling, dims::Dims +) where {Tv,Ti} + len = min(dims[1], dims[2]) + rowInd = ROCVector{Ti}(1:len) + colInd = ROCVector{Ti}(1:len) + nzVal = AMDGPU.fill(I.λ, len) + ROCSparseMatrixCOO{Tv,Ti}(rowInd, colInd, nzVal, dims) +end -Base.:(+)(A::Union{ROCSparseMatrixCSR,ROCSparseMatrixCSC}, J::UniformScaling) = - A .+ _sparse_identity(typeof(A), J, size(A)) +# Base.:(+)(A::Union{ROCSparseMatrixCSR,ROCSparseMatrixCSC}, J::UniformScaling) = +# A .+ _sparse_identity(typeof(A), J, size(A)) -Base.:(-)(J::UniformScaling, A::Union{ROCSparseMatrixCSR,ROCSparseMatrixCSC}) = - _sparse_identity(typeof(A), J, size(A)) .- A +# Base.:(-)(J::UniformScaling, A::Union{ROCSparseMatrixCSR,ROCSparseMatrixCSC}) = +# _sparse_identity(typeof(A), J, size(A)) .- A -# TODO: let Broadcast handle this automatically (a la SparseArrays.PromoteToSparse) -for SparseMatrixType in [:ROCSparseMatrixCSC, :ROCSparseMatrixCSR], op in [:(+), :(-)] - @eval begin - function Base.$op(lhs::Diagonal{T,<:ROCArray}, rhs::$SparseMatrixType{T}) where T - return $op($SparseMatrixType(lhs), rhs) +for (wrapa, unwrapa) in adjtrans_wrappers + for SparseMatrixType in (:(ROCSparseMatrixCSC{T}), :(ROCSparseMatrixCSR{T}), :(ROCSparseMatrixCOO{T})) + TypeA = wrapa(SparseMatrixType) + @eval begin + Base.:(+)(A::$TypeA, J::UniformScaling) where {T} = $(unwrapa(:A)) + _sparse_identity(typeof(A), J, size(A)) + Base.:(+)(J::UniformScaling, A::$TypeA) where {T} = _sparse_identity(typeof(A), J, size(A)) + $(unwrapa(:A)) + + Base.:(-)(A::$TypeA, J::UniformScaling) where {T} = $(unwrapa(:A)) - _sparse_identity(typeof(A), J, size(A)) + Base.:(-)(J::UniformScaling, A::$TypeA) where {T} = _sparse_identity(typeof(A), J, size(A)) - $(unwrapa(:A)) + end + + # Broadcasting is not yet supported for COO matrices + if SparseMatrixType != :(ROCSparseMatrixCOO{T}) + @eval begin + Base.:(*)(A::$TypeA, J::UniformScaling) where {T} = $(unwrapa(:A)) * J.λ + Base.:(*)(J::UniformScaling, A::$TypeA) where {T} = J.λ * $(unwrapa(:A)) + end + else + @eval begin + Base.:(*)(A::$TypeA, J::UniformScaling) where {T} = $(unwrapa(:A)) * _sparse_identity(typeof(A), J, size(A)) + Base.:(*)(J::UniformScaling, A::$TypeA) where {T} = _sparse_identity(typeof(A), J, size(A)) * $(unwrapa(:A)) + end end - function Base.$op(lhs::$SparseMatrixType{T}, rhs::Diagonal{T,<:ROCArray}) where T - return $op(lhs, $SparseMatrixType(rhs)) + end +end + +# TODO: let Broadcast handle this automatically (a la SparseArrays.PromoteToSparse) +for (wrapa, unwrapa) in adjtrans_wrappers, op in (:(+), :(-), :(*)) + for SparseMatrixType in (:(ROCSparseMatrixCSC{T}), :(ROCSparseMatrixCSR{T}), :(ROCSparseMatrixCOO{T})) + TypeA = wrapa(SparseMatrixType) + @eval begin + function Base.$op(lhs::Diagonal, rhs::$TypeA) where {T} + return $op($SparseMatrixType(lhs), $(unwrapa(:rhs))) + end + function Base.$op(lhs::$TypeA, rhs::Diagonal) where {T} + return $op($(unwrapa(:lhs)), $SparseMatrixType(rhs)) + end end end end +# # TODO: let Broadcast handle this automatically (a la SparseArrays.PromoteToSparse) +# for SparseMatrixType in [:ROCSparseMatrixCSC, :ROCSparseMatrixCSR, :ROCSparseMatrixCOO], op in [:(+), :(-)] +# @eval begin +# function Base.$op(lhs::Diagonal{T,<:ROCArray}, rhs::$SparseMatrixType{T}) where T +# return $op($SparseMatrixType(lhs), rhs) +# end +# function Base.$op(lhs::$SparseMatrixType{T}, rhs::Diagonal{T,<:ROCArray}) where T +# return $op(lhs, $SparseMatrixType(rhs)) +# end +# end +# end + # TODO _sptranspose / _spadjoint