Skip to content

Commit 06cdae2

Browse files
feat: add parameter timeseries support to AbstractDiffEqArray
1 parent 80a2a24 commit 06cdae2

File tree

1 file changed

+44
-44
lines changed

1 file changed

+44
-44
lines changed

src/vector_of_array.jl

+44-44
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,13 @@ A[1, :] # all time periods for f(t)
6060
A.t
6161
```
6262
"""
63-
mutable struct DiffEqArray{T, N, A, B, F, S} <: AbstractDiffEqArray{T, N, A}
63+
mutable struct DiffEqArray{T, N, A, B, F, S, D <: Union{Nothing, ParameterTimeseriesCollection}} <:
64+
AbstractDiffEqArray{T, N, A}
6465
u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
6566
t::B
6667
p::F
6768
sys::S
69+
discretes::D
6870
end
6971
### Abstract Interface
7072
struct AllObserved
@@ -174,29 +176,32 @@ function DiffEqArray(vec::AbstractVector{T},
174176
ts::AbstractVector,
175177
::NTuple{N, Int},
176178
p = nothing,
177-
sys = nothing) where {T, N}
178-
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys)}(vec,
179+
sys = nothing; discretes = nothing) where {T, N}
180+
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(vec,
179181
ts,
180182
p,
181-
sys)
183+
sys,
184+
discretes)
182185
end
183186

184187
# ambiguity resolution
185188
function DiffEqArray(vec::AbstractVector{VT},
186189
ts::AbstractVector,
187190
::NTuple{N, Int}) where {T, N, VT <: AbstractArray{T, N}}
188-
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing}(vec,
191+
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing, Nothing}(vec,
189192
ts,
190193
nothing,
194+
nothing,
191195
nothing)
192196
end
193197
function DiffEqArray(vec::AbstractVector{VT},
194198
ts::AbstractVector,
195-
::NTuple{N, Int}, p) where {T, N, VT <: AbstractArray{T, N}}
196-
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing}(vec,
199+
::NTuple{N, Int}, p; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}}
200+
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}(vec,
197201
ts,
198202
p,
199-
nothing)
203+
nothing,
204+
discretes)
200205
end
201206
# Assume that the first element is representative of all other elements
202207

@@ -206,7 +211,8 @@ function DiffEqArray(vec::AbstractVector,
206211
sys = nothing;
207212
variables = nothing,
208213
parameters = nothing,
209-
independent_variables = nothing)
214+
independent_variables = nothing,
215+
discretes = nothing)
210216
sys = something(sys,
211217
SymbolCache(something(variables, []),
212218
something(parameters, []),
@@ -219,11 +225,13 @@ function DiffEqArray(vec::AbstractVector,
219225
typeof(vec),
220226
typeof(ts),
221227
typeof(p),
222-
typeof(sys)
228+
typeof(sys),
229+
typeof(discretes)
223230
}(vec,
224231
ts,
225232
p,
226-
sys)
233+
sys,
234+
discretes)
227235
end
228236

229237
function DiffEqArray(vec::AbstractVector{VT},
@@ -232,7 +240,8 @@ function DiffEqArray(vec::AbstractVector{VT},
232240
sys = nothing;
233241
variables = nothing,
234242
parameters = nothing,
235-
independent_variables = nothing) where {T, N, VT <: AbstractArray{T, N}}
243+
independent_variables = nothing,
244+
discretes = nothing) where {T, N, VT <: AbstractArray{T, N}}
236245
sys = something(sys,
237246
SymbolCache(something(variables, []),
238247
something(parameters, []),
@@ -243,18 +252,30 @@ function DiffEqArray(vec::AbstractVector{VT},
243252
typeof(vec),
244253
typeof(ts),
245254
typeof(p),
246-
typeof(sys)
255+
typeof(sys),
256+
typeof(discretes),
247257
}(vec,
248258
ts,
249259
p,
250-
sys)
260+
sys,
261+
discretes)
251262
end
252263

264+
has_discretes(::T) where {T <: AbstractDiffEqArray} = hasfield(T, :discretes)
265+
get_discretes(x) = getfield(x, :discretes)
266+
253267
SymbolicIndexingInterface.is_timeseries(::Type{<:AbstractVectorOfArray}) = Timeseries()
268+
function SymbolicIndexingInterface.is_parameter_timeseries(::Type{DiffEqArray{T, N, A, B,
269+
F, S, D}}) where {T, N, A, B, F, S, D <: ParameterIndexingProxy}
270+
Timeseries()
271+
end
254272
SymbolicIndexingInterface.state_values(A::AbstractDiffEqArray) = A.u
255273
SymbolicIndexingInterface.current_time(A::AbstractDiffEqArray) = A.t
256274
SymbolicIndexingInterface.parameter_values(A::AbstractDiffEqArray) = A.p
257275
SymbolicIndexingInterface.symbolic_container(A::AbstractDiffEqArray) = A.sys
276+
function SymbolicIndexingInterface.get_parameter_timeseries_collection(A::AbstractDiffEqArray)
277+
return get_discretes(A)
278+
end
258279

259280
Base.IndexStyle(A::AbstractVectorOfArray) = Base.IndexStyle(typeof(A))
260281
Base.IndexStyle(::Type{<:AbstractVectorOfArray}) = IndexCartesian()
@@ -363,39 +384,18 @@ end
363384

364385
# Symbolic Indexing Methods
365386
for (symtype, elsymtype, valtype, errcheck) in [
366-
(ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym))),
367-
(ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym))),
387+
(ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))),
388+
(ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))),
368389
(NotSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Union{<:Tuple, <:AbstractArray},
369-
:(all(x -> is_parameter(A, x), sym))),
390+
:(all(x -> is_parameter(A, x) && !is_timeseries_parameter(A, x), sym))),
370391
]
371-
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
372-
::$elsymtype, sym::$valtype)
373-
if $errcheck
374-
throw(ParameterIndexingError(sym))
375-
end
376-
getu(A, sym)(A)
377-
end
378-
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
379-
::$elsymtype, sym::$valtype, arg)
380-
if $errcheck
381-
throw(ParameterIndexingError(sym))
382-
end
383-
getu(A, sym)(A, arg)
384-
end
385-
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
386-
::$elsymtype, sym::$valtype, arg::Union{AbstractArray{Int}, AbstractArray{Bool}})
387-
if $errcheck
388-
throw(ParameterIndexingError(sym))
389-
end
390-
getu(A, sym).((A,), arg)
391-
end
392-
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
393-
::$elsymtype, sym::$valtype, ::Colon)
394-
if $errcheck
395-
throw(ParameterIndexingError(sym))
392+
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
393+
::$elsymtype, sym::$valtype, arg...)
394+
if $errcheck
395+
throw(ParameterIndexingError(sym))
396+
end
397+
getu(A, sym)(A, arg...)
396398
end
397-
getu(A, sym)(A)
398-
end
399399
end
400400

401401
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic,

0 commit comments

Comments
 (0)