Skip to content

Commit 50d362c

Browse files
Merge pull request #394 from AayushSabharwal/revert-372-as/parameter-timeseries
Revert "feat: add parameter timeseries support to `AbstractDiffEqArray`"
2 parents f41afbb + 7517384 commit 50d362c

File tree

2 files changed

+45
-45
lines changed

2 files changed

+45
-45
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ StaticArrays = "1.6"
5959
StaticArraysCore = "1.4"
6060
Statistics = "1.10"
6161
StructArrays = "0.6.11"
62-
SymbolicIndexingInterface = "0.3.23"
62+
SymbolicIndexingInterface = "0.3.20"
6363
Tables = "1.11"
6464
Test = "1"
6565
Tracker = "0.2.15"

src/vector_of_array.jl

+44-44
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,11 @@ A[1, :] # all time periods for f(t)
6060
A.t
6161
```
6262
"""
63-
mutable struct DiffEqArray{T, N, A, B, F, S, D <: Union{Nothing, ParameterTimeseriesCollection}} <:
64-
AbstractDiffEqArray{T, N, A}
63+
mutable struct DiffEqArray{T, N, A, B, F, S} <: AbstractDiffEqArray{T, N, A}
6564
u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
6665
t::B
6766
p::F
6867
sys::S
69-
discretes::D
7068
end
7169
### Abstract Interface
7270
struct AllObserved
@@ -176,32 +174,29 @@ function DiffEqArray(vec::AbstractVector{T},
176174
ts::AbstractVector,
177175
::NTuple{N, Int},
178176
p = nothing,
179-
sys = nothing; discretes = nothing) where {T, N}
180-
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(vec,
177+
sys = nothing) where {T, N}
178+
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys)}(vec,
181179
ts,
182180
p,
183-
sys,
184-
discretes)
181+
sys)
185182
end
186183

187184
# ambiguity resolution
188185
function DiffEqArray(vec::AbstractVector{VT},
189186
ts::AbstractVector,
190187
::NTuple{N, Int}) where {T, N, VT <: AbstractArray{T, N}}
191-
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing, Nothing}(vec,
188+
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing}(vec,
192189
ts,
193190
nothing,
194-
nothing,
195191
nothing)
196192
end
197193
function DiffEqArray(vec::AbstractVector{VT},
198194
ts::AbstractVector,
199-
::NTuple{N, Int}, p; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}}
200-
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}(vec,
195+
::NTuple{N, Int}, p) where {T, N, VT <: AbstractArray{T, N}}
196+
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing}(vec,
201197
ts,
202198
p,
203-
nothing,
204-
discretes)
199+
nothing)
205200
end
206201
# Assume that the first element is representative of all other elements
207202

@@ -211,8 +206,7 @@ function DiffEqArray(vec::AbstractVector,
211206
sys = nothing;
212207
variables = nothing,
213208
parameters = nothing,
214-
independent_variables = nothing,
215-
discretes = nothing)
209+
independent_variables = nothing)
216210
sys = something(sys,
217211
SymbolCache(something(variables, []),
218212
something(parameters, []),
@@ -225,13 +219,11 @@ function DiffEqArray(vec::AbstractVector,
225219
typeof(vec),
226220
typeof(ts),
227221
typeof(p),
228-
typeof(sys),
229-
typeof(discretes)
222+
typeof(sys)
230223
}(vec,
231224
ts,
232225
p,
233-
sys,
234-
discretes)
226+
sys)
235227
end
236228

237229
function DiffEqArray(vec::AbstractVector{VT},
@@ -240,8 +232,7 @@ function DiffEqArray(vec::AbstractVector{VT},
240232
sys = nothing;
241233
variables = nothing,
242234
parameters = nothing,
243-
independent_variables = nothing,
244-
discretes = nothing) where {T, N, VT <: AbstractArray{T, N}}
235+
independent_variables = nothing) where {T, N, VT <: AbstractArray{T, N}}
245236
sys = something(sys,
246237
SymbolCache(something(variables, []),
247238
something(parameters, []),
@@ -252,30 +243,18 @@ function DiffEqArray(vec::AbstractVector{VT},
252243
typeof(vec),
253244
typeof(ts),
254245
typeof(p),
255-
typeof(sys),
256-
typeof(discretes),
246+
typeof(sys)
257247
}(vec,
258248
ts,
259249
p,
260-
sys,
261-
discretes)
250+
sys)
262251
end
263252

264-
has_discretes(::T) where {T <: AbstractDiffEqArray} = hasfield(T, :discretes)
265-
get_discretes(x) = getfield(x, :discretes)
266-
267253
SymbolicIndexingInterface.is_timeseries(::Type{<:AbstractVectorOfArray}) = Timeseries()
268-
function SymbolicIndexingInterface.is_parameter_timeseries(::Type{DiffEqArray{T, N, A, B,
269-
F, S, D}}) where {T, N, A, B, F, S, D <: ParameterIndexingProxy}
270-
Timeseries()
271-
end
272254
SymbolicIndexingInterface.state_values(A::AbstractDiffEqArray) = A.u
273255
SymbolicIndexingInterface.current_time(A::AbstractDiffEqArray) = A.t
274256
SymbolicIndexingInterface.parameter_values(A::AbstractDiffEqArray) = A.p
275257
SymbolicIndexingInterface.symbolic_container(A::AbstractDiffEqArray) = A.sys
276-
function SymbolicIndexingInterface.get_parameter_timeseries_collection(A::AbstractDiffEqArray)
277-
return get_discretes(A)
278-
end
279258

280259
Base.IndexStyle(A::AbstractVectorOfArray) = Base.IndexStyle(typeof(A))
281260
Base.IndexStyle(::Type{<:AbstractVectorOfArray}) = IndexCartesian()
@@ -384,18 +363,39 @@ end
384363

385364
# Symbolic Indexing Methods
386365
for (symtype, elsymtype, valtype, errcheck) in [
387-
(ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))),
388-
(ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))),
366+
(ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym))),
367+
(ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym))),
389368
(NotSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Union{<:Tuple, <:AbstractArray},
390-
:(all(x -> is_parameter(A, x) && !is_timeseries_parameter(A, x), sym))),
369+
:(all(x -> is_parameter(A, x), sym))),
391370
]
392-
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
393-
::$elsymtype, sym::$valtype, arg...)
394-
if $errcheck
395-
throw(ParameterIndexingError(sym))
396-
end
397-
getu(A, sym)(A, arg...)
371+
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
372+
::$elsymtype, sym::$valtype)
373+
if $errcheck
374+
throw(ParameterIndexingError(sym))
398375
end
376+
getu(A, sym)(A)
377+
end
378+
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
379+
::$elsymtype, sym::$valtype, arg)
380+
if $errcheck
381+
throw(ParameterIndexingError(sym))
382+
end
383+
getu(A, sym)(A, arg)
384+
end
385+
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
386+
::$elsymtype, sym::$valtype, arg::Union{AbstractArray{Int}, AbstractArray{Bool}})
387+
if $errcheck
388+
throw(ParameterIndexingError(sym))
389+
end
390+
getu(A, sym).((A,), arg)
391+
end
392+
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
393+
::$elsymtype, sym::$valtype, ::Colon)
394+
if $errcheck
395+
throw(ParameterIndexingError(sym))
396+
end
397+
getu(A, sym)(A)
398+
end
399399
end
400400

401401
Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::ScalarSymbolic,

0 commit comments

Comments
 (0)