Skip to content

Commit ad1e627

Browse files
feat: add support for ParameterIndexingProxy
1 parent 350e3d8 commit ad1e627

12 files changed

+90
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ SciMLOperators = "0.3.7"
8181
StaticArrays = "1.7"
8282
StaticArraysCore = "1.4"
8383
Statistics = "1.9"
84-
SymbolicIndexingInterface = "0.3"
84+
SymbolicIndexingInterface = "0.3.3"
8585
Tables = "1.11"
8686
TruncatedStacktraces = "1.4"
8787
Zygote = "0.6.67"

src/integrator_interface.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,8 @@ function Base.getproperty(A::DEIntegrator, sym::Symbol)
445445
if sym === :destats && hasfield(typeof(A), :stats)
446446
@warn "destats has been deprecated for stats"
447447
getfield(A, :stats)
448+
elseif sym === :ps
449+
return ParameterIndexingProxy(A)
448450
else
449451
return getfield(A, sym)
450452
end

src/problems/basic_problems.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,8 @@ function Base.getproperty(prob::IntegralProblem, name::Symbol)
502502
domain = getfield(prob, :domain)
503503
lb, ub = domain
504504
return ub
505+
elseif name === :ps
506+
return ParameterIndexingProxy(prob)
505507
end
506508
return Base.getfield(prob, name)
507509
end

src/problems/problem_interface.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
Base.@propagate_inbounds function Base.getproperty(prob::AbstractSciMLProblem, sym::Symbol)
2+
if sym === :ps
3+
return ParameterIndexingProxy(prob)
4+
end
5+
return getfield(prob, sym)
6+
end
7+
18
SymbolicIndexingInterface.symbolic_container(prob::AbstractSciMLProblem) = prob.f
29
SymbolicIndexingInterface.parameter_values(prob::AbstractSciMLProblem) = prob.p
310

