From c1249b38ce16b0c61b9ad55858f126e22761fdfc Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 28 Dec 2023 19:12:52 +0530 Subject: [PATCH] fix: GPU tests, CuArray conversion, autodiff --- ext/RecursiveArrayToolsZygoteExt.jl | 8 ++++---- src/RecursiveArrayTools.jl | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index 5f93bf96..168b7e71 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -95,8 +95,8 @@ end VectorOfArray(u), y -> begin y isa Ref && (y = VectorOfArray(y[].u)) - (VectorOfArray([y[ntuple(x -> Colon(), ndims(y.u) - 1)..., i] - for i in 1:size(y.u)[end]]),) + (VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] + for i in 1:size(y)[end]]),) end end @@ -104,8 +104,8 @@ end DiffEqArray(u, t), y -> begin y isa Ref && (y = VectorOfArray(y[].u)) - (DiffEqArray([y[ntuple(x -> Colon(), ndims(y.u) - 1)..., i] - for i in 1:size(y.u)[end]], + (DiffEqArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] + for i in 1:size(y)[end]], t), nothing) end end diff --git a/src/RecursiveArrayTools.jl b/src/RecursiveArrayTools.jl index 9ee0947d..b3ee8fc8 100644 --- a/src/RecursiveArrayTools.jl +++ b/src/RecursiveArrayTools.jl @@ -28,6 +28,7 @@ end import GPUArraysCore Base.convert(T::Type{<:GPUArraysCore.AnyGPUArray}, VA::AbstractVectorOfArray) = stack(VA.u) +(T::Type{<:GPUArraysCore.AnyGPUArray})(VA::AbstractVectorOfArray) = T(Array(VA)) import Requires @static if !isdefined(Base, :get_extension)