diff --git a/Project.toml b/Project.toml index 179f88f6..4f5fe764 100644 --- a/Project.toml +++ b/Project.toml @@ -19,10 +19,25 @@ CommonSolve = "0.2" ControllerFormats = "0.1 - 0.2" LazySets = "5 - 6" NeuralNetworkReachability = "0.1" +OrdinaryDiffEq = "6" Parameters = "0.12" +Plots = "1" ReachabilityAnalysis = "0.29 - 0.30" ReachabilityBase = "0.3" Reexport = "0.2, 1" Requires = "0.5, 1" TaylorModels = "0.7 - 0.9" julia = "1.6" + +[extensions] +OrdinaryDiffEqExt = "OrdinaryDiffEq" +OrdinaryDiffEqPlotsExt = ["OrdinaryDiffEq", "Plots"] + +[extras] +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" + +[weakdeps] +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" diff --git a/ext/OrdinaryDiffEqExt.jl b/ext/OrdinaryDiffEqExt.jl new file mode 100644 index 00000000..e50de3f9 --- /dev/null +++ b/ext/OrdinaryDiffEqExt.jl @@ -0,0 +1,47 @@ +module OrdinaryDiffEqExt + +import ClosedLoopReachability +import ClosedLoopReachability.ReachabilityAnalysis as RA + +@static if isdefined(Base, :get_extension) + import OrdinaryDiffEq +else + import .OrdinaryDiffEq +end +const ODE = OrdinaryDiffEq + +if isdefined(OrdinaryDiffEq, :controls) + # before v7, DE had deps importing ModelingToolkit, which exports `controls` + @static if isdefined(Base, :get_extension) + import OrdinaryDiffEq: controls + else + import .OrdinaryDiffEq: controls + end +end + +function ClosedLoopReachability._initialize_simulation_container(iterations::Int) + return Vector{ODE.EnsembleSolution}(undef, iterations) +end + +# simulation of multiple trajectories for an ODE system and a time span +# currently we can't use this method from RA because the sampling should be made from outside the function +function ClosedLoopReachability._solve_ensemble(ivp, X0_samples::Vector, tspan; + trajectories_alg=ODE.Tsit5(), + ensemble_alg=ODE.EnsembleThreads(), + inplace=true, + kwargs...) + if inplace + field = RA.inplace_field!(ivp) # NOTE: this is an internal function + else + field = RA.outofplace_field(ivp) # NOTE: this is an internal function + end + + # the third argument `repeat` is not needed here + _prob_func(prob, i, _) = ODE.remake(prob; u0=X0_samples[i]) + ensemble_prob = ODE.EnsembleProblem(ODE.ODEProblem(field, first(X0_samples), tspan); + prob_func=_prob_func) + return ODE.solve(ensemble_prob, trajectories_alg, ensemble_alg; + trajectories=length(X0_samples)) +end + +end # module diff --git a/src/init_OrdinaryDiffEq_Plots.jl b/ext/OrdinaryDiffEqPlotsExt.jl similarity index 80% rename from src/init_OrdinaryDiffEq_Plots.jl rename to ext/OrdinaryDiffEqPlotsExt.jl index e29f922d..35e1ec47 100644 --- a/src/init_OrdinaryDiffEq_Plots.jl +++ b/ext/OrdinaryDiffEqPlotsExt.jl @@ -1,9 +1,20 @@ -export plot_simulation! +module OrdinaryDiffEqPlotsExt + +import ClosedLoopReachability +using ClosedLoopReachability: EnsembleSimulationSolution, trajectories + +@static if isdefined(Base, :get_extension) + import OrdinaryDiffEq + import Plots +else + import .OrdinaryDiffEq + import .Plots +end # convenience function for plotting simulation results # use `output_map` to plot a linear combination of the state variables -function plot_simulation!(fig, sim::EnsembleSimulationSolution; vars=nothing, output_map=nothing, - kwargs...) +function ClosedLoopReachability.plot_simulation!(fig, sim::EnsembleSimulationSolution; vars=nothing, + output_map=nothing, kwargs...) # The main problem is that plotting trajectories one by one changes the plot # limits. Hence we store the plot limits from an existing figure and restore # them after plotting all trajectories. @@ -50,6 +61,7 @@ function _plot_simulation_vars!(fig, sim, vars; color, label) label = "" # overwrite to have exactly one label end end + return nothing end function _plot_simulation_output_map!(fig, sim, output_map; color, label) @@ -64,3 +76,5 @@ function _plot_simulation_output_map!(fig, sim, output_map; color, label) end return fig end + +end # module diff --git a/src/ClosedLoopReachability.jl b/src/ClosedLoopReachability.jl index da36420d..86c55418 100644 --- a/src/ClosedLoopReachability.jl +++ b/src/ClosedLoopReachability.jl @@ -17,7 +17,18 @@ export BoxSplitter, IndexedSplitter, SignSplitter -# solvers -export solve, simulate +# solve +export solve + +# simulation +export simulate, + trajectory, + trajectories, + controls, + disturbances, + solutions + +# plotting +export plot_simulation! end # module diff --git a/src/init.jl b/src/init.jl index f8a76039..689a6cab 100644 --- a/src/init.jl +++ b/src/init.jl @@ -1,4 +1,6 @@ -using Requires: @require +@static if !isdefined(Base, :get_extension) + using Requires: @require +end using Reexport: @reexport using Base: isempty @@ -31,12 +33,13 @@ const IA = IntervalArithmetic import CommonSolve: solve # optional dependencies -function __init__() - @require OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" begin - include("init_OrdinaryDiffEq.jl") - @require Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" begin - import .Plots - include("init_OrdinaryDiffEq_Plots.jl") +@static if !isdefined(Base, :get_extension) + function __init__() + @require OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" begin + include("../ext/OrdinaryDiffEqExt.jl") + @require Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" begin + include("../ext/OrdinaryDiffEqPlotsExt.jl") + end end end end diff --git a/src/init_OrdinaryDiffEq.jl b/src/init_OrdinaryDiffEq.jl deleted file mode 100644 index def17685..00000000 --- a/src/init_OrdinaryDiffEq.jl +++ /dev/null @@ -1,82 +0,0 @@ -import .OrdinaryDiffEq -const ODE = OrdinaryDiffEq - -if isdefined(OrdinaryDiffEq, :controls) - # before v7, DE had deps importing ModelingToolkit, which exports `controls` - import .OrdinaryDiffEq: controls -end - -export trajectory, - trajectories, - controls, - disturbances, - solutions - -struct SimulationSolution{TT,CT,IT} - trajectory::TT # trajectory pieces for each control cycle - controls::CT # control inputs for each control cycle - disturbances::IT # disturbances for each control cycle -end - -Base.length(sol::SimulationSolution) = length(sol.trajectory) -function Base.getindex(sol::SimulationSolution, i) - return SimulationSolution(sol.trajectory[i], - sol.controls[i], - sol.disturbances[i]) -end -trajectory(sol::SimulationSolution) = sol.trajectory -controls(sol::SimulationSolution) = sol.controls -disturbances(sol::SimulationSolution) = sol.disturbances - -struct EnsembleSimulationSolution{TT,CT,IT} - solutions::Vector{SimulationSolution{TT,CT,IT}} -end - -# constructor from a bulk input -function EnsembleSimulationSolution(simulations, controls, disturbances) - n = length(simulations) # number of pieces - m = length(simulations[1]) # number of trajectories - @assert n == length(controls) == length(disturbances) "incompatible lengths" - @assert all(m == length(piece) for piece in simulations) - - simulations_new = @inbounds [[simulations[i].u[j] for i in 1:n] for j in 1:m] - controls_new = @inbounds [[controls[i][j] for i in 1:n] for j in 1:m] - disturbances_new = @inbounds [[(isassigned(disturbances, i) ? disturbances[i][j] : nothing) - for i in 1:n] for j in 1:m] - solutions = @inbounds [SimulationSolution(simulations_new[j], - controls_new[j], disturbances_new[j]) for j in 1:m] - return EnsembleSimulationSolution(solutions) -end - -Base.length(ess::EnsembleSimulationSolution) = length(ess.solutions) -Base.getindex(ess::EnsembleSimulationSolution, i) = ess.solutions[i] -function solutions(ess::EnsembleSimulationSolution, i) - return EnsembleSimulationSolution([sol[i] for sol in ess.solutions]) -end -trajectory(ess::EnsembleSimulationSolution, i) = trajectory(solution(ess, i)) -trajectories(ess::EnsembleSimulationSolution) = trajectory.(ess.solutions) -controls(ess::EnsembleSimulationSolution, i) = controls(solution(ess, i)) -controls(ess::EnsembleSimulationSolution) = controls.(ess.solutions) -disturbances(ess::EnsembleSimulationSolution, i) = disturbances(solution(ess, i)) -disturbances(ess::EnsembleSimulationSolution) = disturbances.(ess.solutions) - -# simulation of multiple trajectories for an ODE system and a time span -# currently we can't use this method from RA because the sampling should be made from outside the function -function _solve_ensemble(ivp, X0_samples, tspan; - trajectories_alg=ODE.Tsit5(), - ensemble_alg=ODE.EnsembleThreads(), - inplace=true, - kwargs...) - if inplace - field = ReachabilityAnalysis.inplace_field!(ivp) # NOTE: this is an internal function - else - field = ReachabilityAnalysis.outofplace_field(ivp) # NOTE: this is an internal function - end - - # the third argument `repeat` is not needed here - _prob_func(prob, i, _) = ODE.remake(prob; u0=X0_samples[i]) - ensemble_prob = ODE.EnsembleProblem(ODE.ODEProblem(field, first(X0_samples), tspan); - prob_func=_prob_func) - return ODE.solve(ensemble_prob, trajectories_alg, ensemble_alg; - trajectories=length(X0_samples)) -end diff --git a/src/simulate.jl b/src/simulate.jl index 757d4312..3e013b4c 100644 --- a/src/simulate.jl +++ b/src/simulate.jl @@ -1,3 +1,51 @@ +struct SimulationSolution{TT,CT,IT} + trajectory::TT # trajectory pieces for each control cycle + controls::CT # control inputs for each control cycle + disturbances::IT # disturbances for each control cycle +end + +Base.length(sol::SimulationSolution) = length(sol.trajectory) +function Base.getindex(sol::SimulationSolution, i) + return SimulationSolution(sol.trajectory[i], + sol.controls[i], + sol.disturbances[i]) +end +trajectory(sol::SimulationSolution) = sol.trajectory +controls(sol::SimulationSolution) = sol.controls +disturbances(sol::SimulationSolution) = sol.disturbances + +struct EnsembleSimulationSolution{TT,CT,IT} + solutions::Vector{SimulationSolution{TT,CT,IT}} +end + +# constructor from a bulk input +function EnsembleSimulationSolution(simulations, controls, disturbances) + n = length(simulations) # number of pieces + m = length(simulations[1]) # number of trajectories + @assert n == length(controls) == length(disturbances) "incompatible lengths" + @assert all(m == length(piece) for piece in simulations) + + simulations_new = @inbounds [[simulations[i].u[j] for i in 1:n] for j in 1:m] + controls_new = @inbounds [[controls[i][j] for i in 1:n] for j in 1:m] + disturbances_new = @inbounds [[(isassigned(disturbances, i) ? disturbances[i][j] : nothing) + for i in 1:n] for j in 1:m] + solutions = @inbounds [SimulationSolution(simulations_new[j], + controls_new[j], disturbances_new[j]) for j in 1:m] + return EnsembleSimulationSolution(solutions) +end + +Base.length(ess::EnsembleSimulationSolution) = length(ess.solutions) +Base.getindex(ess::EnsembleSimulationSolution, i) = ess.solutions[i] +function solutions(ess::EnsembleSimulationSolution, i) + return EnsembleSimulationSolution([sol[i] for sol in ess.solutions]) +end +trajectory(ess::EnsembleSimulationSolution, i) = trajectory(solution(ess, i)) +trajectories(ess::EnsembleSimulationSolution) = trajectory.(ess.solutions) +controls(ess::EnsembleSimulationSolution, i) = controls(solution(ess, i)) +controls(ess::EnsembleSimulationSolution) = controls.(ess.solutions) +disturbances(ess::EnsembleSimulationSolution, i) = disturbances(solution(ess, i)) +disturbances(ess::EnsembleSimulationSolution) = disturbances.(ess.solutions) + """ simulate(cp::AbstractControlProblem, args...; kwargs...) @@ -18,8 +66,6 @@ This function uses the ensemble simulations feature from [`OrdinaryDiffEq.jl`](https://github.com/SciML/OrdinaryDiffEq.jl). """ function simulate(cp::AbstractControlProblem, args...; kwargs...) - require(@__MODULE__, :OrdinaryDiffEq; fun_name="simulate") - ivp = plant(cp) network = controller(cp) st_vars = states(cp) @@ -45,7 +91,7 @@ function simulate(cp::AbstractControlProblem, args...; kwargs...) # preallocate extended = Vector{Vector{Float64}}(undef, trajectories) - simulations = Vector{ODE.EnsembleSolution}(undef, iterations) + simulations = _initialize_simulation_container(iterations) all_controls = Vector{Vector{Vector{Float64}}}(undef, iterations) all_disturbances = Vector{Vector{Vector{Float64}}}(undef, iterations) @@ -95,3 +141,27 @@ function simulate(cp::AbstractControlProblem, args...; kwargs...) return EnsembleSimulationSolution(simulations, all_controls, all_disturbances) end + +# defined in `OrdinaryDiffEqExt.jl` +function _initialize_simulation_container(iterations) + mod = isdefined(Base, :get_extension) ? + Base.get_extension(@__MODULE__, :OrdinaryDiffEqExt) : @__MODULE__ + require(mod, :OrdinaryDiffEq; fun_name="simulate") + return nothing +end + +# defined in `OrdinaryDiffEqExt.jl` +function _solve_ensemble(ivp, extended, tspan; kwargs...) + mod = isdefined(Base, :get_extension) ? + Base.get_extension(@__MODULE__, :OrdinaryDiffEqExt) : @__MODULE__ + require(mod, :OrdinaryDiffEq; fun_name="simulate") + return nothing +end + +# defined in `OrdinaryDiffEqPlotsExt.jl` +function plot_simulation!(fig, sim; vars=nothing, output_map=nothing, kwargs...) + mod = isdefined(Base, :get_extension) ? + Base.get_extension(@__MODULE__, :OrdinaryDiffEqPlotsExt) : @__MODULE__ + require(mod, [:OrdinaryDiffEq, :Plots]; fun_name="plot_simulation!") + return nothing +end