Skip to content

Commit e14c761

Browse files
fix: add FastBroadcastExt for Vector of SArrays
1 parent 71e987f commit e14c761

File tree

5 files changed

+38
-15
lines changed

5 files changed

+38
-15
lines changed

Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,21 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1313
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1414
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1515
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
16+
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
1617
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1718
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1819
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
1920
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
2021

2122
[weakdeps]
23+
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
2224
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
2325
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
2426
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2527
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2628

2729
[extensions]
30+
RecursiveArrayToolsFastBroadcastExt = "FastBroadcast"
2831
RecursiveArrayToolsMeasurementsExt = "Measurements"
2932
RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements"
3033
RecursiveArrayToolsTrackerExt = "Tracker"
@@ -35,6 +38,7 @@ Adapt = "3, 4"
3538
Aqua = "0.8"
3639
ArrayInterface = "7"
3740
DocStringExtensions = "0.8, 0.9"
41+
FastBroadcast = "0.2.8"
3842
ForwardDiff = "0.10"
3943
GPUArraysCore = "0.1"
4044
IteratorInterfaceExtensions = "1"
@@ -50,6 +54,7 @@ RecipesBase = "0.7, 0.8, 1.0"
5054
Requires = "1.0"
5155
SafeTestsets = "0.1"
5256
SparseArrays = "1"
57+
Static = "0.3, 0.4, 0.5, 0.6, 0.7, 0.8"
5358
StaticArrays = "1.6"
5459
StaticArraysCore = "1.1"
5560
Statistics = "1"
@@ -64,6 +69,7 @@ julia = "1.9"
6469

6570
[extras]
6671
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
72+
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
6773
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
6874
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
6975
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
@@ -81,4 +87,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
8187
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
8288

8389
[targets]
84-
test = ["SafeTestsets", "Aqua", "ForwardDiff", "LabelledArrays", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote", "Measurements"]
90+
test = ["SafeTestsets", "Aqua", "FastBroadcast", "ForwardDiff", "LabelledArrays", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote", "Measurements"]
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
module RecursiveArrayToolsFastBroadcastExt
2+
3+
using RecursiveArrayTools
4+
using FastBroadcast
5+
using Static
6+
using StaticArraysCore
7+
8+
const AbstractVectorOfSArray = AbstractVectorOfArray{T,N,<:AbstractVector{<:StaticArraysCore.SArray}} where {T,N}
9+
10+
@inline function FastBroadcast.fast_materialize!(::False, ::DB, dst::AbstractVectorOfSArray, bc::Broadcast.Broadcasted{S}) where {S,DB}
11+
if FastBroadcast.use_fast_broadcast(S)
12+
for i in 1:length(dst.u)
13+
unpacked = RecursiveArrayTools.unpack_voa(bc, i)
14+
dst.u[i] = StaticArraysCore.similar_type(dst.u[i])(unpacked[j] for j in eachindex(unpacked))
15+
end
16+
else
17+
Broadcast.materialize!(dst, bc)
18+
end
19+
return dst
20+
end
21+
22+
end # module

src/RecursiveArrayTools.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using RecipesBase, StaticArraysCore, Statistics,
99
ArrayInterface, LinearAlgebra
1010
using SymbolicIndexingInterface
1111
using SparseArrays
12+
import Static # so it isn't a stale dep, used in FastBroadcastExt
1213

1314
import Adapt
1415

src/vector_of_array.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -687,20 +687,6 @@ for (type, N_expr) in [
687687
end
688688
end
689689

690-
# @inline function Base.copyto!(dest::AbstractVectorOfArray,
691-
# bc::Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle})
692-
# bc = Broadcast.flatten(bc)
693-
# @inbounds for i in 1:length(dest.u)
694-
# if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i])
695-
# copyto!(dest[:, i], unpack_voa(bc, i))
696-
# else
697-
# unpacked = unpack_voa(bc, i)
698-
# dest[:, i] = StaticArraysCore.similar_type(dest[:, i])(unpacked[j] for j in eachindex(unpacked))
699-
# end
700-
# end
701-
# dest
702-
# end
703-
704690
## broadcasting utils
705691

706692
"""

test/interface_tests.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using RecursiveArrayTools, StaticArrays, Test
2+
using FastBroadcast
23

34
t = 1:3
45
testva = VectorOfArray([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
@@ -158,3 +159,10 @@ function f2!(z)
158159
end
159160
f2!(z)
160161
@test (@allocated f2!(z)) == 0
162+
163+
function f3!(z, zz)
164+
@.. broadcast=false z = zz
165+
end
166+
f3!(z, zz)
167+
@test z == VectorOfArray([fill(4, SVector{2, Float64}), fill(2, SVector{2, Float64})])
168+
@test (@allocated f3!(z, zz)) == 0

0 commit comments

Comments
 (0)