Skip to content

Commit

Permalink
Simplify VectorOfArray indexing
Browse files Browse the repository at this point in the history
Relying on stack removes the need to `adapt`, which should make the GPUs more efficient. With `stack`, those extra dispatches were unnecessary. They were rarely hit and one had a bug too! So they can just be removed. Seems to fix downstream.
  • Loading branch information
ChrisRackauckas committed Oct 30, 2024
1 parent b7de81e commit c207bc0
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 26 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "RecursiveArrayTools"
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
authors = ["Chris Rackauckas <[email protected]>"]
version = "3.27.1"
version = "3.27.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
25 changes: 0 additions & 25 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,32 +321,7 @@ end
@deprecate Base.getindex(A::AbstractDiffEqArray, i::Int) Base.getindex(A, :, i) false

__parameterless_type(T) = Base.typename(T).wrapper
Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray{T, N},
::NotSymbolic, I::Colon...) where {T, N}
@assert length(I) == ndims(A.u[1]) + 1
vecs = if N == 1
A.u
else
vec.(A.u)
end
return Adapt.adapt(__parameterless_type(T),
reshape(reduce(hcat, vecs), size(A.u[1])..., length(A.u)))
end
Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray{T, N},
::NotSymbolic, I::Colon...) where {T <: Number, N}
@assert length(I) == ndims(A.u)
return A.u[I...]
end

Base.@propagate_inbounds function _getindex(A::AbstractVectorOfArray{T, N},
::NotSymbolic, I::AbstractArray{Bool},
J::Colon...) where {T, N}
@assert length(J) == ndims(A.u[1]) + 1 - ndims(I)
@assert size(I) == size(A)[1:(ndims(A) - length(J))]
return A[ntuple(x -> Colon(), ndims(A))...][I, J...]
end

# Need two of each methods to avoid ambiguities
Base.@propagate_inbounds function _getindex(
A::AbstractVectorOfArray, ::NotSymbolic, ::Colon, I::Int)
A.u[I]
Expand Down

0 comments on commit c207bc0

Please sign in to comment.