Skip to content

Commit bb5a2c8

Browse files
Merge pull request #317 from AayushSabharwal/as/all-symbols
feat: add support for new *_symbols methods in SII
2 parents 63421ad + 501387d commit bb5a2c8

File tree

5 files changed

+25
-11
lines changed

5 files changed

+25
-11
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ StaticArrays = "1.6"
5757
StaticArraysCore = "1.1"
5858
Statistics = "1"
5959
StructArrays = "0.6"
60-
SymbolicIndexingInterface = "0.3"
60+
SymbolicIndexingInterface = "0.3.1"
6161
Tables = "1"
6262
Test = "1"
6363
Tracker = "0.2"

src/vector_of_array.jl

+10-4
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,7 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymb
284284
return getindex.(A.u, variable_index.((A,), (sym,), eachindex(A.t)))
285285
end
286286
elseif is_parameter(A, sym)
287-
Base.depwarn("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.", :parameter_getindex)
288-
return getp(A, sym)(A)
287+
error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.")
289288
elseif is_observed(A, sym)
290289
return observed(A, sym).(A.u, (parameter_values(A),), A.t)
291290
else
@@ -325,8 +324,7 @@ end
325324

326325
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, sym::Union{Tuple,AbstractArray})
327326
if all(x -> is_parameter(A, x), sym)
328-
Base.depwarn("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.", :parameter_getindex)
329-
return getp(A, sym)(A)
327+
error("Indexing with parameters is deprecated. Use `getp(A, $sym)` for parameter indexing.")
330328
else
331329
return [getindex.((A,), sym, i) for i in eachindex(A.t)]
332330
end
@@ -336,6 +334,14 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymb
336334
return reduce(vcat, map(s -> A[s, args...]', sym))
337335
end
338336

337+
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, ::SymbolicIndexingInterface.SolvedVariables, args...)
338+
return getindex(A, variable_symbols(A), args...)
339+
end
340+
341+
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic, ::SymbolicIndexingInterface.AllVariables, args...)
342+
return getindex(A, all_variable_symbols(A), args...)
343+
end
344+
339345
Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg, args...)
340346
symtype = symbolic_type(_arg)
341347
elsymtype = symbolic_type(eltype(_arg))

test/downstream/symbol_indexing.jl

+7-2
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,13 @@ sol_new = DiffEqArray(sol.u[1:10],
2222
@test sol_new[t] sol_new.t
2323
@test sol_new[t, 1:5] sol_new.t[1:5]
2424
@test getp(sol, τ)(sol) == getp(sol_new, τ)(sol_new) == 3.0
25-
@test_deprecated sol[τ]
26-
@test_deprecated sol_new[τ]
25+
@test variable_symbols(sol) == variable_symbols(sol_new) == [x]
26+
@test all_variable_symbols(sol) == all_variable_symbols(sol_new) == [x, RHS]
27+
@test all_symbols(sol) == all_symbols(sol_new) == [x, RHS, τ, t]
28+
@test sol[solvedvariables, 1:10] == sol_new[solvedvariables] == sol_new[[x]]
29+
@test sol[allvariables, 1:10] == sol_new[allvariables] == sol_new[[x, RHS]]
30+
@test_throws Exception sol[τ]
31+
@test_throws Exception sol_new[τ]
2732

2833
# Tables interface
2934
test_tables_interface(sol_new, [:timestamp, Symbol("x(t)")], hcat(sol_new[t], sol_new[x]))

test/qa.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using RecursiveArrayTools, Aqua
33
Aqua.find_persistent_tasks_deps(RecursiveArrayTools)
44
ambs = Aqua.detect_ambiguities(RecursiveArrayTools; recursive = true)
55
@warn "Number of method ambiguities: $(length(ambs))"
6-
@test length(ambs) <= 1
6+
@test length(ambs) <= 11
77
Aqua.test_deps_compat(RecursiveArrayTools)
88
Aqua.test_piracies(RecursiveArrayTools)
99
Aqua.test_project_extras(RecursiveArrayTools)

test/symbolic_indexing_interface_test.jl

+6-3
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ dx = DiffEqArray([[f(x), f2(x)] for x in t],
2121
@test dx[(:a, :b)] == [(f(x), f2(x)) for x in t]
2222
@test dx[[:a, :b], 3] [f(t[3]), f2(t[3])]
2323
@test dx[[:a, :b], 4:5] vcat(f.(t[4:5])', f2.(t[4:5])')
24+
@test dx[solvedvariables] == dx[allvariables] == dx[[:a, :b]]
25+
@test dx[solvedvariables, 3] == dx[allvariables, 3] == dx[[:a, :b], 3]
2426
@test getp(dx, [:p, :q])(dx) == [1.0, 2.0]
2527
@test getp(dx, :p)(dx) == 1.0
2628
@test getp(dx, :q)(dx) == 2.0
27-
@test_deprecated dx[:p]
28-
@test_deprecated dx[[:p, :q]]
29+
@test_throws Exception dx[:p]
30+
@test_throws Exception dx[[:p, :q]]
2931
@test dx[:t] == t
3032

3133
@test symbolic_container(dx) isa SymbolCache
@@ -35,11 +37,12 @@ dx = DiffEqArray([[f(x), f2(x)] for x in t],
3537
@test is_parameter.((dx,), [:a, :b, :p, :q, :t]) == [false, false, true, true, false]
3638
@test parameter_index.((dx,), [:a, :b, :p, :q, :t]) == [nothing, nothing, 1, 2, nothing]
3739
@test is_independent_variable.((dx,), [:a, :b, :p, :q, :t]) == [false, false, false, false, true]
38-
@test variable_symbols(dx) == [:a, :b]
40+
@test variable_symbols(dx) == all_variable_symbols(dx) == [:a, :b]
3941
@test parameter_symbols(dx) == [:p, :q]
4042
@test independent_variable_symbols(dx) == [:t]
4143
@test is_time_dependent(dx)
4244
@test constant_structure(dx)
45+
@test all_symbols(dx) == [:a, :b, :p, :q, :t]
4346

4447
dx = DiffEqArray([[f(x), f2(x)] for x in t], t; variables = [:a, :b])
4548
@test_throws Exception dx[nothing] # make sure it isn't storing [nothing] as indepsym

0 commit comments

Comments
 (0)