Skip to content

Commit 88b7aea

Browse files
Merge pull request #307 from AayushSabharwal/as/utils
fix: add utils methods for AbstractVectorOfArray, fix method overwriting on 1.10
2 parents 6c4aa29 + 5c0742b commit 88b7aea

File tree

5 files changed

+74
-50
lines changed

5 files changed

+74
-50
lines changed

src/array_partition.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ ArrayInterface.zeromatrix(A::ArrayPartition) = ArrayInterface.zeromatrix(Vector(
423423

424424
function __get_subtypes_in_module(mod, supertype; include_supertype = true, all=false, except=[])
425425
return filter([getproperty(mod, name) for name in names(mod; all) if !in(name, except)]) do value
426-
return value isa Type && (value <: supertype) && (include_supertype || value != supertype) && !in(value, except)
426+
return value != Union{} && value isa Type && (value <: supertype) && (include_supertype || value != supertype) && !in(value, except)
427427
end
428428
end
429429

src/utils.jl

Lines changed: 59 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
```julia
3-
recursivecopy(b::AbstractArray{T, N}, a::AbstractArray{T, N})
3+
recursivecopy(a::Union{AbstractArray{T, N}, AbstractVectorOfArray{T,N}})
44
```
55
66
A recursive `copy` function. Acts like a `deepcopy` on arrays of arrays, but
@@ -26,6 +26,12 @@ function recursivecopy(a::AbstractArray{T, N}) where {T <: AbstractArray, N}
2626
end
2727
end
2828

29+
function recursivecopy(a::AbstractVectorOfArray)
30+
b = copy(a)
31+
b.u = recursivecopy.(a.u)
32+
return b
33+
end
34+
2935
"""
3036
```julia
3137
recursivecopy!(b::AbstractArray{T, N}, a::AbstractArray{T, N})
@@ -36,37 +42,39 @@ like `copy!` on arrays of scalars.
3642
"""
3743
function recursivecopy! end
3844

39-
function recursivecopy!(b::AbstractArray{T, N},
40-
a::AbstractArray{T2, N}) where {T <: StaticArraysCore.StaticArray,
41-
T2 <: StaticArraysCore.StaticArray,
42-
N}
43-
@inbounds for i in eachindex(a)
44-
# TODO: Check for `setindex!`` and use `copy!(b[i],a[i])` or `b[i] = a[i]`, see #19
45-
b[i] = copy(a[i])
45+
for type in [AbstractArray, AbstractVectorOfArray]
46+
@eval function recursivecopy!(b::$type{T, N},
47+
a::$type{T2, N}) where {T <: StaticArraysCore.StaticArray,
48+
T2 <: StaticArraysCore.StaticArray,
49+
N}
50+
@inbounds for i in eachindex(a)
51+
# TODO: Check for `setindex!`` and use `copy!(b[i],a[i])` or `b[i] = a[i]`, see #19
52+
b[i] = copy(a[i])
53+
end
4654
end
47-
end
4855

49-
function recursivecopy!(b::AbstractArray{T, N},
50-
a::AbstractArray{T2, N}) where {T <: Enum, T2 <: Enum, N}
51-
copyto!(b, a)
52-
end
56+
@eval function recursivecopy!(b::$type{T, N},
57+
a::$type{T2, N}) where {T <: Enum, T2 <: Enum, N}
58+
copyto!(b, a)
59+
end
5360

54-
function recursivecopy!(b::AbstractArray{T, N},
55-
a::AbstractArray{T2, N}) where {T <: Number, T2 <: Number, N}
56-
copyto!(b, a)
57-
end
61+
@eval function recursivecopy!(b::$type{T, N},
62+
a::$type{T2, N}) where {T <: Number, T2 <: Number, N}
63+
copyto!(b, a)
64+
end
5865

59-
function recursivecopy!(b::AbstractArray{T, N},
60-
a::AbstractArray{T2, N}) where {T <: AbstractArray,
61-
T2 <: AbstractArray, N}
62-
if ArrayInterface.ismutable(T)
63-
@inbounds for i in eachindex(b, a)
64-
recursivecopy!(b[i], a[i])
66+
@eval function recursivecopy!(b::$type{T, N},
67+
a::$type{T2, N}) where {T <: Union{AbstractArray, AbstractVectorOfArray},
68+
T2 <: Union{AbstractArray, AbstractVectorOfArray}, N}
69+
if ArrayInterface.ismutable(T)
70+
@inbounds for i in eachindex(b, a)
71+
recursivecopy!(b[i], a[i])
72+
end
73+
else
74+
copyto!(b, a)
6575
end
66-
else
67-
copyto!(b, a)
76+
return b
6877
end
69-
return b
7078
end
7179

7280
"""
@@ -110,32 +118,36 @@ function recursivefill!(bs::AbstractVectorOfArray{T, N},
110118
end
111119
end
112120

113-
function recursivefill!(b::AbstractArray{T, N}, a::T2) where {T <: Enum, T2 <: Enum, N}
114-
fill!(b, a)
115-
end
121+
for type in [AbstractArray, AbstractVectorOfArray]
122+
@eval function recursivefill!(b::$type{T, N}, a::T2) where {T <: Enum, T2 <: Enum, N}
123+
fill!(b, a)
124+
end
116125

117-
function recursivefill!(b::AbstractArray{T, N},
118-
a::T2) where {T <: Union{Number, Bool}, T2 <: Union{Number, Bool}, N
119-
}
120-
fill!(b, a)
121-
end
126+
@eval function recursivefill!(b::$type{T, N},
127+
a::T2) where {T <: Union{Number, Bool}, T2 <: Union{Number, Bool}, N
128+
}
129+
fill!(b, a)
130+
end
122131

123-
function recursivefill!(b::AbstractArray{T, N}, a) where {T <: StaticArraysCore.MArray, N}
124-
@inbounds for i in eachindex(b)
125-
if isassigned(b, i)
126-
recursivefill!(b[i], a)
127-
else
128-
b[i] = zero(eltype(b))
129-
recursivefill!(b[i], a)
132+
for type2 in [Any, StaticArraysCore.StaticArray]
133+
@eval function recursivefill!(b::$type{T, N}, a::$type2) where {T <: StaticArraysCore.MArray, N}
134+
@inbounds for i in eachindex(b)
135+
if isassigned(b, i)
136+
recursivefill!(b[i], a)
137+
else
138+
b[i] = zero(eltype(b))
139+
recursivefill!(b[i], a)
140+
end
141+
end
130142
end
131143
end
132-
end
133-
134-
function recursivefill!(b::AbstractArray{T, N}, a) where {T <: AbstractArray, N}
135-
@inbounds for i in eachindex(b)
136-
recursivefill!(b[i], a)
144+
145+
@eval function recursivefill!(b::$type{T, N}, a) where {T <: AbstractArray, N}
146+
@inbounds for i in eachindex(b)
147+
recursivefill!(b[i], a)
148+
end
149+
return b
137150
end
138-
return b
139151
end
140152

141153
# Deprecated

src/vector_of_array.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,9 @@ end
488488
function Base.checkbounds(VA::AbstractVectorOfArray, idx...)
489489
checkbounds(Bool, VA, idx...) || throw(BoundsError(VA, idx))
490490
end
491+
function Base.copyto!(dest::AbstractVectorOfArray{T,N}, src::AbstractVectorOfArray{T,N}) where {T,N}
492+
copyto!.(dest.u, src.u)
493+
end
491494

492495
# Operations
493496
function Base.isapprox(A::AbstractVectorOfArray,
@@ -544,7 +547,6 @@ end
544547
@inline function Base.similar(VA::VectorOfArray, ::Type{T} = eltype(VA)) where {T}
545548
VectorOfArray([similar(VA[:, i], T) for i in eachindex(VA.u)])
546549
end
547-
recursivecopy(VA::VectorOfArray) = VectorOfArray(copy.(VA.u))
548550

549551
# fill!
550552
# For DiffEqArray it ignores ts and fills only u

test/qa.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using RecursiveArrayTools, Aqua
33
Aqua.find_persistent_tasks_deps(RecursiveArrayTools)
44
ambs = Aqua.detect_ambiguities(RecursiveArrayTools; recursive = true)
55
@warn "Number of method ambiguities: $(length(ambs))"
6-
@test length(ambs) <= 2
6+
@test length(ambs) <= 1
77
Aqua.test_deps_compat(RecursiveArrayTools)
88
Aqua.test_piracies(RecursiveArrayTools)
99
Aqua.test_project_extras(RecursiveArrayTools)

test/utils_test.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,13 @@ recursivefill!(x, true)
113113
recursivefill!(y_voa, ones(Vec3))
114114
@test all(y_voa[:, n] == fill(ones(Vec3), n) for n in 1:4)
115115
end
116+
117+
@testset "VectorOfArray recursivecopy!" begin
118+
u1 = VectorOfArray([fill(2, MVector{2, Float64}), ones(MVector{2, Float64})])
119+
u2 = VectorOfArray([fill(4, MVector{2, Float64}), 2 .* ones(MVector{2, Float64})])
120+
recursivecopy!(u1,u2)
121+
@test u1.u[1] == [4.0,4.0]
122+
@test u1.u[2] == [2.0,2.0]
123+
@test u1.u[1] isa MVector
124+
@test u1.u[2] isa MVector
125+
end

0 commit comments

Comments
 (0)