From efac89a8692d65a252f87c8a89d900f33a168613 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 8 Jan 2024 18:30:12 +0530 Subject: [PATCH] fix: ArrayPartition arithmetic type-stability --- src/array_partition.jl | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/src/array_partition.jl b/src/array_partition.jl index 5be01442..7940fe83 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -153,11 +153,11 @@ for op in (:*, :/) end function Base.:*(A::Number, B::ArrayPartition) - ArrayPartition(map(y -> Base.broadcast(*, A, y), B.x)) + ArrayPartition(map(y -> A .* y, B.x)) end function Base.:\(A::Number, B::ArrayPartition) - ArrayPartition(map(y -> Base.broadcast(/, y, A), B.x)) + B / A end Base.:(==)(A::ArrayPartition, B::ArrayPartition) = A.x == B.x @@ -284,7 +284,7 @@ recursive_eltype(A::ArrayPartition) = recursive_eltype(first(A.x)) Base.iterate(A::ArrayPartition) = iterate(Chain(A.x)) Base.iterate(A::ArrayPartition, state) = iterate(Chain(A.x), state) -Base.length(A::ArrayPartition) = sum((length(x) for x in A.x)) +Base.length(A::ArrayPartition) = sum(broadcast(length, A.x)) Base.size(A::ArrayPartition) = (length(A),) # redefine first and last to avoid slow and not type-stable indexing @@ -323,21 +323,13 @@ function Broadcast.BroadcastStyle(::ArrayPartitionStyle, Broadcast.DefaultArrayStyle{N}() end -combine_styles(args::Tuple{}) = Broadcast.DefaultArrayStyle{0}() -@inline function combine_styles(args::Tuple{Any}) - Broadcast.result_style(Broadcast.BroadcastStyle(args[1])) -end -@inline function combine_styles(args::Tuple{Any, Any}) - Broadcast.result_style(Broadcast.BroadcastStyle(args[1]), - Broadcast.BroadcastStyle(args[2])) -end -@inline function combine_styles(args::Tuple) - Broadcast.result_style(Broadcast.BroadcastStyle(args[1]), - combine_styles(Base.tail(args))) +@generated function combine_styles(t) + @show t + return :($(reduce(Broadcast.result_style, Broadcast.BroadcastStyle.(t.parameters[1].parameters)))) end function Broadcast.BroadcastStyle(::Type{ArrayPartition{T, S}}) where {T, S} - Style = combine_styles((S.parameters...,)) + Style = combine_styles(S) ArrayPartitionStyle(Style) end