Skip to content

Commit b131762

Browse files
committed
Fully omit extra allocation in staticstructbroadcast.
Now we get elements via `StaticArrays.__broadcast`
1 parent 95da1c6 commit b131762

File tree

2 files changed

+31
-15
lines changed

2 files changed

+31
-15
lines changed

ext/StructArraysStaticArraysExt.jl

+29-15
Original file line numberDiff line numberDiff line change
@@ -45,29 +45,43 @@ StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{
4545
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(StructArrays.createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)
4646

4747
# Broadcast overload
48-
@loadext using StaticArrays: StaticArrayStyle, similar_type
48+
@loadext using StaticArrays: StaticArrayStyle, similar_type, Size, SOneTo
49+
@loadext using StaticArrays: broadcast_flatten, broadcast_sizes, first_statictype, __broadcast
4950
@loadext using StructArrays: isnonemptystructtype
5051
using Base.Broadcast: Broadcasted
5152

5253
# StaticArrayStyle has no similar defined.
5354
# Overload `try_struct_copy` instead.
5455
@inline function StructArrays.try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M}
55-
sa = copy(bc)
56-
ET = eltype(sa)
57-
isnonemptystructtype(ET) || return sa
58-
elements = Tuple(sa)
59-
@static if VERSION >= v"1.7"
60-
arrs = ntuple(Val(fieldcount(ET))) do i
61-
similar_type(sa, fieldtype(ET, i))(_getfields(elements, i))
62-
end
56+
flat = broadcast_flatten(bc); as = flat.args; f = flat.f
57+
argsizes = broadcast_sizes(as...)
58+
ax = axes(bc)
59+
ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug.")
60+
return _broadcast(f, Size(map(length, ax)), argsizes, as...)
61+
end
62+
63+
@inline function _broadcast(f, sz::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where {newsize}
64+
first_staticarray = first_statictype(a...)
65+
elements, ET = if prod(newsize) == 0
66+
# Use inference to get eltype in empty case (see also comments in _map)
67+
eltys = Tuple{map(eltype, a)...}
68+
(), Core.Compiler.return_type(f, eltys)
6369
else
64-
_fieldtype(::Type{T}) where {T} = i -> fieldtype(T, i)
65-
__fieldtype = _fieldtype(ET)
66-
arrs = ntuple(Val(fieldcount(ET))) do i
67-
similar_type(sa, __fieldtype(i))(_getfields(elements, i))
68-
end
70+
temp = __broadcast(f, sz, s, a...)
71+
temp, eltype(temp)
72+
end
73+
if isnonemptystructtype(ET)
74+
@static if VERSION >= v"1.7"
75+
arrs = ntuple(Val(fieldcount(ET))) do i
76+
@inbounds similar_type(first_staticarray, fieldtype(ET, i), sz)(_getfields(elements, i))
77+
end
78+
else
79+
similarET(::Type{SA}, ::Type{T}) where {SA, T} = i -> @inbounds similar_type(SA, fieldtype(T, i), sz)(_getfields(elements, i))
80+
arrs = ntuple(similarET(first_staticarray, ET), Val(fieldcount(ET)))
81+
end
82+
return StructArray{ET}(arrs)
6983
end
70-
return StructArray{ET}(arrs)
84+
@inbounds return similar_type(first_staticarray, ET, sz)(elements)
7185
end
7286

7387
@inline function _getfields(x::Tuple, i::Int)

test/runtests.jl

+2
Original file line numberDiff line numberDiff line change
@@ -1297,8 +1297,10 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
12971297

12981298
@testset "allocation test" begin
12991299
a = StructArray{ComplexF64}(undef, 1)
1300+
sa = StructArray{ComplexF64}((SizedVector{1}(a.re), SizedVector{1}(a.re)))
13001301
allocated(a) = @allocated a .+ 1
13011302
@test allocated(a) == 2allocated(a.re)
1303+
@test allocated(sa) == 2allocated(sa.re)
13021304
allocated2(a) = @allocated a .= complex.(a.im, a.re)
13031305
@test allocated2(a) == 0
13041306
end

0 commit comments

Comments
 (0)