@@ -230,7 +230,8 @@ function wrap_parameter_dependencies(sys::AbstractSystem, isscalar)
230
230
end
231
231
232
232
function wrap_array_vars (
233
- sys:: AbstractSystem , exprs; dvs = unknowns (sys), ps = parameters (sys), inputs = nothing )
233
+ sys:: AbstractSystem , exprs; dvs = unknowns (sys), ps = parameters (sys),
234
+ inputs = nothing , history = false )
234
235
isscalar = ! (exprs isa AbstractArray)
235
236
array_vars = Dict {Any, AbstractArray{Int}} ()
236
237
if dvs != = nothing
@@ -328,6 +329,19 @@ function wrap_array_vars(
328
329
array_parameters[p] = (idxs, buffer_idx, sz)
329
330
end
330
331
end
332
+
333
+ inputind = if history
334
+ uind + 2
335
+ else
336
+ uind + 1
337
+ end
338
+ params_offset = if history && hasinputs
339
+ uind + 2
340
+ elseif history || hasinputs
341
+ uind + 1
342
+ else
343
+ uind
344
+ end
331
345
if isscalar
332
346
function (expr)
333
347
Func (
@@ -336,10 +350,10 @@ function wrap_array_vars(
336
350
Let (
337
351
vcat (
338
352
[k ← :(view ($ (expr. args[uind]. name), $ v)) for (k, v) in array_vars],
339
- [k ← :(view ($ (expr. args[uind + hasinputs ]. name), $ v))
353
+ [k ← :(view ($ (expr. args[inputind ]. name), $ v))
340
354
for (k, v) in input_vars],
341
355
[k ← :(reshape (
342
- view ($ (expr. args[uind + hasinputs + buffer_idx]. name), $ idxs),
356
+ view ($ (expr. args[params_offset + buffer_idx]. name), $ idxs),
343
357
$ sz))
344
358
for (k, (idxs, buffer_idx, sz)) in array_parameters],
345
359
[k ← Code. MakeArray (v, symtype (k))
@@ -358,10 +372,10 @@ function wrap_array_vars(
358
372
Let (
359
373
vcat (
360
374
[k ← :(view ($ (expr. args[uind]. name), $ v)) for (k, v) in array_vars],
361
- [k ← :(view ($ (expr. args[uind + hasinputs ]. name), $ v))
375
+ [k ← :(view ($ (expr. args[inputind ]. name), $ v))
362
376
for (k, v) in input_vars],
363
377
[k ← :(reshape (
364
- view ($ (expr. args[uind + hasinputs + buffer_idx]. name), $ idxs),
378
+ view ($ (expr. args[params_offset + buffer_idx]. name), $ idxs),
365
379
$ sz))
366
380
for (k, (idxs, buffer_idx, sz)) in array_parameters],
367
381
[k ← Code. MakeArray (v, symtype (k))
@@ -380,10 +394,10 @@ function wrap_array_vars(
380
394
vcat (
381
395
[k ← :(view ($ (expr. args[uind + 1 ]. name), $ v))
382
396
for (k, v) in array_vars],
383
- [k ← :(view ($ (expr. args[uind + hasinputs + 1 ]. name), $ v))
397
+ [k ← :(view ($ (expr. args[inputind + 1 ]. name), $ v))
384
398
for (k, v) in input_vars],
385
399
[k ← :(reshape (
386
- view ($ (expr. args[uind + hasinputs + buffer_idx + 1 ]. name),
400
+ view ($ (expr. args[params_offset + buffer_idx + 1 ]. name),
387
401
$ idxs),
388
402
$ sz))
389
403
for (k, (idxs, buffer_idx, sz)) in array_parameters],
@@ -398,50 +412,76 @@ function wrap_array_vars(
398
412
end
399
413
end
400
414
401
- function wrap_mtkparameters (sys:: AbstractSystem , isscalar:: Bool )
415
+ const MTKPARAMETERS_ARG = Sym {Vector{Vector}} (:___mtkparameters___ )
416
+
417
+ """
418
+ wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2)
419
+
420
+ Return function(s) to be passed to the `wrap_code` keyword of `build_function` which
421
+ allow the compiled function to be called as `f(u, p, t)` where `p isa MTKParameters`
422
+ instead of `f(u, p..., t)`. `isscalar` denotes whether the function expression being
423
+ wrapped is for a scalar value. `p_start` is the index of the argument containing
424
+ the first parameter vector in the out-of-place version of the function. For example,
425
+ if a history function (DDEs) was passed before `p`, then the function before wrapping
426
+ would have the signature `f(u, h, p..., t)` and hence `p_start` would need to be `3`.
427
+
428
+ The returned function is `identity` if the system does not have an `IndexCache`.
429
+ """
430
+ function wrap_mtkparameters (sys:: AbstractSystem , isscalar:: Bool , p_start = 2 )
402
431
if has_index_cache (sys) && get_index_cache (sys) != = nothing
403
432
offset = Int (is_time_dependent (sys))
404
433
405
434
if isscalar
406
435
function (expr)
407
- p = gensym (:p )
436
+ param_args = expr. args[p_start: (end - offset)]
437
+ param_buffer_idxs = findall (x -> x isa DestructuredArgs, param_args)
438
+ param_buffer_args = param_args[param_buffer_idxs]
439
+ destructured_mtkparams = DestructuredArgs (
440
+ [x. name for x in param_buffer_args],
441
+ MTKPARAMETERS_ARG; inds = param_buffer_idxs)
408
442
Func (
409
443
[
410
- expr. args[1 ],
411
- DestructuredArgs (
412
- [arg. name for arg in expr. args[2 : (end - offset)]], p),
413
- (isone (offset) ? (expr. args[end ],) : ()). ..
444
+ expr. args[begin : (p_start - 1 )]. .. ,
445
+ destructured_mtkparams,
446
+ expr. args[(end - offset + 1 ): end ]. ..
414
447
],
415
448
[],
416
- Let (expr . args[ 2 : ( end - offset)] , expr. body, false )
449
+ Let (param_buffer_args , expr. body, false )
417
450
)
418
451
end
419
452
else
420
453
function (expr)
421
- p = gensym (:p )
454
+ param_args = expr. args[p_start: (end - offset)]
455
+ param_buffer_idxs = findall (x -> x isa DestructuredArgs, param_args)
456
+ param_buffer_args = param_args[param_buffer_idxs]
457
+ destructured_mtkparams = DestructuredArgs (
458
+ [x. name for x in param_buffer_args],
459
+ MTKPARAMETERS_ARG; inds = param_buffer_idxs)
422
460
Func (
423
461
[
424
- expr. args[1 ],
425
- DestructuredArgs (
426
- [arg. name for arg in expr. args[2 : (end - offset)]], p),
427
- (isone (offset) ? (expr. args[end ],) : ()). ..
462
+ expr. args[begin : (p_start - 1 )]. .. ,
463
+ destructured_mtkparams,
464
+ expr. args[(end - offset + 1 ): end ]. ..
428
465
],
429
466
[],
430
- Let (expr . args[ 2 : ( end - offset)] , expr. body, false )
467
+ Let (param_buffer_args , expr. body, false )
431
468
)
432
469
end ,
433
470
function (expr)
434
- p = gensym (:p )
471
+ param_args = expr. args[(p_start + 1 ): (end - offset)]
472
+ param_buffer_idxs = findall (x -> x isa DestructuredArgs, param_args)
473
+ param_buffer_args = param_args[param_buffer_idxs]
474
+ destructured_mtkparams = DestructuredArgs (
475
+ [x. name for x in param_buffer_args],
476
+ MTKPARAMETERS_ARG; inds = param_buffer_idxs)
435
477
Func (
436
478
[
437
- expr. args[1 ],
438
- expr. args[2 ],
439
- DestructuredArgs (
440
- [arg. name for arg in expr. args[3 : (end - offset)]], p),
441
- (isone (offset) ? (expr. args[end ],) : ()). ..
479
+ expr. args[begin : p_start]. .. ,
480
+ destructured_mtkparams,
481
+ expr. args[(end - offset + 1 ): end ]. ..
442
482
],
443
483
[],
444
- Let (expr . args[ 3 : ( end - offset)] , expr. body, false )
484
+ Let (param_buffer_args , expr. body, false )
445
485
)
446
486
end
447
487
end
@@ -669,25 +709,17 @@ function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
669
709
if rawobs isa Tuple
670
710
if is_time_dependent (sys)
671
711
obsfn = let oop = rawobs[1 ], iip = rawobs[2 ]
672
- f1a (p:: MTKParameters , t) = oop (p... , t)
673
- f1a (out, p:: MTKParameters , t) = iip (out, p... , t)
712
+ f1a (p, t) = oop (p, t)
713
+ f1a (out, p, t) = iip (out, p, t)
674
714
end
675
715
else
676
716
obsfn = let oop = rawobs[1 ], iip = rawobs[2 ]
677
- f1b (p:: MTKParameters ) = oop (p... )
678
- f1b (out, p:: MTKParameters ) = iip (out, p... )
717
+ f1b (p) = oop (p)
718
+ f1b (out, p) = iip (out, p)
679
719
end
680
720
end
681
721
else
682
- if is_time_dependent (sys)
683
- obsfn = let rawobs = rawobs
684
- f2a (p:: MTKParameters , t) = rawobs (p... , t)
685
- end
686
- else
687
- obsfn = let rawobs = rawobs
688
- f2b (p:: MTKParameters ) = rawobs (p... )
689
- end
690
- end
722
+ obsfn = rawobs
691
723
end
692
724
else
693
725
obsfn = build_explicit_observed_function (sys, sym; param_only = true )
@@ -802,17 +834,11 @@ function SymbolicIndexingInterface.observed(
802
834
_fn = build_explicit_observed_function (sys, sym; eval_expression, eval_module)
803
835
804
836
if is_time_dependent (sys)
805
- return let _fn = _fn
806
- fn1 (u, p, t) = _fn (u, p, t)
807
- fn1 (u, p:: MTKParameters , t) = _fn (u, p... , t)
808
- fn1
809
- end
837
+ return _fn
810
838
else
811
839
return let _fn = _fn
812
840
fn2 (u, p) = _fn (u, p)
813
- fn2 (u, p:: MTKParameters ) = _fn (u, p... )
814
841
fn2 (:: Nothing , p) = _fn ([], p)
815
- fn2 (:: Nothing , p:: MTKParameters ) = _fn ([], p... )
816
842
fn2
817
843
end
818
844
end
828
854
SymbolicIndexingInterface. is_time_dependent (:: AbstractTimeDependentSystem ) = true
829
855
SymbolicIndexingInterface. is_time_dependent (:: AbstractTimeIndependentSystem ) = false
830
856
857
+ SymbolicIndexingInterface. is_markovian (sys:: AbstractSystem ) = ! is_dde (sys)
858
+
831
859
SymbolicIndexingInterface. constant_structure (:: AbstractSystem ) = true
832
860
833
861
function SymbolicIndexingInterface. all_variable_symbols (sys:: AbstractSystem )
@@ -971,6 +999,7 @@ for prop in [:eqs
971
999
:solved_unknowns
972
1000
:split_idxs
973
1001
:parent
1002
+ :is_dde
974
1003
:index_cache
975
1004
:is_scalar_noise
976
1005
:isscheduled ]
@@ -2349,8 +2378,8 @@ function linearization_function(sys::AbstractSystem, inputs,
2349
2378
u_getter = u_getter
2350
2379
2351
2380
function (u, p, t)
2352
- p_setter! (oldps, p_getter (u, p... , t))
2353
- newu = u_getter (u, p... , t)
2381
+ p_setter! (oldps, p_getter (u, p, t))
2382
+ newu = u_getter (u, p, t)
2354
2383
return newu, oldps
2355
2384
end
2356
2385
end
@@ -2361,20 +2390,15 @@ function linearization_function(sys::AbstractSystem, inputs,
2361
2390
2362
2391
function (u, p, t)
2363
2392
state = ProblemState (; u, p, t)
2364
- return u_getter (state), p_getter (state)
2393
+ return u_getter (
2394
+ state_values (state), parameter_values (state), current_time (state)),
2395
+ p_getter (state)
2365
2396
end
2366
2397
end
2367
2398
end
2368
2399
initfn = NonlinearFunction (initsys; eval_expression, eval_module)
2369
2400
initprobmap = build_explicit_observed_function (
2370
2401
initsys, unknowns (sys); eval_expression, eval_module)
2371
- if has_index_cache (sys) && get_index_cache (sys) != = nothing
2372
- initprobmap = let inner = initprobmap
2373
- fn (u, p:: MTKParameters ) = inner (u, p... )
2374
- fn (u, p) = inner (u, p)
2375
- fn
2376
- end
2377
- end
2378
2402
ps = parameters (sys)
2379
2403
h = build_explicit_observed_function (sys, outputs; eval_expression, eval_module)
2380
2404
lin_fun = let diff_idxs = diff_idxs,
@@ -2421,7 +2445,7 @@ function linearization_function(sys::AbstractSystem, inputs,
2421
2445
fg_xz = ForwardDiff. jacobian (uf, u)
2422
2446
h_xz = ForwardDiff. jacobian (
2423
2447
let p = p, t = t
2424
- xz -> p isa MTKParameters ? h (xz, p ... , t) : h (xz, p, t)
2448
+ xz -> h (xz, p, t)
2425
2449
end , u)
2426
2450
pf = SciMLBase. ParamJacobianWrapper (fun, t, u)
2427
2451
fg_u = jacobian_wrt_vars (pf, p, input_idxs, chunk)
@@ -2433,7 +2457,6 @@ function linearization_function(sys::AbstractSystem, inputs,
2433
2457
end
2434
2458
hp = let u = u, t = t
2435
2459
_hp (p) = h (u, p, t)
2436
- _hp (p:: MTKParameters ) = h (u, p... , t)
2437
2460
_hp
2438
2461
end
2439
2462
h_u = jacobian_wrt_vars (hp, p, input_idxs, chunk)
@@ -2486,7 +2509,7 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
2486
2509
dx = fun (sts, p... , t)
2487
2510
2488
2511
h = build_explicit_observed_function (sys, outputs; eval_expression, eval_module)
2489
- y = h (sts, p... , t)
2512
+ y = h (sts, p, t)
2490
2513
2491
2514
fg_xz = Symbolics. jacobian (dx, sts)
2492
2515
fg_u = Symbolics. jacobian (dx, inputs)
@@ -2955,6 +2978,9 @@ function compose(sys::AbstractSystem, systems::AbstractArray; name = nameof(sys)
2955
2978
nsys == 0 && return sys
2956
2979
@set! sys. name = name
2957
2980
@set! sys. systems = [get_systems (sys); systems]
2981
+ if has_is_dde (sys)
2982
+ @set! sys. is_dde = _check_if_dde (equations (sys), get_iv (sys), get_systems (sys))
2983
+ end
2958
2984
return sys
2959
2985
end
2960
2986
function compose (syss... ; name = nameof (first (syss)))
0 commit comments