Skip to content

Commit 95da1c6

Browse files
committed
move current staticarray support to Ext
Update StructArraysStaticArraysExt.jl
1 parent 569c70e commit 95da1c6

File tree

4 files changed

+98
-5
lines changed

4 files changed

+98
-5
lines changed

Diff for: Project.toml

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
name = "StructArrays"
22
uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
3-
version = "0.6.14"
3+
version = "0.6.15"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
88
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
9-
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
9+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1010
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1111

1212
[compat]
1313
Adapt = "1, 2, 3"
1414
DataAPI = "1"
1515
GPUArraysCore = "0.1.2"
16+
Requires = "1"
1617
StaticArrays = "1.5.6"
17-
StaticArraysCore = "1.3"
1818
Tables = "1"
1919
julia = "1.6"
2020

21+
[extensions]
22+
StructArraysStaticArraysExt = "StaticArrays"
23+
2124
[extras]
2225
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
2326
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
@@ -31,3 +34,6 @@ WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"
3134

3235
[targets]
3336
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays"]
37+
38+
[weakdeps]
39+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

Diff for: ext/StructArraysStaticArraysExt.jl

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
module StructArraysStaticArraysExt
2+
3+
macro loadext(ex)
4+
if !isdefined(Base, :get_extension)
5+
if ex.args[1].head == :(:)
6+
pushfirst!(ex.args[1].args[1].args, :., :.)
7+
else
8+
for i in eachindex(ex.args)
9+
pushfirst!(ex.args[i].args, :., :.)
10+
end
11+
end
12+
end
13+
return esc(:($ex))
14+
end
15+
16+
@loadext using StructArrays
17+
@loadext using StaticArrays: StaticArray, FieldArray, tuple_prod
18+
19+
"""
20+
StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
21+
22+
The `staticschema` of a `StaticArray` element type is the `staticschema` of the underlying `Tuple`.
23+
```julia
24+
julia> StructArrays.staticschema(SVector{2, Float64})
25+
Tuple{Float64, Float64}
26+
```
27+
The one exception to this rule is `<:StaticArrays.FieldArray`, since `FieldArray` is based on a
28+
struct. In this case, `staticschema(<:FieldArray)` returns the `staticschema` for the struct
29+
which subtypes `FieldArray`.
30+
"""
31+
@generated function StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
32+
return quote
33+
Base.@_inline_meta
34+
return NTuple{$(tuple_prod(S)), T}
35+
end
36+
end
37+
StructArrays.createinstance(::Type{T}, args...) where {T<:StaticArray} = T(args)
38+
StructArrays.component(s::StaticArray, i) = getindex(s, i)
39+
40+
# invoke general fallbacks for a `FieldArray` type.
41+
@inline function StructArrays.staticschema(T::Type{<:FieldArray})
42+
invoke(StructArrays.staticschema, Tuple{Type{<:Any}}, T)
43+
end
44+
StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i)
45+
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(StructArrays.createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)
46+
47+
# Broadcast overload
48+
@loadext using StaticArrays: StaticArrayStyle, similar_type
49+
@loadext using StructArrays: isnonemptystructtype
50+
using Base.Broadcast: Broadcasted
51+
52+
# StaticArrayStyle has no similar defined.
53+
# Overload `try_struct_copy` instead.
54+
@inline function StructArrays.try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M}
55+
sa = copy(bc)
56+
ET = eltype(sa)
57+
isnonemptystructtype(ET) || return sa
58+
elements = Tuple(sa)
59+
@static if VERSION >= v"1.7"
60+
arrs = ntuple(Val(fieldcount(ET))) do i
61+
similar_type(sa, fieldtype(ET, i))(_getfields(elements, i))
62+
end
63+
else
64+
_fieldtype(::Type{T}) where {T} = i -> fieldtype(T, i)
65+
__fieldtype = _fieldtype(ET)
66+
arrs = ntuple(Val(fieldcount(ET))) do i
67+
similar_type(sa, __fieldtype(i))(_getfields(elements, i))
68+
end
69+
end
70+
return StructArray{ET}(arrs)
71+
end
72+
73+
@inline function _getfields(x::Tuple, i::Int)
74+
if @generated
75+
return Expr(:tuple, (:(getfield(x[$j], i)) for j in 1:fieldcount(x))...)
76+
else
77+
return map(Base.Fix2(getfield, i), x)
78+
end
79+
end
80+
81+
end

Diff for: src/StructArrays.jl

+7-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ include("collect.jl")
1313
include("sort.jl")
1414
include("lazy.jl")
1515
include("tables.jl")
16-
include("staticarrays_support.jl")
1716

1817
# Implement refarray and refvalue to deal with pooled arrays and weakrefstrings effectively
1918
import DataAPI: refarray, refvalue
@@ -40,4 +39,11 @@ function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
4039
end
4140
always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true
4241

42+
import Requires
43+
@static if !isdefined(Base, :get_extension)
44+
function __init__()
45+
Requires.@require StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" begin include("../ext/StructArraysStaticArraysExt.jl") end
46+
end
47+
end
48+
4349
end # module

Diff for: src/structarray.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ See also [`always_struct_broadcast`](@ref).
551551
"""
552552
try_struct_copy(bc::Broadcasted) = copy(bc)
553553

554-
function Base.copy(bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
554+
@inline function Base.copy(bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
555555
if always_struct_broadcast(S())
556556
return invoke(copy, Tuple{Broadcasted}, bc)
557557
else

0 commit comments

Comments
 (0)