Skip to content

add generic syrk/herk #1249

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ julia> rmul!([NaN], 0.0)
```
"""
function rmul!(X::AbstractArray, s::Number)
isone(s) && return X
@simd for I in eachindex(X)
@inbounds X[I] *= s
end
Expand Down Expand Up @@ -318,6 +319,7 @@ julia> lmul!(0.0, [Inf])
```
"""
function lmul!(s::Number, X::AbstractArray)
isone(s) && return X
@simd for I in eachindex(X)
@inbounds X[I] = s*X[I]
end
Expand Down
113 changes: 107 additions & 6 deletions src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -488,14 +488,13 @@

# THE one big BLAS dispatch. This is split into two methods to improve latency
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:BlasFloat}
α::Number, β::Number, val::BlasFlag.SyrkHerkGemm) where {T<:Number}
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α)
matmul_size_check(size(C), (mA, nA), (mB, nB))
return _rmul_or_fill!(C, β)
end
matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C
_syrk_herk_gemm_wrapper!(C, tA, tB, A, B, α, β, val)
return C
end
Expand Down Expand Up @@ -570,6 +569,70 @@
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
end

"""
generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bool, aat::Bool, α, β) where {T<:Number}

Computes syrk/herk for generic number types. If `conjugate` is false computes syrk, i.e.,
``A transpose(A) α + C β`` if `aat` is true, and ``transpose(A) A α + C β`` otherwise.
If `conjugate` is true computes herk, i.e., ``A A' α + C β`` if `aat` is true, and
``A' A α + C β`` otherwise. Only the upper triangular is computed.
"""
function generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bool, aat::Bool, α, β) where {T<:Number}
require_one_based_indexing(C, A)
nC = checksquare(C)
m, n = size(A, 1), size(A, 2)
mA = aat ? m : n
if nC != mA
throw(DimensionMismatch(lazy"output matrix has size: $(size(C)), but should have size $((mA, mA))"))

Check warning on line 586 in src/matmul.jl

View check run for this annotation

Codecov / codecov/patch

src/matmul.jl#L586

Added line #L586 was not covered by tests
end

_rmul_or_fill!(C, β)
@inbounds if !conjugate
if aat
for k ∈ 1:n, j ∈ 1:m
αA_jk = A[j, k] * α
for i ∈ 1:j
C[i, j] += A[i, k] * αA_jk
end
end
else
for j ∈ 1:n, i ∈ 1:j
temp = A[1, i] * A[1, j]
for k ∈ 2:m
temp += A[k, i] * A[k, j]
end
C[i, j] += temp * α
end
end
else
if aat
for k ∈ 1:n, j ∈ 1:m
αA_jk_bar = conj(A[j, k]) * α
for i ∈ 1:j-1
C[i, j] += A[i, k] * αA_jk_bar
end
C[j, j] += abs2(A[j, k]) * α
end
else
for j ∈ 1:n
for i ∈ 1:j-1
temp = conj(A[1, i]) * A[1, j]
for k ∈ 2:m
temp += conj(A[k, i]) * A[k, j]
end
C[i, j] += temp * α
end
temp = abs2(A[1, j])
for k ∈ 2:m
temp += abs2(A[k, j])
end
C[j, j] += temp * α
end
end
end
return C
end

# legacy method
Base.@constprop :aggressive generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
_add::MulAddMul = MulAddMul()) where {T<:BlasFloat} =
Expand Down Expand Up @@ -713,12 +776,27 @@
if (alpha isa Union{Bool,T} &&
beta isa Union{Bool,T} &&
stride(A, 1) == stride(C, 1) == 1 &&
_fullstride2(A) && _fullstride2(C))
return copytri!(BLAS.syrk!('U', tA, alpha, A, beta, C), 'U')
_fullstride2(A) && _fullstride2(C)) &&
max(nA, mA) ≥ 4
BLAS.syrk!('U', tA, alpha, A, beta, C)
else
generic_syrk!(C, A, false, tA_uc == 'N', alpha, beta)
end
return copytri!(C, 'U')
end
return gemm_wrapper!(C, tA, tAt, A, A, α, β)
end
Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T},
α::Number, β::Number) where {T<:Number}

tA_uc = uppercase(tA) # potentially strip a WrapperChar
aat = (tA_uc == 'N')
if T <: Union{Real,Complex} && (iszero(β) || issymmetric(C))
return copytri!(generic_syrk!(C, A, false, aat, α, β), 'U')
end
tAt = aat ? 'T' : 'N'
return _generic_matmatmul!(C, wrap(A, tA), wrap(A, tAt), α, β)
end
# legacy method
syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} =
syrk_wrapper!(C, tA, A, _add.alpha, _add.beta)
Expand Down Expand Up @@ -748,12 +826,27 @@
if (alpha isa Union{Bool,T} &&
beta isa Union{Bool,T} &&
stride(A, 1) == stride(C, 1) == 1 &&
_fullstride2(A) && _fullstride2(C))
return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true)
_fullstride2(A) && _fullstride2(C)) &&
max(nA, mA) ≥ 4
BLAS.herk!('U', tA, alpha, A, beta, C)
else
generic_syrk!(C, A, true, tA_uc == 'N', alpha, beta)
end
return copytri!(C, 'U', true)
end
return gemm_wrapper!(C, tA, tAt, A, A, α, β)
end
Base.@constprop :aggressive function herk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T},
α::Number, β::Number) where {T<:Number}

tA_uc = uppercase(tA) # potentially strip a WrapperChar
aat = (tA_uc == 'N')
if isreal(α) && isreal(β) && (iszero(β) || ishermitian(C))
return copytri!(generic_syrk!(C, A, true, aat, α, β), 'U', true)
end
tAt = aat ? 'C' : 'N'
return _generic_matmatmul!(C, wrap(A, tA), wrap(A, tAt), α, β)
end
# legacy method
herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}},
_add::MulAddMul = MulAddMul()) where {T<:BlasReal} =
Expand Down Expand Up @@ -787,6 +880,7 @@
mB, nB = lapack_size(tB, B)

matmul_size_check(size(C), (mA, nA), (mB, nB))
matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C

if C === A || B === C
throw(ArgumentError("output matrix must not be aliased with input matrix"))
Expand All @@ -805,6 +899,13 @@
gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} =
gemm_wrapper!(C, tA, tB, A, B, _add.alpha, _add.beta)
# fallback for generic types
Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
α::Number, β::Number) where {T<:Number}
matmul2x2or3x3_nonzeroalpha!(C, tA, tB, A, B, α, β) && return C
return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), α, β)
end

# Aggressive constprop helps propagate the values of tA and tB into wrap, which
# makes the calls concretely inferred
Expand Down
36 changes: 36 additions & 0 deletions test/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,42 @@ end
@test_throws DimensionMismatch axpy!(α, x, Vector(1:3), y, Vector(1:5))
end

@testset "generic syrk & herk" begin
for T ∈ (BigFloat, Complex{BigFloat}, Quaternion{Float64})
α = randn(T)
a = randn(T, 3, 4)
csmall = similar(a, 3, 3)
csmall_fallback = similar(a, 3, 3)
cbig = similar(a, 4, 4)
cbig_fallback = similar(a, 4, 4)
mul!(csmall, a, a', real(α), false)
LinearAlgebra._generic_matmatmul!(csmall_fallback, a, a', real(α), false)
@test ishermitian(csmall)
@test csmall ≈ csmall_fallback
mul!(cbig, a', a, real(α), false)
LinearAlgebra._generic_matmatmul!(cbig_fallback, a', a, real(α), false)
@test ishermitian(cbig)
@test cbig ≈ cbig_fallback
mul!(csmall, a, transpose(a), α, false)
LinearAlgebra._generic_matmatmul!(csmall_fallback, a, transpose(a), α, false)
@test csmall ≈ csmall_fallback
mul!(cbig, transpose(a), a, α, false)
LinearAlgebra._generic_matmatmul!(cbig_fallback, transpose(a), a, α, false)
@test cbig ≈ cbig_fallback
if T <: Union{Real, Complex}
@test issymmetric(csmall)
@test issymmetric(cbig)
end
#make sure generic herk is not called for non-real α
mul!(csmall, a, a', α, false)
LinearAlgebra._generic_matmatmul!(csmall_fallback, a, a', α, false)
@test csmall ≈ csmall_fallback
mul!(cbig, a', a, α, false)
LinearAlgebra._generic_matmatmul!(cbig_fallback, a', a, α, false)
@test cbig ≈ cbig_fallback
end
end

@test !issymmetric(fill(1,5,3))
@test !ishermitian(fill(1,5,3))
@test (x = fill(1,3); cross(x,x) == zeros(3))
Expand Down