Skip to content

Commit 19a4362

Browse files
fix: some bug fixes
1 parent 77591ca commit 19a4362

File tree

8 files changed

+39
-21
lines changed

8 files changed

+39
-21
lines changed

src/systems/abstractsystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
165165
if !iscomplete(sys)
166166
error("A completed system is required. Call `complete` or `structural_simplify` on the system.")
167167
end
168-
p = reorder_parameters(sys, ps)
168+
p = reorder_parameters(sys, unwrap.(ps))
169169
isscalar = !(exprs isa AbstractArray)
170170
if wrap_code === nothing
171171
wrap_code = isscalar ? identity : (identity, identity)
@@ -1701,6 +1701,7 @@ function linearization_function(sys::AbstractSystem, inputs,
17011701
u0, _p, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
17021702
ps = parameters(sys)
17031703
if has_index_cache(sys) && get_index_cache(sys) !== nothing
1704+
@show p full_parameters(sys)
17041705
p = MTKParameters(sys, p)
17051706
else
17061707
p = _p

src/systems/index_cache.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ function reorder_parameters(sys::AbstractSystem, ps; kwargs...)
240240
end
241241

242242
function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
243+
isempty(ps) && return ()
243244
param_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
244245
for temp in ic.tunable_buffer_sizes)
245246
disc_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
@@ -281,6 +282,9 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
281282
end
282283
end
283284
end
285+
if all(isempty, result)
286+
return ()
287+
end
284288
return result
285289
end
286290

