Skip to content

Clean up herk_wrapper! and add 5-arg tests #1254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -725,8 +725,9 @@ syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, _add

# the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it
# to be concretely inferred
Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}},
α::Number, β::Number) where {T<:BlasReal}
Base.@constprop :aggressive function herk_wrapper!(C::StridedMatrix{TC}, tA::AbstractChar, A::StridedVecOrMat{TC},
α::Number, β::Number) where {TC<:BlasComplex}
T = real(TC)
nC = checksquare(C)
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
if tA_uc == 'C'
Expand All @@ -740,13 +741,10 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St
throw(DimensionMismatch(lazy"output matrix has size: $(size(C)), but should have size $((mA, mA))"))
end

# Result array does not need to be initialized as long as beta==0
# C = Matrix{T}(undef, mA, mA)

# BLAS.herk! only updates hermitian C, alpha and beta need to be real
if iszero(β) || ishermitian(C)
alpha, beta = promote(α, β, zero(T))
if (alpha isa Union{Bool,T} &&
beta isa Union{Bool,T} &&
if (alpha isa T && beta isa T &&
stride(A, 1) == stride(C, 1) == 1 &&
_fullstride2(A) && _fullstride2(C))
return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true)
Expand Down
12 changes: 11 additions & 1 deletion test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,17 @@ end

A5x5, A6x5 = Matrix{Float64}.(undef, ((5, 5), (6, 5)))
@test_throws DimensionMismatch LinearAlgebra.syrk_wrapper!(A5x5, 'N', A6x5)
@test_throws DimensionMismatch LinearAlgebra.herk_wrapper!(A5x5, 'N', A6x5)
@test_throws DimensionMismatch LinearAlgebra.herk_wrapper!(complex(A5x5), 'N', complex(A6x5))
end

@testset "5-arg syrk! & herk!" begin
for T in (Float32, Float64, ComplexF32, ComplexF64), A in (randn(T, 5), randn(T, 5, 5))
B = A' * A
C = B isa Number ? [B;;] : Matrix(Hermitian(B))
@test mul!(copy(C), A', A, true, 2) ≈ 3C
D = Matrix(Hermitian(A * A'))
@test mul!(copy(D), A, A', true, 3) ≈ 4D
end
end

@testset "matmul for types w/o sizeof (issue #1282)" begin
Expand Down