Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,25 @@ CommonSolve = "0.2"
ControllerFormats = "0.1 - 0.2"
LazySets = "5 - 6"
NeuralNetworkReachability = "0.1"
OrdinaryDiffEq = "6"
Parameters = "0.12"
Plots = "1"
ReachabilityAnalysis = "0.29 - 0.30"
ReachabilityBase = "0.3"
Reexport = "0.2, 1"
Requires = "0.5, 1"
TaylorModels = "0.7 - 0.9"
julia = "1.6"

[extensions]
OrdinaryDiffEqExt = "OrdinaryDiffEq"
OrdinaryDiffEqPlotsExt = ["OrdinaryDiffEq", "Plots"]

[extras]
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"

[weakdeps]
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
47 changes: 47 additions & 0 deletions ext/OrdinaryDiffEqExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
module OrdinaryDiffEqExt

import ClosedLoopReachability
import ClosedLoopReachability.ReachabilityAnalysis as RA

@static if isdefined(Base, :get_extension)
import OrdinaryDiffEq
else
import .OrdinaryDiffEq
end
const ODE = OrdinaryDiffEq

if isdefined(OrdinaryDiffEq, :controls)
# before v7, DE had deps importing ModelingToolkit, which exports `controls`
@static if isdefined(Base, :get_extension)
import OrdinaryDiffEq: controls
else
import .OrdinaryDiffEq: controls
end
end

function ClosedLoopReachability._initialize_simulation_container(iterations::Int)
return Vector{ODE.EnsembleSolution}(undef, iterations)
end

# simulation of multiple trajectories for an ODE system and a time span
# currently we can't use this method from RA because the sampling should be made from outside the function
function ClosedLoopReachability._solve_ensemble(ivp, X0_samples::Vector, tspan;
trajectories_alg=ODE.Tsit5(),
ensemble_alg=ODE.EnsembleThreads(),
inplace=true,
kwargs...)
if inplace
field = RA.inplace_field!(ivp) # NOTE: this is an internal function
else
field = RA.outofplace_field(ivp) # NOTE: this is an internal function
end

# the third argument `repeat` is not needed here
_prob_func(prob, i, _) = ODE.remake(prob; u0=X0_samples[i])
ensemble_prob = ODE.EnsembleProblem(ODE.ODEProblem(field, first(X0_samples), tspan);
prob_func=_prob_func)
return ODE.solve(ensemble_prob, trajectories_alg, ensemble_alg;
trajectories=length(X0_samples))
end

end # module
20 changes: 17 additions & 3 deletions src/init_OrdinaryDiffEq_Plots.jl → ext/OrdinaryDiffEqPlotsExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
export plot_simulation!
module OrdinaryDiffEqPlotsExt

import ClosedLoopReachability
using ClosedLoopReachability: EnsembleSimulationSolution, trajectories

@static if isdefined(Base, :get_extension)
import OrdinaryDiffEq
import Plots
else
import .OrdinaryDiffEq
import .Plots
end

# convenience function for plotting simulation results
# use `output_map` to plot a linear combination of the state variables
function plot_simulation!(fig, sim::EnsembleSimulationSolution; vars=nothing, output_map=nothing,
kwargs...)
function ClosedLoopReachability.plot_simulation!(fig, sim::EnsembleSimulationSolution; vars=nothing,
output_map=nothing, kwargs...)
# The main problem is that plotting trajectories one by one changes the plot
# limits. Hence we store the plot limits from an existing figure and restore
# them after plotting all trajectories.
Expand Down Expand Up @@ -50,6 +61,7 @@ function _plot_simulation_vars!(fig, sim, vars; color, label)
label = "" # overwrite to have exactly one label
end
end
return nothing
end

function _plot_simulation_output_map!(fig, sim, output_map; color, label)
Expand All @@ -64,3 +76,5 @@ function _plot_simulation_output_map!(fig, sim, output_map; color, label)
end
return fig
end

end # module
15 changes: 13 additions & 2 deletions src/ClosedLoopReachability.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,18 @@ export BoxSplitter,
IndexedSplitter,
SignSplitter

# solvers
export solve, simulate
# solve
export solve

# simulation
export simulate,
trajectory,
trajectories,
controls,
disturbances,
solutions

# plotting
export plot_simulation!

end # module
17 changes: 10 additions & 7 deletions src/init.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using Requires: @require
@static if !isdefined(Base, :get_extension)
using Requires: @require
end
using Reexport: @reexport

using Base: isempty
Expand Down Expand Up @@ -31,12 +33,13 @@ const IA = IntervalArithmetic
import CommonSolve: solve

# optional dependencies
function __init__()
@require OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" begin
include("init_OrdinaryDiffEq.jl")
@require Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" begin
import .Plots
include("init_OrdinaryDiffEq_Plots.jl")
@static if !isdefined(Base, :get_extension)
function __init__()
@require OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" begin
include("../ext/OrdinaryDiffEqExt.jl")
@require Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" begin
include("../ext/OrdinaryDiffEqPlotsExt.jl")
end
end
end
end
82 changes: 0 additions & 82 deletions src/init_OrdinaryDiffEq.jl

