|
| 1 | +module StructArraysStaticArraysExt |
| 2 | + |
| 3 | +macro loadext(ex) |
| 4 | + if !isdefined(Base, :get_extension) |
| 5 | + if ex.args[1].head == :(:) |
| 6 | + pushfirst!(ex.args[1].args[1].args, :., :.) |
| 7 | + else |
| 8 | + for i in eachindex(ex.args) |
| 9 | + pushfirst!(ex.args[i].args, :., :.) |
| 10 | + end |
| 11 | + end |
| 12 | + end |
| 13 | + return esc(:($ex)) |
| 14 | +end |
| 15 | + |
| 16 | +@loadext using StructArrays |
| 17 | +@loadext using StaticArrays: StaticArray, FieldArray, tuple_prod |
| 18 | + |
| 19 | +""" |
| 20 | + StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T} |
| 21 | +
|
| 22 | +The `staticschema` of a `StaticArray` element type is the `staticschema` of the underlying `Tuple`. |
| 23 | +```julia |
| 24 | +julia> StructArrays.staticschema(SVector{2, Float64}) |
| 25 | +Tuple{Float64, Float64} |
| 26 | +``` |
| 27 | +The one exception to this rule is `<:StaticArrays.FieldArray`, since `FieldArray` is based on a |
| 28 | +struct. In this case, `staticschema(<:FieldArray)` returns the `staticschema` for the struct |
| 29 | +which subtypes `FieldArray`. |
| 30 | +""" |
| 31 | +@generated function StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T} |
| 32 | + return quote |
| 33 | + Base.@_inline_meta |
| 34 | + return NTuple{$(tuple_prod(S)), T} |
| 35 | + end |
| 36 | +end |
| 37 | +StructArrays.createinstance(::Type{T}, args...) where {T<:StaticArray} = T(args) |
| 38 | +StructArrays.component(s::StaticArray, i) = getindex(s, i) |
| 39 | + |
| 40 | +# invoke general fallbacks for a `FieldArray` type. |
| 41 | +@inline function StructArrays.staticschema(T::Type{<:FieldArray}) |
| 42 | + invoke(StructArrays.staticschema, Tuple{Type{<:Any}}, T) |
| 43 | +end |
| 44 | +StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i) |
| 45 | +StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(StructArrays.createinstance, Tuple{Type{<:Any}, Vararg}, T, args...) |
| 46 | + |
| 47 | +# Broadcast overload |
| 48 | +@loadext using StaticArrays: StaticArrayStyle, similar_type |
| 49 | +@loadext using StructArrays: isnonemptystructtype |
| 50 | +using Base.Broadcast: Broadcasted |
| 51 | + |
| 52 | +# StaticArrayStyle has no similar defined. |
| 53 | +# Overload `try_struct_copy` instead. |
| 54 | +@inline function StructArrays.try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M} |
| 55 | + sa = copy(bc) |
| 56 | + ET = eltype(sa) |
| 57 | + isnonemptystructtype(ET) || return sa |
| 58 | + elements = Tuple(sa) |
| 59 | + @static if VERSION >= v"1.7" |
| 60 | + arrs = ntuple(Val(fieldcount(ET))) do i |
| 61 | + similar_type(sa, fieldtype(ET, i))(_getfields(elements, i)) |
| 62 | + end |
| 63 | + else |
| 64 | + _fieldtype(::Type{T}) where {T} = i -> fieldtype(T, i) |
| 65 | + __fieldtype = _fieldtype(ET) |
| 66 | + arrs = ntuple(Val(fieldcount(ET))) do i |
| 67 | + similar_type(sa, __fieldtype(i))(_getfields(elements, i)) |
| 68 | + end |
| 69 | + end |
| 70 | + return StructArray{ET}(arrs) |
| 71 | +end |
| 72 | + |
| 73 | +@inline function _getfields(x::Tuple, i::Int) |
| 74 | + if @generated |
| 75 | + return Expr(:tuple, (:(getfield(x[$j], i)) for j in 1:fieldcount(x))...) |
| 76 | + else |
| 77 | + return map(Base.Fix2(getfield, i), x) |
| 78 | + end |
| 79 | +end |
| 80 | + |
| 81 | +end |
0 commit comments