Skip to content

Commit 23b7b2e

Browse files
Merge pull request #3053 from AayushSabharwal/as/ddesol-obs
feat: support observed function generation for DDEs
2 parents 07701c4 + 14789d5 commit 23b7b2e

11 files changed

+270
-90
lines changed

Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,15 @@ PrecompileTools = "1"
112112
RecursiveArrayTools = "3.26"
113113
Reexport = "0.2, 1"
114114
RuntimeGeneratedFunctions = "0.5.9"
115-
SciMLBase = "2.52.1"
115+
SciMLBase = "2.55"
116116
SciMLStructures = "1.0"
117117
Serialization = "1"
118118
Setfield = "0.7, 0.8, 1"
119119
SimpleNonlinearSolve = "0.1.0, 1"
120120
SparseArrays = "1"
121121
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
122122
StaticArrays = "0.10, 0.11, 0.12, 1.0"
123-
SymbolicIndexingInterface = "0.3.29"
123+
SymbolicIndexingInterface = "0.3.31"
124124
SymbolicUtils = "3.7"
125125
Symbolics = "6.12"
126126
URIs = "1"

src/systems/abstractsystem.jl

+86-60
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,8 @@ function wrap_parameter_dependencies(sys::AbstractSystem, isscalar)
230230
end
231231