This file was deleted.

76 changes: 73 additions & 3 deletions src/simulate.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,51 @@
struct SimulationSolution{TT,CT,IT}
trajectory::TT # trajectory pieces for each control cycle
controls::CT # control inputs for each control cycle
disturbances::IT # disturbances for each control cycle
end

Base.length(sol::SimulationSolution) = length(sol.trajectory)
function Base.getindex(sol::SimulationSolution, i)
return SimulationSolution(sol.trajectory[i],
sol.controls[i],
sol.disturbances[i])
end
trajectory(sol::SimulationSolution) = sol.trajectory
controls(sol::SimulationSolution) = sol.controls
disturbances(sol::SimulationSolution) = sol.disturbances

struct EnsembleSimulationSolution{TT,CT,IT}
solutions::Vector{SimulationSolution{TT,CT,IT}}
end

# constructor from a bulk input
function EnsembleSimulationSolution(simulations, controls, disturbances)
n = length(simulations) # number of pieces
m = length(simulations[1]) # number of trajectories
@assert n == length(controls) == length(disturbances) "incompatible lengths"
@assert all(m == length(piece) for piece in simulations)

simulations_new = @inbounds [[simulations[i].u[j] for i in 1:n] for j in 1:m]
controls_new = @inbounds [[controls[i][j] for i in 1:n] for j in 1:m]
disturbances_new = @inbounds [[(isassigned(disturbances, i) ? disturbances[i][j] : nothing)
for i in 1:n] for j in 1:m]
solutions = @inbounds [SimulationSolution(simulations_new[j],
controls_new[j], disturbances_new[j]) for j in 1:m]
return EnsembleSimulationSolution(solutions)
end

Base.length(ess::EnsembleSimulationSolution) = length(ess.solutions)
Base.getindex(ess::EnsembleSimulationSolution, i) = ess.solutions[i]
function solutions(ess::EnsembleSimulationSolution, i)
return EnsembleSimulationSolution([sol[i] for sol in ess.solutions])
end
trajectory(ess::EnsembleSimulationSolution, i) = trajectory(solution(ess, i))
trajectories(ess::EnsembleSimulationSolution) = trajectory.(ess.solutions)
controls(ess::EnsembleSimulationSolution, i) = controls(solution(ess, i))
controls(ess::EnsembleSimulationSolution) = controls.(ess.solutions)
disturbances(ess::EnsembleSimulationSolution, i) = disturbances(solution(ess, i))
disturbances(ess::EnsembleSimulationSolution) = disturbances.(ess.solutions)

"""
simulate(cp::AbstractControlProblem, args...; kwargs...)

Expand All @@ -18,8 +66,6 @@ This function uses the ensemble simulations feature from
[`OrdinaryDiffEq.jl`](https://github.com/SciML/OrdinaryDiffEq.jl).
"""
function simulate(cp::AbstractControlProblem, args...; kwargs...)
require(@__MODULE__, :OrdinaryDiffEq; fun_name="simulate")

ivp = plant(cp)
network = controller(cp)
st_vars = states(cp)
Expand All @@ -45,7 +91,7 @@ function simulate(cp::AbstractControlProblem, args...; kwargs...)

# preallocate
extended = Vector{Vector{Float64}}(undef, trajectories)
simulations = Vector{ODE.EnsembleSolution}(undef, iterations)
simulations = _initialize_simulation_container(iterations)
all_controls = Vector{Vector{Vector{Float64}}}(undef, iterations)
all_disturbances = Vector{Vector{Vector{Float64}}}(undef, iterations)

Expand Down Expand Up @@ -95,3 +141,27 @@ function simulate(cp::AbstractControlProblem, args...; kwargs...)

return EnsembleSimulationSolution(simulations, all_controls, all_disturbances)
end

# defined in `OrdinaryDiffEqExt.jl`
function _initialize_simulation_container(iterations)
mod = isdefined(Base, :get_extension) ?
Base.get_extension(@__MODULE__, :OrdinaryDiffEqExt) : @__MODULE__
require(mod, :OrdinaryDiffEq; fun_name="simulate")
return nothing
end

# defined in `OrdinaryDiffEqExt.jl`
function _solve_ensemble(ivp, extended, tspan; kwargs...)
mod = isdefined(Base, :get_extension) ?
Base.get_extension(@__MODULE__, :OrdinaryDiffEqExt) : @__MODULE__
require(mod, :OrdinaryDiffEq; fun_name="simulate")
return nothing
end

# defined in `OrdinaryDiffEqPlotsExt.jl`
function plot_simulation!(fig, sim; vars=nothing, output_map=nothing, kwargs...)
mod = isdefined(Base, :get_extension) ?
Base.get_extension(@__MODULE__, :OrdinaryDiffEqPlotsExt) : @__MODULE__
require(mod, [:OrdinaryDiffEq, :Plots]; fun_name="plot_simulation!")
return nothing
end
Loading