diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index 519640b6..0ba79820 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -25,6 +25,12 @@ jobs: - {user: SciML, repo: OrdinaryDiffEq.jl, group: Core} - {user: SciML, repo: OrdinaryDiffEq.jl, group: Interface} - {user: SciML, repo: DelayDiffEq.jl, group: Interface} + - {user: SciML, repo: SciMLSensitivity.jl, group: Core1} + - {user: SciML, repo: SciMLSensitivity.jl, group: Core2} + - {user: SciML, repo: SciMLSensitivity.jl, group: Core3} + - {user: SciML, repo: SciMLSensitivity.jl, group: Core4} + - {user: SciML, repo: SciMLSensitivity.jl, group: Core5} + - {user: SciML, repo: SciMLSensitivity.jl, group: Core6} steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 diff --git a/Project.toml b/Project.toml index 6a33022e..62d8ed85 100644 --- a/Project.toml +++ b/Project.toml @@ -23,11 +23,13 @@ Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" [extensions] RecursiveArrayToolsFastBroadcastExt = "FastBroadcast" RecursiveArrayToolsMeasurementsExt = "Measurements" RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements" +RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"] RecursiveArrayToolsTrackerExt = "Tracker" RecursiveArrayToolsZygoteExt = "Zygote" @@ -49,6 +51,7 @@ OrdinaryDiffEq = "6.62" Pkg = "1" Random = "1" RecipesBase = "1.1" +ReverseDiff = "1.15" SafeTestsets = "0.1" SparseArrays = "1.10" StaticArrays = "1.6" diff --git a/ext/RecursiveArrayToolsReverseDiffExt.jl b/ext/RecursiveArrayToolsReverseDiffExt.jl new file mode 100644 index 00000000..115949a1 --- /dev/null +++ b/ext/RecursiveArrayToolsReverseDiffExt.jl @@ -0,0 +1,25 @@ +module RecursiveArrayToolsReverseDiffExt + +using RecursiveArrayTools +using ReverseDiff +using Zygote: @adjoint + +function trackedarraycopyto!(dest, src) + for (i, slice) in zip(eachindex(dest.u), eachslice(src, dims=ndims(src))) + if dest.u[i] isa AbstractArray + dest.u[i] = reshape(reduce(vcat, slice), size(dest.u[i])) + else + trackedarraycopyto!(dest.u[i], slice) + end + end +end + +@adjoint function Array(VA::AbstractVectorOfArray{<:ReverseDiff.TrackedReal}) + function Array_adjoint(y) + VA = recursivecopy(VA) + trackedarraycopyto!(VA, y) + return (VA,) + end + return Array(VA), Array_adjoint +end +end # module diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index c4611137..0b75593f 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -110,7 +110,7 @@ end @adjoint function Base.Array(VA::AbstractVectorOfArray) adj = let VA=VA function Array_adjoint(y) - VA = copy(VA) + VA = recursivecopy(VA) copyto!(VA, y) return (VA,) end @@ -118,15 +118,21 @@ end Array(VA), adj end +@adjoint function Base.view(A::AbstractVectorOfArray, I::Colon...) + function adjoint(y) + (recursivecopy(parent(y)), map(_ -> nothing, I)...) + end + return view(A, I...), adjoint +end + @adjoint function Base.view(A::AbstractVectorOfArray, I...) - adj = let A = A, I = I - function view_adjoint(y) - A = zero(A) - view(A, I...) .= y - return (A, map(_ -> nothing, I)...) - end + function view_adjoint(y) + A = recursivecopy(parent(y)) + recursivefill!(A, zero(eltype(A))) + A[I...] .= y + return (A, map(_ -> nothing, I)...) end - view(A, I...), adj + view(A, I...), view_adjoint end ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a))) diff --git a/src/utils.jl b/src/utils.jl index 658e8418..4af362c9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -28,7 +28,7 @@ end function recursivecopy(a::AbstractVectorOfArray) b = copy(a) - b.u = recursivecopy.(a.u) + b.u .= recursivecopy.(a.u) return b end diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index d348cac4..b4f6f2e7 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -585,16 +585,26 @@ end function Base.checkbounds(VA::AbstractVectorOfArray, idx...) checkbounds(Bool, VA, idx...) || throw(BoundsError(VA, idx)) end -function Base.copyto!(dest::AbstractVectorOfArray{T,N}, src::AbstractVectorOfArray{T,N}) where {T,N} - copyto!.(dest.u, src.u) +function Base.copyto!(dest::AbstractVectorOfArray{T,N}, src::AbstractVectorOfArray{T2,N}) where {T, T2, N} + for (i, j) in zip(eachindex(dest.u), eachindex(src.u)) + if ArrayInterface.ismutable(dest.u[i]) || dest.u[i] isa AbstractVectorOfArray + copyto!(dest.u[i], src.u[j]) + else + dest.u[i] = StaticArraysCore.similar_type(dest.u[i])(src.u[j]) + end + end end -function Base.copyto!(dest::AbstractVectorOfArray{T, N}, src::AbstractArray{T, N}) where {T, N} - for (i, slice) in enumerate(eachslice(src, dims = ndims(src))) - copyto!(dest.u[i], slice) +function Base.copyto!(dest::AbstractVectorOfArray{T, N}, src::AbstractArray{T2, N}) where {T, T2, N} + for (i, slice) in zip(eachindex(dest.u), eachslice(src, dims = ndims(src))) + if ArrayInterface.ismutable(dest.u[i]) || dest.u[i] isa AbstractVectorOfArray + copyto!(dest.u[i], slice) + else + dest.u[i] = StaticArraysCore.similar_type(dest.u[i])(slice) + end end dest end -function Base.copyto!(dest::AbstractVectorOfArray{T, N, <:AbstractVector{T}}, src::AbstractVector{T}) where {T, N} +function Base.copyto!(dest::AbstractVectorOfArray{T, N, <:AbstractVector{T}}, src::AbstractVector{T2}) where {T, T2, N} copyto!(dest.u, src) dest end