Skip to content

Commit da26283

Browse files
feat: support inplace parameter observed
1 parent d4c430e commit da26283

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

src/systems/abstractsystem.jl

+12-4
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ function SymbolicIndexingInterface.is_timeseries_parameter(sys::AbstractSystem,
481481
end
482482

483483
function SymbolicIndexingInterface.timeseries_parameter_index(sys::AbstractSystem, sym)
484-
has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return false
484+
has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return nothing
485485
timeseries_parameter_index(ic, sym)
486486
end
487487

@@ -504,9 +504,17 @@ function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
504504
else
505505
ts_idx = nothing
506506
end
507-
obsfn = let raw_obs_fn = build_explicit_observed_function(
508-
sys, sym; param_only = true)
509-
f(p::MTKParameters, t) = raw_obs_fn(p..., t)
507+
rawobs = build_explicit_observed_function(
508+
sys, sym; param_only = true, return_inplace = true)
509+
if rawobs isa Tuple
510+
obsfn = let oop = rawobs[1], iip = rawobs[2]
511+
f1(p::MTKParameters, t) = oop(p..., t)
512+
f1(out, p::MTKParameters, t) = iip(out, p..., t)
513+
end
514+
else
515+
obsfn = let rawobs = rawobs
516+
f2(p::MTKParameters, t) = rawobs(p..., t)
517+
end
510518
end
511519
else
512520
ts_idx = nothing

src/systems/diffeqs/odesystem.jl

+8-1
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ function build_explicit_observed_function(sys, ts;
381381
drop_expr = drop_expr,
382382
ps = full_parameters(sys),
383383
param_only = false,
384+
return_inplace = false,
384385
op = Operator,
385386
throw = true)
386387
if (isscalar = symbolic_type(ts) !== NotSymbolic())
@@ -490,7 +491,13 @@ function build_explicit_observed_function(sys, ts;
490491
args = param_only ? [ipts, ps..., ivs...] : [dvs, ipts, ps..., ivs...]
491492
end
492493
pre = get_postprocess_fbody(sys)
493-
494+
res = build_function(isscalar ? ts[1] : ts, args...; get_postprocess_fbody = pre, wrap_code = wrap_array_vars(sys, isscalar ? ts[1] : ts; dvs = param_only ? [] : unknowns(sys)), expression = Val{expression})
495+
if isscalar || return_inplace
496+
return res
497+
else
498+
return res[1]
499+
end
500+
494501
ex = Func(args, [],
495502
pre(Let(obsexprs,
496503
isscalar ? ts[1] : MakeArray(ts, output_type),

src/systems/index_cache.jl

-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ function IndexCache(sys::AbstractSystem)
127127
if istree(sym) && operation(sym) == Shift(t, 1)
128128
sym = only(arguments(sym))
129129
end
130-
# is_parameter(sys, sym) || is_parameter(sys, Hold(sym)) || continue
131130
disc_clocks[sym] = i - 1
132131
disc_clocks[sym] = i - 1
133132
disc_clocks[default_toterm(sym)] = i - 1

0 commit comments

Comments
 (0)