Skip to content

Commit

Permalink
Showing 5 changed files with 38 additions and 15 deletions.
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -13,18 +13,21 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[weakdeps]
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
RecursiveArrayToolsFastBroadcastExt = "FastBroadcast"
RecursiveArrayToolsMeasurementsExt = "Measurements"
RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements"
RecursiveArrayToolsTrackerExt = "Tracker"
@@ -35,6 +38,7 @@ Adapt = "3, 4"
Aqua = "0.8"
ArrayInterface = "7"
DocStringExtensions = "0.8, 0.9"
FastBroadcast = "0.2.8"
ForwardDiff = "0.10"
GPUArraysCore = "0.1"
IteratorInterfaceExtensions = "1"
@@ -50,6 +54,7 @@ RecipesBase = "0.7, 0.8, 1.0"
Requires = "1.0"
SafeTestsets = "0.1"
SparseArrays = "1"
Static = "0.3, 0.4, 0.5, 0.6, 0.7, 0.8"
StaticArrays = "1.6"
StaticArraysCore = "1.1"
Statistics = "1"
@@ -64,6 +69,7 @@ julia = "1.9"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
@@ -81,4 +87,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["SafeTestsets", "Aqua", "ForwardDiff", "LabelledArrays", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote", "Measurements"]
test = ["SafeTestsets", "Aqua", "FastBroadcast", "ForwardDiff", "LabelledArrays", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote", "Measurements"]
22 changes: 22 additions & 0 deletions ext/RecursiveArrayToolsFastBroadcastExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
module RecursiveArrayToolsFastBroadcastExt

using RecursiveArrayTools
using FastBroadcast
using Static
using StaticArraysCore

const AbstractVectorOfSArray = AbstractVectorOfArray{T,N,<:AbstractVector{<:StaticArraysCore.SArray}} where {T,N}

@inline function FastBroadcast.fast_materialize!(::False, ::DB, dst::AbstractVectorOfSArray, bc::Broadcast.Broadcasted{S}) where {S,DB}
if FastBroadcast.use_fast_broadcast(S)
for i in 1:length(dst.u)
unpacked = RecursiveArrayTools.unpack_voa(bc, i)
dst.u[i] = StaticArraysCore.similar_type(dst.u[i])(unpacked[j] for j in eachindex(unpacked))
end
else
Broadcast.materialize!(dst, bc)
end
return dst
end

end # module
1 change: 1 addition & 0 deletions src/RecursiveArrayTools.jl
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@ using RecipesBase, StaticArraysCore, Statistics,
ArrayInterface, LinearAlgebra
using SymbolicIndexingInterface
using SparseArrays
import Static # so it isn't a stale dep, used in FastBroadcastExt

import Adapt

14 changes: 0 additions & 14 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
@@ -687,20 +687,6 @@ for (type, N_expr) in [
end
end

# @inline function Base.copyto!(dest::AbstractVectorOfArray,
# bc::Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle})
# bc = Broadcast.flatten(bc)
# @inbounds for i in 1:length(dest.u)
# if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i])
# copyto!(dest[:, i], unpack_voa(bc, i))
# else
# unpacked = unpack_voa(bc, i)
# dest[:, i] = StaticArraysCore.similar_type(dest[:, i])(unpacked[j] for j in eachindex(unpacked))
# end
# end
# dest
# end

## broadcasting utils

"""
8 changes: 8 additions & 0 deletions test/interface_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using RecursiveArrayTools, StaticArrays, Test
using FastBroadcast

t = 1:3
testva = VectorOfArray([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
@@ -158,3 +159,10 @@ function f2!(z)
end
f2!(z)
@test (@allocated f2!(z)) == 0

function f3!(z, zz)
@.. broadcast=false z = zz
end
f3!(z, zz)
@test z == VectorOfArray([fill(4, SVector{2, Float64}), fill(2, SVector{2, Float64})])
@test (@allocated f3!(z, zz)) == 0

0 comments on commit e14c761

Please sign in to comment.