Skip to content

feaet: add ControlFunction #996

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ export ODEFunction, DiscreteFunction, ImplicitDiscreteFunction, SplitFunction, D
DDEFunction, SDEFunction, SplitSDEFunction, RODEFunction, SDDEFunction,
IncrementingODEFunction, NonlinearFunction, HomotopyNonlinearFunction,
IntervalNonlinearFunction, BVPFunction,
DynamicalBVPFunction, IntegralFunction, BatchIntegralFunction
DynamicalBVPFunction, IntegralFunction, BatchIntegralFunction, ODEInputFunction

export OptimizationFunction, MultiObjectiveOptimizationFunction

Expand Down
2 changes: 1 addition & 1 deletion src/problems/implicit_discrete_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dt: the time step

### Constructors

- `ImplicitDiscreteProblem(f::ODEFunction,u0,tspan,p=NullParameters();kwargs...)` :
- `ImplicitDiscreteProblem(f::ImplicitDiscreteFunction,u0,tspan,p=NullParameters();kwargs...)` :
Defines the discrete problem with the specified functions.
- `ImplicitDiscreteProblem{isinplace,specialize}(f,u0,tspan,p=NullParameters();kwargs...)` :
Defines the discrete problem with the specified functions.
Expand Down
248 changes: 248 additions & 0 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2094,6 +2094,109 @@ struct MultiObjectiveOptimizationFunction{
initialization_data::ID
end

"""
$(TYPEDEF)
"""
abstract type AbstractODEInputFunction{iip} <: AbstractDiffEqFunction{iip} end

@doc doc"""
$(TYPEDEF)

A representation of a ODE function `f` with inputs, defined by:

```math
\frac{dx}{dt} = f(x, u, p, t)
```
where `x` are the states of the system and `u` are the inputs (which may represent
different things in different contexts, such as control variables in optimal control).

Includes all of its related functions, such as the Jacobian of `f`, its gradient
with respect to time, and more. For all cases, `u0` is the initial condition,
`p` are the parameters, and `t` is the independent variable.

```julia
ODEInputFunction{iip, specialize}(f;
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I,
analytic = __has_analytic(f) ? f.analytic : nothing,
tgrad= __has_tgrad(f) ? f.tgrad : nothing,
jac = __has_jac(f) ? f.jac : nothing,
control_jac = __has_controljac(f) ? f.controljac : nothing,
jvp = __has_jvp(f) ? f.jvp : nothing,
vjp = __has_vjp(f) ? f.vjp : nothing,
jac_prototype = __has_jac_prototype(f) ? f.jac_prototype : nothing,
controljac_prototype = __has_controljac_prototype(f) ? f.controljac_prototype : nothing,
sparsity = __has_sparsity(f) ? f.sparsity : jac_prototype,
paramjac = __has_paramjac(f) ? f.paramjac : nothing,
syms = nothing,
indepsym = nothing,
paramsyms = nothing,
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
sys = __has_sys(f) ? f.sys : nothing)
```

`f` should be given as `f(x_out,x,u,p,t)` or `out = f(x,u,p,t)`.
See the section on `iip` for more details on in-place vs out-of-place handling.

- `mass_matrix`: the mass matrix `M` represented in the BVP function. Can be used
to determine that the equation is actually a BVP for differential algebraic equation (DAE)
if `M` is singular.
- `jac(J,dx,x,u,p,gamma,t)` or `J=jac(dx,x,u,p,gamma,t)`: returns ``\frac{df}{dx}``
- `control_jac(J,du,x,u,p,gamma,t)` or `J=control_jac(du,x,u,p,gamma,t)`: returns ``\frac{df}{du}``
- `jvp(Jv,v,du,x,u,p,gamma,t)` or `Jv=jvp(v,du,x,u,p,gamma,t)`: returns the directional
derivative ``\frac{df}{du} v``
- `vjp(Jv,v,du,x,u,p,gamma,t)` or `Jv=vjp(v,du,x,u,p,gamma,t)`: returns the adjoint
derivative ``\frac{df}{du}^\ast v``
- `jac_prototype`: a prototype matrix matching the type that matches the Jacobian. For example,
if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used
as the prototype and integrators will specialize on this structure where possible. Non-structured
sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian.
The default is `nothing`, which means a dense Jacobian.
- `controljac_prototype`: a prototype matrix matching the type that matches the Jacobian. For example,
if the Jacobian is tridiagonal, then an appropriately sized `Tridiagonal` matrix can be used
as the prototype and integrators will specialize on this structure where possible. Non-structured
sparsity patterns should use a `SparseMatrixCSC` with a correct sparsity pattern for the Jacobian.
The default is `nothing`, which means a dense Jacobian.
- `paramjac(pJ,x,u,p,t)`: returns the parameter Jacobian ``\frac{df}{dp}``.
- `colorvec`: a color vector according to the SparseDiffTools.jl definition for the sparsity
pattern of the `jac_prototype`. This specializes the Jacobian construction when using
finite differences and automatic differentiation to be computed in an accelerated manner
based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be
internally computed on demand when required. The cost of this operation is highly dependent
on the sparsity pattern.

## iip: In-Place vs Out-Of-Place
For more details on this argument, see the ODEFunction documentation.

## specialize: Controlling Compilation and Specialization
For more details on this argument, see the ODEFunction documentation.

## Fields
The fields of the ODEInputFunction type directly match the names of the inputs.
"""
struct ODEInputFunction{iip, specialize, F, TMM, Ta, Tt, TJ, CTJ, JVP, VJP,
JP, CJP, SP, TW, TWt, WP, TPJ, O, TCV,
SYS, ID} <: AbstractODEInputFunction{iip}
f::F
mass_matrix::TMM
analytic::Ta
tgrad::Tt
jac::TJ
controljac::CTJ
jvp::JVP
vjp::VJP
jac_prototype::JP
controljac_prototype::CJP
sparsity::SP
Wfact::TW
Wfact_t::TWt
W_prototype::WP
paramjac::TPJ
observed::O
colorvec::TCV
sys::SYS
initialization_data::ID
end

"""
$(TYPEDEF)
"""
Expand Down Expand Up @@ -2493,6 +2596,7 @@ end
(f::ImplicitDiscreteFunction)(args...) = f.f(args...)
(f::DAEFunction)(args...) = f.f(args...)
(f::DDEFunction)(args...) = f.f(args...)
(f::ODEInputFunction)(args...) = f.f(args...)

function (f::DynamicalDDEFunction)(u, h, p, t)
ArrayPartition(f.f1(u.x[1], u.x[2], h, p, t), f.f2(u.x[1], u.x[2], h, p, t))
Expand Down Expand Up @@ -4595,6 +4699,149 @@ function BatchIntegralFunction(f, integrand_prototype; kwargs...)
BatchIntegralFunction{calculated_iip}(f, integrand_prototype; kwargs...)
end

function ODEInputFunction{iip, specialize}(f;
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix :
I,
analytic = __has_analytic(f) ? f.analytic : nothing,
tgrad = __has_tgrad(f) ? f.tgrad : nothing,
jac = __has_jac(f) ? f.jac : nothing,
controljac = __has_controljac(f) ? f.controljac : nothing,
jvp = __has_jvp(f) ? f.jvp : nothing,
vjp = __has_vjp(f) ? f.vjp : nothing,
jac_prototype = __has_jac_prototype(f) ?
f.jac_prototype :
nothing,
controljac_prototype = __has_controljac_prototype(f) ?
f.controljac_prototype :
nothing,
sparsity = __has_sparsity(f) ? f.sparsity :
jac_prototype,
Wfact = __has_Wfact(f) ? f.Wfact : nothing,
Wfact_t = __has_Wfact_t(f) ? f.Wfact_t : nothing,
W_prototype = __has_W_prototype(f) ? f.W_prototype : nothing,
paramjac = __has_paramjac(f) ? f.paramjac : nothing,
syms = nothing,
indepsym = nothing,
paramsyms = nothing,
observed = __has_observed(f) ? f.observed :
DEFAULT_OBSERVED,
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
sys = __has_sys(f) ? f.sys : nothing,
initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing,
update_initializeprob! = __has_update_initializeprob!(f) ?
f.update_initializeprob! : nothing,
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
initialization_data = __has_initialization_data(f) ? f.initialization_data :
nothing,
nlprob_data = __has_nlprob_data(f) ? f.nlprob_data : nothing
) where {iip,
specialize
}
if mass_matrix === I && f isa Tuple
mass_matrix = ((I for i in 1:length(f))...,)
end

if (specialize === FunctionWrapperSpecialize) &&
!(f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
error("FunctionWrapperSpecialize must be used on the problem constructor for access to u0, p, and t types!")
end

if jac === nothing && isa(jac_prototype, AbstractSciMLOperator)
if iip
jac = (J, x, u, p, t) -> update_coefficients!(J, x, p, t) #(J,x,u,p,t)
else
jac = (x, u, p, t) -> update_coefficients(deepcopy(jac_prototype), x, p, t)
end
end

if controljac === nothing && isa(controljac_prototype, AbstractSciMLOperator)
if iip_bc
controljac = (J, x, u, p, t) -> update_coefficients!(J, u, p, t) #(J,x,u,p,t)
else
controljac = (x, u, p, t) -> update_coefficients(deepcopy(controljac_prototype), u, p, t)
end
end

if jac_prototype !== nothing && colorvec === nothing &&
ArrayInterface.fast_matrix_colors(jac_prototype)
_colorvec = ArrayInterface.matrix_colors(jac_prototype)
else
_colorvec = colorvec
end

jaciip = jac !== nothing ? isinplace(jac, 5, "jac", iip) : iip
controljaciip = controljac !== nothing ? isinplace(controljac, 5, "controljac", iip) : iip
tgradiip = tgrad !== nothing ? isinplace(tgrad, 5, "tgrad", iip) : iip
jvpiip = jvp !== nothing ? isinplace(jvp, 6, "jvp", iip) : iip
vjpiip = vjp !== nothing ? isinplace(vjp, 6, "vjp", iip) : iip
Wfactiip = Wfact !== nothing ? isinplace(Wfact, 6, "Wfact", iip) : iip
Wfact_tiip = Wfact_t !== nothing ? isinplace(Wfact_t, 6, "Wfact_t", iip) : iip
paramjaciip = paramjac !== nothing ? isinplace(paramjac, 5, "paramjac", iip) : iip

nonconforming = (jaciip, tgradiip, jvpiip, vjpiip, Wfactiip, Wfact_tiip,
paramjaciip) .!= iip
if any(nonconforming)
nonconforming = findall(nonconforming)
functions = ["jac", "tgrad", "jvp", "vjp", "Wfact", "Wfact_t", "paramjac"][nonconforming]
throw(NonconformingFunctionsError(functions))
end

_f = prepare_function(f)

sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
initdata = reconstruct_initialization_data(
initialization_data, initializeprob, update_initializeprob!,
initializeprobmap, initializeprobpmap)

if specialize === NoSpecialize
ODEInputFunction{iip, specialize,
Any, Any, Any, Any,
Any, Any, Any, Any, typeof(jac_prototype), typeof(controljac_prototype),
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
Any,
typeof(_colorvec),
typeof(sys), Union{Nothing, OverrideInitData}}(
_f, mass_matrix, analytic, tgrad, jac, controljac,
jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initdata)
elseif specialize === false
ODEInputFunction{iip, FunctionWrapperSpecialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
typeof(jac), typeof(controljac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(controljac_prototype),
typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype),
typeof(paramjac),
typeof(observed),
typeof(_colorvec),
typeof(sys), typeof(initdata)}(_f, mass_matrix,
analytic, tgrad, jac, controljac,
jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initdata)
else
ODEInputFunction{iip, specialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
typeof(jac), typeof(controljac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(controljac_prototype),
typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype),
typeof(paramjac),
typeof(observed),
typeof(_colorvec),
typeof(sys), typeof(initdata)}(
_f, mass_matrix, analytic, tgrad,
jac, controljac, jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initdata)
end
end

function ODEInputFunction{iip}(f; kwargs...) where {iip}
ODEInputFunction{iip, FullSpecialize}(f; kwargs...)
end
ODEInputFunction{iip}(f::ODEInputFunction; kwargs...) where {iip} = f
ODEInputFunction(f; kwargs...) = ODEInputFunction{isinplace(f, 5), FullSpecialize}(f; kwargs...)
ODEInputFunction(f::ODEInputFunction; kwargs...) = f

########## Utility functions

function sys_or_symbolcache(sys, syms, paramsyms, indepsym = nothing)
Expand Down Expand Up @@ -4628,6 +4875,7 @@ __has_Wfact_t(f) = isdefined(f, :Wfact_t)
__has_W_prototype(f) = isdefined(f, :W_prototype)
__has_paramjac(f) = isdefined(f, :paramjac)
__has_jac_prototype(f) = isdefined(f, :jac_prototype)
__has_controljac_prototype(f) = isdefined(f, :controljac_prototype)
__has_sparsity(f) = isdefined(f, :sparsity)
__has_mass_matrix(f) = isdefined(f, :mass_matrix)
__has_syms(f) = isdefined(f, :syms)
Expand Down
Empty file added src/solutions/solution_utils.jl
Empty file.
Loading