Skip to content

Commit 3e6d723

Browse files
fixup! test: rewrite hybrid system tests to not use MTK
1 parent af0d5ed commit 3e6d723

File tree

1 file changed

+87
-29
lines changed

1 file changed

+87
-29
lines changed

test/downstream/comprehensive_indexing.jl

Lines changed: 87 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using ModelingToolkit, JumpProcesses, LinearAlgebra, NonlinearSolve, Optimization,
22
OptimizationOptimJL, OrdinaryDiffEq, RecursiveArrayTools, SciMLBase,
33
SteadyStateDiffEq, StochasticDiffEq, SymbolicIndexingInterface,
4-
DiffeqCallbacks, Test
4+
DiffEqCallbacks, Test
55
using ModelingToolkit: t_nounits as t, D_nounits as D
66

77
# Sets rnd number.
@@ -536,7 +536,8 @@ end
536536
end
537537
SymbolicIndexingInterface.symbolic_container(s::NumSymbolCache) = s.sc
538538
function SymbolicIndexingInterface.is_observed(s::NumSymbolCache, x)
539-
return symbolic_type(x) != NotSymbolic() && !is_variable(s, x) && !is_parameter(s, x) && !is_independent_variable(s, x)
539+
return symbolic_type(x) != NotSymbolic() && !is_variable(s, x) &&
540+
!is_parameter(s, x) && !is_independent_variable(s, x)
540541
end
541542
function SymbolicIndexingInterface.observed(s::NumSymbolCache, x)
542543
res = ModelingToolkit.build_function(x,
@@ -584,7 +585,8 @@ end
584585
end
585586
end
586587
end
587-
function SymbolicIndexingInterface.with_updated_parameter_timeseries_values(p::Vector{Float64}, args...)
588+
function SymbolicIndexingInterface.with_updated_parameter_timeseries_values(
589+
::NumSymbolCache, p::Vector{Float64}, args...)
588590
for (idx, buf) in args
589591
if idx == 1
590592
p[1:2] .= buf
@@ -604,7 +606,7 @@ end
604606
dea2 = DiffEqArray(Vector{Float64}[], Float64[])
605607
return ParameterTimeseriesCollection((dea1, dea2), deepcopy(ps))
606608
end
607-
function SciMLBase.get_saveable_values(p::Vector{Float64}, tsidx)
609+
function SciMLBase.get_saveable_values(::NumSymbolCache, p::Vector{Float64}, tsidx)
608610
if tsidx == 1
609611
return p[1:2]
610612
else
@@ -614,38 +616,96 @@ end
614616

615617
@variables x(t) ud1(t) ud2(t) xd1(t) xd2(t)
616618
@parameters kp
617-
sc = SymbolCache([x], Dict(ud1 => 1, xd1 => 2, ud2 => 3, xd2 => 4, kp => 5), t; timeseries_parameters = Dict(ud1 => ParameterTimeseriesIndex(1, 1), xd1 => ParameterTimeseriesIndex(1, 2), ud2 => ParameterTimeseriesIndex(2, 1), xd2 => ParameterTimeseriesIndex(2, 2)))
619+
sc = SymbolCache([x],
620+
Dict(ud1 => 1, xd1 => 2, ud2 => 3, xd2 => 4, kp => 5),
621+
t;
622+
timeseries_parameters = Dict(
623+
ud1 => ParameterTimeseriesIndex(1, 1), xd1 => ParameterTimeseriesIndex(1, 2),
624+
ud2 => ParameterTimeseriesIndex(2, 1), xd2 => ParameterTimeseriesIndex(2, 2)))
618625
sys = NumSymbolCache(sc)
619626

620627
function f!(du, u, p, t)
621628
du .= u .* t .+ p[5] * sum(u)
622629
end
623630
fn = ODEFunction(f!; sys = sys)
624631
prob = ODEProblem(fn, [1.0], (0.0, 1.0), [1.0, 2.0, 3.0, 4.0, 5.0])
625-
cb1 = PeriodicCallback(0.1; initial_affect = true, final_affect = true, save_positions = (false, false)) do integ
632+
cb1 = PeriodicCallback(0.1; initial_affect = true, final_affect = true,
633+
save_positions = (false, false)) do integ
626634
integ.p[1:2] .+= exp(-integ.t)
627635
SciMLBase.save_discretes!(integ, 1)
628636
end
629637
function affect2!(integ)
630638
integ.p[3:4] .+= only(integ.u)
631639
SciMLBase.save_discretes!(integ, 2)
632640
end
633-
cb2 = DiscreteCallback((args...) -> true, affect2!, save_positions = (false, false), initialize = (c, u, t, integ) -> affect2!(integ))
641+
cb2 = DiscreteCallback((args...) -> true, affect2!, save_positions = (false, false),
642+
initialize = (c, u, t, integ) -> affect2!(integ))
634643
sol = solve(deepcopy(prob), Tsit5(); callback = CallbackSet(cb1, cb2))
635644

636645
ud1val = getindex.(sol.discretes.collection[1].u, 1)
637646
xd1val = getindex.(sol.discretes.collection[1].u, 2)
638647
ud2val = getindex.(sol.discretes.collection[2].u, 1)
639648
xd2val = getindex.(sol.discretes.collection[2].u, 2)
640649

641-
for (sym, timeseries_index, val, buffer, isobs, check_inference) in [
642-
(ud1, 1, ud1val, zeros(length(ud1val)), false, true)
643-
([ud1, xd1], 1, vcat.(ud1val, xd1val), map(_ -> zeros(2), ud1val), false, true)
644-
((ud2, xd2), 2, tuple.(ud2val, xd2val), map(_ -> zeros(2), ud2val), false, true)
645-
(ud2 + xd2, 2, ud2val .+ xd2val, zeros(length(ud2val)), true, true)
646-
([ud2 + xd2, ud2 * xd2], 2, vcat.(ud2val .+ xd2val, ud2val .* xd2val), map(_ -> zeros(2), ud2val), true, true)
647-
((ud1 + xd1, ud1 * xd1), 1, tuple.(ud1val .+ xd1val, ud1val .* xd1val), map(_ -> zeros(2), ud1val), true, true)
648-
]
650+
for (sym, timeseries_index, val, buffer, isobs, check_inference) in [(ud1,
651+
1,
652+
ud1val,
653+
zeros(length(ud1val)),
654+
false,
655+
true)
656+
([ud1, xd1],
657+
1,
658+
vcat.(ud1val,
659+
xd1val),
660+
map(
661+
_ -> zeros(2),
662+
ud1val),
663+
false,
664+
true)
665+
((ud2, xd2),
666+
2,
667+
tuple.(ud2val,
668+
xd2val),
669+
map(
670+
_ -> zeros(2),
671+
ud2val),
672+
false,
673+
true)
674+
(ud2 + xd2,
675+
2,
676+
ud2val .+
677+
xd2val,
678+
zeros(length(ud2val)),
679+
true,
680+
true)
681+
(
682+
[ud2 + xd2,
683+
ud2 * xd2],
684+
2,
685+
vcat.(
686+
ud2val .+
687+
xd2val,
688+
ud2val .*
689+
xd2val),
690+
map(
691+
_ -> zeros(2),
692+
ud2val),
693+
true,
694+
true)
695+
(
696+
(ud1 + xd1,
697+
ud1 * xd1),
698+
1,
699+
tuple.(
700+
ud1val .+
701+
xd1val,
702+
ud1val .*
703+
xd1val),
704+
map(
705+
_ -> zeros(2),
706+
ud1val),
707+
true,
708+
true)]
649709
getter = getp(sys, sym)
650710
if check_inference
651711
@inferred getter(sol)
@@ -678,7 +738,8 @@ end
678738
end
679739
end
680740

681-
for subidx in [1, CartesianIndex(2), :, rand(Bool, length(val)), rand(eachindex(val), 4), 2:5]
741+
for subidx in [
742+
1, CartesianIndex(2), :, rand(Bool, length(val)), rand(eachindex(val), 4), 2:5]
682743
if check_inference
683744
@inferred getter(sol, subidx)
684745
if !isa(val[subidx], Number)
@@ -706,9 +767,7 @@ end
706767
(ud2, xd1, xd2),
707768
ud1 + ud2,
708769
[ud1 + ud2, ud1 * xd1],
709-
(ud1 + ud2, ud1 * xd1),
710-
711-
]
770+
(ud1 + ud2, ud1 * xd1)]
712771
getter = getp(sys, sym)
713772
@test_throws Exception getter(sol)
714773
@test_throws Exception getter([], sol)
@@ -732,7 +791,7 @@ end
732791
((kp, x), true, tuple.(kpval, xval), false),
733792
(2ud2, true, 2 .* ud2val, true),
734793
([kp, 2ud1], true, vcat.(kpval, 2 .* ud1val), false),
735-
((kp, 2ud1), true, tuple.(kpval, 2 .* ud1val), false),
794+
((kp, 2ud1), true, tuple.(kpval, 2 .* ud1val), false)
736795
]
737796
getter = getu(sys, sym)
738797
if check_inference
@@ -741,8 +800,8 @@ end
741800
@test getter(sol) == val
742801
reference = val_is_timeseries ? val : xval
743802
for subidx in [
744-
1, CartesianIndex(2), : ,rand(Bool, length(reference)),
745-
rand(eachindex(reference), 4), 2:6
803+
1, CartesianIndex(2), :, rand(Bool, length(reference)),
804+
rand(eachindex(reference), 4), 2:6
746805
]
747806
if check_inference
748807
@inferred getter(sol, subidx)
@@ -767,7 +826,7 @@ end
767826
((x, ud1), (_xval, _ud1val), true),
768827
(x + ud2, _xval + _ud2val, true),
769828
([2x, 3xd1], [2_xval, 3_xd1val], true),
770-
((2x, 3xd2), (2_xval, 3_xd2val), true),
829+
((2x, 3xd2), (2_xval, 3_xd2val), true)
771830
]
772831
getter = getu(sys, sym)
773832
@test_throws Exception getter(sol)
@@ -781,8 +840,8 @@ end
781840
@test getter(integ) == val
782841
end
783842

784-
xinterp = sol(0.1:0.1:0.3, idxs=x)
785-
xinterp2 = sol(sol.discretes.collection[2].t[2:4], idxs=x)
843+
xinterp = sol(0.1:0.1:0.3, idxs = x)
844+
xinterp2 = sol(sol.discretes.collection[2].t[2:4], idxs = x)
786845
ud1interp = ud1val[2:4]
787846
ud2interp = ud2val[2:4]
788847

@@ -796,13 +855,12 @@ end
796855
(x, c2[2], xinterp2[1]),
797856
(x, c2[2:4], xinterp2),
798857
([x, ud2], c2[2], [xinterp2[1], ud2interp[1]]),
799-
([x, ud2], c2[2:4], vcat.(xinterp2, ud2interp)),
858+
([x, ud2], c2[2:4], vcat.(xinterp2, ud2interp))
800859
]
801-
res = sol(t, idxs=sym)
860+
res = sol(t, idxs = sym)
802861
if res isa DiffEqArray
803862
res = res.u
804863
end
805864
@test res == val
806-
end
865+
end
807866
end
808-

0 commit comments

Comments
 (0)