diff --git a/Project.toml b/Project.toml index c6cf5fac..6a4a3dc1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RecursiveArrayTools" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" authors = ["Chris Rackauckas "] -version = "3.26.0" +version = "3.27.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -11,7 +11,6 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" @@ -22,6 +21,7 @@ FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -32,6 +32,7 @@ RecursiveArrayToolsForwardDiffExt = "ForwardDiff" RecursiveArrayToolsMeasurementsExt = "Measurements" RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements" RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"] +RecursiveArrayToolsSparseArraysExt = ["SparseArrays"] RecursiveArrayToolsTrackerExt = "Tracker" RecursiveArrayToolsZygoteExt = "Zygote" @@ -78,6 +79,7 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -86,4 +88,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["SafeTestsets", "Aqua", "FastBroadcast", "ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote", "Measurements"] +test = ["SafeTestsets", "Aqua", "FastBroadcast", "SparseArrays", "ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote", "Measurements"] diff --git a/ext/RecursiveArrayToolsSparseArraysExt.jl b/ext/RecursiveArrayToolsSparseArraysExt.jl new file mode 100644 index 00000000..f9261bac --- /dev/null +++ b/ext/RecursiveArrayToolsSparseArraysExt.jl @@ -0,0 +1,20 @@ +module RecursiveArrayToolsSparseArraysExt + +import SparseArrays +import RecursiveArrayTools + +function Base.copyto!(dest::SparseArrays.AbstractCompressedVector, A::RecursiveArrayTools.ArrayPartition) + @assert length(dest) == length(A) + cur = 1 + @inbounds for i in 1:length(A.x) + if A.x[i] isa Number + dest[cur:(cur + length(A.x[i]) - 1)] .= A.x[i] + else + dest[cur:(cur + length(A.x[i]) - 1)] .= vec(A.x[i]) + end + cur += length(A.x[i]) + end + dest +end + +end \ No newline at end of file diff --git a/src/RecursiveArrayTools.jl b/src/RecursiveArrayTools.jl index 71cc01f9..9e251ac0 100644 --- a/src/RecursiveArrayTools.jl +++ b/src/RecursiveArrayTools.jl @@ -8,7 +8,6 @@ using DocStringExtensions using RecipesBase, StaticArraysCore, Statistics, ArrayInterface, LinearAlgebra using SymbolicIndexingInterface -using SparseArrays import Adapt diff --git a/src/array_partition.jl b/src/array_partition.jl index 5b3a1c09..7bd8bb52 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -176,7 +176,7 @@ Base.all(f, A::ArrayPartition) = all(f, (all(f, x) for x in A.x)) Base.all(f::Function, A::ArrayPartition) = all((all(f, x) for x in A.x)) Base.all(A::ArrayPartition) = all(identity, A) -for type in [AbstractArray, SparseArrays.AbstractCompressedVector, PermutedDimsArray] +for type in [AbstractArray, PermutedDimsArray] @eval function Base.copyto!(dest::$(type), A::ArrayPartition) @assert length(dest) == length(A) cur = 1