Skip to content

Commit efac89a

Browse files
fix: ArrayPartition arithmetic type-stability
1 parent ef55922 commit efac89a

File tree

1 file changed

+7
-15
lines changed

1 file changed

+7
-15
lines changed

src/array_partition.jl

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,11 @@ for op in (:*, :/)
153153
end
154154

155155
function Base.:*(A::Number, B::ArrayPartition)
156-
ArrayPartition(map(y -> Base.broadcast(*, A, y), B.x))
156+
ArrayPartition(map(y -> A .* y, B.x))
157157
end
158158

159159
function Base.:\(A::Number, B::ArrayPartition)
160-
ArrayPartition(map(y -> Base.broadcast(/, y, A), B.x))
160+
B / A
161161
end
162162

163163
Base.:(==)(A::ArrayPartition, B::ArrayPartition) = A.x == B.x
@@ -284,7 +284,7 @@ recursive_eltype(A::ArrayPartition) = recursive_eltype(first(A.x))
284284
Base.iterate(A::ArrayPartition) = iterate(Chain(A.x))
285285
Base.iterate(A::ArrayPartition, state) = iterate(Chain(A.x), state)
286286

287-
Base.length(A::ArrayPartition) = sum((length(x) for x in A.x))
287+
Base.length(A::ArrayPartition) = sum(broadcast(length, A.x))
288288
Base.size(A::ArrayPartition) = (length(A),)
289289

290290
# redefine first and last to avoid slow and not type-stable indexing
@@ -323,21 +323,13 @@ function Broadcast.BroadcastStyle(::ArrayPartitionStyle,
323323
Broadcast.DefaultArrayStyle{N}()
324324
end
325325

326-
combine_styles(args::Tuple{}) = Broadcast.DefaultArrayStyle{0}()
327-
@inline function combine_styles(args::Tuple{Any})
328-
Broadcast.result_style(Broadcast.BroadcastStyle(args[1]))
329-
end
330-
@inline function combine_styles(args::Tuple{Any, Any})
331-
Broadcast.result_style(Broadcast.BroadcastStyle(args[1]),
332-
Broadcast.BroadcastStyle(args[2]))
333-
end
334-
@inline function combine_styles(args::Tuple)
335-
Broadcast.result_style(Broadcast.BroadcastStyle(args[1]),
336-
combine_styles(Base.tail(args)))
326+
@generated function combine_styles(t)
327+
@show t
328+
return :($(reduce(Broadcast.result_style, Broadcast.BroadcastStyle.(t.parameters[1].parameters))))
337329
end
338330

339331
function Broadcast.BroadcastStyle(::Type{ArrayPartition{T, S}}) where {T, S}
340-
Style = combine_styles((S.parameters...,))
332+
Style = combine_styles(S)
341333
ArrayPartitionStyle(Style)
342334
end
343335

0 commit comments

Comments
 (0)