diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 40a48233..4f8a368a 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -406,13 +406,20 @@ Base.@propagate_inbounds function _getindex( if all(x -> is_parameter(A, x), sym) error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.") else - return [getindex.((A,), sym, i) for i in eachindex(A.t)] + return A[sym, eachindex(A.t)] end end Base.@propagate_inbounds function _getindex( A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple, AbstractArray}, args...) - return reduce(vcat, map(s -> A[s, args...]', sym)) + u = A.u[args...] + t = A.t[args...] + observed_fn = observed(A, sym) + if t isa AbstractArray + return observed_fn.(u, (parameter_values(A),), t) + else + return observed_fn(u, parameter_values(A), t) + end end Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index e384f714..bcb8dced 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -35,7 +35,7 @@ sol_new = DiffEqArray(sol.u[1:10], @test_throws Exception sol_new[τ] gs, = Zygote.gradient(sol) do sol - sum(sol[fol_separate.x]) + sum(sol[fol_separate.x]) end @test "Symbolic Indexing ADjoint" all(all.(isone, gs.u))