Skip to content

Commit

Permalink
Merge pull request #372 from AayushSabharwal/as/parameter-timeseries
Browse files Browse the repository at this point in the history
feat: add parameter timeseries support to `AbstractDiffEqArray`
  • Loading branch information
ChrisRackauckas authored Jun 14, 2024
2 parents 320b577 + 33b2fe3 commit c531cb2
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 45 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ StaticArrays = "1.6"
StaticArraysCore = "1.4"
Statistics = "1.10"
StructArrays = "0.6.11"
SymbolicIndexingInterface = "0.3.20"
SymbolicIndexingInterface = "0.3.23"
Tables = "1.11"
Test = "1"
Tracker = "0.2.15"
Expand Down
88 changes: 44 additions & 44 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,13 @@ A[1, :] # all time periods for f(t)
A.t
```
"""
mutable struct DiffEqArray{T, N, A, B, F, S} <: AbstractDiffEqArray{T, N, A}
mutable struct DiffEqArray{T, N, A, B, F, S, D <: Union{Nothing, ParameterTimeseriesCollection}} <:
AbstractDiffEqArray{T, N, A}
u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
t::B
p::F
sys::S
discretes::D
end
### Abstract Interface
struct AllObserved
Expand Down Expand Up @@ -174,29 +176,32 @@ function DiffEqArray(vec::AbstractVector{T},
ts::AbstractVector,
::NTuple{N, Int},
p = nothing,
sys = nothing) where {T, N}
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys)}(vec,
sys = nothing; discretes = nothing) where {T, N}
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(vec,
ts,
p,
sys)
sys,
discretes)
end

# ambiguity resolution
function DiffEqArray(vec::AbstractVector{VT},
ts::AbstractVector,
::NTuple{N, Int}) where {T, N, VT <: AbstractArray{T, N}}
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing}(vec,
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing, Nothing}(vec,
ts,
nothing,
nothing,
nothing)
end
function DiffEqArray(vec::AbstractVector{VT},
ts::AbstractVector,
::NTuple{N, Int}, p) where {T, N, VT <: AbstractArray{T, N}}
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing}(vec,
::NTuple{N, Int}, p; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}}
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}(vec,
ts,
p,
nothing)
nothing,
discretes)
end
# Assume that the first element is representative of all other elements

Expand All @@ -206,7 +211,8 @@ function DiffEqArray(vec::AbstractVector,
sys = nothing;
variables = nothing,
parameters = nothing,
independent_variables = nothing)
independent_variables = nothing,
discretes = nothing)
sys = something(sys,
SymbolCache(something(variables, []),
something(parameters, []),
Expand All @@ -219,11 +225,13 @@ function DiffEqArray(vec::AbstractVector,
typeof(vec),
typeof(ts),
typeof(p),
typeof(sys)
typeof(sys),
typeof(discretes)
}(vec,
ts,
p,
sys)
sys,
discretes)
end

function DiffEqArray(vec::AbstractVector{VT},
Expand All @@ -232,7 +240,8 @@ function DiffEqArray(vec::AbstractVector{VT},
sys = nothing;
variables = nothing,
parameters = nothing,
independent_variables = nothing) where {T, N, VT <: AbstractArray{T, N}}
independent_variables = nothing,
discretes = nothing) where {T, N, VT <: AbstractArray{T, N}}
sys = something(sys,
SymbolCache(something(variables, []),
something(parameters, []),
Expand All @@ -243,18 +252,30 @@ function DiffEqArray(vec::AbstractVector{VT},
typeof(vec),
typeof(ts),
typeof(p),
typeof(sys)
typeof(sys),
typeof(discretes),
}(vec,
ts,
p,
sys)
sys,
discretes)
end

has_discretes(::T) where {T <: AbstractDiffEqArray} = hasfield(T, :discretes)
get_discretes(x) = getfield(x, :discretes)

SymbolicIndexingInterface.is_timeseries(::Type{<:AbstractVectorOfArray}) = Timeseries()
function SymbolicIndexingInterface.is_parameter_timeseries(::Type{DiffEqArray{T, N, A, B,
F, S, D}}) where {T, N, A, B, F, S, D <: ParameterIndexingProxy}
Timeseries()
end
SymbolicIndexingInterface.state_values(A::AbstractDiffEqArray) = A.u
SymbolicIndexingInterface.current_time(A::AbstractDiffEqArray) = A.t
SymbolicIndexingInterface.parameter_values(A::AbstractDiffEqArray) = A.p
SymbolicIndexingInterface.symbolic_container(A::AbstractDiffEqArray) = A.sys
function SymbolicIndexingInterface.get_parameter_timeseries_collection(A::AbstractDiffEqArray)
return get_discretes(A)
end

Base.IndexStyle(A::AbstractVectorOfArray) = Base.IndexStyle(typeof(A))
Base.IndexStyle(::Type{<:AbstractVectorOfArray}) = IndexCartesian()
Expand Down Expand Up @@ -363,39 +384,18 @@ end

# Symbolic Indexing Methods
for (symtype, elsymtype, valtype, errcheck) in [
(ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym))),
(ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym))),
(ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))),
(ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))),
(NotSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Union{<:Tuple, <:AbstractArray},
:(all(x -> is_parameter(A, x), sym))),
:(all(x -> is_parameter(A, x) && !is_timeseries_parameter(A, x), sym))),
]
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
::$elsymtype, sym::$valtype)
if $errcheck
throw(ParameterIndexingError(sym))
end
getu(A, sym)(A)
end
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
::$elsymtype, sym::$valtype, arg)
if $errcheck
throw(ParameterIndexingError(sym))
end
getu(A, sym)(A, arg)
end
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
::$elsymtype, sym::$valtype, arg::Union{AbstractArray{Int}, AbstractArray{Bool}})
if $errcheck
throw(ParameterIndexingError(sym))
end
getu(A, sym).((A,), arg)
end
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
::$elsymtype, sym::$valtype, ::Colon)
if $errcheck
throw(ParameterIndexingError(sym))
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
::$elsymtype, sym::$valtype, arg...)
if $errcheck
throw(ParameterIndexingError(sym))
end
getu(A, sym)(A, arg...)
end
getu(A, sym)(A)
end
end

Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic,
Expand Down

0 comments on commit c531cb2

Please sign in to comment.