src/solutions/dae_solutions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ Base.@propagate_inbounds function Base.getproperty(x::AbstractDAESolution, s::Sy
4747
if s === :destats
4848
Base.depwarn("`sol.destats` is deprecated. Use `sol.stats` instead.", "sol.destats")
4949
return getfield(x, :stats)
50+
elseif s === :ps
51+
return ParameterIndexingProxy(x)
5052
end
5153
return getfield(x, s)
5254
end

src/solutions/ode_solutions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Sy
121121
if s === :destats
122122
Base.depwarn("`sol.destats` is deprecated. Use `sol.stats` instead.", "sol.destats")
123123
return getfield(x, :stats)
124+
elseif s === :ps
125+
return ParameterIndexingProxy(x)
124126
end
125127
return getfield(x, s)
126128
end

src/solutions/optimization_solutions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ Base.@propagate_inbounds function Base.getproperty(x::AbstractOptimizationSoluti
119119
Base.depwarn("`sol.prob` is deprecated. Use getters like `get_p` or `get_syms` on `sol` instead.",
120120
"sol.prob")
121121
return getfield(x, :cache)
122+
elseif s === :ps
123+
return ParameterIndexingProxy(x)
122124
end
123125
return getfield(x, s)
124126
end

src/solutions/rode_solutions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ Base.@propagate_inbounds function Base.getproperty(x::AbstractRODESolution, s::S
5555
if s === :destats
5656
Base.depwarn("`sol.destats` is deprecated. Use `sol.stats` instead.", "sol.destats")
5757
return getfield(x, :stats)
58+
elseif s === :ps
59+
return ParameterIndexingProxy(x)
5860
end
5961
return getfield(x, s)
6062
end

src/solutions/solution_interface.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@ end
2929

3030
# SymbolicIndexingInterface.jl
3131
const AbstractSolution = Union{AbstractTimeseriesSolution,AbstractNoTimeSolution}
32+
33+
Base.@propagate_inbounds function Base.getproperty(A::AbstractSolution, sym::Symbol)
34+
if sym === :ps
35+
return ParameterIndexingProxy(A)
36+
else
37+
return getfield(A, sym)
38+
end
39+
end
40+
3241
SymbolicIndexingInterface.symbolic_container(A::AbstractSolution) = A.prob.f
3342
SymbolicIndexingInterface.parameter_values(A::AbstractSolution) = A.prob.p
3443

test/downstream/integrator_indexing.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,13 @@ integrator = init(oprob, Rodas4())
2222
@test_throws Exception integrator[population_model.a]
2323
@test_throws Exception integrator[:a]
2424
@test getp(oprob, a)(integrator) == getp(oprob, population_model.a)(integrator) == getp(oprob, :a)(integrator) == 2.0
25+
@test integrator.ps[a] == integrator.ps[population_model.a] == integrator.ps[:a] == 2.0
2526
@test getp(oprob, b)(integrator) == getp(oprob, population_model.b)(integrator) == getp(oprob, :b)(integrator) == 1.0
27+
@test integrator.ps[b] == integrator.ps[population_model.b] == integrator.ps[:b] == 1.0
2628
@test getp(oprob, c)(integrator) == getp(oprob, population_model.c)(integrator) == getp(oprob, :c)(integrator) == 1.0
29+
@test integrator.ps[d] == integrator.ps[population_model.d] == integrator.ps[:d] == 1.0
2730
@test getp(oprob, d)(integrator) == getp(oprob, population_model.d)(integrator) == getp(oprob, :d)(integrator) == 1.0
31+
@test integrator.ps[d] == integrator.ps[population_model.d] == integrator.ps[:d] == 1.0
2832

2933
@test integrator[s1] == integrator[population_model.s1] == integrator[:s1] == 2.0
3034
@test integrator[s2] == integrator[population_model.s2] == integrator[:s2] == 1.0
@@ -42,10 +46,15 @@ step!(integrator, 100.0, true)
4246

4347
setp(oprob, a)(integrator, 10.0)
4448
@test getp(integrator, a)(integrator) == getp(integrator, population_model.a)(integrator) == getp(integrator, :a)(integrator) == 10.0
49+
@test integrator.ps[a] == integrator.ps[population_model.a] == integrator.ps[:a] == 10.0
4550
setp(population_model, population_model.b)(integrator, 20.0)
4651
@test getp(integrator, b)(integrator) == getp(integrator, population_model.b)(integrator) == getp(integrator, :b)(integrator) == 20.0
52+
@test integrator.ps[b] == integrator.ps[population_model.b] == integrator.ps[:b] == 20.0
4753
setp(integrator, c)(integrator, 30.0)
4854
@test getp(integrator, c)(integrator) == getp(integrator, population_model.c)(integrator) == getp(integrator, :c)(integrator) == 30.0
55+
@test integrator.ps[c] == integrator.ps[population_model.c] == integrator.ps[:c] == 30.0
56+
integrator.ps[d] = 40.0
57+
@test integrator.ps[d] == integrator.ps[population_model.d] == integrator.ps[:d] == 40.0
4958

5059
integrator[s1] = 10.0
5160
@test integrator[s1] == integrator[population_model.s1] == integrator[:s1] == 10.0
@@ -303,3 +312,8 @@ prob = ODEProblem(sys, [], (0, 1.0))
303312
integrator = init(prob, Tsit5())
304313
@test integrator[x] isa Vector{Float64}
305314
@test integrator[@nonamespace sys.x] isa Vector{Float64}
315+
@test getp(sys, p)(integrator) == integrator.ps[p] == [1, 2, 3]
316+
setp(sys, p)(integrator, [4, 5, 6])
317+
@test getp(sys, p)(integrator) == integrator.ps[p] == [4, 5, 6]
318+
integrator.ps[p] = [7, 8, 9]
319+
@test getp(sys, p)(integrator) == integrator.ps[p] == [7, 8, 9]

test/downstream/problem_interface.jl

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,15 @@ oprob = ODEProblem(sys, u0, tspan, p, jac = true)
3333
getσ1 = getp(sys, σ)
3434
getσ2 = getp(sys, sys.σ)
3535
getσ3 = getp(sys, )
36-
@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == 28.0
36+
@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == oprob.ps[σ] == oprob.ps[sys.σ] == oprob.ps[] == 28.0
3737
getρ1 = getp(sys, ρ)
3838
getρ2 = getp(sys, sys.ρ)
3939
getρ3 = getp(sys, )
40-
@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == 10.0
40+
@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == oprob.ps[ρ] == oprob.ps[sys.ρ] == oprob.ps[] == 10.0
4141
getβ1 = getp(sys, β)
4242
getβ2 = getp(sys, sys.β)
4343
getβ3 = getp(sys, )
44-
@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == 8 / 3
44+
@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == oprob.ps[β] == oprob.ps[sys.β] == oprob.ps[] == 8 / 3
4545

4646
@test oprob[x] == oprob[sys.x] == oprob[:x] == 1.0
4747
@test oprob[y] == oprob[sys.y] == oprob[:y] == 0.0
@@ -51,13 +51,20 @@ getβ3 = getp(sys, :β)
5151

5252
setσ = setp(sys, σ)
5353
setσ(oprob, 10.0)
54-
@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == 10.0
54+
@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == oprob.ps[σ] == oprob.ps[sys.σ] == oprob.ps[] == 10.0
5555
setρ = setp(sys, sys.ρ)
5656
setρ(oprob, 20.0)
57-
@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == 20.0
57+
@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == oprob.ps[ρ] == oprob.ps[sys.ρ] == oprob.ps[] == 20.0
5858
setβ = setp(sys, )
5959
setβ(oprob, 30.0)
60-
@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == 30.0
60+
@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == oprob.ps[β] == oprob.ps[sys.β] == oprob.ps[] == 30.0
61+
62+
oprob.ps[σ] = 11.0
63+
@test getσ1(oprob) == getσ2(oprob) == getσ3(oprob) == oprob.ps[σ] == oprob.ps[sys.σ] == oprob.ps[] == 11.0
64+
oprob.ps[sys.ρ] = 21.0
65+
@test getρ1(oprob) == getρ2(oprob) == getρ3(oprob) == oprob.ps[ρ] == oprob.ps[sys.ρ] == oprob.ps[] == 21.0
66+
oprob.ps[] = 31.0
67+
@test getβ1(oprob) == getβ2(oprob) == getβ3(oprob) == oprob.ps[β] == oprob.ps[sys.β] == oprob.ps[] == 31.0
6168

6269
oprob[x] = 10.0
6370
@test oprob[x] == oprob[sys.x] == oprob[:x] == 10.0
@@ -74,26 +81,47 @@ noiseeqs = [0.1 * x,
7481
sprob = SDEProblem(noise_sys, u0, (0.0, 100.0), p)
7582
u0
7683

77-
@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == 28.0
78-
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == 10.0
79-
@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == 8 / 3
84+
@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == sprob.ps[σ] == sprob.ps[sys.σ] == sprob.ps[] == 28.0
85+
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == sprob.ps[ρ] == sprob.ps[sys.ρ] == sprob.ps[] == 10.0
86+
@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == sprob.ps[β] == sprob.ps[sys.β] == sprob.ps[] == 8 / 3
8087

8188
@test sprob[x] == sprob[noise_sys.x] == sprob[:x] == 1.0
8289
@test sprob[y] == sprob[noise_sys.y] == sprob[:y] == 0.0
8390
@test sprob[z] == sprob[noise_sys.z] == sprob[:z] == 0.0
8491

8592
setσ(sprob, 10.0)
86-
@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == 10.0
93+
@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == sprob.ps[σ] == sprob.ps[sys.σ] == sprob.ps[] == 10.0
8794
setρ(sprob, 20.0)
88-
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == 20.0
95+
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == sprob.ps[ρ] == sprob.ps[sys.ρ] == sprob.ps[] == 20.0
8996
setp(noise_sys, noise_sys.ρ)(sprob, 25.0)
90-
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == 25.0
97+
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == sprob.ps[ρ] == sprob.ps[sys.ρ] == sprob.ps[] == 25.0
9198
setβ(sprob, 30.0)
92-
@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == 30.0
99+
@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == sprob.ps[β] == sprob.ps[sys.β] == sprob.ps[] == 30.0
100+
sprob.ps[σ] = 11.0
101+
@test getσ1(sprob) == getσ2(sprob) == getσ3(sprob) == sprob.ps[σ] == sprob.ps[sys.σ] == sprob.ps[] == 11.0
102+
sprob.ps[sys.ρ] = 21.0
103+
@test getρ1(sprob) == getρ2(sprob) == getρ3(sprob) == sprob.ps[ρ] == sprob.ps[sys.ρ] == sprob.ps[] == 21.0
104+
sprob.ps[] = 31.0
105+
@test getβ1(sprob) == getβ2(sprob) == getβ3(sprob) == sprob.ps[β] == sprob.ps[sys.β] == sprob.ps[] == 31.0
93106

94107
sprob[x] = 10.0
95108
@test sprob[x] == sprob[noise_sys.x] == sprob[:x] == 10.0
96109
sprob[noise_sys.y] = 10.0
97110
@test sprob[y] == sprob[noise_sys.y] == sprob[:y] == 10.0
98111
sprob[:z] = 1.0
99112
@test sprob[z] == sprob[noise_sys.z] == sprob[:z] == 1.0
113+
114+
using LinearAlgebra
115+
@variables t
116+
sts = @variables x(t)[1:3]=[1, 2, 3.0] y(t)=1.0
117+
ps = @parameters p[1:3] = [1, 2, 3]
118+
D = Differential(t)
119+
eqs = [collect(D.(x) .~ x)
120+
D(y) ~ norm(x) * y - x[1]]
121+
@named sys = ODESystem(eqs, t, [sts...;], [ps...;])
122+
prob = ODEProblem(sys, [], (0, 1.0))
123+
@test getp(sys, p)(prob) == prob.ps[p] == [1, 2, 3]
124+
setp(sys, p)(prob, [4, 5, 6])
125+
@test getp(sys, p)(prob) == prob.ps[p] == [4, 5, 6]
126+
prob.ps[p] = [7, 8, 9]
127+
@test getp(sys, p)(prob) == prob.ps[p] == [7, 8, 9]

test/downstream/symbol_indexing.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ sol = solve(prob, Rodas4())
7272
@test sol[α, 3] isa Float64
7373
@test length(sol[α, 5:10]) == 6
7474
@test getp(prob, γ)(sol) isa Real
75-
@test getp(prob, γ)(sol) == getp(prob, )(sol) == 2.0
75+
@test sol.ps[γ] isa Real
76+
@test getp(prob, γ)(sol) == getp(prob, )(sol) == sol.ps[γ] == sol.ps[] == 2.0
7677
@test getp(prob, (lorenz1.σ, lorenz1.ρ))(sol) isa Tuple
78+
@test sol.ps[(lorenz1.σ, lorenz1.ρ)] isa Tuple
7779

7880
@test sol[[lorenz1.x, lorenz2.x]] isa Vector{Vector{Float64}}
7981
@test length(sol[[lorenz1.x, lorenz2.x]]) == length(sol)
@@ -196,6 +198,7 @@ prob = ODEProblem(sys, [], (0, 1.0))
196198
sol = solve(prob, Tsit5())
197199
@test sol[x] isa Vector{<:Vector}
198200
@test sol[@nonamespace sys.x] isa Vector{<:Vector}
201+
@test sol.ps[p] == [1, 2, 3]
199202

200203
# accessing parameters
201204
@variables t x(t)
@@ -225,4 +228,6 @@ sol = solve(prob, Tsit5())
225228
@test sol[y]1 atol=1e-3
226229
@test getp(sys, a)(sol) 1
227230
@test getp(sys, b)(sol) 100
231+
@test sol.ps[a] 1
232+
@test sol.ps[b] 100
228233
end

0 commit comments

Comments
 (0)