232232
function wrap_array_vars(
233-
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys), inputs = nothing)
233+
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys),
234+
inputs = nothing, history = false)
234235
isscalar = !(exprs isa AbstractArray)
235236
array_vars = Dict{Any, AbstractArray{Int}}()
236237
if dvs !== nothing
@@ -328,6 +329,19 @@ function wrap_array_vars(
328329
array_parameters[p] = (idxs, buffer_idx, sz)
329330
end
330331
end
332+
333+
inputind = if history
334+
uind + 2
335+
else
336+
uind + 1
337+
end
338+
params_offset = if history && hasinputs
339+
uind + 2
340+
elseif history || hasinputs
341+
uind + 1
342+
else
343+
uind
344+
end
331345
if isscalar
332346
function (expr)
333347
Func(
@@ -336,10 +350,10 @@ function wrap_array_vars(
336350
Let(
337351
vcat(
338352
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
339-
[k :(view($(expr.args[uind + hasinputs].name), $v))
353+
[k :(view($(expr.args[inputind].name), $v))
340354
for (k, v) in input_vars],
341355
[k :(reshape(
342-
view($(expr.args[uind + hasinputs + buffer_idx].name), $idxs),
356+
view($(expr.args[params_offset + buffer_idx].name), $idxs),
343357
$sz))
344358
for (k, (idxs, buffer_idx, sz)) in array_parameters],
345359
[k Code.MakeArray(v, symtype(k))
@@ -358,10 +372,10 @@ function wrap_array_vars(
358372
Let(
359373
vcat(
360374
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
361-
[k :(view($(expr.args[uind + hasinputs].name), $v))
375+
[k :(view($(expr.args[inputind].name), $v))
362376
for (k, v) in input_vars],
363377
[k :(reshape(
364-
view($(expr.args[uind + hasinputs + buffer_idx].name), $idxs),
378+
view($(expr.args[params_offset + buffer_idx].name), $idxs),
365379
$sz))
366380
for (k, (idxs, buffer_idx, sz)) in array_parameters],
367381
[k Code.MakeArray(v, symtype(k))
@@ -380,10 +394,10 @@ function wrap_array_vars(
380394
vcat(
381395
[k :(view($(expr.args[uind + 1].name), $v))
382396
for (k, v) in array_vars],
383-
[k :(view($(expr.args[uind + hasinputs + 1].name), $v))
397+
[k :(view($(expr.args[inputind + 1].name), $v))
384398
for (k, v) in input_vars],
385399
[k :(reshape(
386-
view($(expr.args[uind + hasinputs + buffer_idx + 1].name),
400+
view($(expr.args[params_offset + buffer_idx + 1].name),
387401
$idxs),
388402
$sz))
389403
for (k, (idxs, buffer_idx, sz)) in array_parameters],
@@ -398,50 +412,76 @@ function wrap_array_vars(
398412
end
399413
end
400414

401-
function wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool)
415+
const MTKPARAMETERS_ARG = Sym{Vector{Vector}}(:___mtkparameters___)
416+
417+
"""
418+
wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2)
419+
420+
Return function(s) to be passed to the `wrap_code` keyword of `build_function` which
421+
allow the compiled function to be called as `f(u, p, t)` where `p isa MTKParameters`
422+
instead of `f(u, p..., t)`. `isscalar` denotes whether the function expression being
423+
wrapped is for a scalar value. `p_start` is the index of the argument containing
424+
the first parameter vector in the out-of-place version of the function. For example,
425+
if a history function (DDEs) was passed before `p`, then the function before wrapping
426+
would have the signature `f(u, h, p..., t)` and hence `p_start` would need to be `3`.
427+
428+
The returned function is `identity` if the system does not have an `IndexCache`.
429+
"""
430+
function wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2)
402431
if has_index_cache(sys) && get_index_cache(sys) !== nothing
403432
offset = Int(is_time_dependent(sys))
404433

405434
if isscalar
406435
function (expr)
407-
p = gensym(:p)
436+
param_args = expr.args[p_start:(end - offset)]
437+
param_buffer_idxs = findall(x -> x isa DestructuredArgs, param_args)
438+
param_buffer_args = param_args[param_buffer_idxs]
439+
destructured_mtkparams = DestructuredArgs(
440+
[x.name for x in param_buffer_args],
441+
MTKPARAMETERS_ARG; inds = param_buffer_idxs)
408442
Func(
409443
[
410-
expr.args[1],
411-
DestructuredArgs(
412-
[arg.name for arg in expr.args[2:(end - offset)]], p),
413-
(isone(offset) ? (expr.args[end],) : ())...
444+
expr.args[begin:(p_start - 1)]...,
445+
destructured_mtkparams,
446+
expr.args[(end - offset + 1):end]...
414447
],
415448
[],
416-
Let(expr.args[2:(end - offset)], expr.body, false)
449+
Let(param_buffer_args, expr.body, false)
417450
)
418451
end
419452
else
420453
function (expr)
421-
p = gensym(:p)
454+
param_args = expr.args[p_start:(end - offset)]
455+
param_buffer_idxs = findall(x -> x isa DestructuredArgs, param_args)
456+
param_buffer_args = param_args[param_buffer_idxs]
457+
destructured_mtkparams = DestructuredArgs(
458+
[x.name for x in param_buffer_args],
459+
MTKPARAMETERS_ARG; inds = param_buffer_idxs)
422460
Func(
423461
[
424-
expr.args[1],
425-
DestructuredArgs(
426-
[arg.name for arg in expr.args[2:(end - offset)]], p),
427-
(isone(offset) ? (expr.args[end],) : ())...
462+
expr.args[begin:(p_start - 1)]...,
463+
destructured_mtkparams,
464+
expr.args[(end - offset + 1):end]...
428465
],
429466
[],
430-
Let(expr.args[2:(end - offset)], expr.body, false)
467+
Let(param_buffer_args, expr.body, false)
431468
)
432469
end,
433470
function (expr)
434-
p = gensym(:p)
471+
param_args = expr.args[(p_start + 1):(end - offset)]
472+
param_buffer_idxs = findall(x -> x isa DestructuredArgs, param_args)
473+
param_buffer_args = param_args[param_buffer_idxs]
474+
destructured_mtkparams = DestructuredArgs(
475+
[x.name for x in param_buffer_args],
476+
MTKPARAMETERS_ARG; inds = param_buffer_idxs)
435477
Func(
436478
[
437-
expr.args[1],
438-
expr.args[2],
439-
DestructuredArgs(
440-
[arg.name for arg in expr.args[3:(end - offset)]], p),
441-
(isone(offset) ? (expr.args[end],) : ())...
479+
expr.args[begin:p_start]...,
480+
destructured_mtkparams,
481+
expr.args[(end - offset + 1):end]...
442482
],
443483
[],
444-
Let(expr.args[3:(end - offset)], expr.body, false)
484+
Let(param_buffer_args, expr.body, false)
445485
)
446486
end
447487
end
@@ -669,25 +709,17 @@ function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
669709
if rawobs isa Tuple
670710
if is_time_dependent(sys)
671711
obsfn = let oop = rawobs[1], iip = rawobs[2]
672-
f1a(p::MTKParameters, t) = oop(p..., t)
673-
f1a(out, p::MTKParameters, t) = iip(out, p..., t)
712+
f1a(p, t) = oop(p, t)
713+
f1a(out, p, t) = iip(out, p, t)
674714
end
675715
else
676716
obsfn = let oop = rawobs[1], iip = rawobs[2]
677-
f1b(p::MTKParameters) = oop(p...)
678-
f1b(out, p::MTKParameters) = iip(out, p...)
717+
f1b(p) = oop(p)
718+
f1b(out, p) = iip(out, p)
679719
end
680720
end
681721
else
682-
if is_time_dependent(sys)
683-
obsfn = let rawobs = rawobs
684-
f2a(p::MTKParameters, t) = rawobs(p..., t)
685-
end
686-
else
687-
obsfn = let rawobs = rawobs
688-
f2b(p::MTKParameters) = rawobs(p...)
689-
end
690-
end
722+
obsfn = rawobs
691723
end
692724
else
693725
obsfn = build_explicit_observed_function(sys, sym; param_only = true)
@@ -802,17 +834,11 @@ function SymbolicIndexingInterface.observed(
802834
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module)
803835

804836
if is_time_dependent(sys)
805-
return let _fn = _fn
806-
fn1(u, p, t) = _fn(u, p, t)
807-
fn1(u, p::MTKParameters, t) = _fn(u, p..., t)
808-
fn1
809-
end
837+
return _fn
810838
else
811839
return let _fn = _fn
812840
fn2(u, p) = _fn(u, p)
813-
fn2(u, p::MTKParameters) = _fn(u, p...)
814841
fn2(::Nothing, p) = _fn([], p)
815-
fn2(::Nothing, p::MTKParameters) = _fn([], p...)
816842
fn2
817843
end
818844
end
@@ -828,6 +854,8 @@ end
828854
SymbolicIndexingInterface.is_time_dependent(::AbstractTimeDependentSystem) = true
829855
SymbolicIndexingInterface.is_time_dependent(::AbstractTimeIndependentSystem) = false
830856

857+
SymbolicIndexingInterface.is_markovian(sys::AbstractSystem) = !is_dde(sys)
858+
831859
SymbolicIndexingInterface.constant_structure(::AbstractSystem) = true
832860

833861
function SymbolicIndexingInterface.all_variable_symbols(sys::AbstractSystem)
@@ -971,6 +999,7 @@ for prop in [:eqs
971999
:solved_unknowns
9721000
:split_idxs
9731001
:parent
1002+
:is_dde
9741003
:index_cache
9751004
:is_scalar_noise
9761005
:isscheduled]
@@ -2349,8 +2378,8 @@ function linearization_function(sys::AbstractSystem, inputs,
23492378
u_getter = u_getter
23502379

23512380
function (u, p, t)
2352-
p_setter!(oldps, p_getter(u, p..., t))
2353-
newu = u_getter(u, p..., t)
2381+
p_setter!(oldps, p_getter(u, p, t))
2382+
newu = u_getter(u, p, t)
23542383
return newu, oldps
23552384
end
23562385
end
@@ -2361,20 +2390,15 @@ function linearization_function(sys::AbstractSystem, inputs,
23612390

23622391
function (u, p, t)
23632392
state = ProblemState(; u, p, t)
2364-
return u_getter(state), p_getter(state)
2393+
return u_getter(
2394+
state_values(state), parameter_values(state), current_time(state)),
2395+
p_getter(state)
23652396
end
23662397
end
23672398
end
23682399
initfn = NonlinearFunction(initsys; eval_expression, eval_module)
23692400
initprobmap = build_explicit_observed_function(
23702401
initsys, unknowns(sys); eval_expression, eval_module)
2371-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
2372-
initprobmap = let inner = initprobmap
2373-
fn(u, p::MTKParameters) = inner(u, p...)
2374-
fn(u, p) = inner(u, p)
2375-
fn
2376-
end
2377-
end
23782402
ps = parameters(sys)
23792403
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
23802404
lin_fun = let diff_idxs = diff_idxs,
@@ -2421,7 +2445,7 @@ function linearization_function(sys::AbstractSystem, inputs,
24212445
fg_xz = ForwardDiff.jacobian(uf, u)
24222446
h_xz = ForwardDiff.jacobian(
24232447
let p = p, t = t
2424-
xz -> p isa MTKParameters ? h(xz, p..., t) : h(xz, p, t)
2448+
xz -> h(xz, p, t)
24252449
end, u)
24262450
pf = SciMLBase.ParamJacobianWrapper(fun, t, u)
24272451
fg_u = jacobian_wrt_vars(pf, p, input_idxs, chunk)
@@ -2433,7 +2457,6 @@ function linearization_function(sys::AbstractSystem, inputs,
24332457
end
24342458
hp = let u = u, t = t
24352459
_hp(p) = h(u, p, t)
2436-
_hp(p::MTKParameters) = h(u, p..., t)
24372460
_hp
24382461
end
24392462
h_u = jacobian_wrt_vars(hp, p, input_idxs, chunk)
@@ -2486,7 +2509,7 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
24862509
dx = fun(sts, p..., t)
24872510

24882511
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
2489-
y = h(sts, p..., t)
2512+
y = h(sts, p, t)
24902513

24912514
fg_xz = Symbolics.jacobian(dx, sts)
24922515
fg_u = Symbolics.jacobian(dx, inputs)
@@ -2955,6 +2978,9 @@ function compose(sys::AbstractSystem, systems::AbstractArray; name = nameof(sys)
29552978
nsys == 0 && return sys
29562979
@set! sys.name = name
29572980
@set! sys.systems = [get_systems(sys); systems]
2981+
if has_is_dde(sys)
2982+
@set! sys.is_dde = _check_if_dde(equations(sys), get_iv(sys), get_systems(sys))
2983+
end
29582984
return sys
29592985
end
29602986
function compose(syss...; name = nameof(first(syss)))

src/systems/diffeqs/abstractodesystem.jl

+35-9
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,29 @@ struct Schedule
33
dummy_sub::Any
44
end
55

6+
"""
7+
is_dde(sys::AbstractSystem)
8+
9+
Return a boolean indicating whether a system represents a set of delay
10+
differential equations.
11+
"""
12+
is_dde(sys::AbstractSystem) = has_is_dde(sys) && get_is_dde(sys)
13+
14+
function _check_if_dde(eqs, iv, subsystems)
15+
is_dde = any(ModelingToolkit.is_dde, subsystems)
16+
if !is_dde
17+
vs = Set()
18+
for eq in eqs
19+
vars!(vs, eq)
20+
is_dde = any(vs) do sym
21+
isdelay(unwrap(sym), iv)
22+
end
23+
is_dde && break
24+
end
25+
end
26+
return is_dde
27+
end
28+
629
function filter_kwargs(kwargs)
730
kwargs = Dict(kwargs)
831
for key in keys(kwargs)
@@ -219,29 +242,32 @@ function isdelay(var, iv)
219242
return false
220243
end
221244
const DDE_HISTORY_FUN = Sym{Symbolics.FnType{Tuple{Any, <:Real}, Vector{Real}}}(:___history___)
222-
function delay_to_function(sys::AbstractODESystem, eqs = full_equations(sys))
245+
const DEFAULT_PARAMS_ARG = Sym{Any}(:ˍ₋arg3)
246+
function delay_to_function(
247+
sys::AbstractODESystem, eqs = full_equations(sys); history_arg = DEFAULT_PARAMS_ARG)
223248
delay_to_function(eqs,
224249
get_iv(sys),
225250
Dict{Any, Int}(operation(s) => i for (i, s) in enumerate(unknowns(sys))),
226251
parameters(sys),
227-
DDE_HISTORY_FUN)
252+
DDE_HISTORY_FUN; history_arg)
228253
end
229-
function delay_to_function(eqs::Vector, iv, sts, ps, h)
230-
delay_to_function.(eqs, (iv,), (sts,), (ps,), (h,))
254+
function delay_to_function(eqs::Vector, iv, sts, ps, h; history_arg = DEFAULT_PARAMS_ARG)
255+
delay_to_function.(eqs, (iv,), (sts,), (ps,), (h,); history_arg)
231256
end
232-
function delay_to_function(eq::Equation, iv, sts, ps, h)
233-
delay_to_function(eq.lhs, iv, sts, ps, h) ~ delay_to_function(eq.rhs, iv, sts, ps, h)
257+
function delay_to_function(eq::Equation, iv, sts, ps, h; history_arg = DEFAULT_PARAMS_ARG)
258+
delay_to_function(eq.lhs, iv, sts, ps, h; history_arg) ~ delay_to_function(
259+
eq.rhs, iv, sts, ps, h; history_arg)
234260
end
235-
function delay_to_function(expr, iv, sts, ps, h)
261+
function delay_to_function(expr, iv, sts, ps, h; history_arg = DEFAULT_PARAMS_ARG)
236262
if isdelay(expr, iv)
237263
v = operation(expr)
238264
time = arguments(expr)[1]
239265
idx = sts[v]
240-
return term(getindex, h(Sym{Any}(:ˍ₋arg3), time), idx, type = Real) # BIG BIG HACK
266+
return term(getindex, h(history_arg, time), idx, type = Real) # BIG BIG HACK
241267
elseif iscall(expr)
242268
return maketerm(typeof(expr),
243269
operation(expr),
244-
map(x -> delay_to_function(x, iv, sts, ps, h), arguments(expr)),
270+
map(x -> delay_to_function(x, iv, sts, ps, h; history_arg), arguments(expr)),
245271
metadata(expr))
246272
else
247273
return expr

0 commit comments

Comments
 (0)