Skip to content

Commit dc573f0

Browse files
fix: rework broadcasting copyto!
1 parent f654568 commit dc573f0

File tree

1 file changed

+37
-28
lines changed

1 file changed

+37
-28
lines changed

src/vector_of_array.jl

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -658,40 +658,49 @@ Broadcast.broadcastable(x::AbstractVectorOfArray) = x
658658
end)
659659
end
660660

661-
@inline function Base.copyto!(dest::AbstractVectorOfArray,
662-
bc::Broadcast.Broadcasted{<:VectorOfArrayStyle})
663-
bc = Broadcast.flatten(bc)
664-
N = narrays(bc)
665-
@inbounds for i in 1:N
666-
if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i])
667-
copyto!(dest[:, i], unpack_voa(bc, i))
668-
else
669-
unpacked = unpack_voa(bc, i)
670-
dest[:, i] = unpacked.f(unpacked.args...)
671-
end
672-
end
673-
dest
674-
end
675-
676-
@inline function Base.copyto!(dest::AbstractVectorOfArray,
677-
bc::Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle})
678-
bc = Broadcast.flatten(bc)
679-
@inbounds for i in 1:length(dest.u)
680-
if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i])
681-
copyto!(dest[:, i], unpack_voa(bc, i))
682-
else
683-
unpacked = unpack_voa(bc, i)
684-
value = unpacked.f(unpacked.args...)
685-
dest[:, i] = if value isa Number && dest[:, i] isa AbstractArray
686-
fill(value, StaticArraysCore.similar_type(dest[:, i]))
661+
for (type, N_expr) in [
662+
(Broadcast.Broadcasted{<:VectorOfArrayStyle}, :(narrays(bc))),
663+
(Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle}, :(length(dest.u)))
664+
]
665+
@eval @inline function Base.copyto!(dest::AbstractVectorOfArray,
666+
bc::$type)
667+
bc = Broadcast.flatten(bc)
668+
N = $N_expr
669+
@inbounds for i in 1:N
670+
if dest[:, i] isa AbstractArray
671+
if ArrayInterface.ismutable(dest[:, i])
672+
copyto!(dest[:, i], unpack_voa(bc, i))
673+
else
674+
unpacked = unpack_voa(bc, i)
675+
arr_type = StaticArraysCore.similar_type(dest[:, i])
676+
dest[:, i] = if length(unpacked) == 1
677+
fill(copy(unpacked), arr_type)
678+
else
679+
arr_type(unpacked[j] for j in eachindex(unpacked))
680+
end
681+
end
687682
else
688-
value
683+
dest[:, i] = copy(unpack_voa(bc, i))
689684
end
690685
end
686+
dest
691687
end
692-
dest
693688
end
694689

690+
# @inline function Base.copyto!(dest::AbstractVectorOfArray,
691+
# bc::Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle})
692+
# bc = Broadcast.flatten(bc)
693+
# @inbounds for i in 1:length(dest.u)
694+
# if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i])
695+
# copyto!(dest[:, i], unpack_voa(bc, i))
696+
# else
697+
# unpacked = unpack_voa(bc, i)
698+
# dest[:, i] = StaticArraysCore.similar_type(dest[:, i])(unpacked[j] for j in eachindex(unpacked))
699+
# end
700+
# end
701+
# dest
702+
# end
703+
695704
## broadcasting utils
696705

697706
"""

0 commit comments

Comments
 (0)