diff --git a/.gitignore b/.gitignore index 3f02ca74..a56200b8 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ *.jl.*.cov *.jl.mem Manifest.toml + +.vscode diff --git a/Project.toml b/Project.toml index c672e08d..38d0e927 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,16 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +[weakdeps] +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" + +[extensions] +StructArraysGPUArraysCoreExt = "GPUArraysCore" +StructArraysStaticArraysCoreExt = "StaticArraysCore" +StructArraysTablesExt = "Tables" + [compat] Adapt = "1, 2, 3" DataAPI = "1" @@ -20,14 +30,17 @@ julia = "1.6" [extras] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" [targets] -test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays"] +test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays", "GPUArraysCore", "StaticArraysCore", "Tables"] diff --git a/ext/StructArraysGPUArraysCoreExt.jl b/ext/StructArraysGPUArraysCoreExt.jl new file mode 100644 index 00000000..b05d3082 --- /dev/null +++ b/ext/StructArraysGPUArraysCoreExt.jl @@ -0,0 +1,21 @@ +module StructArraysGPUArraysCoreExt + +using StructArrays +using StructArrays: map_params, array_types + +using Base: tail + +import GPUArraysCore + +# for GPU broadcast +import GPUArraysCore +function GPUArraysCore.backend(::Type{T}) where {T<:StructArray} + backends = map_params(GPUArraysCore.backend, array_types(T)) + backend, others = backends[1], tail(backends) + isconsistent = mapfoldl(isequal(backend), &, others; init=true) + isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend")) + return backend +end +StructArrays.always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true + +end # module diff --git a/src/staticarrays_support.jl b/ext/StructArraysStaticArraysCoreExt.jl similarity index 89% rename from src/staticarrays_support.jl rename to ext/StructArraysStaticArraysCoreExt.jl index 1af186e8..28e73762 100644 --- a/src/staticarrays_support.jl +++ b/ext/StructArraysStaticArraysCoreExt.jl @@ -1,3 +1,10 @@ +module StructArraysStaticArraysCoreExt + +using StructArrays +using StructArrays: StructArrayStyle, createinstance, replace_structarray, isnonemptystructtype + +using Base.Broadcast: Broadcasted + using StaticArraysCore: StaticArray, FieldArray, tuple_prod """ @@ -40,7 +47,7 @@ Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing) = axes(rep # StaticArrayStyle has no similar defined. # Overload `Base.copy` instead. -@inline function try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M} +@inline function StructArrays.try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M} sa = copy(bc) ET = eltype(sa) isnonemptystructtype(ET) || return sa @@ -66,3 +73,5 @@ end return map(Base.Fix2(getfield, i), x) end end + +end # module diff --git a/src/tables.jl b/ext/StructArraysTablesExt.jl similarity index 92% rename from src/tables.jl rename to ext/StructArraysTablesExt.jl index d6ac2248..14420879 100644 --- a/src/tables.jl +++ b/ext/StructArraysTablesExt.jl @@ -1,3 +1,8 @@ +module StructArraysTablesExt + +using StructArrays +using StructArrays: components, hasfields, foreachfield, staticschema + import Tables Tables.isrowtable(::Type{<:StructArray}) = true @@ -38,3 +43,5 @@ for (f, g) in zip((:append!, :prepend!), (:push!, :pushfirst!)) end end end + +end # module diff --git a/src/StructArrays.jl b/src/StructArrays.jl index 129dcd82..b91007d2 100644 --- a/src/StructArrays.jl +++ b/src/StructArrays.jl @@ -12,8 +12,6 @@ include("utils.jl") include("collect.jl") include("sort.jl") include("lazy.jl") -include("tables.jl") -include("staticarrays_support.jl") # Implement refarray and refvalue to deal with pooled arrays and weakrefstrings effectively import DataAPI: refarray, refvalue @@ -29,15 +27,10 @@ end import Adapt Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s) -# for GPU broadcast -import GPUArraysCore -function GPUArraysCore.backend(::Type{T}) where {T<:StructArray} - backends = map_params(GPUArraysCore.backend, array_types(T)) - backend, others = backends[1], tail(backends) - isconsistent = mapfoldl(isequal(backend), &, others; init=true) - isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend")) - return backend +@static if !isdefined(Base, :get_extension) + include("../ext/StructArraysGPUArraysCoreExt.jl") + include("../ext/StructArraysTablesExt.jl") + include("../ext/StructArraysStaticArraysCoreExt.jl") end -always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true end # module