From fbc7d72c9e5caafc14bd8dada4a291b4e2fa5121 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Mon, 20 May 2024 15:25:44 +0100 Subject: [PATCH] muladd! with adjoints/transposes (#233) --- Project.toml | 2 +- src/ArrayLayouts.jl | 4 ++++ src/muladd.jl | 2 +- test/test_muladd.jl | 8 ++++++++ 4 files changed, 14 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 46e8881..6b8be30 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ArrayLayouts" uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" authors = ["Sheehan Olver "] -version = "1.9.2" +version = "1.9.3" [deps] FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" diff --git a/src/ArrayLayouts.jl b/src/ArrayLayouts.jl index e22ccba..d0d323d 100644 --- a/src/ArrayLayouts.jl +++ b/src/ArrayLayouts.jl @@ -285,6 +285,10 @@ Base.permutedims(D::Diagonal{<:Any,<:LayoutVector}) = D zero!(A) = zero!(MemoryLayout(A), A) zero!(_, A) = fill!(A,zero(eltype(A))) +function zero!(::DualLayout, A) + zero!(parent(A)) + A +end function zero!(_, A::AbstractArray{<:AbstractArray}) for a in A zero!(a) diff --git a/src/muladd.jl b/src/muladd.jl index bcb4ee4..5f36c78 100644 --- a/src/muladd.jl +++ b/src/muladd.jl @@ -77,7 +77,7 @@ materialize(M::MulAdd) = copy(instantiate(M)) copy(M::MulAdd) = copyto!(similar(M), M) _fill_copyto!(dest, C) = copyto!(dest, C) -_fill_copyto!(dest, C::Zeros) = zero!(dest) # exploit special fill! overload +_fill_copyto!(dest, C::Union{Zeros,AdjOrTrans{<:Any,<:Zeros}}) = zero!(dest) # exploit special fill! overload @inline copyto!(dest::AbstractArray{T}, M::MulAdd) where T = muladd!(M.α, unalias(dest,M.A), unalias(dest,M.B), M.β, _fill_copyto!(dest, M.C); Czero = M.Czero) diff --git a/test/test_muladd.jl b/test/test_muladd.jl index 49ce7eb..24a6eaa 100644 --- a/test/test_muladd.jl +++ b/test/test_muladd.jl @@ -844,6 +844,14 @@ Random.seed!(0) @test copy(M) ≈ b * D * α + c * β end end + + @testset "dual" begin + a = randn(5) + X = randn(5,6) + @test copyto!(similar(a,6)', MulAdd(2.0, a', X, 3.0, Zeros(6)')) ≈ 2a'*X + @test copyto!(transpose(similar(a,6)), MulAdd(2.0, a', X, 3.0, Zeros(6)')) ≈ 2a'*X + @test copyto!(transpose(similar(a,6)), MulAdd(2.0, transpose(a), X, 3.0, transpose(Zeros(6)))) ≈ 2a'*X + end end end