src/systems/optimization/constraints_system.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,11 @@ function calculate_jacobian(sys::ConstraintsSystem; sparse = false, simplify = f
164164
end
165165

166166
function generate_jacobian(
167-
sys::ConstraintsSystem, vs = unknowns(sys), ps = parameters(sys);
167+
sys::ConstraintsSystem, vs = unknowns(sys), ps = full_parameters(sys);
168168
sparse = false, simplify = false, kwargs...)
169169
jac = calculate_jacobian(sys, sparse = sparse, simplify = simplify)
170-
return build_function(jac, vs, ps; kwargs...)
170+
p = reorder_parameters(sys, ps)
171+
return build_function(jac, vs, p...; kwargs...)
171172
end
172173

173174
function calculate_hessian(sys::ConstraintsSystem; sparse = false, simplify = false)
@@ -181,19 +182,20 @@ function calculate_hessian(sys::ConstraintsSystem; sparse = false, simplify = fa
181182
return hess
182183
end
183184

184-
function generate_hessian(sys::ConstraintsSystem, vs = unknowns(sys), ps = parameters(sys);
185+
function generate_hessian(sys::ConstraintsSystem, vs = unknowns(sys), ps = full_parameters(sys);
185186
sparse = false, simplify = false, kwargs...)
186187
hess = calculate_hessian(sys, sparse = sparse, simplify = simplify)
187-
return build_function(hess, vs, ps; kwargs...)
188+
p = reorder_parameters(sys, ps)
189+
return build_function(hess, vs, p...; kwargs...)
188190
end
189191

190192
function generate_function(sys::ConstraintsSystem, dvs = unknowns(sys),
191-
ps = parameters(sys);
193+
ps = full_parameters(sys);
192194
kwargs...)
193195
lhss = generate_canonical_form_lhss(sys)
194196
pre, sol_states = get_substitutions_and_solved_unknowns(sys)
195-
196-
func = build_function(lhss, value.(dvs), value.(ps); postprocess_fbody = pre,
197+
p = reorder_parameters(sys, value.(ps))
198+
func = build_function(lhss, value.(dvs), p...; postprocess_fbody = pre,
197199
states = sol_states, kwargs...)
198200

199201
cstr = constraints(sys)

src/systems/optimization/optimizationsystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
361361

362362
if length(cstr) > 0
363363
@named cons_sys = ConstraintsSystem(cstr, dvs, ps)
364+
cons_sys = complete(cons_sys)
364365
cons, lcons_, ucons_ = generate_function(cons_sys, checkbounds = checkbounds,
365366
linenumbers = linenumbers,
366367
expression = Val{false})

src/systems/parameter_buffer.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals
1414
else
1515
error("Cannot create MTKParameters if system does not have index_cache")
1616
end
17-
all_ps = Set(unwrap.(parameters(sys)))
18-
union!(all_ps, default_toterm.(unwrap.(parameters(sys))))
17+
all_ps = Set(unwrap.(full_parameters(sys)))
18+
union!(all_ps, default_toterm.(unwrap.(full_parameters(sys))))
1919
if p isa Vector && !(eltype(p) <: Pair) && !isempty(p)
20-
ps = parameters(sys)
20+
ps = full_parameters(sys)
2121
length(p) == length(ps) || error("Invalid parameters")
2222
p = ps .=> p
2323
end
@@ -38,7 +38,7 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals
3838

3939
for (sym, _) in p
4040
if istree(sym) && operation(sym) === getindex &&
41-
is_parameter(sys, first(arguments(sym)))
41+
first(arguments(sym)) in all_ps
4242
error("Scalarized parameter values ($sym) are not supported. Instead of `[p[1] => 1.0, p[2] => 2.0]` use `[p => [1.0, 2.0]]`")
4343
end
4444
end
@@ -99,7 +99,7 @@ function MTKParameters(sys::AbstractSystem, p; tofloat = false, use_union = fals
9999
i, j = ic.dependent_idx[sym]
100100
dep_exprs.x[i][j] = wrap(val)
101101
end
102-
p = reorder_parameters(ic, parameters(sys))
102+
p = reorder_parameters(ic, full_parameters(sys))
103103
oop, iip = build_function(dep_exprs, p...)
104104
update_function_iip, update_function_oop = RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(iip),
105105
RuntimeGeneratedFunctions.@RuntimeGeneratedFunction(oop)

src/systems/systems.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,16 @@ function structural_simplify(
3232
@set! newsys.parent = complete(sys; split)
3333
end
3434
newsys = complete(newsys; split)
35+
if has_defaults(newsys) && (defs = get_defaults(newsys)) !== nothing
36+
ks = collect(keys(defs))
37+
for k in ks
38+
if Symbolics.isarraysymbolic(k) && Symbolics.shape(k) !== Symbolics.Unknown()
39+
for i in eachindex(k)
40+
defs[k[i]] = defs[k][i]
41+
end
42+
end
43+
end
44+
end
3545
if newsys′ isa Tuple
3646
idxs = [parameter_index(newsys, i) for i in io[1]]
3747
return newsys, idxs

test/initializationsystem.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ prob = ODEProblem(sys, [sys.dx => 1], (0, 1)) # OK
355355
prob = ODEProblem(sys, [sys.ddx => -2], (0, 1), guesses = [sys.dx => 1])
356356
sol = solve(prob, Tsit5())
357357
@test SciMLBase.successful_retcode(sol)
358-
@test sol[1] == [1.0]
358+
@test sol.u[1] == [1.0]
359359

360360
## Late binding initialization_eqs
361361

@@ -376,7 +376,7 @@ end
376376
prob = ODEProblem(sys, [], (0, 1), guesses = [sys.dx => 1])
377377
sol = solve(prob, Tsit5())
378378
@test SciMLBase.successful_retcode(sol)
379-
@test sol[1] == [1.0]
379+
@test sol.u[1] == [1.0]
380380

381381
# Steady state initialization
382382

@@ -417,16 +417,16 @@ tspan = (0.0, 10.0)
417417

418418
prob = ODEProblem(simpsys, [D(x) => 0.0, y => 0.0], tspan, guesses = [x => 0.0])
419419
sol = solve(prob, Tsit5())
420-
@test sol[1] == [0.0, 0.0]
420+
@test sol.u[1] == [0.0, 0.0]
421421

422422
# Initialize with an observed variable
423423
prob = ODEProblem(simpsys, [z => 0.0], tspan, guesses = [x => 2.0, y => 4.0])
424424
sol = solve(prob, Tsit5())
425-
@test sol[1] == [0.0, 0.0]
425+
@test sol.u[1] == [0.0, 0.0]
426426

427427
prob = ODEProblem(simpsys, [z => 1.0, y => 1.0], tspan, guesses = [x => 2.0])
428428
sol = solve(prob, Tsit5())
429-
@test sol[1] == [0.0, 1.0]
429+
@test sol.u[1] == [0.0, 1.0]
430430

431-
@test_broken @test_warn "Initialization system is underdetermined. 1 equations for 3 unknowns. Initialization will default to using least squares. To suppress this warning pass warn_initialize_determined = false." prob=ODEProblem(
432-
simpsys, [], tspan, guesses = [x => 2.0])
431+
# This should warn, but logging tests can't be marked as broken
432+
@test_logs prob=ODEProblem(simpsys, [], tspan, guesses = [x => 2.0])

test/odesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ sts = @variables x(t)[1:3]=[1, 2, 3.0] y(t)=1.0
472472
ps = @parameters p[1:3] = [1, 2, 3]
473473
eqs = [collect(D.(x) .~ x)
474474
D(y) ~ norm(collect(x)) * y - x[1]]
475-
@named sys = ODESystem(eqs, t, [sts...;], [ps...;])
475+
@named sys = ODESystem(eqs, t, sts, ps)
476476
sys = structural_simplify(sys)
477477
@test isequal(@nonamespace(sys.x), x)
478478
@test isequal(@nonamespace(sys.y), y)

0 commit comments

Comments
 (0)