Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify and better test recursivecopy! #308

Merged
merged 1 commit into from
Dec 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "RecursiveArrayTools"
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
version = "3.2.1"
version = "3.2.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
58 changes: 33 additions & 25 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -42,39 +42,47 @@ like `copy!` on arrays of scalars.
"""
function recursivecopy! end

for type in [AbstractArray, AbstractVectorOfArray]
@eval function recursivecopy!(b::$type{T, N},
a::$type{T2, N}) where {T <: StaticArraysCore.StaticArray,
T2 <: StaticArraysCore.StaticArray,
N}
@inbounds for i in eachindex(a)
# TODO: Check for `setindex!`` and use `copy!(b[i],a[i])` or `b[i] = a[i]`, see #19
b[i] = copy(a[i])
end
function recursivecopy!(b::AbstractArray{T, N}, a::AbstractArray{T2, N}) where {T <: StaticArraysCore.StaticArray,
T2 <: StaticArraysCore.StaticArray,
N}
@inbounds for i in eachindex(a)
# TODO: Check for `setindex!`` and use `copy!(b[i],a[i])` or `b[i] = a[i]`, see #19
b[i] = copy(a[i])
end
end

@eval function recursivecopy!(b::$type{T, N},
a::$type{T2, N}) where {T <: Enum, T2 <: Enum, N}
copyto!(b, a)
end
function recursivecopy!(b::AbstractArray{T, N},
a::AbstractArray{T2, N}) where {T <: Enum, T2 <: Enum, N}
copyto!(b, a)
end

function recursivecopy!(b::AbstractArray{T, N},
a::AbstractArray{T2, N}) where {T <: Number, T2 <: Number, N}
copyto!(b, a)
end

@eval function recursivecopy!(b::$type{T, N},
a::$type{T2, N}) where {T <: Number, T2 <: Number, N}
function recursivecopy!(b::AbstractArray{T, N},
a::AbstractArray{T2, N}) where {T <: Union{AbstractArray, AbstractVectorOfArray},
T2 <: Union{AbstractArray, AbstractVectorOfArray}, N}
if ArrayInterface.ismutable(T)
@inbounds for i in eachindex(b, a)
recursivecopy!(b[i], a[i])
end
else
copyto!(b, a)
end
return b
end

@eval function recursivecopy!(b::$type{T, N},
a::$type{T2, N}) where {T <: Union{AbstractArray, AbstractVectorOfArray},
T2 <: Union{AbstractArray, AbstractVectorOfArray}, N}
if ArrayInterface.ismutable(T)
@inbounds for i in eachindex(b, a)
recursivecopy!(b[i], a[i])
end
else
copyto!(b, a)
function recursivecopy!(b::AbstractVectorOfArray, a::AbstractVectorOfArray)
if ArrayInterface.ismutable(eltype(b.u))
@inbounds for i in eachindex(b.u, a.u)
recursivecopy!(b.u[i], a.u[i])
end
return b
else
copyto!(b.u, a.u)
end
return b
end

"""
8 changes: 8 additions & 0 deletions test/utils_test.jl
Original file line number Diff line number Diff line change
@@ -122,4 +122,12 @@ end
@test u1.u[2] == [2.0,2.0]
@test u1.u[1] isa MVector
@test u1.u[2] isa MVector

u1 = VectorOfArray([fill(2, SVector{2, Float64}), ones(SVector{2, Float64})])
u2 = VectorOfArray([fill(4, SVector{2, Float64}), 2 .* ones(SVector{2, Float64})])
recursivecopy!(u1,u2)
@test u1.u[1] == [4.0,4.0]
@test u1.u[2] == [2.0,2.0]
@test u1.u[1] isa SVector
@test u1.u[2] isa SVector
end